diff --git a/.ci/api-current.txt b/.ci/api-current.txt index 947baa9ba7..39388ad791 100644 --- a/.ci/api-current.txt +++ b/.ci/api-current.txt @@ -1850,13 +1850,15 @@ public final class net.corda.core.crypto.CryptoUtils extends java.lang.Object public static final boolean verify(java.security.PublicKey, byte[], byte[]) ## public interface net.corda.core.crypto.DigestAlgorithm + @NotNull + public abstract byte[] componentDigest(byte[]) @NotNull public abstract byte[] digest(byte[]) @NotNull public abstract String getAlgorithm() public abstract int getDigestLength() @NotNull - public abstract byte[] preImageResistantDigest(byte[]) + public abstract byte[] nonceDigest(byte[]) ## @CordaSerializable public class net.corda.core.crypto.DigitalSignature extends net.corda.core.utilities.OpaqueBytes @@ -5935,6 +5937,13 @@ public @interface net.corda.core.serialization.CordaSerializationTransformRename public @interface net.corda.core.serialization.CordaSerializationTransformRenames public abstract net.corda.core.serialization.CordaSerializationTransformRename[] value() ## +public interface net.corda.core.serialization.CustomSerializationScheme + @NotNull + public abstract T deserialize(net.corda.core.utilities.ByteSequence, Class, net.corda.core.serialization.SerializationSchemeContext) + public abstract int getSchemeId() + @NotNull + public abstract net.corda.core.utilities.ByteSequence serialize(T, net.corda.core.serialization.SerializationSchemeContext) +## public @interface net.corda.core.serialization.DeprecatedConstructorForDeserialization public abstract int version() ## @@ -6076,6 +6085,13 @@ public static final class net.corda.core.serialization.SerializationFactory$Comp @NotNull public final net.corda.core.serialization.SerializationFactory getDefaultFactory() ## +@DoNotImplement +public interface net.corda.core.serialization.SerializationSchemeContext + @NotNull + public abstract ClassLoader getDeserializationClassLoader() + @NotNull + public abstract net.corda.core.serialization.ClassWhitelist getWhitelist() +## public interface net.corda.core.serialization.SerializationToken @NotNull public abstract Object fromToken(net.corda.core.serialization.SerializeAsTokenContext) diff --git a/.ci/dev/compatibility/JenkinsfileJDK11Azul b/.ci/dev/compatibility/JenkinsfileJDK11Azul index b0e63b45d8..23e9e4bf95 100644 --- a/.ci/dev/compatibility/JenkinsfileJDK11Azul +++ b/.ci/dev/compatibility/JenkinsfileJDK11Azul @@ -70,6 +70,7 @@ pipeline { stage('Compile') { steps { dir(sameAgentFolder) { + authenticateGradleWrapper() sh script: [ './gradlew', COMMON_GRADLE_PARAMS, diff --git a/.ci/dev/forward-merge/Jenkinsfile b/.ci/dev/forward-merge/Jenkinsfile index df905a8262..169cdd258d 100644 --- a/.ci/dev/forward-merge/Jenkinsfile +++ b/.ci/dev/forward-merge/Jenkinsfile @@ -13,13 +13,13 @@ * the branch name of origin branch, it should match the current branch * and it acts as a fail-safe inside {@code forwardMerger} pipeline */ -String originBranch = 'release/os/4.7' +String originBranch = 'release/os/4.8' /** * the branch name of target branch, it should be the branch with the next version * after the one in current branch. */ -String targetBranch = 'release/os/4.8' +String targetBranch = 'release/os/4.9' /** * Forward merge any changes between #originBranch and #targetBranch diff --git a/.ci/dev/mswin/Jenkinsfile b/.ci/dev/mswin/Jenkinsfile index 714fadf4fb..e1f70bb4dd 100644 --- a/.ci/dev/mswin/Jenkinsfile +++ b/.ci/dev/mswin/Jenkinsfile @@ -56,6 +56,7 @@ pipeline { stage('Unit Tests') { agent { label 'mswin' } steps { + authenticateGradleWrapper() bat "./gradlew --no-daemon " + "--stacktrace " + "-Pcompilation.warningsAsErrors=false " + diff --git a/.ci/dev/nightly-regression/Jenkinsfile b/.ci/dev/nightly-regression/Jenkinsfile index 92eae917af..06266dcb23 100644 --- a/.ci/dev/nightly-regression/Jenkinsfile +++ b/.ci/dev/nightly-regression/Jenkinsfile @@ -50,6 +50,7 @@ pipeline { stages { stage('Compile') { steps { + authenticateGradleWrapper() sh script: [ './gradlew', COMMON_GRADLE_PARAMS, diff --git a/.ci/dev/open-j9/Jenkinsfile b/.ci/dev/open-j9/Jenkinsfile index 65deab2390..808dac8f08 100644 --- a/.ci/dev/open-j9/Jenkinsfile +++ b/.ci/dev/open-j9/Jenkinsfile @@ -3,6 +3,7 @@ * Jenkins pipeline to build Corda OS release branches and tags. * PLEASE NOTE: we DO want to run a build for each commit!!! */ +@Library('corda-shared-build-pipeline-steps') /** * Sense environment @@ -47,6 +48,7 @@ pipeline { stages { stage('Unit Tests') { steps { + authenticateGradleWrapper() sh "./gradlew clean --continue test --info -Ptests.failFast=true" } } diff --git a/.ci/dev/pr-code-checks/Jenkinsfile b/.ci/dev/pr-code-checks/Jenkinsfile index 89e8bd1206..456dc80994 100644 --- a/.ci/dev/pr-code-checks/Jenkinsfile +++ b/.ci/dev/pr-code-checks/Jenkinsfile @@ -30,6 +30,7 @@ pipeline { stages { stage('Detekt check') { steps { + authenticateGradleWrapper() sh "./gradlew --no-daemon --parallel --build-cache clean detekt" } } @@ -54,6 +55,7 @@ pipeline { GRADLE_USER_HOME = "/host_tmp/gradle" } steps { + authenticateGradleWrapper() sh 'mkdir -p ${GRADLE_USER_HOME}' snykDeltaScan(env.SNYK_API_TOKEN, env.C4_OS_SNYK_ORG_ID) } diff --git a/.ci/dev/publish-api-docs/Jenkinsfile b/.ci/dev/publish-api-docs/Jenkinsfile index b45aa95e95..2bdda095be 100644 --- a/.ci/dev/publish-api-docs/Jenkinsfile +++ b/.ci/dev/publish-api-docs/Jenkinsfile @@ -33,6 +33,7 @@ pipeline { stage('Publish Archived API Docs to Artifactory') { when { tag pattern: /^docs-release-os-V(\d+\.\d+)(\.\d+){0,1}(-GA){0,1}(-\d{4}-\d\d-\d\d-\d{4}){0,1}$/, comparator: 'REGEXP' } steps { + authenticateGradleWrapper() sh "./gradlew :clean :docs:artifactoryPublish -DpublishApiDocs" } } diff --git a/.ci/dev/publish-branch/Jenkinsfile.preview b/.ci/dev/publish-branch/Jenkinsfile.preview index c1cc300089..b795edec93 100644 --- a/.ci/dev/publish-branch/Jenkinsfile.preview +++ b/.ci/dev/publish-branch/Jenkinsfile.preview @@ -29,6 +29,7 @@ pipeline { stages { stage('Publish to Artifactory') { steps { + authenticateGradleWrapper() rtServer ( id: 'R3-Artifactory', url: 'https://software.r3.com/artifactory', diff --git a/.ci/dev/regression/Jenkinsfile b/.ci/dev/regression/Jenkinsfile index 20210fb4bd..b565d1bb5e 100644 --- a/.ci/dev/regression/Jenkinsfile +++ b/.ci/dev/regression/Jenkinsfile @@ -70,6 +70,7 @@ pipeline { stages { stage('Compile') { steps { + authenticateGradleWrapper() sh script: [ './gradlew', COMMON_GRADLE_PARAMS, @@ -168,6 +169,7 @@ pipeline { } stage('Recompile') { steps { + authenticateGradleWrapper() sh script: [ './gradlew', COMMON_GRADLE_PARAMS, @@ -390,23 +392,23 @@ pipeline { } } success { - script { - sendSlackNotifications("good", "BUILD PASSED", false, "#corda-corda4-open-source-build-notifications") + script { + sendSlackNotifications("good", "BUILD PASSED", false, "#corda-corda4-open-source-build-notifications") if (isReleaseTag || isReleaseCandidate || isReleaseBranch) { snykSecurityScan.generateHtmlElements() } - } + } } unstable { - script { - sendSlackNotifications("warning", "BUILD UNSTABLE", false, "#corda-corda4-open-source-build-notifications") + script { + sendSlackNotifications("warning", "BUILD UNSTABLE", false, "#corda-corda4-open-source-build-notifications") if (isReleaseTag || isReleaseCandidate || isReleaseBranch) { snykSecurityScan.generateHtmlElements() } if (isReleaseTag || isReleaseCandidate || isReleaseBranch) { snykSecurityScan.generateHtmlElements() } - } + } } failure { script { diff --git a/.github/workflows/jira_create_issue.yml b/.github/workflows/jira_create_issue.yml index fe9f5eb8de..66a3bbdc37 100644 --- a/.github/workflows/jira_create_issue.yml +++ b/.github/workflows/jira_create_issue.yml @@ -16,6 +16,7 @@ jobs: with: jiraBaseUrl: https://r3-cev.atlassian.net project: CORDA + squad: Corda issuetype: Bug summary: ${{ github.event.issue.title }} labels: community @@ -33,4 +34,4 @@ jobs: issue-number: ${{ github.event.issue.number }} body: | Automatically created Jira issue: ${{ steps.create.outputs.issue }} - reaction-type: '+1' \ No newline at end of file + reaction-type: '+1' diff --git a/Jenkinsfile b/Jenkinsfile index 20e051a742..f2b5686859 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -58,6 +58,7 @@ pipeline { stages { stage('Compile') { steps { + authenticateGradleWrapper() sh script: [ './gradlew', COMMON_GRADLE_PARAMS, @@ -100,6 +101,7 @@ pipeline { } stage('Recompile') { steps { + authenticateGradleWrapper() sh script: [ './gradlew', COMMON_GRADLE_PARAMS, diff --git a/build.gradle b/build.gradle index e0642a270a..a0a35a4039 100644 --- a/build.gradle +++ b/build.gradle @@ -101,7 +101,7 @@ buildscript { ext.hibernate_version = '5.4.32.Final' ext.h2_version = '1.4.199' // Update docs if renamed or removed. ext.rxjava_version = '1.3.8' - ext.dokka_version = '0.9.17' + ext.dokka_version = '0.10.1' ext.eddsa_version = '0.3.0' ext.dependency_checker_version = '5.2.0' ext.commons_collections_version = '4.3' diff --git a/client/rpc/src/main/kotlin/net/corda/client/rpc/CordaRPCClient.kt b/client/rpc/src/main/kotlin/net/corda/client/rpc/CordaRPCClient.kt index d90befe6ae..d008a351dc 100644 --- a/client/rpc/src/main/kotlin/net/corda/client/rpc/CordaRPCClient.kt +++ b/client/rpc/src/main/kotlin/net/corda/client/rpc/CordaRPCClient.kt @@ -1,5 +1,6 @@ package net.corda.client.rpc +import io.netty.util.concurrent.DefaultThreadFactory import net.corda.client.rpc.internal.RPCClient import net.corda.client.rpc.internal.ReconnectingCordaRPCOps import net.corda.client.rpc.internal.SerializationEnvironmentHelper @@ -52,7 +53,7 @@ class CordaRPCConnection private constructor( sslConfiguration: ClientRpcSslOptions? = null, classLoader: ClassLoader? = null ): CordaRPCConnection { - val observersPool: ExecutorService = Executors.newCachedThreadPool() + val observersPool: ExecutorService = Executors.newCachedThreadPool(DefaultThreadFactory("RPCObserver")) return CordaRPCConnection(null, observersPool, ReconnectingCordaRPCOps( addresses, username, diff --git a/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClient.kt b/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClient.kt index 2e4f2c529b..ea5a54cef2 100644 --- a/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClient.kt +++ b/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClient.kt @@ -17,7 +17,6 @@ import net.corda.core.serialization.internal.nodeSerializationEnv import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.contextLogger import net.corda.nodeapi.internal.ArtemisTcpTransport.Companion.rpcConnectorTcpTransport -import net.corda.nodeapi.internal.ArtemisTcpTransport.Companion.rpcConnectorTcpTransportsFromList import net.corda.nodeapi.internal.ArtemisTcpTransport.Companion.rpcInternalClientTcpTransport import net.corda.nodeapi.internal.RoundRobinConnectionPolicy import net.corda.nodeapi.internal.config.SslConfiguration @@ -61,8 +60,12 @@ class RPCClient( sslConfiguration: ClientRpcSslOptions? = null, configuration: CordaRPCClientConfiguration = CordaRPCClientConfiguration.DEFAULT, serializationContext: SerializationContext = SerializationDefaults.RPC_CLIENT_CONTEXT - ) : this(rpcConnectorTcpTransport(haAddressPool.first(), sslConfiguration), - configuration, serializationContext, rpcConnectorTcpTransportsFromList(haAddressPool, sslConfiguration)) + ) : this( + rpcConnectorTcpTransport(haAddressPool.first(), sslConfiguration), + configuration, + serializationContext, + haAddressPool.map { rpcConnectorTcpTransport(it, sslConfiguration) } + ) companion object { private val log = contextLogger() diff --git a/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/ReconnectingCordaRPCOps.kt b/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/ReconnectingCordaRPCOps.kt index 71964d961e..005ac70fd3 100644 --- a/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/ReconnectingCordaRPCOps.kt +++ b/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/ReconnectingCordaRPCOps.kt @@ -1,5 +1,6 @@ package net.corda.client.rpc.internal +import io.netty.util.concurrent.DefaultThreadFactory import net.corda.client.rpc.ConnectionFailureException import net.corda.client.rpc.CordaRPCClient import net.corda.client.rpc.CordaRPCClientConfiguration @@ -99,7 +100,8 @@ class ReconnectingCordaRPCOps private constructor( ErrorInterceptingHandler(reconnectingRPCConnection)) as CordaRPCOps } } - private val retryFlowsPool = Executors.newScheduledThreadPool(1) + private val retryFlowsPool = Executors.newScheduledThreadPool(1, DefaultThreadFactory("FlowRetry")) + /** * This function runs a flow and retries until it completes successfully. * diff --git a/common/logging/src/main/kotlin/net/corda/common/logging/Constants.kt b/common/logging/src/main/kotlin/net/corda/common/logging/Constants.kt index 5051b97b41..aa5d86d5d7 100644 --- a/common/logging/src/main/kotlin/net/corda/common/logging/Constants.kt +++ b/common/logging/src/main/kotlin/net/corda/common/logging/Constants.kt @@ -9,4 +9,4 @@ package net.corda.common.logging * (originally added to source control for ease of use) */ -internal const val CURRENT_MAJOR_RELEASE = "4.7-SNAPSHOT" \ No newline at end of file +internal const val CURRENT_MAJOR_RELEASE = "4.8-SNAPSHOT" \ No newline at end of file diff --git a/constants.properties b/constants.properties index 9d7956fe1f..119bf179b3 100644 --- a/constants.properties +++ b/constants.properties @@ -2,7 +2,7 @@ # because some versions here need to be matched by app authors in # their own projects. So don't get fancy with syntax! -cordaVersion=4.7 +cordaVersion=4.8 versionSuffix=SNAPSHOT gradlePluginsVersion=5.0.12 kotlinVersion=1.2.71 @@ -11,7 +11,7 @@ java8MinUpdateVersion=171 # When incrementing platformVersion make sure to update # # net.corda.core.internal.CordaUtilsKt.PLATFORM_VERSION as well. # # ***************************************************************# -platformVersion=9 +platformVersion=10 guavaVersion=28.0-jre # Quasar version to use with Java 8: quasarVersion=0.7.15_r3 @@ -21,7 +21,7 @@ jdkClassifier11=jdk11 dockerJavaVersion=3.2.5 proguardVersion=6.1.1 bouncycastleVersion=1.68 -classgraphVersion=4.8.90 +classgraphVersion=4.8.135 disruptorVersion=3.4.2 typesafeConfigVersion=1.3.4 jsr305Version=3.0.2 diff --git a/core-deterministic/build.gradle b/core-deterministic/build.gradle index 71ba8d4efa..48dac3afd0 100644 --- a/core-deterministic/build.gradle +++ b/core-deterministic/build.gradle @@ -23,7 +23,10 @@ def javaHome = System.getProperty('java.home') def jarBaseName = "corda-${project.name}".toString() configurations { - deterministicLibraries.extendsFrom api + deterministicLibraries { + canBeConsumed = false + extendsFrom api + } deterministicArtifacts.extendsFrom deterministicLibraries } @@ -59,7 +62,7 @@ def originalJar = coreJarTask.map { it.outputs.files.singleFile } def patchCore = tasks.register('patchCore', Zip) { dependsOn coreJarTask - destinationDirectory = file("$buildDir/source-libs") + destinationDirectory = layout.buildDirectory.dir('source-libs') metadataCharset 'UTF-8' archiveClassifier = 'transient' archiveExtension = 'jar' @@ -169,7 +172,7 @@ def determinise = tasks.register('determinise', ProGuardTask) { def checkDeterminism = tasks.register('checkDeterminism', ProGuardTask) def metafix = tasks.register('metafix', MetaFixerTask) { - outputDir file("$buildDir/libs") + outputDir = layout.buildDirectory.dir('libs') jars determinise suffix "" diff --git a/core-deterministic/src/main/kotlin/net/corda/core/serialization/SerializationFactory.kt b/core-deterministic/src/main/kotlin/net/corda/core/serialization/SerializationFactory.kt index 69ddb8887b..2d6d8d3b09 100644 --- a/core-deterministic/src/main/kotlin/net/corda/core/serialization/SerializationFactory.kt +++ b/core-deterministic/src/main/kotlin/net/corda/core/serialization/SerializationFactory.kt @@ -55,12 +55,16 @@ abstract class SerializationFactory { * Change the current context inside the block to that supplied. */ fun withCurrentContext(context: SerializationContext?, block: () -> T): T { - val priorContext = _currentContext - if (context != null) _currentContext = context - try { - return block() - } finally { - if (context != null) _currentContext = priorContext + return if (context == null) { + block() + } else { + val priorContext = _currentContext + _currentContext = context + try { + block() + } finally { + _currentContext = priorContext + } } } diff --git a/core-deterministic/testing/data/build.gradle b/core-deterministic/testing/data/build.gradle index ab3acd9249..0141dc3c61 100644 --- a/core-deterministic/testing/data/build.gradle +++ b/core-deterministic/testing/data/build.gradle @@ -3,7 +3,9 @@ plugins { } configurations { - testData + testData { + canBeResolved = false + } } dependencies { diff --git a/core-deterministic/testing/verifier/build.gradle b/core-deterministic/testing/verifier/build.gradle index 774592c6a4..334191cb9f 100644 --- a/core-deterministic/testing/verifier/build.gradle +++ b/core-deterministic/testing/verifier/build.gradle @@ -9,7 +9,12 @@ apply from: "${rootProject.projectDir}/deterministic.gradle" description 'Test utilities for deterministic contract verification' configurations { - deterministicArtifacts + deterministicArtifacts { + canBeResolved = false + } + + // Compile against the deterministic artifacts to ensure that we use only the deterministic API subset. + compileOnly.extendsFrom deterministicArtifacts runtimeArtifacts.extendsFrom api } @@ -20,8 +25,6 @@ dependencies { runtimeArtifacts project(':serialization') runtimeArtifacts project(':core') - // Compile against the deterministic artifacts to ensure that we use only the deterministic API subset. - compileOnly configurations.deterministicArtifacts api "junit:junit:$junit_version" runtimeOnly "org.junit.vintage:junit-vintage-engine:$junit_vintage_version" } diff --git a/core-deterministic/testing/verifier/src/main/kotlin/net/corda/deterministic/verifier/TransactionVerificationRequest.kt b/core-deterministic/testing/verifier/src/main/kotlin/net/corda/deterministic/verifier/TransactionVerificationRequest.kt index c259f2791c..3c9fde9c06 100644 --- a/core-deterministic/testing/verifier/src/main/kotlin/net/corda/deterministic/verifier/TransactionVerificationRequest.kt +++ b/core-deterministic/testing/verifier/src/main/kotlin/net/corda/deterministic/verifier/TransactionVerificationRequest.kt @@ -13,7 +13,7 @@ import net.corda.core.transactions.WireTransaction @Suppress("MemberVisibilityCanBePrivate") //TODO the use of deprecated toLedgerTransaction need to be revisited as resolveContractAttachment requires attachments of the transactions which created input states... -//TODO ...to check contract version non downgrade rule, curretly dummy Attachment if not fund is used which sets contract version to '1' +//TODO ...to check contract version non downgrade rule, currently dummy Attachment if not fund is used which sets contract version to '1' @CordaSerializable class TransactionVerificationRequest(val wtxToVerify: SerializedBytes, val dependencies: Array>, diff --git a/core-tests/src/test/kotlin/net/corda/coretests/crypto/PartialMerkleTreeTest.kt b/core-tests/src/test/kotlin/net/corda/coretests/crypto/PartialMerkleTreeTest.kt index 654d1cd684..bb6c457ec2 100644 --- a/core-tests/src/test/kotlin/net/corda/coretests/crypto/PartialMerkleTreeTest.kt +++ b/core-tests/src/test/kotlin/net/corda/coretests/crypto/PartialMerkleTreeTest.kt @@ -5,8 +5,11 @@ import com.nhaarman.mockito_kotlin.mock import com.nhaarman.mockito_kotlin.whenever import net.corda.core.contracts.* import net.corda.core.crypto.* +import net.corda.core.crypto.internal.DigestAlgorithmFactory import net.corda.core.identity.CordaX500Name import net.corda.core.identity.Party +import net.corda.core.internal.BLAKE2s256DigestAlgorithm +import net.corda.core.internal.SHA256BLAKE2s256DigestAlgorithm import net.corda.core.node.NotaryInfo import net.corda.core.node.services.IdentityService import net.corda.core.serialization.deserialize @@ -49,12 +52,19 @@ class PartialMerkleTreeTest(private var digestService: DigestService) { val MINI_CORP get() = miniCorp.party val MINI_CORP_PUBKEY get() = miniCorp.publicKey + init { + DigestAlgorithmFactory.registerClass(BLAKE2s256DigestAlgorithm::class.java.name) + DigestAlgorithmFactory.registerClass(SHA256BLAKE2s256DigestAlgorithm::class.java.name) + } + @JvmStatic @Parameterized.Parameters fun data(): Collection = listOf( DigestService.sha2_256, DigestService.sha2_384, - DigestService.sha2_512 + DigestService.sha2_512, + DigestService("BLAKE_TEST"), + DigestService("SHA256-BLAKE2S256-TEST") ) } diff --git a/core-tests/src/test/kotlin/net/corda/coretests/crypto/PartialMerkleTreeWithNamedHashTest.kt b/core-tests/src/test/kotlin/net/corda/coretests/crypto/PartialMerkleTreeWithNamedHashTest.kt index 524dec4bb0..77bb40f6b9 100644 --- a/core-tests/src/test/kotlin/net/corda/coretests/crypto/PartialMerkleTreeWithNamedHashTest.kt +++ b/core-tests/src/test/kotlin/net/corda/coretests/crypto/PartialMerkleTreeWithNamedHashTest.kt @@ -140,7 +140,7 @@ class PartialMerkleTreeWithNamedHashTest { fun `building Merkle tree one node`() { val node = 'a'.serialize().sha2_384() val mt = MerkleTree.getMerkleTree(listOf(node), DigestService.sha2_384) - assertEquals(node, mt.hash) + assertNotEquals(node, mt.hash) } @Test(timeout=300_000) diff --git a/core-tests/src/test/kotlin/net/corda/coretests/crypto/internal/ProviderMapTest.kt b/core-tests/src/test/kotlin/net/corda/coretests/crypto/internal/ProviderMapTest.kt new file mode 100644 index 0000000000..11a7e9ed85 --- /dev/null +++ b/core-tests/src/test/kotlin/net/corda/coretests/crypto/internal/ProviderMapTest.kt @@ -0,0 +1,29 @@ +package net.corda.coretests.crypto.internal + +import net.corda.coretesting.internal.DEV_ROOT_CA +import net.corda.testing.core.createCRL +import org.assertj.core.api.Assertions.assertThatIllegalArgumentException +import org.junit.Test + +class ProviderMapTest { + // https://github.com/corda/corda/pull/3997 + @Test(timeout = 300_000) + fun `verify CRL algorithms`() { + val crl = createCRL( + issuer = DEV_ROOT_CA, + revokedCerts = emptyList(), + signatureAlgorithm = "SHA256withECDSA" + ) + // This should pass. + crl.verify(DEV_ROOT_CA.keyPair.public) + + // Try changing the algorithm to EC will fail. + assertThatIllegalArgumentException().isThrownBy { + createCRL( + issuer = DEV_ROOT_CA, + revokedCerts = emptyList(), + signatureAlgorithm = "EC" + ) + }.withMessage("Unknown signature type requested: EC") + } +} diff --git a/core-tests/src/test/kotlin/net/corda/coretests/transactions/AttachmentsClassLoaderSerializationTests.kt b/core-tests/src/test/kotlin/net/corda/coretests/transactions/AttachmentsClassLoaderSerializationTests.kt index 4ca58d6b46..63f5461e46 100644 --- a/core-tests/src/test/kotlin/net/corda/coretests/transactions/AttachmentsClassLoaderSerializationTests.kt +++ b/core-tests/src/test/kotlin/net/corda/coretests/transactions/AttachmentsClassLoaderSerializationTests.kt @@ -55,7 +55,8 @@ class AttachmentsClassLoaderSerializationTests { arrayOf(isolatedId, att1, att2).map { storage.openAttachment(it)!! }, testNetworkParameters(), SecureHash.zeroHash, - { attachmentTrustCalculator.calculate(it) }, attachmentsClassLoaderCache = null) { classLoader -> + { attachmentTrustCalculator.calculate(it) }, attachmentsClassLoaderCache = null) { serializationContext -> + val classLoader = serializationContext.deserializationClassLoader val contractClass = Class.forName(ISOLATED_CONTRACT_CLASS_NAME, true, classLoader) val contract = contractClass.getDeclaredConstructor().newInstance() as Contract assertEquals("helloworld", contract.declaredField("magicString").value) diff --git a/core-tests/src/test/kotlin/net/corda/coretests/transactions/AttachmentsClassLoaderTests.kt b/core-tests/src/test/kotlin/net/corda/coretests/transactions/AttachmentsClassLoaderTests.kt index 986b6052ef..e0f3778418 100644 --- a/core-tests/src/test/kotlin/net/corda/coretests/transactions/AttachmentsClassLoaderTests.kt +++ b/core-tests/src/test/kotlin/net/corda/coretests/transactions/AttachmentsClassLoaderTests.kt @@ -24,8 +24,9 @@ import net.corda.core.node.NetworkParameters import net.corda.core.node.services.AttachmentId import net.corda.core.serialization.internal.AttachmentsClassLoader import net.corda.core.serialization.internal.AttachmentsClassLoaderCacheImpl -import net.corda.testing.common.internal.testNetworkParameters +import net.corda.core.transactions.LedgerTransaction import net.corda.node.services.attachments.NodeAttachmentTrustCalculator +import net.corda.testing.common.internal.testNetworkParameters import net.corda.testing.contracts.DummyContract import net.corda.testing.core.ALICE_NAME import net.corda.testing.core.BOB_NAME @@ -74,7 +75,7 @@ class AttachmentsClassLoaderTests { val BOB = TestIdentity(BOB_NAME, 80).party val dummyNotary = TestIdentity(DUMMY_NOTARY_NAME, 20) val DUMMY_NOTARY get() = dummyNotary.party - val PROGRAM_ID: String = "net.corda.testing.contracts.MyDummyContract" + const val PROGRAM_ID = "net.corda.testing.contracts.MyDummyContract" } @Rule @@ -89,7 +90,7 @@ class AttachmentsClassLoaderTests { private lateinit var internalStorage: InternalMockAttachmentStorage private lateinit var attachmentTrustCalculator: AttachmentTrustCalculator private val networkParameters = testNetworkParameters() - private val cacheFactory = TestingNamedCacheFactory() + private val cacheFactory = TestingNamedCacheFactory(1) private fun createClassloader( attachment: AttachmentId, @@ -541,6 +542,50 @@ class AttachmentsClassLoaderTests { } } + @Test(timeout=300_000) + fun `class loader not closed after cache starts evicting`() { + tempFolder.root.toPath().let { path -> + val transactions = mutableListOf() + val iterations = 10 + + val baseOutState = TransactionState(DummyContract.SingleOwnerState(0, ALICE), PROGRAM_ID, DUMMY_NOTARY, constraint = AlwaysAcceptAttachmentConstraint) + val inputs = emptyList>() + val outputs = listOf(baseOutState, baseOutState.copy(notary = ALICE), baseOutState.copy(notary = BOB)) + val commands = emptyList>() + val content = createContractString(PROGRAM_ID) + val timeWindow: TimeWindow? = null + val attachmentsClassLoaderCache = AttachmentsClassLoaderCacheImpl(cacheFactory) + val contractJarPath = ContractJarTestUtils.makeTestContractJar(path, PROGRAM_ID, content = content) + val attachments = createAttachments(contractJarPath) + + for(i in 1 .. iterations) { + val id = SecureHash.randomSHA256() + val privacySalt = PrivacySalt() + val transaction = createLedgerTransaction( + inputs, + outputs, + commands, + attachments, + id, + null, + timeWindow, + privacySalt, + testNetworkParameters(), + emptyList(), + isAttachmentTrusted = { true }, + attachmentsClassLoaderCache = attachmentsClassLoaderCache + ) + transactions.add(transaction) + System.gc() + Thread.sleep(1) + } + + transactions.forEach { + it.verify() + } + } + } + private fun createContractString(contractName: String, versionSeed: Int = 0): String { val pkgs = contractName.split(".") val className = pkgs.last() @@ -563,7 +608,7 @@ class AttachmentsClassLoaderTests { } """.trimIndent() - System.out.println(output) + println(output) return output } @@ -571,6 +616,7 @@ class AttachmentsClassLoaderTests { val attachment = object : AbstractAttachment({contractJarPath.inputStream().readBytes()}, uploader = "app") { @Suppress("OverridingDeprecatedMember") + @Deprecated("Use signerKeys. There is no requirement that attachment signers are Corda parties.") override val signers: List = emptyList() override val signerKeys: List = emptyList() override val size: Int = 1234 @@ -581,6 +627,7 @@ class AttachmentsClassLoaderTests { return listOf( object : AbstractAttachment({ISOLATED_CONTRACTS_JAR_PATH.openStream().readBytes()}, uploader = "app") { @Suppress("OverridingDeprecatedMember") + @Deprecated("Use signerKeys. There is no requirement that attachment signers are Corda parties.") override val signers: List = emptyList() override val signerKeys: List = emptyList() override val size: Int = 1234 @@ -589,6 +636,7 @@ class AttachmentsClassLoaderTests { object : AbstractAttachment({fakeAttachment("importantDoc.pdf", "I am a pdf!").inputStream().readBytes() }, uploader = "app") { @Suppress("OverridingDeprecatedMember") + @Deprecated("Use signerKeys. There is no requirement that attachment signers are Corda parties.") override val signers: List = emptyList() override val signerKeys: List = emptyList() override val size: Int = 1234 diff --git a/core-tests/src/test/kotlin/net/corda/coretests/transactions/MerkleTreeAgilityTest.kt b/core-tests/src/test/kotlin/net/corda/coretests/transactions/MerkleTreeAgilityTest.kt new file mode 100644 index 0000000000..bb56e8cb71 --- /dev/null +++ b/core-tests/src/test/kotlin/net/corda/coretests/transactions/MerkleTreeAgilityTest.kt @@ -0,0 +1,186 @@ +package net.corda.coretests.transactions + +import net.corda.core.contracts.ComponentGroupEnum.COMMANDS_GROUP +import net.corda.core.contracts.ComponentGroupEnum.INPUTS_GROUP +import net.corda.core.contracts.ComponentGroupEnum.NOTARY_GROUP +import net.corda.core.contracts.ComponentGroupEnum.OUTPUTS_GROUP +import net.corda.core.contracts.ComponentGroupEnum.SIGNERS_GROUP +import net.corda.core.contracts.ComponentGroupEnum.TIMEWINDOW_GROUP +import net.corda.core.contracts.StateRef +import net.corda.core.contracts.PrivacySalt +import net.corda.core.contracts.TimeWindow +import net.corda.core.contracts.TransactionState +import net.corda.core.crypto.DigestService +import net.corda.core.crypto.SecureHash +import net.corda.core.crypto.generateKeyPair +import net.corda.core.crypto.internal.DigestAlgorithmFactory +import net.corda.core.internal.SHA256BLAKE2s256DigestAlgorithm +import net.corda.core.internal.accessAvailableComponentHashes +import net.corda.core.internal.accessAvailableComponentNonces +import net.corda.core.internal.accessGroupMerkleRoots +import net.corda.core.serialization.serialize +import net.corda.core.transactions.ComponentGroup +import net.corda.core.transactions.WireTransaction + +import net.corda.testing.contracts.DummyContract +import net.corda.testing.contracts.DummyState +import net.corda.testing.core.BOB_NAME +import net.corda.testing.core.DUMMY_NOTARY_NAME +import net.corda.testing.core.SerializationEnvironmentRule +import net.corda.testing.core.TestIdentity +import net.corda.testing.core.dummyCommand +import org.junit.Before +import org.junit.Rule +import org.junit.Test +import java.time.Instant +import kotlin.test.assertEquals + +class MerkleTreeAgilityTest { + private companion object { + val DUMMY_KEY_1 = generateKeyPair() + val DUMMY_KEY_2 = generateKeyPair() + val BOB = TestIdentity(BOB_NAME, 80).party + val DUMMY_NOTARY = TestIdentity(DUMMY_NOTARY_NAME, 20).party + } + + @Rule + @JvmField + val testSerialization = SerializationEnvironmentRule() + + private val dummyOutState = TransactionState(DummyState(0), DummyContract.PROGRAM_ID, DUMMY_NOTARY) + private val stateRef1 = StateRef(SecureHash.randomSHA256(), 0) + private val stateRef2 = StateRef(SecureHash.randomSHA256(), 1) + private val stateRef3 = StateRef(SecureHash.randomSHA256(), 2) + private val stateRef4 = StateRef(SecureHash.randomSHA256(), 3) + + private val singleInput = listOf(stateRef1) // 1 elements. + private val threeInputs = listOf(stateRef1, stateRef2, stateRef3) // 3 elements. + private val fourInputs = listOf(stateRef1, stateRef2, stateRef3, stateRef4) // 4 elements. + + private val outputs = listOf(dummyOutState, dummyOutState.copy(notary = BOB)) // 2 elements. + private val commands = listOf(dummyCommand(DUMMY_KEY_1.public, DUMMY_KEY_2.public)) // 1 element. + private val notary = DUMMY_NOTARY + private val timeWindow = TimeWindow.fromOnly(Instant.now()) + private val privacySalt: PrivacySalt = PrivacySalt() + + private val singleInputsGroup by lazy { ComponentGroup(INPUTS_GROUP.ordinal, singleInput.map { it.serialize() }) } + private val threeInputsGroup by lazy { ComponentGroup(INPUTS_GROUP.ordinal, threeInputs.map { it.serialize() }) } + private val fourInputsGroup by lazy { ComponentGroup(INPUTS_GROUP.ordinal, fourInputs.map { it.serialize() }) } + private val outputGroup by lazy { ComponentGroup(OUTPUTS_GROUP.ordinal, outputs.map { it.serialize() }) } + private val commandGroup by lazy { ComponentGroup(COMMANDS_GROUP.ordinal, commands.map { it.value.serialize() }) } + private val notaryGroup by lazy { ComponentGroup(NOTARY_GROUP.ordinal, listOf(notary.serialize())) } + private val timeWindowGroup by lazy { ComponentGroup(TIMEWINDOW_GROUP.ordinal, listOf(timeWindow.serialize())) } + private val signersGroup by lazy { ComponentGroup(SIGNERS_GROUP.ordinal, commands.map { it.signers.serialize() }) } + + private val componentGroupsSingle by lazy { + listOf(singleInputsGroup, outputGroup, commandGroup, notaryGroup, timeWindowGroup, signersGroup) + } + + private val componentGroupsFourInputs by lazy { + listOf(fourInputsGroup, outputGroup, commandGroup, notaryGroup, timeWindowGroup, signersGroup) + } + + private val componentGroupsThreeInputs by lazy { + listOf(threeInputsGroup, outputGroup, commandGroup, notaryGroup, timeWindowGroup, signersGroup) + } + + private val defaultDigestService = DigestService.sha2_256 + private val customDigestService = DigestService("SHA256-BLAKE2S256-TEST") + + @Before + fun before() { + DigestAlgorithmFactory.registerClass(SHA256BLAKE2s256DigestAlgorithm::class.java.name) + } + + @Test(timeout = 300_000) + fun `component nonces are correct for custom preimage resistant hash algo`() { + val wireTransaction = WireTransaction(componentGroups = componentGroupsFourInputs, privacySalt = privacySalt, digestService = customDigestService) + val expected = componentGroupsFourInputs.associate { + it.groupIndex to it.components.mapIndexed { componentIndexInGroup, _ -> + customDigestService.computeNonce(privacySalt, it.groupIndex, componentIndexInGroup) + } + } + + assertEquals(expected, wireTransaction.accessAvailableComponentNonces()) + } + + @Test(timeout = 300_000) + fun `component nonces are correct for default SHA256 hash algo`() { + val wireTransaction = WireTransaction(componentGroups = componentGroupsFourInputs, privacySalt = privacySalt) + val expected = componentGroupsFourInputs.associate { + it.groupIndex to it.components.mapIndexed { componentIndexInGroup, componentBytes -> + defaultDigestService.componentHash(componentBytes, privacySalt, it.groupIndex, componentIndexInGroup) + } + } + + assertEquals(expected, wireTransaction.accessAvailableComponentNonces()) + } + + @Test(timeout = 300_000) + fun `custom algorithm transaction pads leaf in single component component group`() { + val wtx = WireTransaction(componentGroups = componentGroupsSingle, privacySalt = privacySalt, digestService = customDigestService) + + val inputsTreeLeaves: List = wtx.accessAvailableComponentHashes()[INPUTS_GROUP.ordinal]!! + val expected = customDigestService.hash(inputsTreeLeaves[0].bytes + customDigestService.zeroHash.bytes) + + assertEquals(expected, wtx.accessGroupMerkleRoots()[INPUTS_GROUP.ordinal]!!) + } + + @Test(timeout = 300_000) + fun `default algorithm transaction does not pad leaf in single component component group`() { + val wtx = WireTransaction(componentGroups = componentGroupsSingle, privacySalt = privacySalt, digestService = defaultDigestService) + + val inputsTreeLeaves: List = wtx.accessAvailableComponentHashes()[INPUTS_GROUP.ordinal]!! + val expected = inputsTreeLeaves[0] + + assertEquals(expected, wtx.accessGroupMerkleRoots()[INPUTS_GROUP.ordinal]!!) + } + + @Test(timeout = 300_000) + fun `custom algorithm transaction has expected root for four components component group tree`() { + val wtx = WireTransaction(componentGroups = componentGroupsFourInputs, privacySalt = privacySalt, digestService = customDigestService) + + val inputsTreeLeaves: List = wtx.accessAvailableComponentHashes()[INPUTS_GROUP.ordinal]!! + val h1 = customDigestService.hash(inputsTreeLeaves[0].bytes + inputsTreeLeaves[1].bytes) + val h2 = customDigestService.hash(inputsTreeLeaves[2].bytes + inputsTreeLeaves[3].bytes) + val expected = customDigestService.hash(h1.bytes + h2.bytes) + + assertEquals(expected, wtx.accessGroupMerkleRoots()[INPUTS_GROUP.ordinal]!!) + } + + @Test(timeout = 300_000) + fun `default algorithm transaction has expected root for four components component group tree`() { + val wtx = WireTransaction(componentGroups = componentGroupsFourInputs, privacySalt = privacySalt, digestService = defaultDigestService) + + val inputsTreeLeaves: List = wtx.accessAvailableComponentHashes()[INPUTS_GROUP.ordinal]!! + val h1 = defaultDigestService.hash(inputsTreeLeaves[0].bytes + inputsTreeLeaves[1].bytes) + val h2 = defaultDigestService.hash(inputsTreeLeaves[2].bytes + inputsTreeLeaves[3].bytes) + val expected = defaultDigestService.hash(h1.bytes + h2.bytes) + + assertEquals(expected, wtx.accessGroupMerkleRoots()[INPUTS_GROUP.ordinal]!!) + } + + @Test(timeout = 300_000) + fun `custom algorithm transaction has expected root for three components component group tree`() { + val wtx = WireTransaction(componentGroups = componentGroupsThreeInputs, privacySalt = privacySalt, digestService = customDigestService) + + val inputsTreeLeaves: List = wtx.accessAvailableComponentHashes()[INPUTS_GROUP.ordinal]!! + val h1 = customDigestService.hash(inputsTreeLeaves[0].bytes + inputsTreeLeaves[1].bytes) + val h2 = customDigestService.hash(inputsTreeLeaves[2].bytes + customDigestService.zeroHash.bytes) + val expected = customDigestService.hash(h1.bytes + h2.bytes) + + assertEquals(expected, wtx.accessGroupMerkleRoots()[INPUTS_GROUP.ordinal]!!) + } + + @Test(timeout = 300_000) + fun `default algorithm transaction has expected root for three components component group tree`() { + val wtx = WireTransaction(componentGroups = componentGroupsThreeInputs, privacySalt = privacySalt, digestService = defaultDigestService) + + val inputsTreeLeaves: List = wtx.accessAvailableComponentHashes()[INPUTS_GROUP.ordinal]!! + val h1 = defaultDigestService.hash(inputsTreeLeaves[0].bytes + inputsTreeLeaves[1].bytes) + val h2 = defaultDigestService.hash(inputsTreeLeaves[2].bytes + defaultDigestService.zeroHash.bytes) + val expected = defaultDigestService.hash(h1.bytes + h2.bytes) + + assertEquals(expected, wtx.accessGroupMerkleRoots()[INPUTS_GROUP.ordinal]!!) + } +} \ No newline at end of file diff --git a/core-tests/src/test/kotlin/net/corda/coretests/transactions/TransactionBuilderTest.kt b/core-tests/src/test/kotlin/net/corda/coretests/transactions/TransactionBuilderTest.kt index 882466a059..0f3c35300b 100644 --- a/core-tests/src/test/kotlin/net/corda/coretests/transactions/TransactionBuilderTest.kt +++ b/core-tests/src/test/kotlin/net/corda/coretests/transactions/TransactionBuilderTest.kt @@ -3,7 +3,16 @@ package net.corda.coretests.transactions import com.nhaarman.mockito_kotlin.doReturn import com.nhaarman.mockito_kotlin.mock import com.nhaarman.mockito_kotlin.whenever -import net.corda.core.contracts.* +import net.corda.core.contracts.Command +import net.corda.core.contracts.ContractAttachment +import net.corda.core.contracts.HashAttachmentConstraint +import net.corda.core.contracts.PrivacySalt +import net.corda.core.contracts.SignatureAttachmentConstraint +import net.corda.core.contracts.StateAndRef +import net.corda.core.contracts.StateRef +import net.corda.core.contracts.TimeWindow +import net.corda.core.contracts.TransactionState +import net.corda.core.contracts.TransactionVerificationException import net.corda.core.cordapp.CordappProvider import net.corda.core.crypto.CompositeKey import net.corda.core.crypto.DigestService @@ -20,11 +29,16 @@ import net.corda.core.node.services.IdentityService import net.corda.core.node.services.NetworkParametersService import net.corda.core.serialization.serialize import net.corda.core.transactions.TransactionBuilder +import net.corda.coretesting.internal.rigorousMock import net.corda.testing.common.internal.testNetworkParameters import net.corda.testing.contracts.DummyContract import net.corda.testing.contracts.DummyState -import net.corda.testing.core.* -import net.corda.coretesting.internal.rigorousMock +import net.corda.testing.core.ALICE_NAME +import net.corda.testing.core.BOB_NAME +import net.corda.testing.core.DUMMY_NOTARY_NAME +import net.corda.testing.core.DummyCommandData +import net.corda.testing.core.SerializationEnvironmentRule +import net.corda.testing.core.TestIdentity import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThatThrownBy import org.junit.Assert.assertFalse @@ -35,6 +49,7 @@ import org.junit.Rule import org.junit.Test import java.security.PublicKey import java.time.Instant +import kotlin.test.assertFailsWith class TransactionBuilderTest { @Rule @@ -299,4 +314,22 @@ class TransactionBuilderTest { HashAgility.init() } } + + @Test(timeout=300_000) + fun `toWireTransaction fails if no scheme is registered with schemeId`() { + val outputState = TransactionState( + data = DummyState(), + contract = DummyContract.PROGRAM_ID, + notary = notary, + constraint = HashAttachmentConstraint(contractAttachmentId) + ) + val builder = TransactionBuilder() + .addOutputState(outputState) + .addCommand(DummyCommandData, notary.owningKey) + + val schemeId = 7 + assertFailsWith("Could not find custom serialization scheme with SchemeId = $schemeId.") { + builder.toWireTransaction(services, schemeId) + } + } } diff --git a/core-tests/src/test/kotlin/net/corda/coretests/transactions/TransactionTests.kt b/core-tests/src/test/kotlin/net/corda/coretests/transactions/TransactionTests.kt index 47d4171c1d..62254d6b4e 100644 --- a/core-tests/src/test/kotlin/net/corda/coretests/transactions/TransactionTests.kt +++ b/core-tests/src/test/kotlin/net/corda/coretests/transactions/TransactionTests.kt @@ -21,6 +21,7 @@ import net.corda.testing.internal.createWireTransaction import net.corda.testing.internal.fakeAttachment import net.corda.coretesting.internal.rigorousMock import net.corda.testing.internal.TestingNamedCacheFactory +import org.assertj.core.api.Assertions.fail import org.junit.Before import org.junit.Rule import org.junit.Test @@ -36,6 +37,7 @@ import kotlin.test.assertNotEquals @RunWith(Parameterized::class) class TransactionTests(private val digestService : DigestService) { private companion object { + const val ISOLATED_JAR = "isolated-4.0.jar" val DUMMY_KEY_1 = generateKeyPair() val DUMMY_KEY_2 = generateKeyPair() val DUMMY_CASH_ISSUER_KEY = entropyToKeyPair(BigInteger.valueOf(10)) @@ -200,15 +202,15 @@ class TransactionTests(private val digestService : DigestService) { val outputs = listOf(outState) val commands = emptyList>() - val attachments = listOf(object : AbstractAttachment({ - AttachmentsClassLoaderTests::class.java.getResource("isolated-4.0.jar").openStream().readBytes() + val attachments = listOf(ContractAttachment(object : AbstractAttachment({ + (AttachmentsClassLoaderTests::class.java.getResource(ISOLATED_JAR) ?: fail("Missing $ISOLATED_JAR")).openStream().readBytes() }, TESTDSL_UPLOADER) { @Suppress("OverridingDeprecatedMember") override val signers: List = emptyList() override val signerKeys: List = emptyList() override val size: Int = 1234 override val id: SecureHash = SecureHash.zeroHash - }) + }, DummyContract.PROGRAM_ID)) val id = digestService.randomHash() val timeWindow: TimeWindow? = null val privacySalt = PrivacySalt(digestService.digestLength) diff --git a/core/build.gradle b/core/build.gradle index d024d164c0..4ed50e21a3 100644 --- a/core/build.gradle +++ b/core/build.gradle @@ -12,6 +12,10 @@ description 'Corda core' // required by DJVM and Avian JVM (for running inside the SGX enclave) which only supports Java 8. targetCompatibility = VERSION_1_8 +sourceSets { + obfuscator +} + configurations { integrationTestCompile.extendsFrom testCompile integrationTestRuntimeOnly.extendsFrom testRuntimeOnly @@ -22,6 +26,9 @@ configurations { dependencies { + obfuscatorImplementation "org.jetbrains.kotlin:kotlin-stdlib-jdk8:$kotlin_version" + + testImplementation sourceSets.obfuscator.output testImplementation "org.junit.jupiter:junit-jupiter-api:${junit_jupiter_version}" testImplementation "junit:junit:$junit_version" testRuntimeOnly "org.junit.vintage:junit-vintage-engine:${junit_vintage_version}" @@ -110,7 +117,16 @@ configurations { } -test{ +processTestResources { + inputs.files(jar) + into("zip") { + from(jar) { + rename { "core.jar" } + } + } +} + +test { maxParallelForks = (System.env.CORDA_CORE_TESTING_FORKS == null) ? 1 : "$System.env.CORDA_CORE_TESTING_FORKS".toInteger() } @@ -163,3 +179,10 @@ scanApi { publish { name jar.baseName } + +tasks.register("writeTestResources", JavaExec) { + classpath sourceSets.obfuscator.output + classpath sourceSets.obfuscator.runtimeClasspath + main 'net.corda.core.internal.utilities.TestResourceWriter' + args new File(sourceSets.test.resources.srcDirs.first(), "zip").toString() +} diff --git a/core/src/main/kotlin/net/corda/core/contracts/TransactionVerificationException.kt b/core/src/main/kotlin/net/corda/core/contracts/TransactionVerificationException.kt index e62feb845f..66eca7cd14 100644 --- a/core/src/main/kotlin/net/corda/core/contracts/TransactionVerificationException.kt +++ b/core/src/main/kotlin/net/corda/core/contracts/TransactionVerificationException.kt @@ -343,7 +343,11 @@ abstract class TransactionVerificationException(val txId: SecureHash, message: S "You will need to manually install the CorDapp to whitelist it for use.") @KeepForDJVM - class UnsupportedHashTypeException(txId: SecureHash) : TransactionVerificationException(txId, "The transaction Id is defined by an unsupported hash type", null); + class UnsupportedHashTypeException(txId: SecureHash) : TransactionVerificationException(txId, "The transaction Id is defined by an unsupported hash type", null) + + @KeepForDJVM + class AttachmentTooBigException(txId: SecureHash) : TransactionVerificationException( + txId, "The transaction attachments are too large and exceed both max transaction size and the maximum allowed compression ratio", null) /* If you add a new class extending [TransactionVerificationException], please add a test in `TransactionVerificationExceptionSerializationTests` diff --git a/core/src/main/kotlin/net/corda/core/crypto/DigestAlgorithm.kt b/core/src/main/kotlin/net/corda/core/crypto/DigestAlgorithm.kt index ef6b6971b7..13610eba5e 100644 --- a/core/src/main/kotlin/net/corda/core/crypto/DigestAlgorithm.kt +++ b/core/src/main/kotlin/net/corda/core/crypto/DigestAlgorithm.kt @@ -25,8 +25,16 @@ interface DigestAlgorithm { fun digest(bytes: ByteArray): ByteArray /** - * Computes the digest of the [ByteArray] which is resistant to pre-image attacks. + * Computes the digest of the [ByteArray] which is resistant to pre-image attacks. Only used to calculate the hash of the leaves of the + * ComponentGroup Merkle tree, starting from its serialized components. * Default implementation provides double hashing, but can it be changed to single hashing or something else for better performance. */ - fun preImageResistantDigest(bytes: ByteArray): ByteArray = digest(digest(bytes)) + fun componentDigest(bytes: ByteArray): ByteArray = digest(digest(bytes)) + + /** + * Computes the digest of the [ByteArray] which is resistant to pre-image attacks. Only used to calculate the nonces for the leaves of + * the ComponentGroup Merkle tree. + * Default implementation provides double hashing, but can it be changed to single hashing or something else for better performance. + */ + fun nonceDigest(bytes: ByteArray): ByteArray = digest(digest(bytes)) } diff --git a/core/src/main/kotlin/net/corda/core/crypto/DigestService.kt b/core/src/main/kotlin/net/corda/core/crypto/DigestService.kt index 2f88e9d959..e163993443 100644 --- a/core/src/main/kotlin/net/corda/core/crypto/DigestService.kt +++ b/core/src/main/kotlin/net/corda/core/crypto/DigestService.kt @@ -75,23 +75,20 @@ data class DigestService(val hashAlgorithm: String) { val zeroHash: SecureHash get() = SecureHash.zeroHashFor(hashAlgorithm) -// val privacySalt: PrivacySalt -// get() = PrivacySalt.createFor(hashAlgorithm) - /** * Compute the hash of each serialised component so as to be used as Merkle tree leaf. The resultant output (leaf) is * calculated using the service's hash algorithm, thus HASH(HASH(nonce || serializedComponent)) for SHA2-256 and other - * algorithms loaded via JCA [MessageDigest], or DigestAlgorithm.preImageResistantDigest(nonce || serializedComponent) + * algorithms loaded via JCA [MessageDigest], or DigestAlgorithm.componentDigest(nonce || serializedComponent) * otherwise, where nonce is computed from [computeNonce]. */ fun componentHash(opaqueBytes: OpaqueBytes, privacySalt: PrivacySalt, componentGroupIndex: Int, internalIndex: Int): SecureHash = componentHash(computeNonce(privacySalt, componentGroupIndex, internalIndex), opaqueBytes) /** Return the HASH(HASH(nonce || serializedComponent)) for SHA2-256 and other algorithms loaded via JCA [MessageDigest], - * otherwise it's defined by DigestAlgorithm.preImageResistantDigest(nonce || serializedComponent). */ + * otherwise it's defined by DigestAlgorithm.componentDigest(nonce || serializedComponent). */ fun componentHash(nonce: SecureHash, opaqueBytes: OpaqueBytes): SecureHash { val data = nonce.bytes + opaqueBytes.bytes - return SecureHash.preImageResistantHashAs(hashAlgorithm, data) + return SecureHash.componentHashAs(hashAlgorithm, data) } /** @@ -109,11 +106,11 @@ data class DigestService(val hashAlgorithm: String) { * @param groupIndex the fixed index (ordinal) of this component group. * @param internalIndex the internal index of this object in its corresponding components list. * @return HASH(HASH(privacySalt || groupIndex || internalIndex)) for SHA2-256 and other algorithms loaded via JCA [MessageDigest], - * otherwise it's defined by DigestAlgorithm.preImageResistantDigest(privacySalt || groupIndex || internalIndex). + * otherwise it's defined by DigestAlgorithm.nonceDigest(privacySalt || groupIndex || internalIndex). */ fun computeNonce(privacySalt: PrivacySalt, groupIndex: Int, internalIndex: Int) : SecureHash { val data = (privacySalt.bytes + ByteBuffer.allocate(NONCE_SIZE).putInt(groupIndex).putInt(internalIndex).array()) - return SecureHash.preImageResistantHashAs(hashAlgorithm, data) + return SecureHash.nonceHashAs(hashAlgorithm, data) } } diff --git a/core/src/main/kotlin/net/corda/core/crypto/MerkleTree.kt b/core/src/main/kotlin/net/corda/core/crypto/MerkleTree.kt index f3dda598b4..692d4e1345 100644 --- a/core/src/main/kotlin/net/corda/core/crypto/MerkleTree.kt +++ b/core/src/main/kotlin/net/corda/core/crypto/MerkleTree.kt @@ -37,19 +37,19 @@ sealed class MerkleTree { require(algorithms.size == 1) { "Cannot build Merkle tree with multiple hash algorithms: $algorithms" } - val leaves = padWithZeros(allLeavesHashes).map { Leaf(it) } + val leaves = padWithZeros(allLeavesHashes, nodeDigestService.hashAlgorithm == SecureHash.SHA2_256).map { Leaf(it) } return buildMerkleTree(leaves, nodeDigestService) } // If number of leaves in the tree is not a power of 2, we need to pad it with zero hashes. - private fun padWithZeros(allLeavesHashes: List): List { + private fun padWithZeros(allLeavesHashes: List, singleLeafWithoutPadding: Boolean): List { var n = allLeavesHashes.size - if (isPow2(n)) return allLeavesHashes + if (isPow2(n) && (n > 1 || singleLeafWithoutPadding)) return allLeavesHashes val paddedHashes = ArrayList(allLeavesHashes) val zeroHash = SecureHash.zeroHashFor(paddedHashes[0].algorithm) - while (!isPow2(n++)) { + do { paddedHashes.add(zeroHash) - } + } while (!isPow2(++n)) return paddedHashes } diff --git a/core/src/main/kotlin/net/corda/core/crypto/SecureHash.kt b/core/src/main/kotlin/net/corda/core/crypto/SecureHash.kt index abee5a22f5..683d2610cc 100644 --- a/core/src/main/kotlin/net/corda/core/crypto/SecureHash.kt +++ b/core/src/main/kotlin/net/corda/core/crypto/SecureHash.kt @@ -216,13 +216,31 @@ sealed class SecureHash(bytes: ByteArray) : OpaqueBytes(bytes) { * @param bytes The [ByteArray] to hash. */ @JvmStatic - fun preImageResistantHashAs(algorithm: String, bytes: ByteArray): SecureHash { + fun componentHashAs(algorithm: String, bytes: ByteArray): SecureHash { return if (algorithm == SHA2_256) { sha256Twice(bytes) } else { val digest = digestFor(algorithm).get() - val firstHash = digest.preImageResistantDigest(bytes) - HASH(algorithm, digest.digest(firstHash)) + val hash = digest.componentDigest(bytes) + HASH(algorithm, hash) + } + } + + /** + * Computes the digest of the [ByteArray] which is resistant to pre-image attacks. + * It computes the hash of the hash for SHA2-256 and other algorithms loaded via JCA [MessageDigest]. + * For custom algorithms the strategy can be modified via [DigestAlgorithm]. + * @param algorithm The [MessageDigest] algorithm to use. + * @param bytes The [ByteArray] to hash. + */ + @JvmStatic + fun nonceHashAs(algorithm: String, bytes: ByteArray): SecureHash { + return if (algorithm == SHA2_256) { + sha256Twice(bytes) + } else { + val digest = digestFor(algorithm).get() + val hash = digest.nonceDigest(bytes) + HASH(algorithm, hash) } } diff --git a/core/src/main/kotlin/net/corda/core/crypto/internal/DigestAlgorithmFactory.kt b/core/src/main/kotlin/net/corda/core/crypto/internal/DigestAlgorithmFactory.kt index 532b95f4c1..892506aa76 100644 --- a/core/src/main/kotlin/net/corda/core/crypto/internal/DigestAlgorithmFactory.kt +++ b/core/src/main/kotlin/net/corda/core/crypto/internal/DigestAlgorithmFactory.kt @@ -28,9 +28,7 @@ sealed class DigestAlgorithmFactory { } private class CustomAlgorithmFactory(className: String) : DigestAlgorithmFactory() { - val constructor: Constructor = javaClass - .classLoader - .loadClass(className) + val constructor: Constructor = Class.forName(className, false, javaClass.classLoader) .asSubclass(DigestAlgorithm::class.java) .getConstructor() override val algorithm: String = constructor.newInstance().algorithm diff --git a/core/src/main/kotlin/net/corda/core/internal/ClassGraphUtils.kt b/core/src/main/kotlin/net/corda/core/internal/ClassGraphUtils.kt index 2d75ff1c82..fb420e373a 100644 --- a/core/src/main/kotlin/net/corda/core/internal/ClassGraphUtils.kt +++ b/core/src/main/kotlin/net/corda/core/internal/ClassGraphUtils.kt @@ -11,9 +11,9 @@ import kotlin.concurrent.withLock private val pooledScanMutex = ReentrantLock() /** - * Use this rather than the built in implementation of [scan] on [ClassGraph]. The built in implementation of [scan] creates - * a thread pool every time resulting in too many threads. This one uses a mutex to restrict concurrency. + * Use this rather than the built-in implementation of [ClassGraph.scan]. The built-in implementation creates + * a thread pool every time, resulting in too many threads. This one uses a mutex to restrict concurrency. */ fun ClassGraph.pooledScan(): ScanResult { - return pooledScanMutex.withLock { this@pooledScan.scan() } + return pooledScanMutex.withLock(::scan) } diff --git a/core/src/main/kotlin/net/corda/core/internal/ClassLoadingUtils.kt b/core/src/main/kotlin/net/corda/core/internal/ClassLoadingUtils.kt index 32ae2608d8..5ead87ca59 100644 --- a/core/src/main/kotlin/net/corda/core/internal/ClassLoadingUtils.kt +++ b/core/src/main/kotlin/net/corda/core/internal/ClassLoadingUtils.kt @@ -23,7 +23,7 @@ import net.corda.core.serialization.internal.AttachmentURLStreamHandlerFactory.a fun createInstancesOfClassesImplementing(classloader: ClassLoader, clazz: Class, classVersionRange: IntRange? = null): Set { return getNamesOfClassesImplementing(classloader, clazz, classVersionRange) - .map { classloader.loadClass(it).asSubclass(clazz) } + .map { Class.forName(it, false, classloader).asSubclass(clazz) } .mapTo(LinkedHashSet()) { it.kotlin.objectOrNewInstance() } } diff --git a/core/src/main/kotlin/net/corda/core/internal/ConstraintsUtils.kt b/core/src/main/kotlin/net/corda/core/internal/ConstraintsUtils.kt index c05ae94680..2cdb80fad1 100644 --- a/core/src/main/kotlin/net/corda/core/internal/ConstraintsUtils.kt +++ b/core/src/main/kotlin/net/corda/core/internal/ConstraintsUtils.kt @@ -16,6 +16,7 @@ typealias Version = Int * Attention: this value affects consensus, so it requires a minimum platform version bump in order to be changed. */ const val MAX_NUMBER_OF_KEYS_IN_SIGNATURE_CONSTRAINT = 20 +private const val DJVM_SANDBOX_PREFIX = "sandbox." private val log = loggerFor() @@ -29,10 +30,14 @@ val Attachment.contractVersion: Version get() = if (this is ContractAttachment) val ContractState.requiredContractClassName: String? get() { val annotation = javaClass.getAnnotation(BelongsToContract::class.java) if (annotation != null) { - return annotation.value.java.typeName + return annotation.value.java.typeName.removePrefix(DJVM_SANDBOX_PREFIX) } val enclosingClass = javaClass.enclosingClass ?: return null - return if (Contract::class.java.isAssignableFrom(enclosingClass)) enclosingClass.typeName else null + return if (Contract::class.java.isAssignableFrom(enclosingClass)) { + enclosingClass.typeName.removePrefix(DJVM_SANDBOX_PREFIX) + } else { + null + } } /** diff --git a/core/src/main/kotlin/net/corda/core/internal/CordaUtils.kt b/core/src/main/kotlin/net/corda/core/internal/CordaUtils.kt index cda6b328e2..e0f41d2fdd 100644 --- a/core/src/main/kotlin/net/corda/core/internal/CordaUtils.kt +++ b/core/src/main/kotlin/net/corda/core/internal/CordaUtils.kt @@ -28,7 +28,7 @@ import java.util.jar.JarInputStream // *Internal* Corda-specific utilities. -const val PLATFORM_VERSION = 9 +const val PLATFORM_VERSION = 10 fun ServicesForResolution.ensureMinimumPlatformVersion(requiredMinPlatformVersion: Int, feature: String) { checkMinimumPlatformVersion(networkParameters.minimumPlatformVersion, requiredMinPlatformVersion, feature) diff --git a/core/src/main/kotlin/net/corda/core/internal/InternalUtils.kt b/core/src/main/kotlin/net/corda/core/internal/InternalUtils.kt index 51a37bbc7a..cb7844b734 100644 --- a/core/src/main/kotlin/net/corda/core/internal/InternalUtils.kt +++ b/core/src/main/kotlin/net/corda/core/internal/InternalUtils.kt @@ -56,7 +56,9 @@ import java.security.cert.TrustAnchor import java.security.cert.X509Certificate import java.time.Duration import java.time.temporal.Temporal -import java.util.* +import java.util.Collections +import java.util.PrimitiveIterator +import java.util.Spliterator import java.util.Spliterator.DISTINCT import java.util.Spliterator.IMMUTABLE import java.util.Spliterator.NONNULL @@ -64,6 +66,7 @@ import java.util.Spliterator.ORDERED import java.util.Spliterator.SIZED import java.util.Spliterator.SORTED import java.util.Spliterator.SUBSIZED +import java.util.Spliterators import java.util.concurrent.ExecutorService import java.util.concurrent.TimeUnit import java.util.stream.Collectors diff --git a/core/src/main/kotlin/net/corda/core/internal/PlatformVersionSwitches.kt b/core/src/main/kotlin/net/corda/core/internal/PlatformVersionSwitches.kt index 3628c05526..c6d93f272f 100644 --- a/core/src/main/kotlin/net/corda/core/internal/PlatformVersionSwitches.kt +++ b/core/src/main/kotlin/net/corda/core/internal/PlatformVersionSwitches.kt @@ -16,6 +16,6 @@ object PlatformVersionSwitches { const val LIMIT_KEYS_IN_SIGNATURE_CONSTRAINTS = 5 const val BATCH_DOWNLOAD_COUNTERPARTY_BACKCHAIN = 6 const val ENABLE_P2P_COMPRESSION = 7 - const val CERTIFICATE_ROTATION = 9 const val RESTRICTED_DATABASE_OPERATIONS = 7 + const val CERTIFICATE_ROTATION = 9 } \ No newline at end of file diff --git a/core/src/main/kotlin/net/corda/core/internal/TransactionUtils.kt b/core/src/main/kotlin/net/corda/core/internal/TransactionUtils.kt index 976ef575f6..7bdfec76be 100644 --- a/core/src/main/kotlin/net/corda/core/internal/TransactionUtils.kt +++ b/core/src/main/kotlin/net/corda/core/internal/TransactionUtils.kt @@ -54,7 +54,7 @@ fun combinedHash(components: Iterable, digestService: DigestService) components.forEach { stream.write(it.bytes) } - return digestService.hash(stream.toByteArray()); + return digestService.hash(stream.toByteArray()) } /** @@ -114,14 +114,14 @@ fun deserialiseCommands( componentGroups: List, forceDeserialize: Boolean = false, factory: SerializationFactory = SerializationFactory.defaultFactory, - @Suppress("UNUSED_PARAMETER") context: SerializationContext = factory.defaultContext, + context: SerializationContext = factory.defaultContext, digestService: DigestService = DigestService.sha2_256 ): List> { // TODO: we could avoid deserialising unrelated signers. // However, current approach ensures the transaction is not malformed // and it will throw if any of the signers objects is not List of public keys). - val signersList: List> = uncheckedCast(deserialiseComponentGroup(componentGroups, List::class, ComponentGroupEnum.SIGNERS_GROUP, forceDeserialize)) - val commandDataList: List = deserialiseComponentGroup(componentGroups, CommandData::class, ComponentGroupEnum.COMMANDS_GROUP, forceDeserialize) + val signersList: List> = uncheckedCast(deserialiseComponentGroup(componentGroups, List::class, ComponentGroupEnum.SIGNERS_GROUP, forceDeserialize, factory, context)) + val commandDataList: List = deserialiseComponentGroup(componentGroups, CommandData::class, ComponentGroupEnum.COMMANDS_GROUP, forceDeserialize, factory, context) val group = componentGroups.firstOrNull { it.groupIndex == ComponentGroupEnum.COMMANDS_GROUP.ordinal } return if (group is FilteredComponentGroup) { check(commandDataList.size <= signersList.size) { @@ -154,7 +154,9 @@ fun createComponentGroups(inputs: List, timeWindow: TimeWindow?, references: List, networkParametersHash: SecureHash?): List { - val serialize = { value: Any, _: Int -> value.serialize() } + val serializationFactory = SerializationFactory.defaultFactory + val serializationContext = serializationFactory.defaultContext + val serialize = { value: Any, _: Int -> value.serialize(serializationFactory, serializationContext) } val componentGroupMap: MutableList = mutableListOf() if (inputs.isNotEmpty()) componentGroupMap.add(ComponentGroup(ComponentGroupEnum.INPUTS_GROUP.ordinal, inputs.lazyMapped(serialize))) if (references.isNotEmpty()) componentGroupMap.add(ComponentGroup(ComponentGroupEnum.REFERENCES_GROUP.ordinal, references.lazyMapped(serialize))) @@ -177,7 +179,11 @@ fun createComponentGroups(inputs: List, */ @KeepForDJVM data class SerializedStateAndRef(val serializedState: SerializedBytes>, val ref: StateRef) { - fun toStateAndRef(): StateAndRef = StateAndRef(serializedState.deserialize(), ref) + fun toStateAndRef(factory: SerializationFactory, context: SerializationContext) = StateAndRef(serializedState.deserialize(factory, context), ref) + fun toStateAndRef(): StateAndRef { + val factory = SerializationFactory.defaultFactory + return toStateAndRef(factory, factory.defaultContext) + } } /** Check that network parameters hash on this transaction is the current hash for the network. */ diff --git a/core/src/main/kotlin/net/corda/core/internal/TransactionVerifierServiceInternal.kt b/core/src/main/kotlin/net/corda/core/internal/TransactionVerifierServiceInternal.kt index e7ca576618..0171d71e91 100644 --- a/core/src/main/kotlin/net/corda/core/internal/TransactionVerifierServiceInternal.kt +++ b/core/src/main/kotlin/net/corda/core/internal/TransactionVerifierServiceInternal.kt @@ -3,14 +3,40 @@ package net.corda.core.internal import net.corda.core.DeleteForDJVM import net.corda.core.KeepForDJVM import net.corda.core.concurrent.CordaFuture -import net.corda.core.contracts.* +import net.corda.core.contracts.Attachment +import net.corda.core.contracts.Contract +import net.corda.core.contracts.ContractAttachment +import net.corda.core.contracts.ContractClassName +import net.corda.core.contracts.ContractState +import net.corda.core.contracts.HashAttachmentConstraint +import net.corda.core.contracts.SignatureAttachmentConstraint +import net.corda.core.contracts.StateAndRef +import net.corda.core.contracts.StateRef +import net.corda.core.contracts.TransactionState +import net.corda.core.contracts.TransactionVerificationException +import net.corda.core.contracts.TransactionVerificationException.ConflictingAttachmentsRejection +import net.corda.core.contracts.TransactionVerificationException.ConstraintPropagationRejection +import net.corda.core.contracts.TransactionVerificationException.ContractConstraintRejection +import net.corda.core.contracts.TransactionVerificationException.ContractCreationError +import net.corda.core.contracts.TransactionVerificationException.ContractRejection +import net.corda.core.contracts.TransactionVerificationException.Direction +import net.corda.core.contracts.TransactionVerificationException.DuplicateAttachmentsRejection +import net.corda.core.contracts.TransactionVerificationException.InvalidConstraintRejection +import net.corda.core.contracts.TransactionVerificationException.MissingAttachmentRejection +import net.corda.core.contracts.TransactionVerificationException.NotaryChangeInWrongTransactionType import net.corda.core.contracts.TransactionVerificationException.TransactionContractConflictException +import net.corda.core.contracts.TransactionVerificationException.TransactionDuplicateEncumbranceException +import net.corda.core.contracts.TransactionVerificationException.TransactionMissingEncumbranceException +import net.corda.core.contracts.TransactionVerificationException.TransactionNonMatchingEncumbranceException +import net.corda.core.contracts.TransactionVerificationException.TransactionNotaryMismatchEncumbranceException +import net.corda.core.contracts.TransactionVerificationException.TransactionRequiredContractUnspecifiedException import net.corda.core.crypto.CompositeKey import net.corda.core.crypto.SecureHash import net.corda.core.internal.rules.StateContractValidationEnforcementRule import net.corda.core.transactions.LedgerTransaction -import net.corda.core.utilities.contextLogger +import net.corda.core.utilities.loggerFor import java.util.function.Function +import java.util.function.Supplier @DeleteForDJVM interface TransactionVerifierServiceInternal { @@ -22,16 +48,54 @@ interface TransactionVerifierServiceInternal { */ fun LedgerTransaction.prepareVerify(attachments: List) = internalPrepareVerify(attachments) +interface Verifier { + + /** + * Placeholder function for the verification logic. + */ + fun verify() +} + +// This class allows us unit-test transaction verification more easily. +abstract class AbstractVerifier( + protected val ltx: LedgerTransaction, + protected val transactionClassLoader: ClassLoader +) : Verifier { + protected abstract val transaction: Supplier + + protected companion object { + @JvmField + val logger = loggerFor() + } + + /** + * Check that the transaction is internally consistent, and then check that it is + * contract-valid by running verify() for each input and output state contract. + * If any contract fails to verify, the whole transaction is considered to be invalid. + * + * Note: Reference states are not verified. + */ + final override fun verify() { + try { + TransactionVerifier(transactionClassLoader).apply(transaction) + } catch (e: TransactionVerificationException) { + logger.error("Error validating transaction ${ltx.id}.", e.cause) + throw e + } + } +} + /** * Because we create a separate [LedgerTransaction] onto which we need to perform verification, it becomes important we don't verify the * wrong object instance. This class helps avoid that. */ -abstract class Verifier(val ltx: LedgerTransaction, protected val transactionClassLoader: ClassLoader) { - private val inputStates: List> = ltx.inputs.map { it.state } - private val allStates: List> = inputStates + ltx.references.map { it.state } + ltx.outputs +@KeepForDJVM +private class Validator(private val ltx: LedgerTransaction, private val transactionClassLoader: ClassLoader) { + private val inputStates: List> = ltx.inputs.map(StateAndRef::state) + private val allStates: List> = inputStates + ltx.references.map(StateAndRef::state) + ltx.outputs - companion object { - val logger = contextLogger() + private companion object { + private val logger = loggerFor() } /** @@ -39,9 +103,9 @@ abstract class Verifier(val ltx: LedgerTransaction, protected val transactionCla * * It is a critical piece of the security of the platform. * - * @throws TransactionVerificationException + * @throws net.corda.core.contracts.TransactionVerificationException */ - fun verify() { + fun validate() { // checkNoNotaryChange and checkEncumbrancesValid are called here, and not in the c'tor, as they need access to the "outputs" // list, the contents of which need to be deserialized under the correct classloader. checkNoNotaryChange() @@ -68,8 +132,7 @@ abstract class Verifier(val ltx: LedgerTransaction, protected val transactionCla // 4. Check that the [TransactionState] objects are correctly formed. validateStatesAgainstContract() - // 5. Final step is to run the contract code. After the first 4 steps we are now sure that we are running the correct code. - verifyContracts() + // 5. Final step will be to run the contract code. } private fun checkTransactionWithTimeWindowIsNotarised() { @@ -81,11 +144,12 @@ abstract class Verifier(val ltx: LedgerTransaction, protected val transactionCla * It makes sure there is one and only one. * This is an important piece of the security of transactions. */ + @Suppress("ThrowsCount") private fun getUniqueContractAttachmentsByContract(): Map { - val contractClasses = allStates.map { it.contract }.toSet() + val contractClasses = allStates.mapTo(LinkedHashSet(), TransactionState<*>::contract) // Check that there are no duplicate attachments added. - if (ltx.attachments.size != ltx.attachments.toSet().size) throw TransactionVerificationException.DuplicateAttachmentsRejection(ltx.id, ltx.attachments.groupBy { it }.filterValues { it.size > 1 }.keys.first()) + if (ltx.attachments.size != ltx.attachments.toSet().size) throw DuplicateAttachmentsRejection(ltx.id, ltx.attachments.groupBy { it }.filterValues { it.size > 1 }.keys.first()) // For each attachment this finds all the relevant state contracts that it provides. // And then maps them to the attachment. @@ -103,12 +167,12 @@ abstract class Verifier(val ltx: LedgerTransaction, protected val transactionCla .groupBy { it.first } // Group by contract. .filter { (_, attachments) -> attachments.size > 1 } // And only keep contracts that are in multiple attachments. It's guaranteed that attachments were unique by a previous check. .keys.firstOrNull() // keep the first one - if any - to throw a meaningful exception. - if (contractWithMultipleAttachments != null) throw TransactionVerificationException.ConflictingAttachmentsRejection(ltx.id, contractWithMultipleAttachments) + if (contractWithMultipleAttachments != null) throw ConflictingAttachmentsRejection(ltx.id, contractWithMultipleAttachments) val result = contractAttachmentsPerContract.toMap() // Check that there is an attachment for each contract. - if (result.keys != contractClasses) throw TransactionVerificationException.MissingAttachmentRejection(ltx.id, contractClasses.minus(result.keys).first()) + if (result.keys != contractClasses) throw MissingAttachmentRejection(ltx.id, contractClasses.minus(result.keys).first()) return result } @@ -124,7 +188,7 @@ abstract class Verifier(val ltx: LedgerTransaction, protected val transactionCla if (ltx.notary != null && (ltx.inputs.isNotEmpty() || ltx.references.isNotEmpty())) { ltx.outputs.forEach { if (it.notary != ltx.notary) { - throw TransactionVerificationException.NotaryChangeInWrongTransactionType(ltx.id, ltx.notary, it.notary) + throw NotaryChangeInWrongTransactionType(ltx.id, ltx.notary, it.notary) } } } @@ -156,10 +220,10 @@ abstract class Verifier(val ltx: LedgerTransaction, protected val transactionCla it.ref.txhash == ref.txhash && it.ref.index == state.encumbrance } if (!encumbranceStateExists) { - throw TransactionVerificationException.TransactionMissingEncumbranceException( + throw TransactionMissingEncumbranceException( ltx.id, state.encumbrance!!, - TransactionVerificationException.Direction.INPUT + Direction.INPUT ) } } @@ -185,6 +249,7 @@ abstract class Verifier(val ltx: LedgerTransaction, protected val transactionCla // b -> c and c -> b // c -> a b -> a // and form a full cycle, meaning that the bi-directionality property is satisfied. + @Suppress("ThrowsCount") private fun checkBidirectionalOutputEncumbrances(statesAndEncumbrance: List>) { // [Set] of "from" (encumbered states). val encumberedSet = mutableSetOf() @@ -194,15 +259,15 @@ abstract class Verifier(val ltx: LedgerTransaction, protected val transactionCla statesAndEncumbrance.forEach { (statePosition, encumbrance) -> // Check it does not refer to itself. if (statePosition == encumbrance || encumbrance >= ltx.outputs.size) { - throw TransactionVerificationException.TransactionMissingEncumbranceException( + throw TransactionMissingEncumbranceException( ltx.id, encumbrance, - TransactionVerificationException.Direction.OUTPUT + Direction.OUTPUT ) } else { encumberedSet.add(statePosition) // Guaranteed to have unique elements. if (!encumbranceSet.add(encumbrance)) { - throw TransactionVerificationException.TransactionDuplicateEncumbranceException(ltx.id, encumbrance) + throw TransactionDuplicateEncumbranceException(ltx.id, encumbrance) } } } @@ -211,7 +276,7 @@ abstract class Verifier(val ltx: LedgerTransaction, protected val transactionCla val symmetricDifference = (encumberedSet union encumbranceSet).subtract(encumberedSet intersect encumbranceSet) if (symmetricDifference.isNotEmpty()) { // At least one encumbered state is not in the [encumbranceSet] and vice versa. - throw TransactionVerificationException.TransactionNonMatchingEncumbranceException(ltx.id, symmetricDifference) + throw TransactionNonMatchingEncumbranceException(ltx.id, symmetricDifference) } } @@ -235,7 +300,7 @@ abstract class Verifier(val ltx: LedgerTransaction, protected val transactionCla if (indicesAlreadyChecked.add(index)) { val encumbranceIndex = ltx.outputs[index].encumbrance!! if (ltx.outputs[index].notary != ltx.outputs[encumbranceIndex].notary) { - throw TransactionVerificationException.TransactionNotaryMismatchEncumbranceException( + throw TransactionNotaryMismatchEncumbranceException( ltx.id, index, encumbranceIndex, @@ -263,7 +328,7 @@ abstract class Verifier(val ltx: LedgerTransaction, protected val transactionCla val shouldEnforce = StateContractValidationEnforcementRule.shouldEnforce(state.data) val requiredContractClassName = state.data.requiredContractClassName - ?: if (shouldEnforce) throw TransactionVerificationException.TransactionRequiredContractUnspecifiedException(ltx.id, state) else return + ?: if (shouldEnforce) throw TransactionRequiredContractUnspecifiedException(ltx.id, state) else return if (state.contract != requiredContractClassName) if (shouldEnforce) { @@ -281,6 +346,7 @@ abstract class Verifier(val ltx: LedgerTransaction, protected val transactionCla * - Constraints should be one of the valid supported ones. * - Constraints should propagate correctly if not marked otherwise (in that case it is the responsibility of the contract to ensure that the output states are created properly). */ + @Suppress("NestedBlockDepth") private fun verifyConstraintsValidity(contractAttachmentsByContract: Map) { // First check that the constraints are valid. @@ -310,7 +376,7 @@ abstract class Verifier(val ltx: LedgerTransaction, protected val transactionCla outputConstraints.forEach { outputConstraint -> inputConstraints.forEach { inputConstraint -> if (!(outputConstraint.canBeTransitionedFrom(inputConstraint, contractAttachment))) { - throw TransactionVerificationException.ConstraintPropagationRejection( + throw ConstraintPropagationRejection( ltx.id, contractClassName, inputConstraint, @@ -331,7 +397,7 @@ abstract class Verifier(val ltx: LedgerTransaction, protected val transactionCla @Suppress("NestedBlockDepth", "MagicNumber") private fun verifyConstraints(contractAttachmentsByContract: Map) { // For each contract/constraint pair check that the relevant attachment is valid. - allStates.map { it.contract to it.constraint }.toSet().forEach { (contract, constraint) -> + allStates.mapTo(LinkedHashSet()) { it.contract to it.constraint }.forEach { (contract, constraint) -> if (constraint is SignatureAttachmentConstraint) { /** * Support for signature constraints has been added on @@ -346,9 +412,9 @@ abstract class Verifier(val ltx: LedgerTransaction, protected val transactionCla "Signature constraints" ) val constraintKey = constraint.key - if (ltx.networkParameters?.minimumPlatformVersion ?: 1 >= PlatformVersionSwitches.LIMIT_KEYS_IN_SIGNATURE_CONSTRAINTS) { + if ((ltx.networkParameters?.minimumPlatformVersion ?: 1) >= PlatformVersionSwitches.LIMIT_KEYS_IN_SIGNATURE_CONSTRAINTS) { if (constraintKey is CompositeKey && constraintKey.leafKeys.size > MAX_NUMBER_OF_KEYS_IN_SIGNATURE_CONSTRAINT) { - throw TransactionVerificationException.InvalidConstraintRejection(ltx.id, contract, + throw InvalidConstraintRejection(ltx.id, contract, "Signature constraint contains composite key with ${constraintKey.leafKeys.size} leaf keys, " + "which is more than the maximum allowed number of keys " + "($MAX_NUMBER_OF_KEYS_IN_SIGNATURE_CONSTRAINT).") @@ -364,39 +430,18 @@ abstract class Verifier(val ltx: LedgerTransaction, protected val transactionCla if (HashAttachmentConstraint.disableHashConstraints && constraint is HashAttachmentConstraint) logger.warnOnce("Skipping hash constraints verification.") else if (!constraint.isSatisfiedBy(constraintAttachment)) - throw TransactionVerificationException.ContractConstraintRejection(ltx.id, contract) - } - } - - /** - * Placeholder function for the contract verification logic. - */ - abstract fun verifyContracts() -} - -class BasicVerifier(ltx: LedgerTransaction, transactionClassLoader: ClassLoader) : Verifier(ltx, transactionClassLoader) { - /** - * Check the transaction is contract-valid by running the verify() for each input and output state contract. - * If any contract fails to verify, the whole transaction is considered to be invalid. - * - * Note: Reference states are not verified. - */ - override fun verifyContracts() { - try { - ContractVerifier(transactionClassLoader).apply(ltx) - } catch (e: TransactionVerificationException.ContractRejection) { - logger.error("Error validating transaction ${ltx.id}.", e.cause) - throw e + throw ContractConstraintRejection(ltx.id, contract) } } } /** - * Verify all of the contracts on the given [LedgerTransaction]. + * Verify the given [LedgerTransaction]. This includes validating + * its contents, as well as executing all of its smart contracts. */ @Suppress("TooGenericExceptionCaught") @KeepForDJVM -class ContractVerifier(private val transactionClassLoader: ClassLoader) : Function { +class TransactionVerifier(private val transactionClassLoader: ClassLoader) : Function, Unit> { // This constructor is used inside the DJVM's sandbox. @Suppress("unused") constructor() : this(ClassLoader.getSystemClassLoader()) @@ -406,34 +451,62 @@ class ContractVerifier(private val transactionClassLoader: ClassLoader) : Functi return try { Class.forName(contractClassName, false, transactionClassLoader).asSubclass(Contract::class.java) } catch (e: Exception) { - throw TransactionVerificationException.ContractCreationError(id, contractClassName, e) + throw ContractCreationError(id, contractClassName, e) } } - override fun apply(ltx: LedgerTransaction) { - val contractClassNames = (ltx.inputs.map(StateAndRef::state) + ltx.outputs) + private fun generateContracts(ltx: LedgerTransaction): List { + return (ltx.inputs.map(StateAndRef::state) + ltx.outputs) .mapTo(LinkedHashSet(), TransactionState<*>::contract) - - contractClassNames.associateBy( - { it }, { createContractClass(ltx.id, it) } - ).map { (contractClassName, contractClass) -> - try { - /** - * This function must execute with the DJVM's sandbox, which does not - * permit user code to invoke [java.lang.Class.getDeclaredConstructor]. - * - * [Class.newInstance] is deprecated as of Java 9. - */ - @Suppress("deprecation") - contractClass.newInstance() - } catch (e: Exception) { - throw TransactionVerificationException.ContractCreationError(ltx.id, contractClassName, e) + .map { contractClassName -> + createContractClass(ltx.id, contractClassName) + }.map { contractClass -> + try { + /** + * This function must execute within the DJVM's sandbox, which does not + * permit user code to invoke [java.lang.reflect.Constructor.newInstance]. + * (This would be fixable now, provided the constructor is public.) + * + * [Class.newInstance] is deprecated as of Java 9. + */ + @Suppress("deprecation") + contractClass.newInstance() + } catch (e: Exception) { + throw ContractCreationError(ltx.id, contractClass.name, e) + } } + } + + private fun validateTransaction(ltx: LedgerTransaction) { + Validator(ltx, transactionClassLoader).validate() + } + + override fun apply(transactionFactory: Supplier) { + var firstLtx: LedgerTransaction? = null + + transactionFactory.get().let { ltx -> + firstLtx = ltx + + /** + * Check that this transaction is correctly formed. + * We only need to run these checks once. + */ + validateTransaction(ltx) + + /** + * Generate the list of unique contracts + * within this transaction. + */ + generateContracts(ltx) }.forEach { contract -> + val ltx = firstLtx ?: transactionFactory.get() + firstLtx = null try { + // Final step is to run the contract code. Having validated the + // transaction, we are now sure that we are running the correct code. contract.verify(ltx) } catch (e: Exception) { - throw TransactionVerificationException.ContractRejection(ltx.id, contract, e) + throw ContractRejection(ltx.id, contract, e) } } } diff --git a/core/src/main/kotlin/net/corda/core/internal/notary/NotaryService.kt b/core/src/main/kotlin/net/corda/core/internal/notary/NotaryService.kt index 216022523a..3f1842e25b 100644 --- a/core/src/main/kotlin/net/corda/core/internal/notary/NotaryService.kt +++ b/core/src/main/kotlin/net/corda/core/internal/notary/NotaryService.kt @@ -14,6 +14,14 @@ abstract class NotaryService : SingletonSerializeAsToken() { abstract val services: ServiceHub abstract val notaryIdentityKey: PublicKey + /** + * Mapping between @InitiatingFlow classes and factory methods that produce responder flows. + * Can be overridden in case of advanced notary service that serves both custom and standard flows. + */ + open val initiatingFlows = mapOf( + NotaryFlow.Client::class to ::createServiceFlow + ) + /** * Interfaces for the request and result formats of queries supported by notary services. To * implement a new query, you must: diff --git a/core/src/main/kotlin/net/corda/core/internal/notary/NotaryServiceFlow.kt b/core/src/main/kotlin/net/corda/core/internal/notary/NotaryServiceFlow.kt index 249c637ce5..3aa7d9cfe7 100644 --- a/core/src/main/kotlin/net/corda/core/internal/notary/NotaryServiceFlow.kt +++ b/core/src/main/kotlin/net/corda/core/internal/notary/NotaryServiceFlow.kt @@ -28,7 +28,11 @@ import java.time.Duration * @param etaThreshold If the ETA for processing the request, according to the service, is greater than this, notify the client. */ // See AbstractStateReplacementFlow.Acceptor for why it's Void? -abstract class NotaryServiceFlow(val otherSideSession: FlowSession, val service: SinglePartyNotaryService, private val etaThreshold: Duration) : FlowLogic() { +abstract class NotaryServiceFlow( + val otherSideSession: FlowSession, + val service: SinglePartyNotaryService, + private val etaThreshold: Duration +) : FlowLogic() { companion object { // TODO: Determine an appropriate limit and also enforce in the network parameters and the transaction builder. private const val maxAllowedInputsAndReferences = 10_000 diff --git a/core/src/main/kotlin/net/corda/core/internal/utilities/ZipBombDetector.kt b/core/src/main/kotlin/net/corda/core/internal/utilities/ZipBombDetector.kt new file mode 100644 index 0000000000..ef98fa64b0 --- /dev/null +++ b/core/src/main/kotlin/net/corda/core/internal/utilities/ZipBombDetector.kt @@ -0,0 +1,63 @@ +package net.corda.core.internal.utilities + +import java.io.FilterInputStream +import java.io.InputStream +import java.util.zip.ZipInputStream + +object ZipBombDetector { + + private class CounterInputStream(source : InputStream) : FilterInputStream(source) { + private var byteCount : Long = 0 + + val count : Long + get() = byteCount + + override fun read(): Int { + return super.read().also { byte -> + if(byte >= 0) byteCount += 1 + } + } + + override fun read(b: ByteArray): Int { + return super.read(b).also { bytesRead -> + if(bytesRead > 0) byteCount += bytesRead + } + } + + override fun read(b: ByteArray, off: Int, len: Int): Int { + return super.read(b, off, len).also { bytesRead -> + if(bytesRead > 0) byteCount += bytesRead + } + } + } + + /** + * Check if a zip file is a potential malicious zip bomb + * @param source the zip archive file content + * @param maxUncompressedSize the maximum allowable uncompressed archive size + * @param maxCompressionRatio the maximum allowable compression ratio + * @return true if the zip file total uncompressed size exceeds [maxUncompressedSize] and the + * average entry compression ratio is larger than [maxCompressionRatio], false otherwise + */ + @Suppress("NestedBlockDepth") + fun scanZip(source : InputStream, maxUncompressedSize : Long, maxCompressionRatio : Float = 10.0f) : Boolean { + val counterInputStream = CounterInputStream(source) + var uncompressedByteCount : Long = 0 + val buffer = ByteArray(DEFAULT_BUFFER_SIZE) + ZipInputStream(counterInputStream).use { zipInputStream -> + while(true) { + zipInputStream.nextEntry ?: break + while(true) { + val read = zipInputStream.read(buffer) + if(read <= 0) break + uncompressedByteCount += read + if(uncompressedByteCount > maxUncompressedSize && + uncompressedByteCount.toFloat() / counterInputStream.count.toFloat() > maxCompressionRatio) { + return true + } + } + } + } + return false + } +} \ No newline at end of file diff --git a/core/src/main/kotlin/net/corda/core/node/NetworkParameters.kt b/core/src/main/kotlin/net/corda/core/node/NetworkParameters.kt index 4c85a4dc07..50fe5eb258 100644 --- a/core/src/main/kotlin/net/corda/core/node/NetworkParameters.kt +++ b/core/src/main/kotlin/net/corda/core/node/NetworkParameters.kt @@ -13,6 +13,8 @@ import net.corda.core.utilities.days import java.security.PublicKey import java.time.Duration import java.time.Instant +import java.util.Collections.unmodifiableList +import java.util.Collections.unmodifiableMap // DOCSTART 1 /** @@ -166,6 +168,38 @@ data class NetworkParameters( epoch=$epoch }""" } + + fun toImmutable(): NetworkParameters { + return NetworkParameters( + minimumPlatformVersion = minimumPlatformVersion, + notaries = unmodifiable(notaries), + maxMessageSize = maxMessageSize, + maxTransactionSize = maxTransactionSize, + modifiedTime = modifiedTime, + epoch = epoch, + whitelistedContractImplementations = unmodifiable(whitelistedContractImplementations) { entry -> + unmodifiableList(entry.value) + }, + eventHorizon = eventHorizon, + packageOwnership = unmodifiable(packageOwnership) + ) + } +} + +private fun unmodifiable(list: List): List { + return if (list.isEmpty()) { + emptyList() + } else { + unmodifiableList(list) + } +} + +private inline fun unmodifiable(map: Map, transform: (Map.Entry) -> V = Map.Entry::value): Map { + return if (map.isEmpty()) { + emptyMap() + } else { + unmodifiableMap(map.mapValues(transform)) + } } /** diff --git a/core/src/main/kotlin/net/corda/core/node/ServiceHub.kt b/core/src/main/kotlin/net/corda/core/node/ServiceHub.kt index d63b63edf4..612e341a6f 100644 --- a/core/src/main/kotlin/net/corda/core/node/ServiceHub.kt +++ b/core/src/main/kotlin/net/corda/core/node/ServiceHub.kt @@ -64,7 +64,7 @@ interface ServicesForResolution { /** * Given a [Set] of [StateRef]'s loads the referenced transaction and looks up the specified output [ContractState]. * - * @throws TransactionResolutionException if [stateRef] points to a non-existent transaction. + * @throws TransactionResolutionException if any of the [stateRefs] point to a non-existent transaction. */ // TODO: future implementation to use a Vault state ref -> contract state BLOB table and perform single query bulk load // as the existing transaction store will become encrypted at some point diff --git a/core/src/main/kotlin/net/corda/core/node/services/VaultService.kt b/core/src/main/kotlin/net/corda/core/node/services/VaultService.kt index 3b6db58dcd..51d61cc214 100644 --- a/core/src/main/kotlin/net/corda/core/node/services/VaultService.kt +++ b/core/src/main/kotlin/net/corda/core/node/services/VaultService.kt @@ -1,3 +1,5 @@ +@file:Suppress("LongParameterList") + package net.corda.core.node.services import co.paralleluniverse.fibers.Suspendable @@ -197,8 +199,7 @@ class Vault(val states: Iterable>) { * 4) Status types used in this query: [StateStatus.UNCONSUMED], [StateStatus.CONSUMED], [StateStatus.ALL]. * 5) Other results as a [List] of any type (eg. aggregate function results with/without group by). * - * Note: currently otherResults are used only for Aggregate Functions (in which case, the states and statesMetadata - * results will be empty). + * Note: currently [otherResults] is used only for aggregate functions (in which case, [states] and [statesMetadata] will be empty). */ @CordaSerializable data class Page(val states: List>, @@ -213,11 +214,11 @@ class Vault(val states: Iterable>) { val contractStateClassName: String, val recordedTime: Instant, val consumedTime: Instant?, - val status: Vault.StateStatus, + val status: StateStatus, val notary: AbstractParty?, val lockId: String?, val lockUpdateTime: Instant?, - val relevancyStatus: Vault.RelevancyStatus? = null, + val relevancyStatus: RelevancyStatus? = null, val constraintInfo: ConstraintInfo? = null ) { fun copy( @@ -225,7 +226,7 @@ class Vault(val states: Iterable>) { contractStateClassName: String = this.contractStateClassName, recordedTime: Instant = this.recordedTime, consumedTime: Instant? = this.consumedTime, - status: Vault.StateStatus = this.status, + status: StateStatus = this.status, notary: AbstractParty? = this.notary, lockId: String? = this.lockId, lockUpdateTime: Instant? = this.lockUpdateTime @@ -237,11 +238,11 @@ class Vault(val states: Iterable>) { contractStateClassName: String = this.contractStateClassName, recordedTime: Instant = this.recordedTime, consumedTime: Instant? = this.consumedTime, - status: Vault.StateStatus = this.status, + status: StateStatus = this.status, notary: AbstractParty? = this.notary, lockId: String? = this.lockId, lockUpdateTime: Instant? = this.lockUpdateTime, - relevancyStatus: Vault.RelevancyStatus? + relevancyStatus: RelevancyStatus? ): StateMetadata { return StateMetadata(ref, contractStateClassName, recordedTime, consumedTime, status, notary, lockId, lockUpdateTime, relevancyStatus, ConstraintInfo(AlwaysAcceptAttachmentConstraint)) } @@ -249,9 +250,9 @@ class Vault(val states: Iterable>) { companion object { @Deprecated("No longer used. The vault does not emit empty updates") - val NoUpdate = Update(emptySet(), emptySet(), type = Vault.UpdateType.GENERAL, references = emptySet()) + val NoUpdate = Update(emptySet(), emptySet(), type = UpdateType.GENERAL, references = emptySet()) @Deprecated("No longer used. The vault does not emit empty updates") - val NoNotaryUpdate = Vault.Update(emptySet(), emptySet(), type = Vault.UpdateType.NOTARY_CHANGE, references = emptySet()) + val NoNotaryUpdate = Update(emptySet(), emptySet(), type = UpdateType.NOTARY_CHANGE, references = emptySet()) } } @@ -302,7 +303,7 @@ interface VaultService { fun whenConsumed(ref: StateRef): CordaFuture> { val query = QueryCriteria.VaultQueryCriteria( stateRefs = listOf(ref), - status = Vault.StateStatus.CONSUMED + status = StateStatus.CONSUMED ) val result = trackBy(query) val snapshot = result.snapshot.states @@ -358,8 +359,8 @@ interface VaultService { /** * Helper function to determine spendable states and soft locking them. * Currently performance will be worse than for the hand optimised version in - * [Cash.unconsumedCashStatesForSpending]. However, this is fully generic and can operate with custom [FungibleState] - * and [FungibleAsset] states. + * [net.corda.finance.workflows.asset.selection.AbstractCashSelection.unconsumedCashStatesForSpending]. However, this is fully generic + * and can operate with custom [FungibleState] and [FungibleAsset] states. * @param lockId The [FlowLogic.runId]'s [UUID] of the current flow used to soft lock the states. * @param eligibleStatesQuery A custom query object that selects down to the appropriate subset of all states of the * [contractStateType]. e.g. by selecting on account, issuer, etc. The query is internally augmented with the diff --git a/core/src/main/kotlin/net/corda/core/serialization/CustomSerializationScheme.kt b/core/src/main/kotlin/net/corda/core/serialization/CustomSerializationScheme.kt new file mode 100644 index 0000000000..599d250e67 --- /dev/null +++ b/core/src/main/kotlin/net/corda/core/serialization/CustomSerializationScheme.kt @@ -0,0 +1,35 @@ +package net.corda.core.serialization + +import net.corda.core.utilities.ByteSequence +import java.io.NotSerializableException + +/*** + * Implement this interface to add your own Serialization Scheme. This is an experimental feature. All methods in this class MUST be + * thread safe i.e. methods from the same instance of this class can be called in different threads simultaneously. + */ +interface CustomSerializationScheme { + /** + * This method must return an id used to uniquely identify the Scheme. This should be unique within a network as serialized data might + * be sent over the wire. + */ + fun getSchemeId(): Int + + /** + * This method must deserialize the data stored [bytes] into an instance of [T]. + * + * @param bytes the serialized data. + * @param clazz the class to instantiate. + * @param context used to pass information about how the object should be deserialized. + */ + @Throws(NotSerializableException::class) + fun deserialize(bytes: ByteSequence, clazz: Class, context: SerializationSchemeContext): T + + /** + * This method must be able to serialize any object [T] into a ByteSequence. + * + * @param obj the object to be serialized. + * @param context used to pass information about how the object should be serialized. + */ + @Throws(NotSerializableException::class) + fun serialize(obj: T, context: SerializationSchemeContext): ByteSequence +} \ No newline at end of file diff --git a/core/src/main/kotlin/net/corda/core/serialization/SerializationAPI.kt b/core/src/main/kotlin/net/corda/core/serialization/SerializationAPI.kt index c97a511db2..77289d8c8e 100644 --- a/core/src/main/kotlin/net/corda/core/serialization/SerializationAPI.kt +++ b/core/src/main/kotlin/net/corda/core/serialization/SerializationAPI.kt @@ -13,6 +13,10 @@ import net.corda.core.utilities.sequence import java.io.NotSerializableException import java.sql.Blob +const val DESERIALIZATION_CACHE_PROPERTY = "DESERIALIZATION_CACHE" +const val AMQP_ENVELOPE_CACHE_PROPERTY = "AMQP_ENVELOPE_CACHE" +const val AMQP_ENVELOPE_CACHE_INITIAL_CAPACITY = 256 + data class ObjectWithCompatibleContext(val obj: T, val context: SerializationContext) /** @@ -65,12 +69,16 @@ abstract class SerializationFactory { * Change the current context inside the block to that supplied. */ fun withCurrentContext(context: SerializationContext?, block: () -> T): T { - val priorContext = _currentContext.get() - if (context != null) _currentContext.set(context) - try { - return block() - } finally { - if (context != null) _currentContext.set(priorContext) + return if (context == null) { + block() + } else { + val priorContext = _currentContext.get() + _currentContext.set(context) + try { + block() + } finally { + _currentContext.set(priorContext) + } } } @@ -134,7 +142,7 @@ interface SerializationContext { */ val encodingWhitelist: EncodingWhitelist /** - * A map of any addition properties specific to the particular use case. + * A map of any additional properties specific to the particular use case. */ val properties: Map /** @@ -178,6 +186,11 @@ interface SerializationContext { */ fun withProperty(property: Any, value: Any): SerializationContext + /** + * Helper method to return a new context based on this context with the extra properties added. + */ + fun withProperties(extraProperties: Map): SerializationContext + /** * Helper method to return a new context based on this context with object references disabled. */ diff --git a/core/src/main/kotlin/net/corda/core/serialization/SerializationSchemeContext.kt b/core/src/main/kotlin/net/corda/core/serialization/SerializationSchemeContext.kt new file mode 100644 index 0000000000..eb1709390a --- /dev/null +++ b/core/src/main/kotlin/net/corda/core/serialization/SerializationSchemeContext.kt @@ -0,0 +1,30 @@ +package net.corda.core.serialization + +import net.corda.core.DoNotImplement + +/** + * This is used to pass information into [CustomSerializationScheme] about how the object should be (de)serialized. + * This context can change depending on the specific circumstances in the node when (de)serialization occurs. + */ +@DoNotImplement +interface SerializationSchemeContext { + /** + * The class loader to use for deserialization. This is guaranteed to be able to load all the required classes passed into + * [CustomSerializationScheme.deserialize]. + */ + val deserializationClassLoader: ClassLoader + /** + * A whitelist that contains (mostly for security purposes) which classes are authorised to be deserialized. + * A secure implementation will not instantiate any object which is not either whitelisted or annotated with [CordaSerializable] when + * deserializing. To catch classes missing from the whitelist as early as possible it is HIGHLY recommended to also check this + * whitelist when serializing (as well as deserializing) objects. + */ + val whitelist: ClassWhitelist + /** + * A map of any additional properties specific to the particular use case. If these properties are set via + * [toWireTransaction][net.corda.core.transactions.TransactionBuilder.toWireTransaction] then they might not be available when + * deserializing. If the properties are required when deserializing, they can be added into the blob when serializing and read back + * when deserializing. + */ + val properties: Map +} \ No newline at end of file diff --git a/core/src/main/kotlin/net/corda/core/serialization/internal/AttachmentsClassLoader.kt b/core/src/main/kotlin/net/corda/core/serialization/internal/AttachmentsClassLoader.kt index e93be2de5d..83de5fe059 100644 --- a/core/src/main/kotlin/net/corda/core/serialization/internal/AttachmentsClassLoader.kt +++ b/core/src/main/kotlin/net/corda/core/serialization/internal/AttachmentsClassLoader.kt @@ -9,21 +9,46 @@ import net.corda.core.contracts.TransactionVerificationException import net.corda.core.contracts.TransactionVerificationException.OverlappingAttachmentsException import net.corda.core.contracts.TransactionVerificationException.PackageOwnershipException import net.corda.core.crypto.SecureHash -import net.corda.core.crypto.sha256 -import net.corda.core.internal.* +import net.corda.core.internal.JDK1_2_CLASS_FILE_FORMAT_MAJOR_VERSION +import net.corda.core.internal.JDK8_CLASS_FILE_FORMAT_MAJOR_VERSION +import net.corda.core.internal.JarSignatureCollector +import net.corda.core.internal.NamedCacheFactory +import net.corda.core.internal.PlatformVersionSwitches +import net.corda.core.internal.VisibleForTesting import net.corda.core.internal.cordapp.targetPlatformVersion +import net.corda.core.internal.createInstancesOfClassesImplementing +import net.corda.core.internal.createSimpleCache +import net.corda.core.internal.toSynchronised import net.corda.core.node.NetworkParameters -import net.corda.core.serialization.* +import net.corda.core.serialization.AMQP_ENVELOPE_CACHE_INITIAL_CAPACITY +import net.corda.core.serialization.AMQP_ENVELOPE_CACHE_PROPERTY +import net.corda.core.serialization.DESERIALIZATION_CACHE_PROPERTY +import net.corda.core.serialization.SerializationContext +import net.corda.core.serialization.SerializationCustomSerializer +import net.corda.core.serialization.SerializationFactory +import net.corda.core.serialization.SerializationWhitelist +import net.corda.core.serialization.SingletonSerializeAsToken import net.corda.core.serialization.internal.AttachmentURLStreamHandlerFactory.toUrl +import net.corda.core.serialization.withWhitelist import net.corda.core.utilities.contextLogger import net.corda.core.utilities.debug -import java.io.ByteArrayOutputStream +import net.corda.core.utilities.loggerFor import java.io.IOException import java.io.InputStream +import java.lang.ref.ReferenceQueue import java.lang.ref.WeakReference -import java.net.* +import java.net.URL +import java.net.URLClassLoader +import java.net.URLConnection +import java.net.URLStreamHandler +import java.net.URLStreamHandlerFactory +import java.security.MessageDigest import java.security.Permission -import java.util.* +import java.util.Locale +import java.util.ServiceLoader +import java.util.WeakHashMap +import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.atomic.AtomicLong import java.util.function.Function /** @@ -51,12 +76,15 @@ class AttachmentsClassLoader(attachments: List, init { // Apply our own URLStreamHandlerFactory to resolve attachments setOrDecorateURLStreamHandlerFactory() + + // Allow AttachmentsClassLoader to be used concurrently. + registerAsParallelCapable() } // Jolokia and Json-simple are dependencies that were bundled by mistake within contract jars. // In the AttachmentsClassLoader we just block any class in those 2 packages. private val ignoreDirectories = listOf("org/jolokia/", "org/json/simple/") - private val ignorePackages = ignoreDirectories.map { it.replace("/", ".") } + private val ignorePackages = ignoreDirectories.map { it.replace('/', '.') } /** * Apply our custom factory either directly, if `URL.setURLStreamHandlerFactory` has not been called yet, @@ -128,6 +156,20 @@ class AttachmentsClassLoader(attachments: List, checkAttachments(attachments) } + private class AttachmentHashContext( + val txId: SecureHash, + val buffer: ByteArray = ByteArray(DEFAULT_BUFFER_SIZE)) + + private fun hash(inputStream : InputStream, ctx : AttachmentHashContext) : SecureHash.SHA256 { + val md = MessageDigest.getInstance(SecureHash.SHA2_256) + while(true) { + val read = inputStream.read(ctx.buffer) + if(read <= 0) break + md.update(ctx.buffer, 0, read) + } + return SecureHash.SHA256(md.digest()) + } + private fun isZipOrJar(attachment: Attachment) = attachment.openAsJAR().use { jar -> jar.nextEntry != null } @@ -146,10 +188,10 @@ class AttachmentsClassLoader(attachments: List, // TODO - investigate potential exploits. private fun shouldCheckForNoOverlap(path: String, targetPlatformVersion: Int): Boolean { require(path.toLowerCase() == path) - require(!path.contains("\\")) + require(!path.contains('\\')) return when { - path.endsWith("/") -> false // Directories (packages) can overlap. + path.endsWith('/') -> false // Directories (packages) can overlap. targetPlatformVersion < PlatformVersionSwitches.IGNORE_JOLOKIA_JSON_SIMPLE_IN_CORDAPPS && ignoreDirectories.any { path.startsWith(it) } -> false // Ignore jolokia and json-simple for old cordapps. path.endsWith(".class") -> true // All class files need to be unique. @@ -160,6 +202,7 @@ class AttachmentsClassLoader(attachments: List, } } + @Suppress("ThrowsCount", "ComplexMethod", "NestedBlockDepth") private fun checkAttachments(attachments: List) { require(attachments.isNotEmpty()) { "attachments list is empty" } @@ -188,7 +231,8 @@ class AttachmentsClassLoader(attachments: List, // attacks on externally connected systems that only consider type names, we allow people to formally // claim their parts of the Java package namespace via registration with the zone operator. - val classLoaderEntries = mutableMapOf() + val classLoaderEntries = mutableMapOf() + val ctx = AttachmentHashContext(sampleTxId) for (attachment in attachments) { // We may have been given an attachment loaded from the database in which case, important info like // signers is already calculated. @@ -206,10 +250,12 @@ class AttachmentsClassLoader(attachments: List, // signed by the owners of the packages, even if it's not. We'd eventually discover that fact // when trying to read the class file to use it, but if we'd made any decisions based on // perceived correctness of the signatures or package ownership already, that would be too late. - attachment.openAsJAR().use { JarSignatureCollector.collectSigners(it) } + attachment.openAsJAR().use(JarSignatureCollector::collectSigners) } + // Now open it again to compute the overlap and package ownership data. attachment.openAsJAR().use { jar -> + val targetPlatformVersion = jar.manifest?.targetPlatformVersion ?: 1 while (true) { val entry = jar.nextJarEntry ?: break @@ -250,13 +296,9 @@ class AttachmentsClassLoader(attachments: List, if (!shouldCheckForNoOverlap(path, targetPlatformVersion)) continue // This calculates the hash of the current entry because the JarInputStream returns only the current entry. - fun entryHash() = ByteArrayOutputStream().use { - jar.copyTo(it) - it.toByteArray() - }.sha256() + val currentHash = hash(jar, ctx) // If 2 entries are identical, it means the same file is present in both attachments, so that is ok. - val currentHash = entryHash() val previousFileHash = classLoaderEntries[path] when { previousFileHash == null -> { @@ -279,11 +321,11 @@ class AttachmentsClassLoader(attachments: List, * Required to prevent classes that were excluded from the no-overlap check from being loaded by contract code. * As it can lead to non-determinism. */ - override fun loadClass(name: String?): Class<*> { - if (ignorePackages.any { name!!.startsWith(it) }) { + override fun loadClass(name: String, resolve: Boolean): Class<*>? { + if (ignorePackages.any { name.startsWith(it) }) { throw ClassNotFoundException(name) } - return super.loadClass(name) + return super.loadClass(name, resolve) } } @@ -293,7 +335,8 @@ class AttachmentsClassLoader(attachments: List, */ @VisibleForTesting object AttachmentsClassLoaderBuilder { - const val CACHE_SIZE = 16 + private const val CACHE_SIZE = 16 + private const val STRONG_REFERENCE_TO_CACHED_SERIALIZATION_CONTEXT = "cachedSerializationContext" private val fallBackCache: AttachmentsClassLoaderCache = AttachmentsClassLoaderSimpleCacheImpl(CACHE_SIZE) @@ -309,20 +352,19 @@ object AttachmentsClassLoaderBuilder { isAttachmentTrusted: (Attachment) -> Boolean, parent: ClassLoader = ClassLoader.getSystemClassLoader(), attachmentsClassLoaderCache: AttachmentsClassLoaderCache?, - block: (ClassLoader) -> T): T { - val attachmentIds = attachments.map(Attachment::id).toSet() + block: (SerializationContext) -> T): T { + val attachmentIds = attachments.mapTo(LinkedHashSet(), Attachment::id) val cache = attachmentsClassLoaderCache ?: fallBackCache - val serializationContext = cache.computeIfAbsent(AttachmentsClassLoaderKey(attachmentIds, params), Function { + val cachedSerializationContext = cache.computeIfAbsent(AttachmentsClassLoaderKey(attachmentIds, params), Function { key -> // Create classloader and load serializers, whitelisted classes - val transactionClassLoader = AttachmentsClassLoader(attachments, params, txId, isAttachmentTrusted, parent) + val transactionClassLoader = AttachmentsClassLoader(attachments, key.params, txId, isAttachmentTrusted, parent) val serializers = try { createInstancesOfClassesImplementing(transactionClassLoader, SerializationCustomSerializer::class.java, JDK1_2_CLASS_FILE_FORMAT_MAJOR_VERSION..JDK8_CLASS_FILE_FORMAT_MAJOR_VERSION) - } - catch(ex: UnsupportedClassVersionError) { - throw TransactionVerificationException.UnsupportedClassVersionError(txId, ex.message!!, ex) - } + } catch (ex: UnsupportedClassVersionError) { + throw TransactionVerificationException.UnsupportedClassVersionError(txId, ex.message!!, ex) + } val whitelistedClasses = ServiceLoader.load(SerializationWhitelist::class.java, transactionClassLoader) .flatMap(SerializationWhitelist::whitelist) @@ -338,9 +380,20 @@ object AttachmentsClassLoaderBuilder { .withoutCarpenter() }) + val serializationContext = cachedSerializationContext.withProperties(mapOf( + // Duplicate the SerializationContext from the cache and give + // it these extra properties, just for this transaction. + // However, keep a strong reference to the cached SerializationContext so we can + // leverage the power of WeakReferences in the AttachmentsClassLoaderCacheImpl to figure + // out when all these have gone out of scope by the BasicVerifier going out of scope. + AMQP_ENVELOPE_CACHE_PROPERTY to HashMap(AMQP_ENVELOPE_CACHE_INITIAL_CAPACITY), + DESERIALIZATION_CACHE_PROPERTY to HashMap(), + STRONG_REFERENCE_TO_CACHED_SERIALIZATION_CONTEXT to cachedSerializationContext + )) + // Deserialize all relevant classes in the transaction classloader. return SerializationFactory.defaultFactory.withCurrentContext(serializationContext) { - block(serializationContext.deserializationClassLoader) + block(serializationContext) } } } @@ -352,6 +405,8 @@ object AttachmentsClassLoaderBuilder { object AttachmentURLStreamHandlerFactory : URLStreamHandlerFactory { internal const val attachmentScheme = "attachment" + private val uniqueness = AtomicLong(0) + private val loadedAttachments: AttachmentsHolder = AttachmentsHolderImpl() override fun createURLStreamHandler(protocol: String): URLStreamHandler? { @@ -362,14 +417,9 @@ object AttachmentURLStreamHandlerFactory : URLStreamHandlerFactory { @Synchronized fun toUrl(attachment: Attachment): URL { - val proposedURL = URL(attachmentScheme, "", -1, attachment.id.toString(), AttachmentURLStreamHandler) - val existingURL = loadedAttachments.getKey(proposedURL) - return if (existingURL == null) { - loadedAttachments[proposedURL] = attachment - proposedURL - } else { - existingURL - } + val uniqueURL = URL(attachmentScheme, "", -1, attachment.id.toString()+ "?" + uniqueness.getAndIncrement(), AttachmentURLStreamHandler) + loadedAttachments[uniqueURL] = attachment + return uniqueURL } @VisibleForTesting @@ -427,9 +477,52 @@ interface AttachmentsClassLoaderCache { @DeleteForDJVM class AttachmentsClassLoaderCacheImpl(cacheFactory: NamedCacheFactory) : SingletonSerializeAsToken(), AttachmentsClassLoaderCache { - private val cache: Cache = cacheFactory.buildNamed(Caffeine.newBuilder(), "AttachmentsClassLoader_cache") + private class ToBeClosed( + serializationContext: SerializationContext, + val classLoaderToClose: AutoCloseable, + val cacheKey: AttachmentsClassLoaderKey, + queue: ReferenceQueue + ) : WeakReference(serializationContext, queue) + + private val logger = loggerFor() + private val toBeClosed = ConcurrentHashMap.newKeySet() + private val expiryQueue = ReferenceQueue() + + @Suppress("TooGenericExceptionCaught") + private fun purgeExpiryQueue() { + // Close the AttachmentsClassLoader for every SerializationContext + // that has already been garbage-collected. + while (true) { + val head = expiryQueue.poll() as? ToBeClosed ?: break + if (!toBeClosed.remove(head)) { + logger.warn("Reaped unexpected serialization context for {}", head.cacheKey) + } + + try { + head.classLoaderToClose.close() + } catch (e: Exception) { + logger.warn("Error destroying serialization context for ${head.cacheKey}", e) + } + } + } + + private val cache: Cache = cacheFactory.buildNamed( + // Schedule for closing the deserialization classloaders when we evict them + // to release any resources they may be holding. + Caffeine.newBuilder().removalListener { key, context, _ -> + (context?.deserializationClassLoader as? AutoCloseable)?.also { autoCloseable -> + // ClassLoader to be closed once the BasicVerifier, which has a strong + // reference chain to this SerializationContext, has gone out of scope. + toBeClosed += ToBeClosed(context, autoCloseable, key!!, expiryQueue) + } + + // Reap any entries which have been garbage-collected. + purgeExpiryQueue() + }, "AttachmentsClassLoader_cache" + ) override fun computeIfAbsent(key: AttachmentsClassLoaderKey, mappingFunction: Function): SerializationContext { + purgeExpiryQueue() return cache.get(key, mappingFunction) ?: throw NullPointerException("null returned from cache mapping function") } } @@ -465,4 +558,4 @@ private class AttachmentURLConnection(url: URL, private val attachment: Attachme override fun connect() { connected = true } -} \ No newline at end of file +} diff --git a/core/src/main/kotlin/net/corda/core/serialization/internal/CustomSerializationSchemeUtils.kt b/core/src/main/kotlin/net/corda/core/serialization/internal/CustomSerializationSchemeUtils.kt new file mode 100644 index 0000000000..b0588755aa --- /dev/null +++ b/core/src/main/kotlin/net/corda/core/serialization/internal/CustomSerializationSchemeUtils.kt @@ -0,0 +1,28 @@ +package net.corda.core.serialization.internal + +import net.corda.core.KeepForDJVM +import net.corda.core.serialization.SerializationMagic +import net.corda.core.utilities.ByteSequence +import java.nio.ByteBuffer + +class CustomSerializationSchemeUtils { + + @KeepForDJVM + companion object { + + private const val SERIALIZATION_SCHEME_ID_SIZE = 4 + private val PREFIX = "CUS".toByteArray() + + fun getCustomSerializationMagicFromSchemeId(schemeId: Int) : SerializationMagic { + return SerializationMagic.of(PREFIX + ByteBuffer.allocate(SERIALIZATION_SCHEME_ID_SIZE).putInt(schemeId).array()) + } + + fun getSchemeIdIfCustomSerializationMagic(magic: SerializationMagic): Int? { + return if (magic.take(PREFIX.size) != ByteSequence.of(PREFIX)) { + null + } else { + return magic.slice(start = PREFIX.size).int + } + } + } +} \ No newline at end of file diff --git a/core/src/main/kotlin/net/corda/core/transactions/ContractUpgradeTransactions.kt b/core/src/main/kotlin/net/corda/core/transactions/ContractUpgradeTransactions.kt index f1f53f90b8..911c67d0da 100644 --- a/core/src/main/kotlin/net/corda/core/transactions/ContractUpgradeTransactions.kt +++ b/core/src/main/kotlin/net/corda/core/transactions/ContractUpgradeTransactions.kt @@ -145,7 +145,7 @@ data class ContractUpgradeWireTransaction( private fun upgradedContract(className: ContractClassName, classLoader: ClassLoader): UpgradedContract = try { @Suppress("UNCHECKED_CAST") - classLoader.loadClass(className).asSubclass(UpgradedContract::class.java).getDeclaredConstructor().newInstance() as UpgradedContract + Class.forName(className, false, classLoader).asSubclass(UpgradedContract::class.java).getDeclaredConstructor().newInstance() as UpgradedContract } catch (e: Exception) { throw TransactionVerificationException.ContractCreationError(id, className, e) } @@ -166,9 +166,9 @@ data class ContractUpgradeWireTransaction( params, id, { (services as ServiceHubCoreInternal).attachmentTrustCalculator.calculate(it) }, - attachmentsClassLoaderCache = (services as ServiceHubCoreInternal).attachmentsClassLoaderCache) { transactionClassLoader -> + attachmentsClassLoaderCache = (services as ServiceHubCoreInternal).attachmentsClassLoaderCache) { serializationContext -> val resolvedInput = binaryInput.deserialize() - val upgradedContract = upgradedContract(upgradedContractClassName, transactionClassLoader) + val upgradedContract = upgradedContract(upgradedContractClassName, serializationContext.deserializationClassLoader) val outputState = calculateUpgradedState(resolvedInput, upgradedContract, upgradedAttachment) outputState.serialize() } @@ -311,8 +311,7 @@ private constructor( @CordaInternal internal fun loadUpgradedContract(upgradedContractClassName: ContractClassName, classLoader: ClassLoader): UpgradedContract { @Suppress("UNCHECKED_CAST") - return classLoader - .loadClass(upgradedContractClassName) + return Class.forName(upgradedContractClassName, false, classLoader) .asSubclass(Contract::class.java) .getConstructor() .newInstance() as UpgradedContract diff --git a/core/src/main/kotlin/net/corda/core/transactions/LedgerTransaction.kt b/core/src/main/kotlin/net/corda/core/transactions/LedgerTransaction.kt index 717b9b5937..25dfa7f293 100644 --- a/core/src/main/kotlin/net/corda/core/transactions/LedgerTransaction.kt +++ b/core/src/main/kotlin/net/corda/core/transactions/LedgerTransaction.kt @@ -18,21 +18,25 @@ import net.corda.core.crypto.DigestService import net.corda.core.crypto.SecureHash import net.corda.core.flows.FlowLogic import net.corda.core.identity.Party -import net.corda.core.internal.BasicVerifier +import net.corda.core.internal.AbstractVerifier import net.corda.core.internal.SerializedStateAndRef import net.corda.core.internal.Verifier import net.corda.core.internal.castIfPossible import net.corda.core.internal.deserialiseCommands import net.corda.core.internal.deserialiseComponentGroup +import net.corda.core.internal.eagerDeserialise import net.corda.core.internal.isUploaderTrusted import net.corda.core.internal.uncheckedCast import net.corda.core.node.NetworkParameters import net.corda.core.serialization.DeprecatedConstructorForDeserialization +import net.corda.core.serialization.SerializationContext +import net.corda.core.serialization.SerializationFactory import net.corda.core.serialization.internal.AttachmentsClassLoaderCache import net.corda.core.serialization.internal.AttachmentsClassLoaderBuilder import net.corda.core.utilities.contextLogger import java.util.Collections.unmodifiableList import java.util.function.Predicate +import java.util.function.Supplier /** * A LedgerTransaction is derived from a [WireTransaction]. It is the result of doing the following operations: @@ -90,7 +94,7 @@ private constructor( private val serializedInputs: List?, private val serializedReferences: List?, private val isAttachmentTrusted: (Attachment) -> Boolean, - private val verifierFactory: (LedgerTransaction, ClassLoader) -> Verifier, + private val verifierFactory: (LedgerTransaction, SerializationContext) -> Verifier, private val attachmentsClassLoaderCache: AttachmentsClassLoaderCache?, val digestService: DigestService ) : FullTransaction() { @@ -100,22 +104,23 @@ private constructor( */ @DeprecatedConstructorForDeserialization(1) private constructor( - inputs: List>, - outputs: List>, - commands: List>, - attachments: List, - id: SecureHash, - notary: Party?, - timeWindow: TimeWindow?, - privacySalt: PrivacySalt, - networkParameters: NetworkParameters?, - references: List>, - componentGroups: List?, - serializedInputs: List?, - serializedReferences: List?, - isAttachmentTrusted: (Attachment) -> Boolean, - verifierFactory: (LedgerTransaction, ClassLoader) -> Verifier, - attachmentsClassLoaderCache: AttachmentsClassLoaderCache?) : this( + inputs: List>, + outputs: List>, + commands: List>, + attachments: List, + id: SecureHash, + notary: Party?, + timeWindow: TimeWindow?, + privacySalt: PrivacySalt, + networkParameters: NetworkParameters?, + references: List>, + componentGroups: List?, + serializedInputs: List?, + serializedReferences: List?, + isAttachmentTrusted: (Attachment) -> Boolean, + verifierFactory: (LedgerTransaction, SerializationContext) -> Verifier, + attachmentsClassLoaderCache: AttachmentsClassLoaderCache? + ) : this( inputs, outputs, commands, attachments, id, notary, timeWindow, privacySalt, networkParameters, references, componentGroups, serializedInputs, serializedReferences, isAttachmentTrusted, verifierFactory, attachmentsClassLoaderCache, DigestService.sha2_256) @@ -124,8 +129,8 @@ private constructor( companion object { private val logger = contextLogger() - private fun protect(list: List?): List? { - return list?.run { + private fun protect(list: List): List { + return list.run { if (isEmpty()) { emptyList() } else { @@ -134,6 +139,8 @@ private constructor( } } + private fun protectOrNull(list: List?): List? = list?.let(::protect) + @CordaInternal internal fun create( inputs: List>, @@ -164,9 +171,9 @@ private constructor( privacySalt = privacySalt, networkParameters = networkParameters, references = references, - componentGroups = protect(componentGroups), - serializedInputs = protect(serializedInputs), - serializedReferences = protect(serializedReferences), + componentGroups = protectOrNull(componentGroups), + serializedInputs = protectOrNull(serializedInputs), + serializedReferences = protectOrNull(serializedReferences), isAttachmentTrusted = isAttachmentTrusted, verifierFactory = ::BasicVerifier, attachmentsClassLoaderCache = attachmentsClassLoaderCache, @@ -176,10 +183,11 @@ private constructor( /** * This factory function will create an instance of [LedgerTransaction] - * that will be used inside the DJVM sandbox. + * that will be used for contract verification. See [BasicVerifier] and + * [DeterministicVerifier][net.corda.node.internal.djvm.DeterministicVerifier]. */ @CordaInternal - fun createForSandbox( + fun createForContractVerify( inputs: List>, outputs: List>, commands: List>, @@ -188,28 +196,31 @@ private constructor( notary: Party?, timeWindow: TimeWindow?, privacySalt: PrivacySalt, - networkParameters: NetworkParameters, + networkParameters: NetworkParameters?, references: List>, digestService: DigestService): LedgerTransaction { return LedgerTransaction( - inputs = inputs, - outputs = outputs, - commands = commands, - attachments = attachments, + inputs = protect(inputs), + outputs = protect(outputs), + commands = protect(commands), + attachments = protect(attachments), id = id, notary = notary, timeWindow = timeWindow, privacySalt = privacySalt, networkParameters = networkParameters, - references = references, + references = protect(references), componentGroups = null, serializedInputs = null, serializedReferences = null, isAttachmentTrusted = { true }, - verifierFactory = ::BasicVerifier, + verifierFactory = ::NoOpVerifier, attachmentsClassLoaderCache = null, digestService = digestService - ) + // This check accesses input states and must run on the LedgerTransaction + // instance that is verified, not on the outer LedgerTransaction shell. + // All states must also deserialize using the correct SerializationContext. + ).also(LedgerTransaction::checkBaseInvariants) } } @@ -251,11 +262,17 @@ private constructor( getParamsWithGoo(), id, isAttachmentTrusted = isAttachmentTrusted, - attachmentsClassLoaderCache = attachmentsClassLoaderCache) { transactionClassLoader -> - // Create a copy of the outer LedgerTransaction which deserializes all fields inside the [transactionClassLoader]. + attachmentsClassLoaderCache = attachmentsClassLoaderCache) { serializationContext -> + + // Legacy check - warns if the LedgerTransaction was created incorrectly. + checkLtxForVerification() + + // Create a copy of the outer LedgerTransaction which deserializes all fields using + // the serialization context (or its deserializationClassloader). // Only the copy will be used for verification, and the outer shell will be discarded. // This artifice is required to preserve backwards compatibility. - verifierFactory(createLtxForVerification(), transactionClassLoader) + // NOTE: The Verifier creates the copies of the LedgerTransaction object now. + verifierFactory(this, serializationContext) } } @@ -272,7 +289,7 @@ private constructor( * Node without changing either the wire format or any public APIs. */ @CordaInternal - fun specialise(alternateVerifier: (LedgerTransaction, ClassLoader) -> Verifier): LedgerTransaction = LedgerTransaction( + fun specialise(alternateVerifier: (LedgerTransaction, SerializationContext) -> Verifier): LedgerTransaction = LedgerTransaction( inputs = inputs, outputs = outputs, commands = commands, @@ -287,7 +304,11 @@ private constructor( serializedInputs = serializedInputs, serializedReferences = serializedReferences, isAttachmentTrusted = isAttachmentTrusted, - verifierFactory = alternateVerifier, + verifierFactory = if (verifierFactory == ::NoOpVerifier) { + throw IllegalStateException("Cannot specialise transaction while verifying contracts") + } else { + alternateVerifier + }, attachmentsClassLoaderCache = attachmentsClassLoaderCache, digestService = digestService ) @@ -319,58 +340,12 @@ private constructor( } /** - * Create the [LedgerTransaction] instance that will be used by contract verification. - * - * This method needs to run in the special transaction attachments classloader context. */ - private fun createLtxForVerification(): LedgerTransaction { - val serializedInputs = this.serializedInputs - val serializedReferences = this.serializedReferences - val componentGroups = this.componentGroups - - val transaction= if (serializedInputs != null && serializedReferences != null && componentGroups != null) { - // Deserialize all relevant classes in the transaction classloader. - val deserializedInputs = serializedInputs.map { it.toStateAndRef() } - val deserializedReferences = serializedReferences.map { it.toStateAndRef() } - val deserializedOutputs = deserialiseComponentGroup(componentGroups, TransactionState::class, ComponentGroupEnum.OUTPUTS_GROUP, forceDeserialize = true) - val deserializedCommands = deserialiseCommands(componentGroups, forceDeserialize = true, digestService = digestService) - val authenticatedDeserializedCommands = deserializedCommands.map { cmd -> - @Suppress("DEPRECATION") // Deprecated feature. - val parties = commands.find { it.value.javaClass.name == cmd.value.javaClass.name }!!.signingParties - CommandWithParties(cmd.signers, parties, cmd.value) - } - - LedgerTransaction( - inputs = deserializedInputs, - outputs = deserializedOutputs, - commands = authenticatedDeserializedCommands, - attachments = this.attachments, - id = this.id, - notary = this.notary, - timeWindow = this.timeWindow, - privacySalt = this.privacySalt, - networkParameters = this.networkParameters, - references = deserializedReferences, - componentGroups = componentGroups, - serializedInputs = serializedInputs, - serializedReferences = serializedReferences, - isAttachmentTrusted = isAttachmentTrusted, - verifierFactory = verifierFactory, - attachmentsClassLoaderCache = attachmentsClassLoaderCache, - digestService = digestService - ) - } else { - // This branch is only present for backwards compatibility. + private fun checkLtxForVerification() { + if (serializedInputs == null || serializedReferences == null || componentGroups == null) { logger.warn("The LedgerTransaction should not be instantiated directly from client code. Please use WireTransaction.toLedgerTransaction." + "The result of the verify method might not be accurate.") - this } - - // This check accesses input states and must be run in this context. - // It must run on the instance that is verified, not on the outer LedgerTransaction shell. - transaction.checkBaseInvariants() - - return transaction } /** @@ -740,7 +715,7 @@ private constructor( componentGroups = null, serializedInputs = null, serializedReferences = null, - isAttachmentTrusted = { it.isUploaderTrusted() }, + isAttachmentTrusted = Attachment::isUploaderTrusted, verifierFactory = ::BasicVerifier, attachmentsClassLoaderCache = null ) @@ -770,7 +745,7 @@ private constructor( componentGroups = null, serializedInputs = null, serializedReferences = null, - isAttachmentTrusted = { it.isUploaderTrusted() }, + isAttachmentTrusted = Attachment::isUploaderTrusted, verifierFactory = ::BasicVerifier, attachmentsClassLoaderCache = null ) @@ -838,3 +813,80 @@ private constructor( ) } } + +/** + * This is the default [Verifier] that configures Corda + * to execute [Contract.verify(LedgerTransaction)]. + * + * THIS CLASS IS NOT PUBLIC API, AND IS DELIBERATELY PRIVATE! + */ +@CordaInternal +private class BasicVerifier( + ltx: LedgerTransaction, + private val serializationContext: SerializationContext +) : AbstractVerifier(ltx, serializationContext.deserializationClassLoader) { + + init { + // This is a sanity check: We should only instantiate this + // class from [LedgerTransaction.internalPrepareVerify]. + require(serializationContext === SerializationFactory.defaultFactory.currentContext) { + "BasicVerifier for TX ${ltx.id} created outside its SerializationContext" + } + + // Fetch these commands' signing parties from the database. + // Corda forbids database access during contract verification, + // and so we must load the commands here eagerly instead. + // THIS ALSO DESERIALISES THE COMMANDS USING THE WRONG CONTEXT + // BECAUSE THAT CONTEXT WAS CHOSEN WHEN THE LAZY MAP WAS CREATED, + // AND CHANGING THE DEFAULT CONTEXT HERE DOES NOT AFFECT IT. + ltx.commands.eagerDeserialise() + } + + override val transaction: Supplier + get() = Supplier(::createTransaction) + + private fun createTransaction(): LedgerTransaction { + // Deserialize all relevant classes using the serializationContext. + return SerializationFactory.defaultFactory.withCurrentContext(serializationContext) { + ltx.transform { componentGroups, serializedInputs, serializedReferences -> + val deserializedInputs = serializedInputs.map(SerializedStateAndRef::toStateAndRef) + val deserializedReferences = serializedReferences.map(SerializedStateAndRef::toStateAndRef) + val deserializedOutputs = deserialiseComponentGroup(componentGroups, TransactionState::class, ComponentGroupEnum.OUTPUTS_GROUP, forceDeserialize = true) + val deserializedCommands = deserialiseCommands(componentGroups, forceDeserialize = true, digestService = ltx.digestService) + val authenticatedDeserializedCommands = deserializedCommands.mapIndexed { idx, cmd -> + // Requires ltx.commands to have been deserialized already. + @Suppress("DEPRECATION") // Deprecated feature. + val parties = ltx.commands[idx].signingParties + CommandWithParties(cmd.signers, parties, cmd.value) + } + + LedgerTransaction.createForContractVerify( + inputs = deserializedInputs, + outputs = deserializedOutputs, + commands = authenticatedDeserializedCommands, + attachments = ltx.attachments, + id = ltx.id, + notary = ltx.notary, + timeWindow = ltx.timeWindow, + privacySalt = ltx.privacySalt, + networkParameters = ltx.networkParameters, + references = deserializedReferences, + digestService = ltx.digestService + ) + } + } + } +} + +/** + * A "do nothing" [Verifier] installed for contract verification. + * + * THIS CLASS IS NOT PUBLIC API, AND IS DELIBERATELY PRIVATE! + */ +@Suppress("unused_parameter") +@CordaInternal +private class NoOpVerifier(ltx: LedgerTransaction, serializationContext: SerializationContext) : Verifier { + // Invoking LedgerTransaction.verify() from Contract.verify(LedgerTransaction) + // will execute this function. But why would anyone do that?! + override fun verify() {} +} diff --git a/core/src/main/kotlin/net/corda/core/transactions/TransactionBuilder.kt b/core/src/main/kotlin/net/corda/core/transactions/TransactionBuilder.kt index 2b2f9655c2..71fb6c728b 100644 --- a/core/src/main/kotlin/net/corda/core/transactions/TransactionBuilder.kt +++ b/core/src/main/kotlin/net/corda/core/transactions/TransactionBuilder.kt @@ -16,14 +16,18 @@ import net.corda.core.node.ServicesForResolution import net.corda.core.node.ZoneVersionTooLowException import net.corda.core.node.services.AttachmentId import net.corda.core.node.services.KeyManagementService +import net.corda.core.serialization.CustomSerializationScheme import net.corda.core.serialization.SerializationContext +import net.corda.core.serialization.SerializationDefaults import net.corda.core.serialization.SerializationFactory +import net.corda.core.serialization.SerializationMagic +import net.corda.core.serialization.SerializationSchemeContext +import net.corda.core.serialization.internal.CustomSerializationSchemeUtils.Companion.getCustomSerializationMagicFromSchemeId import net.corda.core.utilities.contextLogger import java.security.PublicKey import java.time.Duration import java.time.Instant -import java.util.ArrayDeque -import java.util.UUID +import java.util.* import java.util.regex.Pattern import kotlin.collections.ArrayList import kotlin.collections.component1 @@ -140,6 +144,41 @@ open class TransactionBuilder( fun toWireTransaction(services: ServicesForResolution): WireTransaction = toWireTransactionWithContext(services, null) .apply { checkSupportedHashType() } + /** + * Generates a [WireTransaction] from this builder, resolves any [AutomaticPlaceholderConstraint], and selects the attachments to use for this transaction. + * + * @param [schemeId] is used to specify the [CustomSerializationScheme] used to serialize each component of the componentGroups of the [WireTransaction]. + * This is an experimental feature. + * + * @returns A new [WireTransaction] that will be unaffected by further changes to this [TransactionBuilder]. + * + * @throws [ZoneVersionTooLowException] if there are reference states and the zone minimum platform version is less than 4. + */ + @Throws(MissingContractAttachments::class) + fun toWireTransaction(services: ServicesForResolution, schemeId: Int): WireTransaction { + return toWireTransaction(services, schemeId, emptyMap()).apply { checkSupportedHashType() } + } + + /** + * Generates a [WireTransaction] from this builder, resolves any [AutomaticPlaceholderConstraint], and selects the attachments to use for this transaction. + * + * @param [schemeId] is used to specify the [CustomSerializationScheme] used to serialize each component of the componentGroups of the [WireTransaction]. + * This is an experimental feature. + * + * @param [properties] a list of properties to add to the [SerializationSchemeContext] these properties can be accessed in [CustomSerializationScheme.serialize] + * when serializing the componentGroups of the wire transaction but might not be available when deserializing. + * + * @returns A new [WireTransaction] that will be unaffected by further changes to this [TransactionBuilder]. + * + * @throws [ZoneVersionTooLowException] if there are reference states and the zone minimum platform version is less than 4. + */ + @Throws(MissingContractAttachments::class) + fun toWireTransaction(services: ServicesForResolution, schemeId: Int, properties: Map): WireTransaction { + val magic: SerializationMagic = getCustomSerializationMagicFromSchemeId(schemeId) + val serializationContext = SerializationDefaults.P2P_CONTEXT.withPreferredSerializationVersion(magic).withProperties(properties) + return toWireTransactionWithContext(services, serializationContext).apply { checkSupportedHashType() } + } + @CordaInternal internal fun toWireTransactionWithContext( services: ServicesForResolution, diff --git a/core/src/main/kotlin/net/corda/core/transactions/WireTransaction.kt b/core/src/main/kotlin/net/corda/core/transactions/WireTransaction.kt index fbcb012d10..5ff7cae23e 100644 --- a/core/src/main/kotlin/net/corda/core/transactions/WireTransaction.kt +++ b/core/src/main/kotlin/net/corda/core/transactions/WireTransaction.kt @@ -15,6 +15,7 @@ import net.corda.core.node.ServicesForResolution import net.corda.core.node.services.AttachmentId import net.corda.core.serialization.CordaSerializable import net.corda.core.serialization.DeprecatedConstructorForDeserialization +import net.corda.core.serialization.SerializationFactory import net.corda.core.serialization.SerializedBytes import net.corda.core.serialization.internal.AttachmentsClassLoaderCache import net.corda.core.serialization.serialize @@ -154,7 +155,7 @@ class WireTransaction(componentGroups: List, val privacySalt: Pr resolveAttachment, { stateRef -> resolveStateRef(stateRef)?.serialize() }, { null }, - { it.isUploaderTrusted() }, + Attachment::isUploaderTrusted, null ) } @@ -187,19 +188,26 @@ class WireTransaction(componentGroups: List, val privacySalt: Pr ): LedgerTransaction { // Look up public keys to authenticated identities. val authenticatedCommands = commands.lazyMapped { cmd, _ -> - val parties = cmd.signers.mapNotNull { pk -> resolveIdentity(pk) } + val parties = cmd.signers.mapNotNull(resolveIdentity) CommandWithParties(cmd.signers, parties, cmd.value) } + // Ensure that the lazy mappings will use the correct SerializationContext. + val serializationFactory = SerializationFactory.defaultFactory + val serializationContext = serializationFactory.defaultContext + val toStateAndRef = { ssar: SerializedStateAndRef, _: Int -> + ssar.toStateAndRef(serializationFactory, serializationContext) + } + val serializedResolvedInputs = inputs.map { ref -> SerializedStateAndRef(resolveStateRefAsSerialized(ref) ?: throw TransactionResolutionException(ref.txhash), ref) } - val resolvedInputs = serializedResolvedInputs.lazyMapped { star, _ -> star.toStateAndRef() } + val resolvedInputs = serializedResolvedInputs.lazyMapped(toStateAndRef) val serializedResolvedReferences = references.map { ref -> SerializedStateAndRef(resolveStateRefAsSerialized(ref) ?: throw TransactionResolutionException(ref.txhash), ref) } - val resolvedReferences = serializedResolvedReferences.lazyMapped { star, _ -> star.toStateAndRef() } + val resolvedReferences = serializedResolvedReferences.lazyMapped(toStateAndRef) val resolvedAttachments = attachments.lazyMapped { att, _ -> resolveAttachment(att) ?: throw AttachmentResolutionException(att) } @@ -214,7 +222,7 @@ class WireTransaction(componentGroups: List, val privacySalt: Pr notary, timeWindow, privacySalt, - resolvedNetworkParameters, + resolvedNetworkParameters.toImmutable(), resolvedReferences, componentGroups, serializedResolvedInputs, @@ -318,7 +326,11 @@ class WireTransaction(componentGroups: List, val privacySalt: Pr * nothing about the rest. */ internal val availableComponentNonces: Map> by lazy { - componentGroups.associate { it.groupIndex to it.components.mapIndexed { internalIndex, internalIt -> digestService.componentHash(internalIt, privacySalt, it.groupIndex, internalIndex) } } + if(digestService.hashAlgorithm == SecureHash.SHA2_256) { + componentGroups.associate { it.groupIndex to it.components.mapIndexed { internalIndex, internalIt -> digestService.componentHash(internalIt, privacySalt, it.groupIndex, internalIndex) } } + } else { + componentGroups.associate { it.groupIndex to it.components.mapIndexed { internalIndex, _ -> digestService.computeNonce(privacySalt, it.groupIndex, internalIndex) } } + } } /** diff --git a/core/src/obfuscator/kotlin/net/corda/core/internal/utilities/TestResourceWriter.kt b/core/src/obfuscator/kotlin/net/corda/core/internal/utilities/TestResourceWriter.kt new file mode 100644 index 0000000000..a0ba712730 --- /dev/null +++ b/core/src/obfuscator/kotlin/net/corda/core/internal/utilities/TestResourceWriter.kt @@ -0,0 +1,54 @@ +package net.corda.core.internal.utilities + +import net.corda.core.obfuscator.XorOutputStream +import java.net.URL +import java.nio.file.Files +import java.nio.file.Paths +import java.util.zip.ZipEntry +import java.util.zip.ZipOutputStream + +object TestResourceWriter { + + private val externalZipBombUrls = arrayOf( + URL("https://www.bamsoftware.com/hacks/zipbomb/zbsm.zip"), + URL("https://www.bamsoftware.com/hacks/zipbomb/zblg.zip"), + URL("https://www.bamsoftware.com/hacks/zipbomb/zbxl.zip") + ) + + @JvmStatic + @Suppress("NestedBlockDepth", "MagicNumber") + fun main(vararg args : String) { + for(arg in args) { + /** + * Download zip bombs + */ + for(url in externalZipBombUrls) { + url.openStream().use { inputStream -> + val destination = Paths.get(arg).resolve(Paths.get(url.path + ".xor").fileName) + Files.newOutputStream(destination).buffered().let(::XorOutputStream).use { outputStream -> + inputStream.copyTo(outputStream) + } + } + } + /** + * Create a jar archive with a huge manifest file, used in unit tests to check that it is also identified as a zip bomb. + * This is because {@link java.util.jar.JarInputStream} + * eagerly loads the manifest file in memory + * which would make such a jar dangerous if used as an attachment + */ + val destination = Paths.get(arg).resolve(Paths.get("big-manifest.jar.xor").fileName) + ZipOutputStream(XorOutputStream((Files.newOutputStream(destination).buffered()))).use { zos -> + val zipEntry = ZipEntry("MANIFEST.MF") + zipEntry.method = ZipEntry.DEFLATED + zos.putNextEntry(zipEntry) + val buffer = ByteArray(0x100000) { 0x0 } + var written = 0L + while(written < 10_000_000_000) { + zos.write(buffer) + written += buffer.size + } + zos.closeEntry() + } + } + } +} \ No newline at end of file diff --git a/core/src/obfuscator/kotlin/net/corda/core/obfuscator/XorInputStream.kt b/core/src/obfuscator/kotlin/net/corda/core/obfuscator/XorInputStream.kt new file mode 100644 index 0000000000..9226be93fc --- /dev/null +++ b/core/src/obfuscator/kotlin/net/corda/core/obfuscator/XorInputStream.kt @@ -0,0 +1,30 @@ +package net.corda.core.obfuscator + +import java.io.FilterInputStream +import java.io.InputStream + +@Suppress("MagicNumber") +class XorInputStream(private val source : InputStream) : FilterInputStream(source) { + var prev : Int = 0 + + override fun read(): Int { + prev = source.read() xor prev + return prev - 0x80 + } + + override fun read(buffer: ByteArray): Int { + return read(buffer, 0, buffer.size) + } + + override fun read(buffer: ByteArray, off: Int, len: Int): Int { + var read = 0 + while(true) { + val b = source.read() + if(b < 0) break + buffer[off + read++] = ((b xor prev) - 0x80).toByte() + prev = b + if(read == len) break + } + return read + } +} \ No newline at end of file diff --git a/core/src/obfuscator/kotlin/net/corda/core/obfuscator/XorOutputStream.kt b/core/src/obfuscator/kotlin/net/corda/core/obfuscator/XorOutputStream.kt new file mode 100644 index 0000000000..4f4bdbab0a --- /dev/null +++ b/core/src/obfuscator/kotlin/net/corda/core/obfuscator/XorOutputStream.kt @@ -0,0 +1,30 @@ +package net.corda.core.obfuscator + +import java.io.FilterOutputStream +import java.io.OutputStream + +@Suppress("MagicNumber") +class XorOutputStream(private val destination : OutputStream) : FilterOutputStream(destination) { + var prev : Int = 0 + + override fun write(byte: Int) { + val b = (byte + 0x80) xor prev + destination.write(b) + prev = b + } + + override fun write(buffer: ByteArray) { + write(buffer, 0, buffer.size) + } + + override fun write(buffer: ByteArray, off: Int, len: Int) { + var written = 0 + while(true) { + val b = (buffer[written] + 0x80) xor prev + destination.write(b) + prev = b + ++written + if(written == len) break + } + } +} \ No newline at end of file diff --git a/core/src/test/README.md b/core/src/test/README.md index 7eb94e5eab..36b5e2dfa2 100644 --- a/core/src/test/README.md +++ b/core/src/test/README.md @@ -7,6 +7,18 @@ The Corda core module defines a lot of types and helpers that can only be exerci the context of a node. However, as everything else depends on the core module, we cannot pull the node into this module. Therefore, any tests that require further Corda dependencies need to be defined in the module `core-tests`, which has the full set of dependencies including `node-driver`. - +# ZipBomb tests + +There is a unit test that checks the zip bomb detector in `net.corda.core.internal.utilities.ZipBombDetector` works correctly. +This test (`core/src/test/kotlin/net/corda/core/internal/utilities/ZipBombDetectorTest.kt`) uses real zip bombs, provided by `https://www.bamsoftware.com/hacks/zipbomb/`. +As it is undesirable to have unit test depends on external internet resources we do not control, those files are included as resources in +`core/src/test/resources/zip/`, however some Windows antivirus software correctly identifies those files as zip bombs, +raising an alert to the user. To mitigate this, those files have been obfuscated using `net.corda.core.obfuscator.XorOutputStream` +(which simply XORs every byte of the file with the previous one, except for the first byte that is XORed with zero) +to prevent antivirus software from detecting them as zip bombs and are de-obfuscated on the fly in unit tests using +`net.corda.core.obfuscator.XorInputStream`. + +There is a dedicated Gradle task to re-download and re-obfuscate all the test resource files named `writeTestResources`, +its source code is in `core/src/obfuscator/kotlin/net/corda/core/internal/utilities/TestResourceWriter.kt` \ No newline at end of file diff --git a/core/src/test/kotlin/net/corda/core/crypto/Blake2s256DigestServiceTest.kt b/core/src/test/kotlin/net/corda/core/crypto/Blake2s256DigestServiceTest.kt index d2d51f3761..f0e9d9bee6 100644 --- a/core/src/test/kotlin/net/corda/core/crypto/Blake2s256DigestServiceTest.kt +++ b/core/src/test/kotlin/net/corda/core/crypto/Blake2s256DigestServiceTest.kt @@ -1,33 +1,19 @@ package net.corda.core.crypto import net.corda.core.crypto.internal.DigestAlgorithmFactory -import org.bouncycastle.crypto.digests.Blake2sDigest +import net.corda.core.internal.BLAKE2s256DigestAlgorithm import org.junit.Assert.assertArrayEquals import org.junit.Before import org.junit.Test import kotlin.test.assertEquals class Blake2s256DigestServiceTest { - class BLAKE2s256DigestService : DigestAlgorithm { - override val algorithm = "BLAKE_TEST" - - override val digestLength = 32 - - override fun digest(bytes: ByteArray): ByteArray { - val blake2s256 = Blake2sDigest(null, digestLength, null, "12345678".toByteArray()) - blake2s256.reset() - blake2s256.update(bytes, 0, bytes.size) - val hash = ByteArray(digestLength) - blake2s256.doFinal(hash, 0) - return hash - } - } private val service = DigestService("BLAKE_TEST") @Before fun before() { - DigestAlgorithmFactory.registerClass(BLAKE2s256DigestService::class.java.name) + DigestAlgorithmFactory.registerClass(BLAKE2s256DigestAlgorithm::class.java.name) } @Test(timeout = 300_000) diff --git a/core/src/test/kotlin/net/corda/core/internal/HashAgilityHelpers.kt b/core/src/test/kotlin/net/corda/core/internal/HashAgilityHelpers.kt new file mode 100644 index 0000000000..4795b05f75 --- /dev/null +++ b/core/src/test/kotlin/net/corda/core/internal/HashAgilityHelpers.kt @@ -0,0 +1,36 @@ +package net.corda.core.internal + +import net.corda.core.crypto.DigestAlgorithm +import net.corda.core.crypto.SecureHash +import org.bouncycastle.crypto.digests.Blake2sDigest + +/** + * A set of custom hash algorithms + */ + +open class BLAKE2s256DigestAlgorithm : DigestAlgorithm { + override val algorithm = "BLAKE_TEST" + + override val digestLength = 32 + + protected fun blake2sHash(bytes: ByteArray): ByteArray { + val blake2s256 = Blake2sDigest(null, digestLength, null, "12345678".toByteArray()) + blake2s256.reset() + blake2s256.update(bytes, 0, bytes.size) + val hash = ByteArray(digestLength) + blake2s256.doFinal(hash, 0) + return hash + } + + override fun digest(bytes: ByteArray): ByteArray = blake2sHash(bytes) +} + +class SHA256BLAKE2s256DigestAlgorithm : BLAKE2s256DigestAlgorithm() { + override val algorithm = "SHA256-BLAKE2S256-TEST" + + override fun digest(bytes: ByteArray): ByteArray = SecureHash.hashAs(SecureHash.SHA2_256, bytes).bytes + + override fun componentDigest(bytes: ByteArray): ByteArray = blake2sHash(bytes) + + override fun nonceDigest(bytes: ByteArray): ByteArray = blake2sHash(bytes) +} \ No newline at end of file diff --git a/core/src/test/kotlin/net/corda/core/internal/internalAccessTestHelpers.kt b/core/src/test/kotlin/net/corda/core/internal/internalAccessTestHelpers.kt index 16ce7444ad..16a6e6bef8 100644 --- a/core/src/test/kotlin/net/corda/core/internal/internalAccessTestHelpers.kt +++ b/core/src/test/kotlin/net/corda/core/internal/internalAccessTestHelpers.kt @@ -5,10 +5,12 @@ import net.corda.core.crypto.DigestService import net.corda.core.crypto.SecureHash import net.corda.core.identity.Party import net.corda.core.node.NetworkParameters +import net.corda.core.serialization.SerializationContext import net.corda.core.serialization.internal.AttachmentsClassLoaderCache import net.corda.core.transactions.ComponentGroup import net.corda.core.transactions.LedgerTransaction import net.corda.core.transactions.WireTransaction +import java.util.function.Supplier /** * A set of functions in core:test that allows testing of core internal classes in the core-tests project. @@ -18,6 +20,7 @@ fun WireTransaction.accessGroupHashes() = this.groupHashes fun WireTransaction.accessGroupMerkleRoots() = this.groupsMerkleRoots fun WireTransaction.accessAvailableComponentHashes() = this.availableComponentHashes +fun WireTransaction.accessAvailableComponentNonces() = this.availableComponentNonces @Suppress("LongParameterList") fun createLedgerTransaction( @@ -37,7 +40,17 @@ fun createLedgerTransaction( isAttachmentTrusted: (Attachment) -> Boolean, attachmentsClassLoaderCache: AttachmentsClassLoaderCache, digestService: DigestService = DigestService.default -): LedgerTransaction = LedgerTransaction.create(inputs, outputs, commands, attachments, id, notary, timeWindow, privacySalt, networkParameters, references, componentGroups, serializedInputs, serializedReferences, isAttachmentTrusted, attachmentsClassLoaderCache, digestService) +): LedgerTransaction = LedgerTransaction.create( + inputs, outputs, commands, attachments, id, notary, timeWindow, privacySalt, networkParameters, references, componentGroups, serializedInputs, serializedReferences, isAttachmentTrusted, attachmentsClassLoaderCache, digestService +).specialise(::PassthroughVerifier) fun createContractCreationError(txId: SecureHash, contractClass: String, cause: Throwable) = TransactionVerificationException.ContractCreationError(txId, contractClass, cause) fun createContractRejection(txId: SecureHash, contract: Contract, cause: Throwable) = TransactionVerificationException.ContractRejection(txId, contract, cause) + +/** + * Verify the [LedgerTransaction] we already have. + */ +private class PassthroughVerifier(ltx: LedgerTransaction, context: SerializationContext) : AbstractVerifier(ltx, context.deserializationClassLoader) { + override val transaction: Supplier + get() = Supplier { ltx } +} diff --git a/core/src/test/kotlin/net/corda/core/internal/utilities/ZipBombDetectorTest.kt b/core/src/test/kotlin/net/corda/core/internal/utilities/ZipBombDetectorTest.kt new file mode 100644 index 0000000000..131fb29736 --- /dev/null +++ b/core/src/test/kotlin/net/corda/core/internal/utilities/ZipBombDetectorTest.kt @@ -0,0 +1,70 @@ +package net.corda.core.internal.utilities + +import net.corda.core.obfuscator.XorInputStream +import org.junit.Assert +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.Parameterized + +@RunWith(Parameterized::class) +class ZipBombDetectorTest(private val case : TestCase) { + + enum class TestCase( + val description : String, + val zipResource : String, + val maxUncompressedSize : Long, + val maxCompressionRatio : Float, + val expectedOutcome : Boolean + ) { + LEGIT_JAR("This project's jar file", "zip/core.jar", 128_000, 10f, false), + + // This is not detected as a zip bomb as ZipInputStream is unable to read all of its entries + // (https://stackoverflow.com/questions/69286786/zipinputstream-cannot-parse-a-281-tb-zip-bomb), + // so the total uncompressed size doesn't exceed maxUncompressedSize + SMALL_BOMB( + "A large (5.5 GB) zip archive", + "zip/zbsm.zip.xor", 64_000_000, 10f, false), + + // Decreasing maxUncompressedSize leads to a successful detection + SMALL_BOMB2( + "A large (5.5 GB) zip archive, with 1MB maxUncompressedSize", + "zip/zbsm.zip.xor", 1_000_000, 10f, true), + + // ZipInputStream is also unable to read all entries of zblg.zip, but since the first one is already bigger than 4GB, + // that is enough to exceed maxUncompressedSize + LARGE_BOMB( + "A huge (281 TB) Zip bomb, this is the biggest possible non-recursive non-Zip64 archive", + "zip/zblg.zip.xor", 64_000_000, 10f, true), + + //Same for this, but its entries are 22GB each + EXTRA_LARGE_BOMB( + "A humongous (4.5 PB) Zip64 bomb", + "zip/zbxl.zip.xor", 64_000_000, 10f, true), + + //This is a jar file containing a single 10GB manifest + BIG_MANIFEST( + "A jar file with a huge manifest", + "zip/big-manifest.jar.xor", 64_000_000, 10f, true); + + override fun toString() = description + } + + companion object { + @JvmStatic + @Parameterized.Parameters(name = "{0}") + fun generateTestCases(): Collection<*> { + return TestCase.values().toList() + } + } + + @Test(timeout=10_000) + fun test() { + (javaClass.classLoader.getResourceAsStream(case.zipResource) ?: + throw IllegalStateException("Missing test resource file ${case.zipResource}")) + .buffered() + .let(::XorInputStream) + .let { + Assert.assertEquals(case.expectedOutcome, ZipBombDetector.scanZip(it, case.maxUncompressedSize, case.maxCompressionRatio)) + } + } +} \ No newline at end of file diff --git a/core/src/test/kotlin/net/corda/core/obfuscator/XorStreamTest.kt b/core/src/test/kotlin/net/corda/core/obfuscator/XorStreamTest.kt new file mode 100644 index 0000000000..9996a235a9 --- /dev/null +++ b/core/src/test/kotlin/net/corda/core/obfuscator/XorStreamTest.kt @@ -0,0 +1,50 @@ +package net.corda.core.obfuscator + +import org.junit.Assert +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.Parameterized +import java.io.ByteArrayInputStream +import java.io.ByteArrayOutputStream +import java.security.DigestInputStream +import java.security.DigestOutputStream +import java.security.MessageDigest +import java.util.* + +@RunWith(Parameterized::class) +class XorStreamTest(private val size : Int) { + private val random = Random(0) + + companion object { + @JvmStatic + @Parameterized.Parameters(name = "{0}") + fun generateTestCases(): Collection<*> { + return listOf(0, 16, 31, 127, 1000, 1024) + } + } + + @Test(timeout = 5000) + fun test() { + val baos = ByteArrayOutputStream(size) + val md = MessageDigest.getInstance("MD5") + val buffer = ByteArray(DEFAULT_BUFFER_SIZE) + DigestOutputStream(XorOutputStream(baos), md).use { outputStream -> + var written = 0 + while(written < size) { + random.nextBytes(buffer) + val bytesToWrite = (size - written).coerceAtMost(buffer.size) + outputStream.write(buffer, 0, bytesToWrite) + written += bytesToWrite + } + } + val digest = md.digest() + md.reset() + DigestInputStream(XorInputStream(ByteArrayInputStream(baos.toByteArray())), md).use { inputStream -> + while(true) { + val read = inputStream.read(buffer) + if(read <= 0) break + } + } + Assert.assertArrayEquals(digest, md.digest()) + } +} \ No newline at end of file diff --git a/core/src/test/resources/zip/big-manifest.jar.xor b/core/src/test/resources/zip/big-manifest.jar.xor new file mode 100644 index 0000000000..fc9e5f8fd2 Binary files /dev/null and b/core/src/test/resources/zip/big-manifest.jar.xor differ diff --git a/core/src/test/resources/zip/zblg.zip.xor b/core/src/test/resources/zip/zblg.zip.xor new file mode 100644 index 0000000000..dc6788a30c Binary files /dev/null and b/core/src/test/resources/zip/zblg.zip.xor differ diff --git a/core/src/test/resources/zip/zbsm.zip.xor b/core/src/test/resources/zip/zbsm.zip.xor new file mode 100644 index 0000000000..55c0f1b7d5 Binary files /dev/null and b/core/src/test/resources/zip/zbsm.zip.xor differ diff --git a/core/src/test/resources/zip/zbxl.zip.xor b/core/src/test/resources/zip/zbxl.zip.xor new file mode 100644 index 0000000000..168c7f2257 Binary files /dev/null and b/core/src/test/resources/zip/zbxl.zip.xor differ diff --git a/detekt-baseline.xml b/detekt-baseline.xml index 4398ca38b0..d6fdc13c8c 100644 --- a/detekt-baseline.xml +++ b/detekt-baseline.xml @@ -1162,6 +1162,7 @@ MatchingDeclarationName:NamedCache.kt$net.corda.core.internal.NamedCache.kt MatchingDeclarationName:NetParams.kt$net.corda.netparams.NetParams.kt MatchingDeclarationName:NetworkParametersServiceInternal.kt$net.corda.core.internal.NetworkParametersServiceInternal.kt + MatchingDeclarationName:NotaryQueries.kt$net.corda.nodeapi.notary.NotaryQueries.kt MatchingDeclarationName:OGSwapPricingCcpExample.kt$net.corda.vega.analytics.example.OGSwapPricingCcpExample.kt MatchingDeclarationName:OGSwapPricingExample.kt$net.corda.vega.analytics.example.OGSwapPricingExample.kt MatchingDeclarationName:PlatformSecureRandom.kt$net.corda.core.crypto.internal.PlatformSecureRandom.kt diff --git a/docker/src/bash/example-mini-network.sh b/docker/src/bash/example-mini-network.sh index 0f1e116d7a..9b08f0ee61 100755 --- a/docker/src/bash/example-mini-network.sh +++ b/docker/src/bash/example-mini-network.sh @@ -1,8 +1,8 @@ #!/usr/bin/env bash NODE_LIST=("dockerNode1" "dockerNode2" "dockerNode3") NETWORK_NAME=mininet -CORDAPP_VERSION="4.6-SNAPSHOT" -DOCKER_IMAGE_VERSION="corda-zulu-4.6-snapshot" +CORDAPP_VERSION="4.8-SNAPSHOT" +DOCKER_IMAGE_VERSION="corda-zulu-4.8-snapshot" mkdir cordapps rm -f cordapps/* diff --git a/docker/src/docker/Dockerfile b/docker/src/docker/Dockerfile index dd6b51db40..48a0e330ee 100644 --- a/docker/src/docker/Dockerfile +++ b/docker/src/docker/Dockerfile @@ -1,10 +1,11 @@ -FROM azul/zulu-openjdk:8u192 +FROM azul/zulu-openjdk:8u312 ## Remove Azul Zulu repo, as it is gone by now RUN rm -rf /etc/apt/sources.list.d/zulu.list ## Add packages, clean cache, create dirs, create corda user and change ownership RUN apt-get update && \ + apt-mark hold zulu8-jdk && \ apt-get -y upgrade && \ apt-get -y install bash curl unzip && \ rm -rf /var/lib/apt/lists/* && \ diff --git a/docker/src/docker/Dockerfile-debug b/docker/src/docker/Dockerfile-debug index 5b0c7bbb1f..a1175c989c 100644 --- a/docker/src/docker/Dockerfile-debug +++ b/docker/src/docker/Dockerfile-debug @@ -1,7 +1,8 @@ -FROM azul/zulu-openjdk:8u192 +FROM azul/zulu-openjdk:8u312 ## Add packages, clean cache, create dirs, create corda user and change ownership RUN apt-get update && \ + apt-mark hold zulu8-jdk && \ apt-get -y upgrade && \ apt-get -y install bash curl unzip netstat lsof telnet netcat && \ rm -rf /var/lib/apt/lists/* && \ diff --git a/docs/build.gradle b/docs/build.gradle index aa50560300..94fd4e6043 100644 --- a/docs/build.gradle +++ b/docs/build.gradle @@ -5,6 +5,10 @@ apply plugin: 'net.corda.plugins.publish-utils' apply plugin: 'maven-publish' apply plugin: 'com.jfrog.artifactory' +dependencies { + compile rootProject +} + def internalPackagePrefixes(sourceDirs) { def prefixes = [] // Kotlin allows packages to deviate from the directory structure, but let's assume they don't: @@ -36,10 +40,13 @@ task dokkaJavadoc(type: org.jetbrains.dokka.gradle.DokkaTask) { } [dokka, dokkaJavadoc].collect { - it.configure { + it.configuration { moduleName = 'corda' - processConfigurations = ['compile'] - sourceDirs = dokkaSourceDirs + dokkaSourceDirs.collect { sourceDir -> + sourceRoot { + path = sourceDir.path + } + } includes = ['packages.md'] jdkVersion = 8 externalDocumentationLink { @@ -52,7 +59,7 @@ task dokkaJavadoc(type: org.jetbrains.dokka.gradle.DokkaTask) { url = new URL("https://www.bouncycastle.org/docs/docs1.5on/") } internalPackagePrefixes.collect { packagePrefix -> - packageOptions { + perPackageOption { prefix = packagePrefix suppress = true } diff --git a/node-api-tests/src/test/kotlin/net/corda/nodeapitests/internal/crypto/X509UtilitiesTest.kt b/node-api-tests/src/test/kotlin/net/corda/nodeapitests/internal/crypto/X509UtilitiesTest.kt index 9ce80590a1..9affc6a0b1 100644 --- a/node-api-tests/src/test/kotlin/net/corda/nodeapitests/internal/crypto/X509UtilitiesTest.kt +++ b/node-api-tests/src/test/kotlin/net/corda/nodeapitests/internal/crypto/X509UtilitiesTest.kt @@ -21,14 +21,29 @@ import net.corda.core.serialization.deserialize import net.corda.core.serialization.serialize import net.corda.core.utilities.days import net.corda.core.utilities.hours -import net.corda.nodeapi.internal.serialization.amqp.AMQPServerSerializationScheme +import net.corda.coretesting.internal.NettyTestClient +import net.corda.coretesting.internal.NettyTestHandler +import net.corda.coretesting.internal.NettyTestServer +import net.corda.coretesting.internal.stubs.CertificateStoreStubs import net.corda.nodeapi.internal.config.MutualSslConfiguration import net.corda.nodeapi.internal.createDevNodeCa +import net.corda.nodeapi.internal.crypto.CertificateType +import net.corda.nodeapi.internal.crypto.X509CertificateFactory +import net.corda.nodeapi.internal.crypto.X509Utilities import net.corda.nodeapi.internal.crypto.X509Utilities.DEFAULT_IDENTITY_SIGNATURE_SCHEME import net.corda.nodeapi.internal.crypto.X509Utilities.DEFAULT_TLS_SIGNATURE_SCHEME +import net.corda.nodeapi.internal.crypto.checkValidity +import net.corda.nodeapi.internal.crypto.getSupportedKey +import net.corda.nodeapi.internal.crypto.loadOrCreateKeyStore +import net.corda.nodeapi.internal.crypto.save +import net.corda.nodeapi.internal.crypto.toBc +import net.corda.nodeapi.internal.crypto.x509 +import net.corda.nodeapi.internal.crypto.x509Certificates import net.corda.nodeapi.internal.installDevNodeCaCertPath -import net.corda.nodeapi.internal.protonwrapper.netty.init +import net.corda.nodeapi.internal.protonwrapper.netty.keyManagerFactory +import net.corda.nodeapi.internal.protonwrapper.netty.trustManagerFactory import net.corda.nodeapi.internal.registerDevP2pCertificates +import net.corda.nodeapi.internal.serialization.amqp.AMQPServerSerializationScheme import net.corda.serialization.internal.AllWhitelist import net.corda.serialization.internal.SerializationContextImpl import net.corda.serialization.internal.SerializationFactoryImpl @@ -37,25 +52,16 @@ import net.corda.testing.core.ALICE_NAME import net.corda.testing.core.BOB_NAME import net.corda.testing.core.TestIdentity import net.corda.testing.driver.internal.incrementalPortAllocation -import net.corda.coretesting.internal.NettyTestClient -import net.corda.coretesting.internal.NettyTestHandler -import net.corda.coretesting.internal.NettyTestServer -import net.corda.testing.internal.createDevIntermediateCaCertPath -import net.corda.coretesting.internal.stubs.CertificateStoreStubs -import net.corda.nodeapi.internal.crypto.CertificateType -import net.corda.nodeapi.internal.crypto.X509CertificateFactory -import net.corda.nodeapi.internal.crypto.X509Utilities -import net.corda.nodeapi.internal.crypto.checkValidity -import net.corda.nodeapi.internal.crypto.getSupportedKey -import net.corda.nodeapi.internal.crypto.loadOrCreateKeyStore -import net.corda.nodeapi.internal.crypto.save -import net.corda.nodeapi.internal.crypto.toBc -import net.corda.nodeapi.internal.crypto.x509 -import net.corda.nodeapi.internal.crypto.x509Certificates import net.corda.testing.internal.IS_OPENJ9 +import net.corda.testing.internal.createDevIntermediateCaCertPath import net.i2p.crypto.eddsa.EdDSAPrivateKey import org.assertj.core.api.Assertions.assertThat -import org.bouncycastle.asn1.x509.* +import org.bouncycastle.asn1.x509.AuthorityKeyIdentifier +import org.bouncycastle.asn1.x509.BasicConstraints +import org.bouncycastle.asn1.x509.CRLDistPoint +import org.bouncycastle.asn1.x509.Extension +import org.bouncycastle.asn1.x509.KeyUsage +import org.bouncycastle.asn1.x509.SubjectKeyIdentifier import org.bouncycastle.jcajce.provider.asymmetric.edec.BCEdDSAPrivateKey import org.bouncycastle.pqc.jcajce.provider.sphincs.BCSphincs256PrivateKey import org.junit.Assume @@ -74,10 +80,19 @@ import java.security.PrivateKey import java.security.cert.CertPath import java.security.cert.X509Certificate import java.util.* -import javax.net.ssl.* +import javax.net.ssl.SSLContext +import javax.net.ssl.SSLParameters +import javax.net.ssl.SSLServerSocket +import javax.net.ssl.SSLSocket import javax.security.auth.x500.X500Principal import kotlin.concurrent.thread -import kotlin.test.* +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertFalse +import kotlin.test.assertNotNull +import kotlin.test.assertNull +import kotlin.test.assertTrue +import kotlin.test.fail class X509UtilitiesTest { private companion object { @@ -295,15 +310,10 @@ class X509UtilitiesTest { sslConfig.keyStore.get(true).registerDevP2pCertificates(MEGA_CORP.name, rootCa.certificate, intermediateCa) sslConfig.createTrustStore(rootCa.certificate) - val keyStore = sslConfig.keyStore.get() - val trustStore = sslConfig.trustStore.get() - val context = SSLContext.getInstance("TLS") - val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()) - keyManagerFactory.init(keyStore) + val keyManagerFactory = keyManagerFactory(sslConfig.keyStore.get()) val keyManagers = keyManagerFactory.keyManagers - val trustMgrFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()) - trustMgrFactory.init(trustStore) + val trustMgrFactory = trustManagerFactory(sslConfig.trustStore.get()) val trustManagers = trustMgrFactory.trustManagers context.init(keyManagers, trustManagers, newSecureRandom()) @@ -388,15 +398,8 @@ class X509UtilitiesTest { sslConfig.keyStore.get(true).registerDevP2pCertificates(MEGA_CORP.name, rootCa.certificate, intermediateCa) sslConfig.createTrustStore(rootCa.certificate) - val keyStore = sslConfig.keyStore.get() - val trustStore = sslConfig.trustStore.get() - - val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()) - keyManagerFactory.init(keyStore) - - val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()) - trustManagerFactory.init(trustStore) - + val keyManagerFactory = keyManagerFactory(sslConfig.keyStore.get()) + val trustManagerFactory = trustManagerFactory(sslConfig.trustStore.get()) val sslServerContext = SslContextBuilder .forServer(keyManagerFactory) diff --git a/node-api/build.gradle b/node-api/build.gradle index b47bc040be..bb2280e23b 100644 --- a/node-api/build.gradle +++ b/node-api/build.gradle @@ -29,7 +29,7 @@ dependencies { // SQL connection pooling library compile "com.zaxxer:HikariCP:$hikari_version" - + // ClassGraph: classpath scanning compile "io.github.classgraph:classgraph:$class_graph_version" @@ -54,6 +54,9 @@ dependencies { testRuntimeOnly "org.junit.vintage:junit-vintage-engine:${junit_vintage_version}" testRuntimeOnly "org.junit.jupiter:junit-jupiter-engine:${junit_jupiter_version}" testRuntimeOnly "org.junit.platform:junit-platform-launcher:${junit_platform_version}" + + testCompile project(':node-driver') + // Unit testing helpers. testCompile "org.assertj:assertj-core:$assertj_version" testCompile "org.jetbrains.kotlin:kotlin-test:$kotlin_version" diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/ArtemisMessagingClient.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/ArtemisMessagingClient.kt index 1206cbe8ec..74d580a827 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/ArtemisMessagingClient.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/ArtemisMessagingClient.kt @@ -5,7 +5,6 @@ import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.loggerFor import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.NODE_P2P_USER import net.corda.nodeapi.internal.ArtemisTcpTransport.Companion.p2pConnectorTcpTransport -import net.corda.nodeapi.internal.ArtemisTcpTransport.Companion.p2pConnectorTcpTransportFromList import net.corda.nodeapi.internal.config.MessagingServerConnectionConfiguration import net.corda.nodeapi.internal.config.MutualSslConfiguration import org.apache.activemq.artemis.api.core.client.* @@ -25,7 +24,9 @@ class ArtemisMessagingClient(private val config: MutualSslConfiguration, private val confirmationWindowSize: Int = -1, private val messagingServerConnectionConfig: MessagingServerConnectionConfiguration? = null, private val backupServerAddressPool: List = emptyList(), - private val failoverCallback: ((FailoverEventType) -> Unit)? = null + private val failoverCallback: ((FailoverEventType) -> Unit)? = null, + private val threadPoolName: String = "ArtemisClient", + private val trace: Boolean = false ) : ArtemisSessionProvider { companion object { private val log = loggerFor() @@ -40,8 +41,10 @@ class ArtemisMessagingClient(private val config: MutualSslConfiguration, override fun start(): Started = synchronized(this) { check(started == null) { "start can't be called twice" } - val tcpTransport = p2pConnectorTcpTransport(serverAddress, config) - val backupTransports = p2pConnectorTcpTransportFromList(backupServerAddressPool, config) + val tcpTransport = p2pConnectorTcpTransport(serverAddress, config, threadPoolName = threadPoolName, trace = trace) + val backupTransports = backupServerAddressPool.mapIndexed { index, address -> + p2pConnectorTcpTransport(address, config, threadPoolName = "$threadPoolName-backup${index+1}", trace = trace) + } log.info("Connecting to message broker: $serverAddress") if (backupTransports.isNotEmpty()) { @@ -50,8 +53,6 @@ class ArtemisMessagingClient(private val config: MutualSslConfiguration, // If back-up artemis addresses are configured, the locator will be created using HA mode. @Suppress("SpreadOperator") val locator = ActiveMQClient.createServerLocator(backupTransports.isNotEmpty(), *(listOf(tcpTransport) + backupTransports).toTypedArray()).apply { - // Never time out on our loopback Artemis connections. If we switch back to using the InVM transport this - // would be the default and the two lines below can be deleted. connectionTTL = 60000 clientFailureCheckPeriod = 30000 callFailoverTimeout = java.lang.Long.getLong(CORDA_ARTEMIS_CALL_TIMEOUT_PROP_NAME, CORDA_ARTEMIS_CALL_TIMEOUT_DEFAULT) diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/ArtemisTcpTransport.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/ArtemisTcpTransport.kt index d3122c9dc8..84d63df5e2 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/ArtemisTcpTransport.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/ArtemisTcpTransport.kt @@ -1,19 +1,20 @@ +@file:Suppress("LongParameterList") + package net.corda.nodeapi.internal import net.corda.core.messaging.ClientRpcSslOptions import net.corda.core.serialization.internal.nodeSerializationEnv import net.corda.core.utilities.NetworkHostAndPort import net.corda.nodeapi.BrokerRpcSslOptions -import net.corda.nodeapi.internal.config.CertificateStore -import net.corda.nodeapi.internal.config.FileBasedCertificateStoreSupplier +import net.corda.nodeapi.internal.config.DEFAULT_SSL_HANDSHAKE_TIMEOUT import net.corda.nodeapi.internal.config.MutualSslConfiguration import net.corda.nodeapi.internal.config.SslConfiguration +import net.corda.nodeapi.internal.protonwrapper.netty.trustManagerFactory import org.apache.activemq.artemis.api.core.TransportConfiguration -import org.apache.activemq.artemis.core.remoting.impl.netty.NettyConnectorFactory import org.apache.activemq.artemis.core.remoting.impl.netty.TransportConstants -import java.nio.file.Path +import javax.net.ssl.TrustManagerFactory -// This avoids internal types from leaking in the public API. The "external" ArtemisTcpTransport delegates to this internal one. +@Suppress("LongParameterList") class ArtemisTcpTransport { companion object { val CIPHER_SUITES = listOf( @@ -23,65 +24,52 @@ class ArtemisTcpTransport { val TLS_VERSIONS = listOf("TLSv1.2") - internal fun defaultArtemisOptions(hostAndPort: NetworkHostAndPort) = mapOf( + const val SSL_HANDSHAKE_TIMEOUT_NAME = "Corda-SSLHandshakeTimeout" + const val TRUST_MANAGER_FACTORY_NAME = "Corda-TrustManagerFactory" + const val TRACE_NAME = "Corda-Trace" + const val THREAD_POOL_NAME_NAME = "Corda-ThreadPoolName" + + // Turn on AMQP support, which needs the protocol jar on the classpath. + // Unfortunately we cannot disable core protocol as artemis only uses AMQP for interop. + // It does not use AMQP messages for its own messages e.g. topology and heartbeats. + private const val P2P_PROTOCOLS = "CORE,AMQP" + private const val RPC_PROTOCOLS = "CORE" + + private fun defaultArtemisOptions(hostAndPort: NetworkHostAndPort, protocols: String) = mapOf( // Basic TCP target details. TransportConstants.HOST_PROP_NAME to hostAndPort.host, TransportConstants.PORT_PROP_NAME to hostAndPort.port, - - // Turn on AMQP support, which needs the protocol jar on the classpath. - // Unfortunately we cannot disable core protocol as artemis only uses AMQP for interop. - // It does not use AMQP messages for its own messages e.g. topology and heartbeats. - // TODO further investigate how to ensure we use a well defined wire level protocol for Node to Node communications. - TransportConstants.PROTOCOLS_PROP_NAME to "CORE,AMQP", + TransportConstants.PROTOCOLS_PROP_NAME to protocols, TransportConstants.USE_GLOBAL_WORKER_POOL_PROP_NAME to (nodeSerializationEnv != null), - TransportConstants.REMOTING_THREADS_PROPNAME to (if (nodeSerializationEnv != null) -1 else 1), // turn off direct delivery in Artemis - this is latency optimisation that can lead to //hick-ups under high load (CORDA-1336) TransportConstants.DIRECT_DELIVER to false) - internal val defaultSSLOptions = mapOf( - TransportConstants.ENABLED_CIPHER_SUITES_PROP_NAME to CIPHER_SUITES.joinToString(","), - TransportConstants.ENABLED_PROTOCOLS_PROP_NAME to TLS_VERSIONS.joinToString(",")) - - private fun SslConfiguration.toTransportOptions(): Map { - - val options = mutableMapOf() - (keyStore to trustStore).addToTransportOptions(options) - return options - } - - private fun Pair.addToTransportOptions(options: MutableMap) { - - val keyStore = first - val trustStore = second + private fun SslConfiguration.addToTransportOptions(options: MutableMap) { + if (keyStore != null || trustStore != null) { + options[TransportConstants.SSL_ENABLED_PROP_NAME] = true + options[TransportConstants.NEED_CLIENT_AUTH_PROP_NAME] = true + } keyStore?.let { with (it) { path.requireOnDefaultFileSystem() - options.putAll(get().toKeyStoreTransportOptions(path)) + options[TransportConstants.KEYSTORE_PROVIDER_PROP_NAME] = "JKS" + options[TransportConstants.KEYSTORE_PATH_PROP_NAME] = path + options[TransportConstants.KEYSTORE_PASSWORD_PROP_NAME] = get().password } } trustStore?.let { with (it) { path.requireOnDefaultFileSystem() - options.putAll(get().toTrustStoreTransportOptions(path)) + options[TransportConstants.TRUSTSTORE_PROVIDER_PROP_NAME] = "JKS" + options[TransportConstants.TRUSTSTORE_PATH_PROP_NAME] = path + options[TransportConstants.TRUSTSTORE_PASSWORD_PROP_NAME] = get().password } } + options[TransportConstants.SSL_PROVIDER] = if (useOpenSsl) TransportConstants.OPENSSL_PROVIDER else TransportConstants.DEFAULT_SSL_PROVIDER + options[SSL_HANDSHAKE_TIMEOUT_NAME] = handshakeTimeout ?: DEFAULT_SSL_HANDSHAKE_TIMEOUT } - private fun CertificateStore.toKeyStoreTransportOptions(path: Path) = mapOf( - TransportConstants.SSL_ENABLED_PROP_NAME to true, - TransportConstants.KEYSTORE_PROVIDER_PROP_NAME to "JKS", - TransportConstants.KEYSTORE_PATH_PROP_NAME to path, - TransportConstants.KEYSTORE_PASSWORD_PROP_NAME to password, - TransportConstants.NEED_CLIENT_AUTH_PROP_NAME to true) - - private fun CertificateStore.toTrustStoreTransportOptions(path: Path) = mapOf( - TransportConstants.SSL_ENABLED_PROP_NAME to true, - TransportConstants.TRUSTSTORE_PROVIDER_PROP_NAME to "JKS", - TransportConstants.TRUSTSTORE_PATH_PROP_NAME to path, - TransportConstants.TRUSTSTORE_PASSWORD_PROP_NAME to password, - TransportConstants.NEED_CLIENT_AUTH_PROP_NAME to true) - private fun ClientRpcSslOptions.toTransportOptions() = mapOf( TransportConstants.SSL_ENABLED_PROP_NAME to true, TransportConstants.TRUSTSTORE_PROVIDER_PROP_NAME to trustStoreProvider, @@ -95,86 +83,164 @@ class ArtemisTcpTransport { TransportConstants.KEYSTORE_PASSWORD_PROP_NAME to keyStorePassword, TransportConstants.NEED_CLIENT_AUTH_PROP_NAME to false) - internal val acceptorFactoryClassName = "org.apache.activemq.artemis.core.remoting.impl.netty.NettyAcceptorFactory" - internal val connectorFactoryClassName = NettyConnectorFactory::class.java.name - - fun p2pAcceptorTcpTransport(hostAndPort: NetworkHostAndPort, config: MutualSslConfiguration?, enableSSL: Boolean = true): TransportConfiguration { - - return p2pAcceptorTcpTransport(hostAndPort, config?.keyStore, config?.trustStore, enableSSL = enableSSL, useOpenSsl = config?.useOpenSsl ?: false) - } - - fun p2pConnectorTcpTransport(hostAndPort: NetworkHostAndPort, config: MutualSslConfiguration?, enableSSL: Boolean = true, keyStoreProvider: String? = null): TransportConfiguration { - - return p2pConnectorTcpTransport(hostAndPort, config?.keyStore, config?.trustStore, enableSSL = enableSSL, useOpenSsl = config?.useOpenSsl ?: false, keyStoreProvider = keyStoreProvider) - } - - fun p2pAcceptorTcpTransport(hostAndPort: NetworkHostAndPort, keyStore: FileBasedCertificateStoreSupplier?, trustStore: FileBasedCertificateStoreSupplier?, enableSSL: Boolean = true, useOpenSsl: Boolean = false): TransportConfiguration { - - val options = defaultArtemisOptions(hostAndPort).toMutableMap() + fun p2pAcceptorTcpTransport(hostAndPort: NetworkHostAndPort, + config: MutualSslConfiguration?, + trustManagerFactory: TrustManagerFactory? = config?.trustStore?.get()?.let(::trustManagerFactory), + enableSSL: Boolean = true, + threadPoolName: String = "P2PServer", + trace: Boolean = false, + remotingThreads: Int? = null): TransportConfiguration { + val options = mutableMapOf() if (enableSSL) { - options.putAll(defaultSSLOptions) - (keyStore to trustStore).addToTransportOptions(options) - options[TransportConstants.SSL_PROVIDER] = if (useOpenSsl) TransportConstants.OPENSSL_PROVIDER else TransportConstants.DEFAULT_SSL_PROVIDER + config?.addToTransportOptions(options) } - options[TransportConstants.HANDSHAKE_TIMEOUT] = 0 // Suppress core.server.lambda$channelActive$0 - AMQ224088 error from load balancer type connections - return TransportConfiguration(acceptorFactoryClassName, options) + return createAcceptorTransport( + hostAndPort, + P2P_PROTOCOLS, + options, + trustManagerFactory, + enableSSL, + threadPoolName, + trace, + remotingThreads + ) } - @Suppress("LongParameterList") - fun p2pConnectorTcpTransport(hostAndPort: NetworkHostAndPort, keyStore: FileBasedCertificateStoreSupplier?, trustStore: FileBasedCertificateStoreSupplier?, enableSSL: Boolean = true, useOpenSsl: Boolean = false, keyStoreProvider: String? = null): TransportConfiguration { - - val options = defaultArtemisOptions(hostAndPort).toMutableMap() + fun p2pConnectorTcpTransport(hostAndPort: NetworkHostAndPort, + config: MutualSslConfiguration?, + enableSSL: Boolean = true, + threadPoolName: String = "P2PClient", + trace: Boolean = false, + remotingThreads: Int? = null): TransportConfiguration { + val options = mutableMapOf() if (enableSSL) { - options.putAll(defaultSSLOptions) - (keyStore to trustStore).addToTransportOptions(options) - options[TransportConstants.SSL_PROVIDER] = if (useOpenSsl) TransportConstants.OPENSSL_PROVIDER else TransportConstants.DEFAULT_SSL_PROVIDER - keyStoreProvider?.let { options.put(TransportConstants.KEYSTORE_PROVIDER_PROP_NAME, keyStoreProvider) } + config?.addToTransportOptions(options) } - return TransportConfiguration(connectorFactoryClassName, options) + return createConnectorTransport(hostAndPort, P2P_PROTOCOLS, options, enableSSL, threadPoolName, trace, remotingThreads) } - fun p2pConnectorTcpTransportFromList(hostAndPortList: List, config: MutualSslConfiguration?, enableSSL: Boolean = true, keyStoreProvider: String? = null): List = hostAndPortList.map { - p2pConnectorTcpTransport(it, config, enableSSL, keyStoreProvider) - } - - fun rpcAcceptorTcpTransport(hostAndPort: NetworkHostAndPort, config: BrokerRpcSslOptions?, enableSSL: Boolean = true): TransportConfiguration { - val options = defaultArtemisOptions(hostAndPort).toMutableMap() - + fun rpcAcceptorTcpTransport(hostAndPort: NetworkHostAndPort, + config: BrokerRpcSslOptions?, + enableSSL: Boolean = true, + threadPoolName: String = "RPCServer", + trace: Boolean = false, + remotingThreads: Int? = null): TransportConfiguration { + val options = mutableMapOf() if (config != null && enableSSL) { config.keyStorePath.requireOnDefaultFileSystem() options.putAll(config.toTransportOptions()) - options.putAll(defaultSSLOptions) } - options[TransportConstants.HANDSHAKE_TIMEOUT] = 0 // Suppress core.server.lambda$channelActive$0 - AMQ224088 error from load balancer type connections - return TransportConfiguration(acceptorFactoryClassName, options) + return createAcceptorTransport(hostAndPort, RPC_PROTOCOLS, options, null, enableSSL, threadPoolName, trace, remotingThreads) } - fun rpcConnectorTcpTransport(hostAndPort: NetworkHostAndPort, config: ClientRpcSslOptions?, enableSSL: Boolean = true): TransportConfiguration { - val options = defaultArtemisOptions(hostAndPort).toMutableMap() - + fun rpcConnectorTcpTransport(hostAndPort: NetworkHostAndPort, + config: ClientRpcSslOptions?, + enableSSL: Boolean = true, + trace: Boolean = false, + remotingThreads: Int? = null): TransportConfiguration { + val options = mutableMapOf() if (config != null && enableSSL) { config.trustStorePath.requireOnDefaultFileSystem() options.putAll(config.toTransportOptions()) - options.putAll(defaultSSLOptions) } - return TransportConfiguration(connectorFactoryClassName, options) + return createConnectorTransport(hostAndPort, RPC_PROTOCOLS, options, enableSSL, "RPCClient", trace, remotingThreads) } - fun rpcConnectorTcpTransportsFromList(hostAndPortList: List, config: ClientRpcSslOptions?, enableSSL: Boolean = true): List = hostAndPortList.map { - rpcConnectorTcpTransport(it, config, enableSSL) + fun rpcInternalClientTcpTransport(hostAndPort: NetworkHostAndPort, + config: SslConfiguration, + threadPoolName: String = "Internal-RPCClient", + trace: Boolean = false): TransportConfiguration { + val options = mutableMapOf() + config.addToTransportOptions(options) + return createConnectorTransport(hostAndPort, RPC_PROTOCOLS, options, true, threadPoolName, trace, null) } - fun rpcInternalClientTcpTransport(hostAndPort: NetworkHostAndPort, config: SslConfiguration, keyStoreProvider: String? = null): TransportConfiguration { - return TransportConfiguration(connectorFactoryClassName, defaultArtemisOptions(hostAndPort) + defaultSSLOptions + config.toTransportOptions() + asMap(keyStoreProvider)) + fun rpcInternalAcceptorTcpTransport(hostAndPort: NetworkHostAndPort, + config: SslConfiguration, + threadPoolName: String = "Internal-RPCServer", + trace: Boolean = false, + remotingThreads: Int? = null): TransportConfiguration { + val options = mutableMapOf() + config.addToTransportOptions(options) + return createAcceptorTransport( + hostAndPort, + RPC_PROTOCOLS, + options, + trustManagerFactory(requireNotNull(config.trustStore).get()), + true, + threadPoolName, + trace, + remotingThreads + ) } - fun rpcInternalAcceptorTcpTransport(hostAndPort: NetworkHostAndPort, config: SslConfiguration, keyStoreProvider: String? = null): TransportConfiguration { - return TransportConfiguration(acceptorFactoryClassName, defaultArtemisOptions(hostAndPort) + defaultSSLOptions + - config.toTransportOptions() + (TransportConstants.HANDSHAKE_TIMEOUT to 0) + asMap(keyStoreProvider)) + private fun createAcceptorTransport(hostAndPort: NetworkHostAndPort, + protocols: String, + options: MutableMap, + trustManagerFactory: TrustManagerFactory?, + enableSSL: Boolean, + threadPoolName: String, + trace: Boolean, + remotingThreads: Int?): TransportConfiguration { + // Suppress core.server.lambda$channelActive$0 - AMQ224088 error from load balancer type connections + options[TransportConstants.HANDSHAKE_TIMEOUT] = 0 + if (trustManagerFactory != null) { + // NettyAcceptor only creates default TrustManagerFactorys with the provided trust store details. However, we need to use + // more customised instances which use our revocation checkers, which we pass directly into NodeNettyAcceptorFactory. + // + // This, however, requires copying a lot of code from NettyAcceptor into NodeNettyAcceptor. The version of Artemis in + // Corda 4.9 solves this problem by introducing a "trustManagerFactoryPlugin" config option. + options[TRUST_MANAGER_FACTORY_NAME] = trustManagerFactory + } + return createTransport( + "net.corda.node.services.messaging.NodeNettyAcceptorFactory", + hostAndPort, + protocols, + options, + enableSSL, + threadPoolName, + trace, + remotingThreads + ) } - private fun asMap(keyStoreProvider: String?): Map { - return keyStoreProvider?.let {mutableMapOf(TransportConstants.KEYSTORE_PROVIDER_PROP_NAME to it)} ?: emptyMap() + private fun createConnectorTransport(hostAndPort: NetworkHostAndPort, + protocols: String, + options: MutableMap, + enableSSL: Boolean, + threadPoolName: String, + trace: Boolean, + remotingThreads: Int?): TransportConfiguration { + return createTransport( + CordaNettyConnectorFactory::class.java.name, + hostAndPort, + protocols, + options, + enableSSL, + threadPoolName, + trace, + remotingThreads + ) + } + + private fun createTransport(className: String, + hostAndPort: NetworkHostAndPort, + protocols: String, + options: MutableMap, + enableSSL: Boolean, + threadPoolName: String, + trace: Boolean, + remotingThreads: Int?): TransportConfiguration { + options += defaultArtemisOptions(hostAndPort, protocols) + if (enableSSL) { + options[TransportConstants.ENABLED_CIPHER_SUITES_PROP_NAME] = CIPHER_SUITES.joinToString(",") + options[TransportConstants.ENABLED_PROTOCOLS_PROP_NAME] = TLS_VERSIONS.joinToString(",") + } + // By default, use only one remoting thread in tests (https://github.com/corda/corda/pull/2357) + options[TransportConstants.REMOTING_THREADS_PROPNAME] = remotingThreads ?: if (nodeSerializationEnv == null) 1 else -1 + options[THREAD_POOL_NAME_NAME] = threadPoolName + options[TRACE_NAME] = trace + return TransportConfiguration(className, options) } } -} \ No newline at end of file +} diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/ArtemisUtils.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/ArtemisUtils.kt index 23bb9d1428..a3c2109d32 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/ArtemisUtils.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/ArtemisUtils.kt @@ -1,8 +1,14 @@ @file:JvmName("ArtemisUtils") package net.corda.nodeapi.internal +import net.corda.core.internal.declaredField +import org.apache.activemq.artemis.utils.actors.ProcessorBase import java.nio.file.FileSystems import java.nio.file.Path +import java.util.concurrent.Executor +import java.util.concurrent.ThreadFactory +import java.util.concurrent.ThreadPoolExecutor +import java.util.concurrent.atomic.AtomicInteger /** * Require that the [Path] is on a default file system, and therefore is one that Artemis is willing to use. @@ -16,3 +22,29 @@ fun requireMessageSize(messageSize: Int, limit: Int) { require(messageSize <= limit) { "Message exceeds maxMessageSize network parameter, maxMessageSize: [$limit], message size: [$messageSize]" } } +val Executor.rootExecutor: Executor get() { + var executor: Executor = this + while (executor is ProcessorBase<*>) { + executor = executor.declaredField("delegate").value + } + return executor +} + +fun Executor.setThreadPoolName(threadPoolName: String) { + (rootExecutor as? ThreadPoolExecutor)?.let { it.threadFactory = NamedThreadFactory(threadPoolName, it.threadFactory) } +} + +private class NamedThreadFactory(poolName: String, private val delegate: ThreadFactory) : ThreadFactory { + companion object { + private val poolId = AtomicInteger(0) + } + + private val prefix = "$poolName-${poolId.incrementAndGet()}-" + private val nextId = AtomicInteger(0) + + override fun newThread(r: Runnable): Thread { + val thread = delegate.newThread(r) + thread.name = "$prefix${nextId.incrementAndGet()}" + return thread + } +} diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/CordaNettyConnectorFactory.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/CordaNettyConnectorFactory.kt new file mode 100644 index 0000000000..a9bdc519a9 --- /dev/null +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/CordaNettyConnectorFactory.kt @@ -0,0 +1,73 @@ +package net.corda.nodeapi.internal + +import io.netty.channel.ChannelPipeline +import io.netty.handler.logging.LogLevel +import io.netty.handler.logging.LoggingHandler +import org.apache.activemq.artemis.core.protocol.core.impl.ActiveMQClientProtocolManager +import org.apache.activemq.artemis.core.remoting.impl.netty.NettyConnector +import org.apache.activemq.artemis.spi.core.remoting.BufferHandler +import org.apache.activemq.artemis.spi.core.remoting.ClientConnectionLifeCycleListener +import org.apache.activemq.artemis.spi.core.remoting.ClientProtocolManager +import org.apache.activemq.artemis.spi.core.remoting.Connector +import org.apache.activemq.artemis.spi.core.remoting.ConnectorFactory +import org.apache.activemq.artemis.utils.ConfigurationHelper +import java.util.concurrent.Executor +import java.util.concurrent.ScheduledExecutorService + +class CordaNettyConnectorFactory : ConnectorFactory { + override fun createConnector(configuration: MutableMap?, + handler: BufferHandler?, + listener: ClientConnectionLifeCycleListener?, + closeExecutor: Executor, + threadPool: Executor, + scheduledThreadPool: ScheduledExecutorService, + protocolManager: ClientProtocolManager?): Connector { + val threadPoolName = ConfigurationHelper.getStringProperty(ArtemisTcpTransport.THREAD_POOL_NAME_NAME, "Connector", configuration) + setThreadPoolName(threadPool, closeExecutor, scheduledThreadPool, threadPoolName) + val trace = ConfigurationHelper.getBooleanProperty(ArtemisTcpTransport.TRACE_NAME, false, configuration) + return NettyConnector( + configuration, + handler, + listener, + closeExecutor, + threadPool, + scheduledThreadPool, + MyClientProtocolManager("$threadPoolName-netty", trace) + ) + } + + override fun isReliable(): Boolean = false + + override fun getDefaults(): Map = NettyConnector.DEFAULT_CONFIG + + private fun setThreadPoolName(threadPool: Executor, closeExecutor: Executor, scheduledThreadPool: ScheduledExecutorService, name: String) { + threadPool.setThreadPoolName("$name-artemis") + // Artemis will actually wrap the same backing Executor to create multiple "OrderedExecutors". In this scenerio both the threadPool + // and the closeExecutor are the same when it comes to the pool names. If however they are different then given them separate names. + if (threadPool.rootExecutor !== closeExecutor.rootExecutor) { + closeExecutor.setThreadPoolName("$name-artemis-closer") + } + // The scheduler is separate + scheduledThreadPool.setThreadPoolName("$name-artemis-scheduler") + } + + + private class MyClientProtocolManager(private val threadPoolName: String, private val trace: Boolean) : ActiveMQClientProtocolManager() { + override fun addChannelHandlers(pipeline: ChannelPipeline) { + applyThreadPoolName() + super.addChannelHandlers(pipeline) + if (trace) { + pipeline.addLast("logger", LoggingHandler(LogLevel.INFO)) + } + } + + /** + * [NettyConnector.start] does not provide a way to configure the thread pool name, so we modify the thread name accordingly. + */ + private fun applyThreadPoolName() { + with(Thread.currentThread()) { + name = name.replace("nioEventLoopGroup", threadPoolName) // pool and thread numbers are preserved + } + } + } +} diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/NodeApiUtils.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/NodeApiUtils.kt new file mode 100644 index 0000000000..65d60ab38d --- /dev/null +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/NodeApiUtils.kt @@ -0,0 +1,32 @@ +@file:Suppress("LongParameterList", "MagicNumber") + +package net.corda.nodeapi.internal + +import io.netty.util.concurrent.DefaultThreadFactory +import net.corda.core.utilities.seconds +import java.time.Duration +import java.util.concurrent.BlockingQueue +import java.util.concurrent.LinkedBlockingQueue +import java.util.concurrent.ThreadPoolExecutor +import java.util.concurrent.TimeUnit + +/** + * Creates a [ThreadPoolExecutor] which will use a maximum of [maxPoolSize] threads at any given time and will by default idle down to 0 + * threads. + */ +fun namedThreadPoolExecutor(maxPoolSize: Int, + corePoolSize: Int = 0, + idleKeepAlive: Duration = 30.seconds, + workQueue: BlockingQueue = LinkedBlockingQueue(), + poolName: String = "pool", + daemonThreads: Boolean = false, + threadPriority: Int = Thread.NORM_PRIORITY): ThreadPoolExecutor { + return ThreadPoolExecutor( + corePoolSize, + maxPoolSize, + idleKeepAlive.toNanos(), + TimeUnit.NANOSECONDS, + workQueue, + DefaultThreadFactory(poolName, daemonThreads, threadPriority) + ) +} diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/bridging/AMQPBridgeManager.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/bridging/AMQPBridgeManager.kt index 40523033f2..93ab5616de 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/bridging/AMQPBridgeManager.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/bridging/AMQPBridgeManager.kt @@ -5,22 +5,24 @@ import com.google.common.util.concurrent.ThreadFactoryBuilder import io.netty.channel.EventLoop import io.netty.channel.EventLoopGroup import io.netty.channel.nio.NioEventLoopGroup +import io.netty.util.concurrent.DefaultThreadFactory import net.corda.core.identity.CordaX500Name import net.corda.core.internal.VisibleForTesting import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.contextLogger +import net.corda.nodeapi.internal.ArtemisConstants.MESSAGE_ID_KEY import net.corda.nodeapi.internal.ArtemisMessagingClient import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.NODE_P2P_USER import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.P2PMessagingHeaders import net.corda.nodeapi.internal.ArtemisMessagingComponent.RemoteInboxAddress.Companion.translateLocalQueueToInboxAddress import net.corda.nodeapi.internal.ArtemisSessionProvider -import net.corda.nodeapi.internal.ArtemisConstants.MESSAGE_ID_KEY import net.corda.nodeapi.internal.config.CertificateStore import net.corda.nodeapi.internal.protonwrapper.messages.MessageStatus import net.corda.nodeapi.internal.protonwrapper.netty.AMQPClient import net.corda.nodeapi.internal.protonwrapper.netty.AMQPConfiguration import net.corda.nodeapi.internal.protonwrapper.netty.ProxyConfig import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfig +import net.corda.nodeapi.internal.protonwrapper.netty.sslDelegatedTaskExecutor import org.apache.activemq.artemis.api.core.ActiveMQObjectClosedException import org.apache.activemq.artemis.api.core.SimpleString import org.apache.activemq.artemis.api.core.client.ActiveMQClient.DEFAULT_ACK_BATCH_SIZE @@ -29,6 +31,8 @@ import org.apache.activemq.artemis.api.core.client.ClientMessage import org.apache.activemq.artemis.api.core.client.ClientSession import org.slf4j.MDC import rx.Subscription +import java.time.Duration +import java.util.concurrent.ExecutorService import java.util.concurrent.Executors import java.util.concurrent.ScheduledExecutorService import java.util.concurrent.ScheduledFuture @@ -51,10 +55,10 @@ open class AMQPBridgeManager(keyStore: CertificateStore, maxMessageSize: Int, revocationConfig: RevocationConfig, enableSNI: Boolean, - private val artemisMessageClientFactory: () -> ArtemisSessionProvider, + private val artemisMessageClientFactory: (String) -> ArtemisSessionProvider, private val bridgeMetricsService: BridgeMetricsService? = null, trace: Boolean, - sslHandshakeTimeout: Long?, + sslHandshakeTimeout: Duration?, private val bridgeConnectionTTLSeconds: Int) : BridgeManager { private val lock = ReentrantLock() @@ -69,16 +73,18 @@ open class AMQPBridgeManager(keyStore: CertificateStore, override val enableSNI: Boolean, override val sourceX500Name: String? = null, override val trace: Boolean, - private val _sslHandshakeTimeout: Long?) : AMQPConfiguration { - override val sslHandshakeTimeout: Long + private val _sslHandshakeTimeout: Duration?) : AMQPConfiguration { + override val sslHandshakeTimeout: Duration get() = _sslHandshakeTimeout ?: super.sslHandshakeTimeout } private val amqpConfig: AMQPConfiguration = AMQPConfigurationImpl(keyStore, trustStore, proxyConfig, maxMessageSize, revocationConfig,useOpenSSL, enableSNI, trace = trace, _sslHandshakeTimeout = sslHandshakeTimeout) private var sharedEventLoopGroup: EventLoopGroup? = null + private var sslDelegatedTaskExecutor: ExecutorService? = null private var artemis: ArtemisSessionProvider? = null companion object { + private val log = contextLogger() private const val CORDA_NUM_BRIDGE_THREADS_PROP_NAME = "net.corda.nodeapi.amqpbridgemanager.NumBridgeThreads" @@ -95,18 +101,11 @@ open class AMQPBridgeManager(keyStore: CertificateStore, * however Artemis and the remote Corda instanced will deduplicate these messages. */ @Suppress("TooManyFunctions") - private class AMQPBridge(val sourceX500Name: String, - val queueName: String, - val targets: List, - val legalNames: Set, - private val amqpConfig: AMQPConfiguration, - sharedEventGroup: EventLoopGroup, - private val artemis: ArtemisSessionProvider, - private val bridgeMetricsService: BridgeMetricsService?, - private val bridgeConnectionTTLSeconds: Int) { - companion object { - private val log = contextLogger() - } + private inner class AMQPBridge(val sourceX500Name: String, + val queueName: String, + val targets: List, + val allowedRemoteLegalNames: Set, + private val amqpConfig: AMQPConfiguration) { private fun withMDC(block: () -> Unit) { val oldMDC = MDC.getCopyOfContextMap() ?: emptyMap() @@ -114,7 +113,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore, MDC.put("queueName", queueName) MDC.put("source", amqpConfig.sourceX500Name) MDC.put("targets", targets.joinToString(separator = ";") { it.toString() }) - MDC.put("legalNames", legalNames.joinToString(separator = ";") { it.toString() }) + MDC.put("allowedRemoteLegalNames", allowedRemoteLegalNames.joinToString(separator = ";") { it.toString() }) MDC.put("maxMessageSize", amqpConfig.maxMessageSize.toString()) block() } finally { @@ -132,13 +131,18 @@ open class AMQPBridgeManager(keyStore: CertificateStore, private fun logWarnWithMDC(msg: String) = withMDC { log.warn(msg) } - val amqpClient = AMQPClient(targets, legalNames, amqpConfig, sharedThreadPool = sharedEventGroup) + val amqpClient = AMQPClient( + targets, + allowedRemoteLegalNames, + amqpConfig, + AMQPClient.NettyThreading.Shared(sharedEventLoopGroup!!, sslDelegatedTaskExecutor!!) + ) private var session: ClientSession? = null private var consumer: ClientConsumer? = null private var connectedSubscription: Subscription? = null @Volatile private var messagesReceived: Boolean = false - private val eventLoop: EventLoop = sharedEventGroup.next() + private val eventLoop: EventLoop = sharedEventLoopGroup!!.next() private var artemisState: ArtemisState = ArtemisState.STOPPED set(value) { logDebugWithMDC { "State change $field to $value" } @@ -150,32 +154,9 @@ open class AMQPBridgeManager(keyStore: CertificateStore, private var scheduledExecutorService: ScheduledExecutorService = Executors.newSingleThreadScheduledExecutor(ThreadFactoryBuilder().setNameFormat("bridge-connection-reset-%d").build()) - @Suppress("ClassNaming") - private sealed class ArtemisState { - object STARTING : ArtemisState() - data class STARTED(override val pending: ScheduledFuture) : ArtemisState() - - object CHECKING : ArtemisState() - object RESTARTED : ArtemisState() - object RECEIVING : ArtemisState() - - object AMQP_STOPPED : ArtemisState() - object AMQP_STARTING : ArtemisState() - object AMQP_STARTED : ArtemisState() - object AMQP_RESTARTED : ArtemisState() - - object STOPPING : ArtemisState() - object STOPPED : ArtemisState() - data class STOPPED_AMQP_START_SCHEDULED(override val pending: ScheduledFuture) : ArtemisState() - - open val pending: ScheduledFuture? = null - - override fun toString(): String = javaClass.simpleName - } - private fun artemis(inProgress: ArtemisState, block: (precedingState: ArtemisState) -> ArtemisState) { val runnable = { - synchronized(artemis) { + synchronized(artemis!!) { try { val precedingState = artemisState artemisState.pending?.cancel(false) @@ -229,7 +210,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore, } ArtemisState.STOPPING } - bridgeMetricsService?.bridgeDisconnected(targets, legalNames) + bridgeMetricsService?.bridgeDisconnected(targets, allowedRemoteLegalNames) connectedSubscription?.unsubscribe() connectedSubscription = null // Do this last because we already scheduled the Artemis stop, so it's okay to unsubscribe onConnected first. @@ -241,7 +222,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore, if (connected) { logInfoWithMDC("Bridge Connected") - bridgeMetricsService?.bridgeConnected(targets, legalNames) + bridgeMetricsService?.bridgeConnected(targets, allowedRemoteLegalNames) if (bridgeConnectionTTLSeconds > 0) { // AMQP outbound connection will be restarted periodically with bridgeConnectionTTLSeconds interval amqpRestartEvent = scheduledArtemisInExecutor(bridgeConnectionTTLSeconds.toLong(), TimeUnit.SECONDS, @@ -251,7 +232,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore, } } artemis(ArtemisState.STARTING) { - val startedArtemis = artemis.started + val startedArtemis = artemis!!.started if (startedArtemis == null) { logInfoWithMDC("Bridge Connected but Artemis is disconnected") ArtemisState.STOPPED @@ -284,7 +265,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore, logInfoWithMDC("Bridge Disconnected") amqpRestartEvent?.cancel(false) if (artemisState != ArtemisState.AMQP_STARTING && artemisState != ArtemisState.STOPPED) { - bridgeMetricsService?.bridgeDisconnected(targets, legalNames) + bridgeMetricsService?.bridgeDisconnected(targets, allowedRemoteLegalNames) } artemis(ArtemisState.STOPPING) { precedingState: ArtemisState -> logInfoWithMDC("Stopping Artemis because AMQP bridge disconnected") @@ -416,10 +397,10 @@ open class AMQPBridgeManager(keyStore: CertificateStore, properties[key] = value } } - logDebugWithMDC { "Bridged Send to ${legalNames.first()} uuid: ${artemisMessage.getObjectProperty(MESSAGE_ID_KEY)}" } + logDebugWithMDC { "Bridged Send to ${allowedRemoteLegalNames.first()} uuid: ${artemisMessage.getObjectProperty(MESSAGE_ID_KEY)}" } val peerInbox = translateLocalQueueToInboxAddress(queueName) val sendableMessage = amqpClient.createMessage(artemisMessage.payload(), peerInbox, - legalNames.first().toString(), + allowedRemoteLegalNames.first().toString(), properties) sendableMessage.onComplete.then { logDebugWithMDC { "Bridge ACK ${sendableMessage.onComplete.get()}" } @@ -455,6 +436,29 @@ open class AMQPBridgeManager(keyStore: CertificateStore, } } + @Suppress("ClassNaming") + private sealed class ArtemisState { + object STARTING : ArtemisState() + data class STARTED(override val pending: ScheduledFuture) : ArtemisState() + + object CHECKING : ArtemisState() + object RESTARTED : ArtemisState() + object RECEIVING : ArtemisState() + + object AMQP_STOPPED : ArtemisState() + object AMQP_STARTING : ArtemisState() + object AMQP_STARTED : ArtemisState() + object AMQP_RESTARTED : ArtemisState() + + object STOPPING : ArtemisState() + object STOPPED : ArtemisState() + data class STOPPED_AMQP_START_SCHEDULED(override val pending: ScheduledFuture) : ArtemisState() + + open val pending: ScheduledFuture? = null + + override fun toString(): String = javaClass.simpleName + } + override fun deployBridge(sourceX500Name: String, queueName: String, targets: List, legalNames: Set) { lock.withLock { val bridges = queueNamesToBridgesMap.getOrPut(queueName) { mutableListOf() } @@ -465,8 +469,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore, } val newAMQPConfig = with(amqpConfig) { AMQPConfigurationImpl(keyStore, trustStore, proxyConfig, maxMessageSize, revocationConfig, useOpenSsl, enableSNI, sourceX500Name, trace, sslHandshakeTimeout) } - val newBridge = AMQPBridge(sourceX500Name, queueName, targets, legalNames, newAMQPConfig, sharedEventLoopGroup!!, artemis!!, - bridgeMetricsService, bridgeConnectionTTLSeconds) + val newBridge = AMQPBridge(sourceX500Name, queueName, targets, legalNames, newAMQPConfig) bridges += newBridge bridgeMetricsService?.bridgeCreated(targets, legalNames) newBridge @@ -484,7 +487,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore, queueNamesToBridgesMap.remove(queueName) } bridge.stop() - bridgeMetricsService?.bridgeDestroyed(bridge.targets, bridge.legalNames) + bridgeMetricsService?.bridgeDestroyed(bridge.targets, bridge.allowedRemoteLegalNames) } } } @@ -495,15 +498,16 @@ open class AMQPBridgeManager(keyStore: CertificateStore, // queueNamesToBridgesMap returns a mutable list, .toList converts it to a immutable list so it won't be changed by the [destroyBridge] method. val bridges = queueNamesToBridgesMap[queueName]?.toList() destroyBridge(queueName, bridges?.flatMap { it.targets } ?: emptyList()) - bridges?.map { - it.sourceX500Name to BridgeEntry(it.queueName, it.targets, it.legalNames.toList(), serviceAddress = false) - }?.toMap() ?: emptyMap() + bridges?.associate { + it.sourceX500Name to BridgeEntry(it.queueName, it.targets, it.allowedRemoteLegalNames.toList(), serviceAddress = false) + } ?: emptyMap() } } override fun start() { - sharedEventLoopGroup = NioEventLoopGroup(NUM_BRIDGE_THREADS) - val artemis = artemisMessageClientFactory() + sharedEventLoopGroup = NioEventLoopGroup(NUM_BRIDGE_THREADS, DefaultThreadFactory("NettyBridge", Thread.MAX_PRIORITY)) + sslDelegatedTaskExecutor = sslDelegatedTaskExecutor("NettyBridge") + val artemis = artemisMessageClientFactory("ArtemisBridge") this.artemis = artemis artemis.start() } @@ -520,6 +524,8 @@ open class AMQPBridgeManager(keyStore: CertificateStore, sharedEventLoopGroup = null queueNamesToBridgesMap.clear() artemis?.stop() + sslDelegatedTaskExecutor?.shutdown() + sslDelegatedTaskExecutor = null } } } \ No newline at end of file diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/bridging/BridgeControlListener.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/bridging/BridgeControlListener.kt index 2a37649667..708588cb63 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/bridging/BridgeControlListener.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/bridging/BridgeControlListener.kt @@ -5,16 +5,13 @@ import net.corda.core.identity.CordaX500Name import net.corda.core.serialization.SerializationDefaults import net.corda.core.serialization.deserialize import net.corda.core.serialization.serialize -import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.contextLogger -import net.corda.nodeapi.internal.ArtemisMessagingClient import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.BRIDGE_CONTROL import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.BRIDGE_NOTIFY import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.P2P_PREFIX import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.PEERS_PREFIX import net.corda.nodeapi.internal.ArtemisSessionProvider import net.corda.nodeapi.internal.config.CertificateStore -import net.corda.nodeapi.internal.config.MutualSslConfiguration import net.corda.nodeapi.internal.crypto.x509 import net.corda.nodeapi.internal.protonwrapper.netty.ProxyConfig import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfig @@ -27,6 +24,7 @@ import org.apache.activemq.artemis.api.core.client.ClientMessage import org.apache.activemq.artemis.api.core.client.ClientSession import rx.Observable import rx.subjects.PublishSubject +import java.time.Duration import java.util.* class BridgeControlListener(private val keyStore: CertificateStore, @@ -36,10 +34,10 @@ class BridgeControlListener(private val keyStore: CertificateStore, maxMessageSize: Int, revocationConfig: RevocationConfig, enableSNI: Boolean, - private val artemisMessageClientFactory: () -> ArtemisSessionProvider, + private val artemisMessageClientFactory: (String) -> ArtemisSessionProvider, bridgeMetricsService: BridgeMetricsService? = null, trace: Boolean = false, - sslHandshakeTimeout: Long? = null, + sslHandshakeTimeout: Duration? = null, bridgeConnectionTTLSeconds: Int = 0) : AutoCloseable { private val bridgeId: String = UUID.randomUUID().toString() private var bridgeControlQueue = "$BRIDGE_CONTROL.$bridgeId" @@ -57,13 +55,6 @@ class BridgeControlListener(private val keyStore: CertificateStore, private var controlConsumer: ClientConsumer? = null private var notifyConsumer: ClientConsumer? = null - constructor(config: MutualSslConfiguration, - p2pAddress: NetworkHostAndPort, - maxMessageSize: Int, - revocationConfig: RevocationConfig, - enableSNI: Boolean, - proxy: ProxyConfig? = null) : this(config.keyStore.get(), config.trustStore.get(), config.useOpenSsl, proxy, maxMessageSize, revocationConfig, enableSNI, { ArtemisMessagingClient(config, p2pAddress, maxMessageSize) }) - companion object { private val log = contextLogger() } @@ -88,7 +79,7 @@ class BridgeControlListener(private val keyStore: CertificateStore, bridgeNotifyQueue = "$BRIDGE_NOTIFY.$queueDisambiguityId" bridgeManager.start() - val artemis = artemisMessageClientFactory() + val artemis = artemisMessageClientFactory("BridgeControl") this.artemis = artemis artemis.start() val artemisClient = artemis.started!! diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/bridging/LoopbackBridgeManager.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/bridging/LoopbackBridgeManager.kt index fc27029584..2dd9f8bff0 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/bridging/LoopbackBridgeManager.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/bridging/LoopbackBridgeManager.kt @@ -23,6 +23,7 @@ import org.apache.activemq.artemis.api.core.client.ClientMessage import org.apache.activemq.artemis.api.core.client.ClientProducer import org.apache.activemq.artemis.api.core.client.ClientSession import org.slf4j.MDC +import java.time.Duration /** * The LoopbackBridgeManager holds the list of independent LoopbackBridge objects that actively loopback messages to local Artemis @@ -36,11 +37,11 @@ class LoopbackBridgeManager(keyStore: CertificateStore, maxMessageSize: Int, revocationConfig: RevocationConfig, enableSNI: Boolean, - private val artemisMessageClientFactory: () -> ArtemisSessionProvider, + private val artemisMessageClientFactory: (String) -> ArtemisSessionProvider, private val bridgeMetricsService: BridgeMetricsService? = null, private val isLocalInbox: (String) -> Boolean, trace: Boolean, - sslHandshakeTimeout: Long? = null, + sslHandshakeTimeout: Duration? = null, bridgeConnectionTTLSeconds: Int = 0) : AMQPBridgeManager(keyStore, trustStore, useOpenSSL, proxyConfig, maxMessageSize, revocationConfig, enableSNI, artemisMessageClientFactory, bridgeMetricsService, @@ -203,7 +204,7 @@ class LoopbackBridgeManager(keyStore: CertificateStore, override fun start() { super.start() - val artemis = artemisMessageClientFactory() + val artemis = artemisMessageClientFactory("LoopbackBridge") this.artemis = artemis artemis.start() } diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/config/SslConfiguration.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/config/SslConfiguration.kt index e4433b4e00..fdb8e9aea0 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/config/SslConfiguration.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/config/SslConfiguration.kt @@ -1,16 +1,20 @@ package net.corda.nodeapi.internal.config +import net.corda.core.utilities.seconds +import java.time.Duration + interface SslConfiguration { val keyStore: FileBasedCertificateStoreSupplier? val trustStore: FileBasedCertificateStoreSupplier? val useOpenSsl: Boolean + val handshakeTimeout: Duration? companion object { - - fun mutual(keyStore: FileBasedCertificateStoreSupplier, trustStore: FileBasedCertificateStoreSupplier): MutualSslConfiguration { - - return MutualSslOptions(keyStore, trustStore) + fun mutual(keyStore: FileBasedCertificateStoreSupplier, + trustStore: FileBasedCertificateStoreSupplier, + handshakeTimeout: Duration? = null): MutualSslConfiguration { + return MutualSslOptions(keyStore, trustStore, handshakeTimeout) } } } @@ -21,9 +25,10 @@ interface MutualSslConfiguration : SslConfiguration { } private class MutualSslOptions(override val keyStore: FileBasedCertificateStoreSupplier, - override val trustStore: FileBasedCertificateStoreSupplier) : MutualSslConfiguration { + override val trustStore: FileBasedCertificateStoreSupplier, + override val handshakeTimeout: Duration?) : MutualSslConfiguration { override val useOpenSsl: Boolean = false } -const val DEFAULT_SSL_HANDSHAKE_TIMEOUT_MILLIS = 60000L // Set at least 3 times higher than sun.security.provider.certpath.URICertStore.DEFAULT_CRL_CONNECT_TIMEOUT which is 15 sec - +@Suppress("MagicNumber") +val DEFAULT_SSL_HANDSHAKE_TIMEOUT: Duration = 60.seconds // Set at least 3 times higher than sun.security.provider.certpath.URICertStore.DEFAULT_CRL_CONNECT_TIMEOUT which is 15 sec diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/crypto/X509Utilities.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/crypto/X509Utilities.kt index aff65d7987..79ae834a16 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/crypto/X509Utilities.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/crypto/X509Utilities.kt @@ -1,16 +1,41 @@ +@file:Suppress("MagicNumber", "TooGenericExceptionCaught") + package net.corda.nodeapi.internal.crypto import net.corda.core.CordaOID import net.corda.core.crypto.Crypto import net.corda.core.crypto.newSecureRandom -import net.corda.core.internal.* +import net.corda.core.internal.CertRole +import net.corda.core.internal.SignedDataWithCert +import net.corda.core.internal.reader +import net.corda.core.internal.signWithCert +import net.corda.core.internal.uncheckedCast +import net.corda.core.internal.validate +import net.corda.core.internal.writer import net.corda.core.utilities.days import net.corda.core.utilities.millis -import org.bouncycastle.asn1.* +import net.corda.core.utilities.toHex +import net.corda.nodeapi.internal.protonwrapper.netty.distributionPointsToString +import org.bouncycastle.asn1.ASN1EncodableVector +import org.bouncycastle.asn1.ASN1ObjectIdentifier +import org.bouncycastle.asn1.ASN1Sequence +import org.bouncycastle.asn1.DERSequence +import org.bouncycastle.asn1.DERUTF8String import org.bouncycastle.asn1.x500.X500Name import org.bouncycastle.asn1.x500.style.BCStyle -import org.bouncycastle.asn1.x509.* +import org.bouncycastle.asn1.x509.AuthorityKeyIdentifier +import org.bouncycastle.asn1.x509.BasicConstraints +import org.bouncycastle.asn1.x509.CRLDistPoint +import org.bouncycastle.asn1.x509.DistributionPoint +import org.bouncycastle.asn1.x509.DistributionPointName import org.bouncycastle.asn1.x509.Extension +import org.bouncycastle.asn1.x509.GeneralName +import org.bouncycastle.asn1.x509.GeneralNames +import org.bouncycastle.asn1.x509.KeyPurposeId +import org.bouncycastle.asn1.x509.KeyUsage +import org.bouncycastle.asn1.x509.NameConstraints +import org.bouncycastle.asn1.x509.SubjectKeyIdentifier +import org.bouncycastle.asn1.x509.SubjectPublicKeyInfo import org.bouncycastle.cert.X509CertificateHolder import org.bouncycastle.cert.X509v3CertificateBuilder import org.bouncycastle.cert.bc.BcX509ExtensionUtils @@ -28,8 +53,13 @@ import java.nio.file.Path import java.security.KeyPair import java.security.PublicKey import java.security.SignatureException -import java.security.cert.* +import java.security.cert.CertPath import java.security.cert.Certificate +import java.security.cert.CertificateException +import java.security.cert.CertificateFactory +import java.security.cert.TrustAnchor +import java.security.cert.X509CRL +import java.security.cert.X509Certificate import java.time.Duration import java.time.Instant import java.time.temporal.ChronoUnit @@ -355,7 +385,7 @@ object X509Utilities { private fun addCrlInfo(builder: X509v3CertificateBuilder, crlDistPoint: String?, crlIssuer: X500Name?) { if (crlDistPoint != null) { - val distPointName = DistributionPointName(GeneralNames(GeneralName(GeneralName.uniformResourceIdentifier, crlDistPoint))) + val distPointName = DistributionPointName(toGeneralNames(crlDistPoint, GeneralName.uniformResourceIdentifier)) val crlIssuerGeneralNames = crlIssuer?.let { GeneralNames(GeneralName(crlIssuer)) } @@ -368,7 +398,6 @@ object X509Utilities { } } - @Suppress("MagicNumber") private fun generateCertificateSerialNumber(): BigInteger { val bytes = ByteArray(CERTIFICATE_SERIAL_NUMBER_LENGTH) newSecureRandom().nextBytes(bytes) @@ -376,6 +405,8 @@ object X509Utilities { bytes[0] = bytes[0].and(0x3F).or(0x40) return BigInteger(bytes) } + + fun toGeneralNames(string: String, tag: Int = GeneralName.directoryName): GeneralNames = GeneralNames(GeneralName(tag, string)) } // Assuming cert type to role is 1:1 @@ -408,6 +439,29 @@ fun PKCS10CertificationRequest.isSignatureValid(): Boolean { return this.isSignatureValid(JcaContentVerifierProviderBuilder().build(this.subjectPublicKeyInfo)) } +fun X509Certificate.toSimpleString(): String { + val bcCert = toBc() + val keyIdentifier = try { + SubjectKeyIdentifier.getInstance(bcCert.getExtension(Extension.subjectKeyIdentifier).parsedValue).keyIdentifier.toHex() + } catch (e: Exception) { + "null" + } + val authorityKeyIdentifier = try { + AuthorityKeyIdentifier.getInstance(bcCert.getExtension(Extension.authorityKeyIdentifier).parsedValue).keyIdentifier.toHex() + } catch (e: Exception) { + "null" + } + val subject = bcCert.subject + val issuer = bcCert.issuer + val role = CertRole.extract(this) + return "$subject[$keyIdentifier] issued by $issuer[$authorityKeyIdentifier] $role $serialNumber [${distributionPointsToString()}]" +} + +fun X509CRL.toSimpleString(): String { + val revokedSerialNumbers = revokedCertificates?.map { it.serialNumber } + return "$issuerX500Principal ${thisUpdate.toInstant()} ${nextUpdate.toInstant()} ${revokedSerialNumbers ?: "[]"}" +} + /** * Check certificate validity or print warning if expiry is within 30 days */ @@ -438,6 +492,8 @@ class X509CertificateFactory { fun generateCertPath(vararg certificates: X509Certificate): CertPath = generateCertPath(certificates.asList()) fun generateCertPath(certificates: List): CertPath = delegate.generateCertPath(certificates) + + fun generateCRL(input: InputStream): X509CRL = delegate.generateCRL(input) as X509CRL } enum class CertificateType(val keyUsage: KeyUsage, vararg val purposes: KeyPurposeId, val isCA: Boolean, val role: CertRole?) { diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPChannelHandler.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPChannelHandler.kt index 5ce3db919c..7bb8e9ad39 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPChannelHandler.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPChannelHandler.kt @@ -115,11 +115,10 @@ internal class AMQPChannelHandler(private val serverMode: Boolean, val transport = connection.transport as ProtonJTransport transport.protocolTracer = object : ProtocolTracer { override fun sentFrame(transportFrame: TransportFrame) { - logInfoWithMDC { "${transportFrame.body}" } + logInfoWithMDC { "sentFrame: ${transportFrame.body}" } } - override fun receivedFrame(transportFrame: TransportFrame) { - logInfoWithMDC { "${transportFrame.body}" } + logInfoWithMDC { "receivedFrame: ${transportFrame.body}" } } } } @@ -186,7 +185,7 @@ internal class AMQPChannelHandler(private val serverMode: Boolean, } } - @Suppress("OverridingDeprecatedMember") + @Deprecated("Deprecated in Java") override fun exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable) { logWarnWithMDC("Closing channel due to nonrecoverable exception ${cause.message}") if (log.isTraceEnabled) { @@ -298,16 +297,15 @@ internal class AMQPChannelHandler(private val serverMode: Boolean, cause is ClosedChannelException -> logWarnWithMDC("SSL Handshake closed early.") cause is SslHandshakeTimeoutException -> logWarnWithMDC("SSL Handshake timed out") // Sadly the exception thrown by Netty wrapper requires that we check the message. - cause is SSLException && (cause.message?.contains("close_notify") == true) - -> logWarnWithMDC("Received close_notify during handshake") + cause is SSLException && (cause.message?.contains("close_notify") == true) -> logWarnWithMDC("Received close_notify during handshake") // io.netty.handler.ssl.SslHandler.setHandshakeFailureTransportFailure() cause is SSLException && (cause.message?.contains("writing TLS control frames") == true) -> logWarnWithMDC(cause.message!!) - else -> badCert = true } - logWarnWithMDC("Handshake failure: ${evt.cause().message}") if (log.isTraceEnabled) { - withMDC { log.trace("Handshake failure", evt.cause()) } + withMDC { log.trace("Handshake failure", cause) } + } else { + logWarnWithMDC("Handshake failure: ${cause.message}") } ctx.close() } diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPClient.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPClient.kt index 4551608054..c502817029 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPClient.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPClient.kt @@ -1,7 +1,11 @@ package net.corda.nodeapi.internal.protonwrapper.netty import io.netty.bootstrap.Bootstrap -import io.netty.channel.* +import io.netty.channel.Channel +import io.netty.channel.ChannelFutureListener +import io.netty.channel.ChannelHandler +import io.netty.channel.ChannelInitializer +import io.netty.channel.EventLoopGroup import io.netty.channel.nio.NioEventLoopGroup import io.netty.channel.socket.SocketChannel import io.netty.channel.socket.nio.NioSocketChannel @@ -11,6 +15,7 @@ import io.netty.handler.proxy.HttpProxyHandler import io.netty.handler.proxy.Socks4ProxyHandler import io.netty.handler.proxy.Socks5ProxyHandler import io.netty.resolver.NoopAddressResolverGroup +import io.netty.util.concurrent.DefaultThreadFactory import io.netty.util.internal.logging.InternalLoggerFactory import io.netty.util.internal.logging.Slf4JLoggerFactory import net.corda.core.identity.CordaX500Name @@ -22,14 +27,16 @@ import net.corda.nodeapi.internal.protonwrapper.messages.SendableMessage import net.corda.nodeapi.internal.protonwrapper.messages.impl.SendableMessageImpl import net.corda.nodeapi.internal.protonwrapper.netty.AMQPChannelHandler.Companion.PROXY_LOGGER_NAME import net.corda.nodeapi.internal.requireMessageSize +import net.corda.nodeapi.internal.revocation.CertDistPointCrlSource import rx.Observable import rx.subjects.PublishSubject import java.lang.Long.min import java.net.InetSocketAddress +import java.util.concurrent.Executor +import java.util.concurrent.ExecutorService +import java.util.concurrent.ThreadPoolExecutor import java.util.concurrent.TimeUnit import java.util.concurrent.locks.ReentrantLock -import javax.net.ssl.KeyManagerFactory -import javax.net.ssl.TrustManagerFactory import kotlin.concurrent.withLock enum class ProxyVersion { @@ -53,10 +60,11 @@ data class ProxyConfig(val version: ProxyVersion, val proxyAddress: NetworkHostA * otherwise it creates a self-contained Netty thraed pool and socket objects. * Once connected it can accept application packets to send via the AMQP protocol. */ -class AMQPClient(val targets: List, +class AMQPClient(private val targets: List, val allowedRemoteLegalNames: Set, private val configuration: AMQPConfiguration, - private val sharedThreadPool: EventLoopGroup? = null) : AutoCloseable { + private val nettyThreading: NettyThreading = NettyThreading.NonShared("AMQPClient"), + private val distPointCrlSource: CertDistPointCrlSource = CertDistPointCrlSource.SINGLETON) : AutoCloseable { companion object { init { InternalLoggerFactory.setDefaultFactory(Slf4JLoggerFactory.INSTANCE) @@ -75,7 +83,6 @@ class AMQPClient(val targets: List, private val lock = ReentrantLock() @Volatile private var started: Boolean = false - private var workerGroup: EventLoopGroup? = null @Volatile private var clientChannel: Channel? = null // Offset into the list of targets, so that we can implement round-robin reconnect logic. @@ -109,24 +116,22 @@ class AMQPClient(val targets: List, retryInterval = min(MAX_RETRY_INTERVAL, retryInterval * BACKOFF_MULTIPLIER) } - private val connectListener = object : ChannelFutureListener { - override fun operationComplete(future: ChannelFuture) { - amqpActive = false - if (!future.isSuccess) { - log.info("Failed to connect to $currentTarget", future.cause()) + private val connectListener = ChannelFutureListener { future -> + amqpActive = false + if (!future.isSuccess) { + log.info("Failed to connect to $currentTarget", future.cause()) - if (started) { - workerGroup?.schedule({ - nextTarget() - restart() - }, retryInterval, TimeUnit.MILLISECONDS) - } - } else { - // Connection established successfully - clientChannel = future.channel() - clientChannel?.closeFuture()?.addListener(closeListener) - log.info("Connected to $currentTarget, Local address: $localAddressString") + if (started) { + nettyThreading.eventLoopGroup.schedule({ + nextTarget() + restart() + }, retryInterval, TimeUnit.MILLISECONDS) } + } else { + // Connection established successfully + clientChannel = future.channel() + clientChannel?.closeFuture()?.addListener(closeListener) + log.info("Connected to $currentTarget, Local address: $localAddressString") } } @@ -136,7 +141,7 @@ class AMQPClient(val targets: List, clientChannel = null if (started && !amqpActive) { log.debug { "Scheduling restart of $currentTarget (AMQP inactive)" } - workerGroup?.schedule({ + nettyThreading.eventLoopGroup.schedule({ nextTarget() restart() }, retryInterval, TimeUnit.MILLISECONDS) @@ -144,17 +149,16 @@ class AMQPClient(val targets: List, } private class ClientChannelInitializer(val parent: AMQPClient) : ChannelInitializer() { - private val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()) - private val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()) + private val keyManagerFactory = keyManagerFactory(parent.configuration.keyStore) + private val trustManagerFactory = trustManagerFactoryWithRevocation( + parent.configuration.trustStore, + parent.configuration.revocationConfig, + parent.distPointCrlSource + ) private val conf = parent.configuration @Volatile private lateinit var amqpChannelHandler: AMQPChannelHandler - init { - keyManagerFactory.init(conf.keyStore) - trustManagerFactory.init(initialiseTrustStoreAndEnableCrlChecking(conf.trustStore, conf.revocationConfig)) - } - @Suppress("ComplexMethod") override fun initChannel(ch: SocketChannel) { val pipeline = ch.pipeline() @@ -194,14 +198,28 @@ class AMQPClient(val targets: List, val wrappedKeyManagerFactory = CertHoldingKeyManagerFactoryWrapper(keyManagerFactory, parent.configuration) val target = parent.currentTarget val handler = if (parent.configuration.useOpenSsl) { - createClientOpenSslHandler(target, parent.allowedRemoteLegalNames, wrappedKeyManagerFactory, trustManagerFactory, ch.alloc()) + createClientOpenSslHandler( + target, + parent.allowedRemoteLegalNames, + wrappedKeyManagerFactory, + trustManagerFactory, + ch.alloc(), + parent.nettyThreading.sslDelegatedTaskExecutor + ) } else { - createClientSslHelper(target, parent.allowedRemoteLegalNames, wrappedKeyManagerFactory, trustManagerFactory) + createClientSslHandler( + target, + parent.allowedRemoteLegalNames, + wrappedKeyManagerFactory, + trustManagerFactory, + parent.nettyThreading.sslDelegatedTaskExecutor + ) } - handler.handshakeTimeoutMillis = conf.sslHandshakeTimeout + handler.handshakeTimeoutMillis = conf.sslHandshakeTimeout.toMillis() pipeline.addLast("sslHandler", handler) if (conf.trace) pipeline.addLast("logger", LoggingHandler(LogLevel.INFO)) - amqpChannelHandler = AMQPChannelHandler(false, + amqpChannelHandler = AMQPChannelHandler( + false, parent.allowedRemoteLegalNames, // Single entry, key can be anything. mapOf(DEFAULT to wrappedKeyManagerFactory), @@ -209,37 +227,41 @@ class AMQPClient(val targets: List, conf.password, conf.trace, false, - onOpen = { _, change -> - parent.run { - amqpActive = true - retryInterval = MIN_RETRY_INTERVAL // reset to fast reconnect if we connect properly - _onConnection.onNext(change) - } - }, - onClose = { _, change -> - if (parent.amqpChannelHandler == amqpChannelHandler) { - parent.run { - _onConnection.onNext(change) - if (change.badCert) { - log.error("Blocking future connection attempts to $target due to bad certificate on endpoint") - badCertTargets += target - } - - if (started && amqpActive) { - log.debug { "Scheduling restart of $currentTarget (AMQP active)" } - workerGroup?.schedule({ - nextTarget() - restart() - }, retryInterval, TimeUnit.MILLISECONDS) - } - amqpActive = false - } - } - }, - onReceive = { rcv -> parent._onReceive.onNext(rcv) }) + onOpen = { _, change -> onChannelOpen(change) }, + onClose = { _, change -> onChannelClose(change, target) }, + onReceive = parent._onReceive::onNext + ) parent.amqpChannelHandler = amqpChannelHandler pipeline.addLast(amqpChannelHandler) } + + private fun onChannelOpen(change: ConnectionChange) { + parent.run { + amqpActive = true + retryInterval = MIN_RETRY_INTERVAL // reset to fast reconnect if we connect properly + _onConnection.onNext(change) + } + } + + private fun onChannelClose(change: ConnectionChange, target: NetworkHostAndPort) { + if (parent.amqpChannelHandler != amqpChannelHandler) return + parent.run { + _onConnection.onNext(change) + if (change.badCert) { + log.error("Blocking future connection attempts to $target due to bad certificate on endpoint") + badCertTargets += target + } + + if (started && amqpActive) { + log.debug { "Scheduling restart of $currentTarget (AMQP active)" } + nettyThreading.eventLoopGroup.schedule({ + nextTarget() + restart() + }, retryInterval, TimeUnit.MILLISECONDS) + } + amqpActive = false + } + } } fun start() { @@ -249,7 +271,7 @@ class AMQPClient(val targets: List, return } log.info("Connect to: $currentTarget") - workerGroup = sharedThreadPool ?: NioEventLoopGroup(NUM_CLIENT_THREADS) + (nettyThreading as? NettyThreading.NonShared)?.start() started = true restart() } @@ -261,7 +283,7 @@ class AMQPClient(val targets: List, } val bootstrap = Bootstrap() // TODO Needs more configuration control when we profile. e.g. to use EPOLL on Linux - bootstrap.group(workerGroup).channel(NioSocketChannel::class.java).handler(ClientChannelInitializer(this)) + bootstrap.group(nettyThreading.eventLoopGroup).channel(NioSocketChannel::class.java).handler(ClientChannelInitializer(this)) // Delegate DNS Resolution to the proxy side, if we are using proxy. if (configuration.proxyConfig != null) { bootstrap.resolver(NoopAddressResolverGroup.INSTANCE) @@ -275,14 +297,12 @@ class AMQPClient(val targets: List, lock.withLock { log.info("Stopping connection to: $currentTarget, Local address: $localAddressString") started = false - if (sharedThreadPool == null) { - workerGroup?.shutdownGracefully() - workerGroup?.terminationFuture()?.sync() + if (nettyThreading is NettyThreading.NonShared) { + nettyThreading.stop() } else { clientChannel?.close()?.sync() } clientChannel = null - workerGroup = null log.info("Stopped connection to $currentTarget") } } @@ -323,4 +343,36 @@ class AMQPClient(val targets: List, private val _onConnection = PublishSubject.create().toSerialized() val onConnection: Observable get() = _onConnection -} \ No newline at end of file + + + sealed class NettyThreading { + abstract val eventLoopGroup: EventLoopGroup + abstract val sslDelegatedTaskExecutor: Executor + + class Shared(override val eventLoopGroup: EventLoopGroup, + override val sslDelegatedTaskExecutor: ExecutorService = sslDelegatedTaskExecutor("AMQPClient")) : NettyThreading() + + class NonShared(val threadPoolName: String) : NettyThreading() { + private var _eventLoopGroup: NioEventLoopGroup? = null + override val eventLoopGroup: EventLoopGroup get() = checkNotNull(_eventLoopGroup) + + private var _sslDelegatedTaskExecutor: ThreadPoolExecutor? = null + override val sslDelegatedTaskExecutor: ExecutorService get() = checkNotNull(_sslDelegatedTaskExecutor) + + fun start() { + check(_eventLoopGroup == null) + check(_sslDelegatedTaskExecutor == null) + _eventLoopGroup = NioEventLoopGroup(NUM_CLIENT_THREADS, DefaultThreadFactory(threadPoolName, Thread.MAX_PRIORITY)) + _sslDelegatedTaskExecutor = sslDelegatedTaskExecutor(threadPoolName) + } + + fun stop() { + eventLoopGroup.shutdownGracefully() + eventLoopGroup.terminationFuture().sync() + sslDelegatedTaskExecutor.shutdown() + _eventLoopGroup = null + _sslDelegatedTaskExecutor = null + } + } + } +} diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPConfiguration.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPConfiguration.kt index db0dd8023c..c992dd55e4 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPConfiguration.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPConfiguration.kt @@ -2,7 +2,8 @@ package net.corda.nodeapi.internal.protonwrapper.netty import net.corda.nodeapi.internal.ArtemisMessagingComponent import net.corda.nodeapi.internal.config.CertificateStore -import net.corda.nodeapi.internal.config.DEFAULT_SSL_HANDSHAKE_TIMEOUT_MILLIS +import net.corda.nodeapi.internal.config.DEFAULT_SSL_HANDSHAKE_TIMEOUT +import java.time.Duration interface AMQPConfiguration { /** @@ -67,8 +68,8 @@ interface AMQPConfiguration { get() = false @JvmDefault - val sslHandshakeTimeout: Long - get() = DEFAULT_SSL_HANDSHAKE_TIMEOUT_MILLIS // Aligned with sun.security.provider.certpath.URICertStore.DEFAULT_CRL_CONNECT_TIMEOUT + val sslHandshakeTimeout: Duration + get() = DEFAULT_SSL_HANDSHAKE_TIMEOUT // Aligned with sun.security.provider.certpath.URICertStore.DEFAULT_CRL_CONNECT_TIMEOUT /** * An optional Health Check Phrase which if passed through the channel will cause AMQP Server to echo it back instead of doing normal pipeline processing diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPServer.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPServer.kt index 20834a2041..523cde184a 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPServer.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPServer.kt @@ -11,6 +11,7 @@ import io.netty.channel.socket.SocketChannel import io.netty.channel.socket.nio.NioServerSocketChannel import io.netty.handler.logging.LogLevel import io.netty.handler.logging.LoggingHandler +import io.netty.util.concurrent.DefaultThreadFactory import io.netty.util.internal.logging.InternalLoggerFactory import io.netty.util.internal.logging.Slf4JLoggerFactory import net.corda.core.utilities.NetworkHostAndPort @@ -20,15 +21,15 @@ import net.corda.nodeapi.internal.protonwrapper.messages.ReceivedMessage import net.corda.nodeapi.internal.protonwrapper.messages.SendableMessage import net.corda.nodeapi.internal.protonwrapper.messages.impl.SendableMessageImpl import net.corda.nodeapi.internal.requireMessageSize +import net.corda.nodeapi.internal.revocation.CertDistPointCrlSource import org.apache.qpid.proton.engine.Delivery import rx.Observable import rx.subjects.PublishSubject import java.net.BindException import java.net.InetSocketAddress import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.ExecutorService import java.util.concurrent.locks.ReentrantLock -import javax.net.ssl.KeyManagerFactory -import javax.net.ssl.TrustManagerFactory import kotlin.concurrent.withLock /** @@ -36,37 +37,35 @@ import kotlin.concurrent.withLock */ class AMQPServer(val hostName: String, val port: Int, - private val configuration: AMQPConfiguration) : AutoCloseable { - + private val configuration: AMQPConfiguration, + private val threadPoolName: String = "AMQPServer", + private val distPointCrlSource: CertDistPointCrlSource = CertDistPointCrlSource.SINGLETON, + private val remotingThreads: Int? = null) : AutoCloseable { companion object { init { InternalLoggerFactory.setDefaultFactory(Slf4JLoggerFactory.INSTANCE) } - private const val CORDA_AMQP_NUM_SERVER_THREAD_PROP_NAME = "net.corda.nodeapi.amqpserver.NumServerThreads" - private val log = contextLogger() - private val NUM_SERVER_THREADS = Integer.getInteger(CORDA_AMQP_NUM_SERVER_THREAD_PROP_NAME, 4) + private val DEFAULT_REMOTING_THREADS = Integer.getInteger("net.corda.nodeapi.amqpserver.NumServerThreads", 4) } private val lock = ReentrantLock() - @Volatile - private var stopping: Boolean = false private var bossGroup: EventLoopGroup? = null private var workerGroup: EventLoopGroup? = null private var serverChannel: Channel? = null + private var sslDelegatedTaskExecutor: ExecutorService? = null private val clientChannels = ConcurrentHashMap() private class ServerChannelInitializer(val parent: AMQPServer) : ChannelInitializer() { - private val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()) - private val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()) + private val keyManagerFactory = keyManagerFactory(parent.configuration.keyStore) + private val trustManagerFactory = trustManagerFactoryWithRevocation( + parent.configuration.trustStore, + parent.configuration.revocationConfig, + parent.distPointCrlSource + ) private val conf = parent.configuration - init { - keyManagerFactory.init(conf.keyStore.value.internal, conf.keyStore.entryPassword.toCharArray()) - trustManagerFactory.init(initialiseTrustStoreAndEnableCrlChecking(conf.trustStore, conf.revocationConfig)) - } - override fun initChannel(ch: SocketChannel) { val amqpConfiguration = parent.configuration val pipeline = ch.pipeline() @@ -75,7 +74,8 @@ class AMQPServer(val hostName: String, pipeline.addLast("sslHandler", sslHandler) if (conf.trace) pipeline.addLast("logger", LoggingHandler(LogLevel.INFO)) val suppressLogs = ch.remoteAddress()?.hostString in amqpConfiguration.silencedIPs - pipeline.addLast(AMQPChannelHandler(true, + pipeline.addLast(AMQPChannelHandler( + true, null, // Passing a mapping of legal names to key managers to be able to pick the correct one after // SNI completion event is fired up. @@ -84,36 +84,42 @@ class AMQPServer(val hostName: String, conf.password, conf.trace, suppressLogs, - onOpen = { channel, change -> - parent.run { - clientChannels[channel.remoteAddress()] = channel - _onConnection.onNext(change) - } - }, - onClose = { channel, change -> - parent.run { - val remoteAddress = channel.remoteAddress() - clientChannels.remove(remoteAddress) - _onConnection.onNext(change) - } - }, - onReceive = { rcv -> parent._onReceive.onNext(rcv) })) + onOpen = ::onChannelOpen, + onClose = ::onChannelClose, + onReceive = parent._onReceive::onNext + )) + } + + private fun onChannelOpen(channel: SocketChannel, change: ConnectionChange) { + parent.run { + clientChannels[channel.remoteAddress()] = channel + _onConnection.onNext(change) + } + } + + private fun onChannelClose(channel: SocketChannel, change: ConnectionChange) { + parent.run { + val remoteAddress = channel.remoteAddress() + clientChannels.remove(remoteAddress) + _onConnection.onNext(change) + } } private fun createSSLHandler(amqpConfig: AMQPConfiguration, ch: SocketChannel): Pair> { return if (amqpConfig.useOpenSsl && amqpConfig.enableSNI && amqpConfig.keyStore.aliases().size > 1) { val keyManagerFactoriesMap = splitKeystore(amqpConfig) // SNI matching needed only when multiple nodes exist behind the server. - Pair(createServerSNIOpenSslHandler(keyManagerFactoriesMap, trustManagerFactory), keyManagerFactoriesMap) + Pair(createServerSNIOpenSniHandler(keyManagerFactoriesMap, trustManagerFactory), keyManagerFactoriesMap) } else { val keyManagerFactory = CertHoldingKeyManagerFactoryWrapper(keyManagerFactory, amqpConfig) + val delegatedTaskExecutor = checkNotNull(parent.sslDelegatedTaskExecutor) val handler = if (amqpConfig.useOpenSsl) { - createServerOpenSslHandler(keyManagerFactory, trustManagerFactory, ch.alloc()) + createServerOpenSslHandler(keyManagerFactory, trustManagerFactory, ch.alloc(), delegatedTaskExecutor) } else { // For javaSSL, SNI matching is handled at key manager level. - createServerSslHandler(amqpConfig.keyStore, keyManagerFactory, trustManagerFactory) + createServerSslHandler(amqpConfig.keyStore, keyManagerFactory, trustManagerFactory, delegatedTaskExecutor) } - handler.handshakeTimeoutMillis = amqpConfig.sslHandshakeTimeout + handler.handshakeTimeoutMillis = amqpConfig.sslHandshakeTimeout.toMillis() Pair(handler, mapOf(DEFAULT to keyManagerFactory)) } } @@ -123,8 +129,13 @@ class AMQPServer(val hostName: String, lock.withLock { stop() - bossGroup = NioEventLoopGroup(1) - workerGroup = NioEventLoopGroup(NUM_SERVER_THREADS) + sslDelegatedTaskExecutor = sslDelegatedTaskExecutor(threadPoolName) + + bossGroup = NioEventLoopGroup(1, DefaultThreadFactory("$threadPoolName-boss", Thread.MAX_PRIORITY)) + workerGroup = NioEventLoopGroup( + remotingThreads ?: DEFAULT_REMOTING_THREADS, + DefaultThreadFactory("$threadPoolName-worker", Thread.MAX_PRIORITY) + ) val server = ServerBootstrap() // TODO Needs more configuration control when we profile. e.g. to use EPOLL on Linux @@ -145,22 +156,19 @@ class AMQPServer(val hostName: String, fun stop() { lock.withLock { - try { - stopping = true - serverChannel?.apply { close() } - serverChannel = null + serverChannel?.close() + serverChannel = null - workerGroup?.shutdownGracefully() - workerGroup?.terminationFuture()?.sync() + workerGroup?.shutdownGracefully() + workerGroup?.terminationFuture()?.sync() + workerGroup = null - bossGroup?.shutdownGracefully() - bossGroup?.terminationFuture()?.sync() + bossGroup?.shutdownGracefully() + bossGroup?.terminationFuture()?.sync() + bossGroup = null - workerGroup = null - bossGroup = null - } finally { - stopping = false - } + sslDelegatedTaskExecutor?.shutdown() + sslDelegatedTaskExecutor = null } } diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AllowAllRevocationChecker.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AllowAllRevocationChecker.kt index 30e0445689..a853cbffc8 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AllowAllRevocationChecker.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AllowAllRevocationChecker.kt @@ -11,7 +11,7 @@ object AllowAllRevocationChecker : PKIXRevocationChecker() { private val logger = LoggerFactory.getLogger(AllowAllRevocationChecker::class.java) - override fun check(cert: Certificate?, unresolvedCritExts: MutableCollection?) { + override fun check(cert: Certificate, unresolvedCritExts: Collection) { logger.debug {"Passing certificate check for: $cert"} // Nothing to do } @@ -20,7 +20,7 @@ object AllowAllRevocationChecker : PKIXRevocationChecker() { return true } - override fun getSupportedExtensions(): MutableSet? { + override fun getSupportedExtensions(): Set? { return null } @@ -28,7 +28,9 @@ object AllowAllRevocationChecker : PKIXRevocationChecker() { // Nothing to do } - override fun getSoftFailExceptions(): MutableList { - return LinkedList() + override fun getSoftFailExceptions(): List { + return Collections.emptyList() } + + override fun clone(): AllowAllRevocationChecker = this } \ No newline at end of file diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/ExternalCrlSource.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/CrlSource.kt similarity index 70% rename from node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/ExternalCrlSource.kt rename to node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/CrlSource.kt index 654ead24a0..a0cbd7079c 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/ExternalCrlSource.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/CrlSource.kt @@ -3,10 +3,11 @@ package net.corda.nodeapi.internal.protonwrapper.netty import java.security.cert.X509CRL import java.security.cert.X509Certificate -interface ExternalCrlSource { +@FunctionalInterface +interface CrlSource { /** * Given certificate provides a set of CRLs, potentially performing remote communication. */ - fun fetch(certificate: X509Certificate) : Set -} \ No newline at end of file + fun fetch(certificate: X509Certificate): Set +} diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/RevocationConfig.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/RevocationConfig.kt index f444421128..4e1b4b1930 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/RevocationConfig.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/RevocationConfig.kt @@ -26,7 +26,7 @@ interface RevocationConfig { /** * CRLs are obtained from external source - * @see ExternalCrlSource + * @see CrlSource */ EXTERNAL_SOURCE, @@ -39,14 +39,9 @@ interface RevocationConfig { val mode: Mode /** - * Optional `ExternalCrlSource` which only makes sense with `mode` = `EXTERNAL_SOURCE` + * Optional [CrlSource] which only makes sense with `mode` = `EXTERNAL_SOURCE` */ - val externalCrlSource: ExternalCrlSource? - - /** - * Creates a copy of `RevocationConfig` with ExternalCrlSource enriched - */ - fun enrichExternalCrlSource(sourceFunc: (() -> ExternalCrlSource)?): RevocationConfig + val externalCrlSource: CrlSource? } /** @@ -54,16 +49,7 @@ interface RevocationConfig { */ fun Boolean.toRevocationConfig() = if(this) RevocationConfigImpl(RevocationConfig.Mode.SOFT_FAIL) else RevocationConfigImpl(RevocationConfig.Mode.HARD_FAIL) -data class RevocationConfigImpl(override val mode: RevocationConfig.Mode, override val externalCrlSource: ExternalCrlSource? = null) : RevocationConfig { - override fun enrichExternalCrlSource(sourceFunc: (() -> ExternalCrlSource)?): RevocationConfig { - return if(mode != RevocationConfig.Mode.EXTERNAL_SOURCE) { - this - } else { - assert(sourceFunc != null) { "There should be a way to obtain ExternalCrlSource" } - copy(externalCrlSource = sourceFunc!!()) - } - } -} +data class RevocationConfigImpl(override val mode: RevocationConfig.Mode, override val externalCrlSource: CrlSource? = null) : RevocationConfig class RevocationConfigParser : ConfigParser { override fun parse(config: Config): RevocationConfig { @@ -80,4 +66,4 @@ class RevocationConfigParser : ConfigParser { else -> throw IllegalArgumentException("Unsupported mode : '$mode'") } } -} \ No newline at end of file +} diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/SSLHelper.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/SSLHelper.kt index 98910a673f..dc207f2c7b 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/SSLHelper.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/SSLHelper.kt @@ -1,3 +1,5 @@ +@file:Suppress("ComplexMethod", "LongParameterList") + package net.corda.nodeapi.internal.protonwrapper.netty import io.netty.buffer.ByteBufAllocator @@ -13,31 +15,41 @@ import net.corda.core.identity.CordaX500Name import net.corda.core.internal.VisibleForTesting import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.contextLogger -import net.corda.core.utilities.toHex +import net.corda.core.utilities.debug import net.corda.nodeapi.internal.ArtemisTcpTransport import net.corda.nodeapi.internal.config.CertificateStore -import net.corda.nodeapi.internal.crypto.toBc +import net.corda.nodeapi.internal.crypto.toSimpleString import net.corda.nodeapi.internal.crypto.x509 -import net.corda.nodeapi.internal.protonwrapper.netty.revocation.ExternalSourceRevocationChecker +import net.corda.nodeapi.internal.namedThreadPoolExecutor +import net.corda.nodeapi.internal.revocation.CordaRevocationChecker import org.bouncycastle.asn1.ASN1InputStream +import org.bouncycastle.asn1.ASN1Primitive import org.bouncycastle.asn1.DERIA5String import org.bouncycastle.asn1.DEROctetString -import org.bouncycastle.asn1.x509.AuthorityKeyIdentifier +import org.bouncycastle.asn1.x500.X500Name import org.bouncycastle.asn1.x509.CRLDistPoint import org.bouncycastle.asn1.x509.DistributionPointName import org.bouncycastle.asn1.x509.Extension import org.bouncycastle.asn1.x509.GeneralName import org.bouncycastle.asn1.x509.GeneralNames -import org.bouncycastle.asn1.x509.SubjectKeyIdentifier import org.slf4j.LoggerFactory -import java.io.ByteArrayInputStream import java.net.Socket +import java.net.URI import java.security.KeyStore -import java.security.cert.* -import java.util.* +import java.security.cert.CertificateException +import java.security.cert.PKIXBuilderParameters +import java.security.cert.X509CertSelector +import java.security.cert.X509Certificate import java.util.concurrent.Executor -import javax.net.ssl.* -import kotlin.system.measureTimeMillis +import java.util.concurrent.ThreadPoolExecutor +import javax.net.ssl.CertPathTrustManagerParameters +import javax.net.ssl.KeyManagerFactory +import javax.net.ssl.SNIHostName +import javax.net.ssl.SSLContext +import javax.net.ssl.SSLEngine +import javax.net.ssl.TrustManagerFactory +import javax.net.ssl.X509ExtendedTrustManager +import javax.security.auth.x500.X500Principal private const val HOSTNAME_FORMAT = "%s.corda.net" internal const val DEFAULT = "default" @@ -46,65 +58,73 @@ internal const val DP_DEFAULT_ANSWER = "NO CRLDP ext" internal val logger = LoggerFactory.getLogger("net.corda.nodeapi.internal.protonwrapper.netty.SSLHelper") -fun X509Certificate.distributionPoints() : Set? { - logger.debug("Checking CRLDPs for $subjectX500Principal") +/** + * Returns all the CRL distribution points in the certificate as [URI]s along with the CRL issuer names, if any. + */ +fun X509Certificate.distributionPoints(): Map?> { + logger.debug { "Checking CRLDPs for $subjectX500Principal" } val crldpExtBytes = getExtensionValue(Extension.cRLDistributionPoints.id) if (crldpExtBytes == null) { logger.debug(DP_DEFAULT_ANSWER) - return emptySet() + return emptyMap() } - val derObjCrlDP = ASN1InputStream(ByteArrayInputStream(crldpExtBytes)).readObject() + val derObjCrlDP = crldpExtBytes.toAsn1Object() val dosCrlDP = derObjCrlDP as? DEROctetString if (dosCrlDP == null) { logger.error("Expected to have DEROctetString, actual type: ${derObjCrlDP.javaClass}") - return emptySet() + return emptyMap() } - val crldpExtOctetsBytes = dosCrlDP.octets - val dpObj = ASN1InputStream(ByteArrayInputStream(crldpExtOctetsBytes)).readObject() - val distPoint = CRLDistPoint.getInstance(dpObj) - if (distPoint == null) { + val dpObj = dosCrlDP.octets.toAsn1Object() + val crlDistPoint = CRLDistPoint.getInstance(dpObj) + if (crlDistPoint == null) { logger.error("Could not instantiate CRLDistPoint, from: $dpObj") - return emptySet() + return emptyMap() } - val dpNames = distPoint.distributionPoints.mapNotNull { it.distributionPoint }.filter { it.type == DistributionPointName.FULL_NAME } - val generalNames = dpNames.flatMap { GeneralNames.getInstance(it.name).names.asList() } - return generalNames.filter { it.tagNo == GeneralName.uniformResourceIdentifier}.map { DERIA5String.getInstance(it.name).string }.toSet() -} - -fun X509Certificate.distributionPointsToString() : String { - return with(distributionPoints()) { - if(this == null || isEmpty()) { - DP_DEFAULT_ANSWER - } else { - sorted().joinToString() + val dpMap = HashMap?>() + for (distributionPoint in crlDistPoint.distributionPoints) { + val distributionPointName = distributionPoint.distributionPoint + if (distributionPointName?.type != DistributionPointName.FULL_NAME) continue + val issuerNames = distributionPoint.crlIssuer?.names?.mapNotNull { + if (it.tagNo == GeneralName.directoryName) { + X500Principal(X500Name.getInstance(it.name).encoded) + } else { + null + } + } + for (generalName in GeneralNames.getInstance(distributionPointName.name).names) { + if (generalName.tagNo == GeneralName.uniformResourceIdentifier) { + val uri = URI(DERIA5String.getInstance(generalName.name).string) + dpMap[uri] = issuerNames + } } } + return dpMap } +fun X509Certificate.distributionPointsToString(): String { + return with(distributionPoints().keys) { + if (isEmpty()) DP_DEFAULT_ANSWER else sorted().joinToString() + } +} + +fun ByteArray.toAsn1Object(): ASN1Primitive = ASN1InputStream(this).readObject() + fun certPathToString(certPath: Array?): String { if (certPath == null) { return "" } - val certs = certPath.map { - val bcCert = it.toBc() - val subject = bcCert.subject.toString() - val issuer = bcCert.issuer.toString() - val keyIdentifier = try { - SubjectKeyIdentifier.getInstance(bcCert.getExtension(Extension.subjectKeyIdentifier).parsedValue).keyIdentifier.toHex() - } catch (ex: Exception) { - "null" - } - val authorityKeyIdentifier = try { - AuthorityKeyIdentifier.getInstance(bcCert.getExtension(Extension.authorityKeyIdentifier).parsedValue).keyIdentifier.toHex() - } catch (ex: Exception) { - "null" - } - " $subject[$keyIdentifier] issued by $issuer[$authorityKeyIdentifier] [${it.distributionPointsToString()}]" - } - return certs.joinToString("\r\n") + return certPath.joinToString(System.lineSeparator()) { " ${it.toSimpleString()}" } +} + +/** + * Create an executor for processing SSL handshake tasks asynchronously (see [SSLEngine.getDelegatedTask]). The max number of threads is 3, + * which is the typical number of CRLs expected in a Corda TLS cert path. The executor needs to be passed to the [SslHandler] constructor. + */ +fun sslDelegatedTaskExecutor(parentPoolName: String): ThreadPoolExecutor { + return namedThreadPoolExecutor(maxPoolSize = 3, poolName = "$parentPoolName-ssltask") } @VisibleForTesting @@ -117,7 +137,7 @@ class LoggingTrustManagerWrapper(val wrapped: X509ExtendedTrustManager) : X509Ex if (chain == null) { return "" } - return chain.map { it.toString() }.joinToString(", ") + return chain.joinToString(", ") { it.toString() } } private fun logErrors(chain: Array?, block: () -> Unit) { @@ -169,37 +189,11 @@ class LoggingTrustManagerWrapper(val wrapped: X509ExtendedTrustManager) : X509Ex } -private object LoggingImmediateExecutor : Executor { - - override fun execute(command: Runnable?) { - val log = LoggerFactory.getLogger(javaClass) - - if (command == null) { - log.error("SSL handler executor called with a null command") - throw NullPointerException("command") - } - - @Suppress("TooGenericExceptionCaught", "MagicNumber") // log and rethrow all exceptions - try { - val commandName = command::class.qualifiedName?.let { "[$it]" } ?: "" - log.debug("Entering SSL command $commandName") - val elapsedTime = measureTimeMillis { command.run() } - log.debug("Exiting SSL command $elapsedTime millis") - if (elapsedTime > 100) { - log.info("Command: $commandName took $elapsedTime millis to execute") - } - } - catch (ex: Exception) { - log.error("Caught exception in SSL handler executor", ex) - throw ex - } - } -} - -internal fun createClientSslHelper(target: NetworkHostAndPort, - expectedRemoteLegalNames: Set, - keyManagerFactory: KeyManagerFactory, - trustManagerFactory: TrustManagerFactory): SslHandler { +internal fun createClientSslHandler(target: NetworkHostAndPort, + expectedRemoteLegalNames: Set, + keyManagerFactory: KeyManagerFactory, + trustManagerFactory: TrustManagerFactory, + delegateTaskExecutor: Executor): SslHandler { val sslContext = createAndInitSslContext(keyManagerFactory, trustManagerFactory) val sslEngine = sslContext.createSSLEngine(target.host, target.port) sslEngine.useClientMode = true @@ -211,15 +205,15 @@ internal fun createClientSslHelper(target: NetworkHostAndPort, sslParameters.serverNames = listOf(SNIHostName(x500toHostName(expectedRemoteLegalNames.single()))) sslEngine.sslParameters = sslParameters } - @Suppress("DEPRECATION") - return SslHandler(sslEngine, false, LoggingImmediateExecutor) + return SslHandler(sslEngine, false, delegateTaskExecutor) } internal fun createClientOpenSslHandler(target: NetworkHostAndPort, expectedRemoteLegalNames: Set, keyManagerFactory: KeyManagerFactory, trustManagerFactory: TrustManagerFactory, - alloc: ByteBufAllocator): SslHandler { + alloc: ByteBufAllocator, + delegateTaskExecutor: Executor): SslHandler { val sslContext = SslContextBuilder.forClient().sslProvider(SslProvider.OPENSSL).keyManager(keyManagerFactory).trustManager(LoggingTrustManagerFactoryWrapper(trustManagerFactory)).build() val sslEngine = sslContext.newEngine(alloc, target.host, target.port) sslEngine.enabledProtocols = ArtemisTcpTransport.TLS_VERSIONS.toTypedArray() @@ -229,13 +223,13 @@ internal fun createClientOpenSslHandler(target: NetworkHostAndPort, sslParameters.serverNames = listOf(SNIHostName(x500toHostName(expectedRemoteLegalNames.single()))) sslEngine.sslParameters = sslParameters } - @Suppress("DEPRECATION") - return SslHandler(sslEngine, false, LoggingImmediateExecutor) + return SslHandler(sslEngine, false, delegateTaskExecutor) } internal fun createServerSslHandler(keyStore: CertificateStore, keyManagerFactory: KeyManagerFactory, - trustManagerFactory: TrustManagerFactory): SslHandler { + trustManagerFactory: TrustManagerFactory, + delegateTaskExecutor: Executor): SslHandler { val sslContext = createAndInitSslContext(keyManagerFactory, trustManagerFactory) val sslEngine = sslContext.createSSLEngine() sslEngine.useClientMode = false @@ -246,65 +240,34 @@ internal fun createServerSslHandler(keyStore: CertificateStore, val sslParameters = sslEngine.sslParameters sslParameters.sniMatchers = listOf(ServerSNIMatcher(keyStore)) sslEngine.sslParameters = sslParameters - @Suppress("DEPRECATION") - return SslHandler(sslEngine, false, LoggingImmediateExecutor) -} - -fun createAndInitSslContext(keyManagerFactory: KeyManagerFactory, trustManagerFactory: TrustManagerFactory): SSLContext { - val sslContext = SSLContext.getInstance("TLS") - val keyManagers = keyManagerFactory.keyManagers - val trustManagers = trustManagerFactory.trustManagers.filterIsInstance(X509ExtendedTrustManager::class.java) - .map { LoggingTrustManagerWrapper(it) }.toTypedArray() - sslContext.init(keyManagers, trustManagers, newSecureRandom()) - return sslContext -} - -@VisibleForTesting -fun initialiseTrustStoreAndEnableCrlChecking(trustStore: CertificateStore, revocationConfig: RevocationConfig): ManagerFactoryParameters { - val pkixParams = PKIXBuilderParameters(trustStore.value.internal, X509CertSelector()) - val revocationChecker = when (revocationConfig.mode) { - RevocationConfig.Mode.OFF -> AllowAllRevocationChecker // Custom PKIXRevocationChecker skipping CRL check - RevocationConfig.Mode.EXTERNAL_SOURCE -> { - require(revocationConfig.externalCrlSource != null) { "externalCrlSource must not be null" } - ExternalSourceRevocationChecker(revocationConfig.externalCrlSource!!) { Date() } // Custom PKIXRevocationChecker which uses `externalCrlSource` - } - else -> { - val certPathBuilder = CertPathBuilder.getInstance("PKIX") - val pkixRevocationChecker = certPathBuilder.revocationChecker as PKIXRevocationChecker - pkixRevocationChecker.options = EnumSet.of( - // Prefer CRL over OCSP - PKIXRevocationChecker.Option.PREFER_CRLS, - // Don't fall back to OCSP checking - PKIXRevocationChecker.Option.NO_FALLBACK) - if (revocationConfig.mode == RevocationConfig.Mode.SOFT_FAIL) { - // Allow revocation check to succeed if the revocation status cannot be determined for one of - // the following reasons: The CRL or OCSP response cannot be obtained because of a network error. - pkixRevocationChecker.options = pkixRevocationChecker.options + PKIXRevocationChecker.Option.SOFT_FAIL - } - pkixRevocationChecker - } - } - pkixParams.addCertPathChecker(revocationChecker) - return CertPathTrustManagerParameters(pkixParams) + return SslHandler(sslEngine, false, delegateTaskExecutor) } internal fun createServerOpenSslHandler(keyManagerFactory: KeyManagerFactory, trustManagerFactory: TrustManagerFactory, - alloc: ByteBufAllocator): SslHandler { - + alloc: ByteBufAllocator, + delegateTaskExecutor: Executor): SslHandler { val sslContext = getServerSslContextBuilder(keyManagerFactory, trustManagerFactory).build() val sslEngine = sslContext.newEngine(alloc) sslEngine.useClientMode = false - @Suppress("DEPRECATION") - return SslHandler(sslEngine, false, LoggingImmediateExecutor) + return SslHandler(sslEngine, false, delegateTaskExecutor) +} + +fun createAndInitSslContext(keyManagerFactory: KeyManagerFactory, trustManagerFactory: TrustManagerFactory?): SSLContext { + val sslContext = SSLContext.getInstance("TLS") + val trustManagers = trustManagerFactory + ?.trustManagers + ?.map { if (it is X509ExtendedTrustManager) LoggingTrustManagerWrapper(it) else it } + ?.toTypedArray() + sslContext.init(keyManagerFactory.keyManagers, trustManagers, newSecureRandom()) + return sslContext } /** * Creates a special SNI handler used only when openSSL is used for AMQPServer */ -internal fun createServerSNIOpenSslHandler(keyManagerFactoriesMap: Map, +internal fun createServerSNIOpenSniHandler(keyManagerFactoriesMap: Map, trustManagerFactory: TrustManagerFactory): SniHandler { - // Default value can be any in the map. val sslCtxBuilder = getServerSslContextBuilder(keyManagerFactoriesMap.values.first(), trustManagerFactory) val mapping = DomainWildcardMappingBuilder(sslCtxBuilder.build()) @@ -314,20 +277,19 @@ internal fun createServerSNIOpenSslHandler(keyManagerFactoriesMap: Map { val keyStore = config.keyStore.value.internal val password = config.keyStore.entryPassword.toCharArray() - return keyStore.aliases().toList().map { alias -> + return keyStore.aliases().toList().associate { alias -> val key = keyStore.getKey(alias, password) val certs = keyStore.getCertificateChain(alias) val x500Name = keyStore.getCertificate(alias).x509.subjectX500Principal @@ -338,14 +300,45 @@ internal fun splitKeystore(config: AMQPConfiguration): Map AllowAllRevocationChecker + RevocationConfig.Mode.EXTERNAL_SOURCE -> { + val externalCrlSource = requireNotNull(revocationConfig.externalCrlSource) { + "externalCrlSource must be specfied for EXTERNAL_SOURCE" + } + CordaRevocationChecker(externalCrlSource, softFail = true) + } + RevocationConfig.Mode.SOFT_FAIL -> CordaRevocationChecker(crlSource, softFail = true) + RevocationConfig.Mode.HARD_FAIL -> CordaRevocationChecker(crlSource, softFail = false) + } + val pkixParams = PKIXBuilderParameters(trustStore.value.internal, X509CertSelector()) + pkixParams.addCertPathChecker(revocationChecker) + val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()) + trustManagerFactory.init(CertPathTrustManagerParameters(pkixParams)) + return trustManagerFactory +} /** * Method that converts a [CordaX500Name] to a a valid hostname (RFC-1035). It's used for SNI to indicate the target diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/revocation/ExternalSourceRevocationChecker.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/revocation/ExternalSourceRevocationChecker.kt deleted file mode 100644 index 23af94ca3d..0000000000 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/revocation/ExternalSourceRevocationChecker.kt +++ /dev/null @@ -1,88 +0,0 @@ -package net.corda.nodeapi.internal.protonwrapper.netty.revocation - -import net.corda.core.utilities.contextLogger -import net.corda.nodeapi.internal.protonwrapper.netty.ExternalCrlSource -import org.bouncycastle.asn1.x509.Extension -import java.security.cert.CRLReason -import java.security.cert.CertPathValidatorException -import java.security.cert.Certificate -import java.security.cert.CertificateRevokedException -import java.security.cert.PKIXRevocationChecker -import java.security.cert.X509CRL -import java.security.cert.X509Certificate -import java.util.* - -/** - * Implementation of [PKIXRevocationChecker] which determines whether certificate is revoked using [externalCrlSource] which knows how to - * obtain a set of CRLs for a given certificate from an external source - */ -class ExternalSourceRevocationChecker(private val externalCrlSource: ExternalCrlSource, private val dateSource: () -> Date) : PKIXRevocationChecker() { - - companion object { - private val logger = contextLogger() - } - - override fun check(cert: Certificate, unresolvedCritExts: MutableCollection?) { - val x509Certificate = cert as X509Certificate - checkApprovedCRLs(x509Certificate, externalCrlSource.fetch(x509Certificate)) - } - - /** - * Borrowed from `RevocationChecker.checkApprovedCRLs()` - */ - @Suppress("NestedBlockDepth") - @Throws(CertPathValidatorException::class) - private fun checkApprovedCRLs(cert: X509Certificate, approvedCRLs: Set) { - // See if the cert is in the set of approved crls. - logger.debug("ExternalSourceRevocationChecker.checkApprovedCRLs() cert SN: ${cert.serialNumber}") - - for (crl in approvedCRLs) { - val entry = crl.getRevokedCertificate(cert) - if (entry != null) { - logger.debug("ExternalSourceRevocationChecker.checkApprovedCRLs() CRL entry: $entry") - - /* - * Abort CRL validation and throw exception if there are any - * unrecognized critical CRL entry extensions (see section - * 5.3 of RFC 5280). - */ - val unresCritExts = entry.criticalExtensionOIDs - if (unresCritExts != null && !unresCritExts.isEmpty()) { - /* remove any that we will process */ - unresCritExts.remove(Extension.cRLDistributionPoints.id) - unresCritExts.remove(Extension.certificateIssuer.id) - if (!unresCritExts.isEmpty()) { - throw CertPathValidatorException( - "Unrecognized critical extension(s) in revoked CRL entry: $unresCritExts") - } - } - - val reasonCode = entry.revocationReason ?: CRLReason.UNSPECIFIED - val revocationDate = entry.revocationDate - if (revocationDate.before(dateSource())) { - val t = CertificateRevokedException( - revocationDate, reasonCode, - crl.issuerX500Principal, mutableMapOf()) - throw CertPathValidatorException( - t.message, t, null, -1, CertPathValidatorException.BasicReason.REVOKED) - } - } - } - } - - override fun isForwardCheckingSupported(): Boolean { - return true - } - - override fun getSupportedExtensions(): MutableSet? { - return null - } - - override fun init(forward: Boolean) { - // Nothing to do - } - - override fun getSoftFailExceptions(): MutableList { - return LinkedList() - } -} \ No newline at end of file diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/revocation/CertDistPointCrlSource.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/revocation/CertDistPointCrlSource.kt new file mode 100644 index 0000000000..ee589e73a9 --- /dev/null +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/revocation/CertDistPointCrlSource.kt @@ -0,0 +1,119 @@ +package net.corda.nodeapi.internal.revocation + +import com.github.benmanes.caffeine.cache.Caffeine +import com.github.benmanes.caffeine.cache.LoadingCache +import net.corda.core.internal.readFully +import net.corda.core.utilities.contextLogger +import net.corda.core.utilities.debug +import net.corda.core.utilities.contextLogger +import net.corda.core.utilities.minutes +import net.corda.core.utilities.seconds +import net.corda.nodeapi.internal.crypto.X509CertificateFactory +import net.corda.nodeapi.internal.crypto.toSimpleString +import net.corda.nodeapi.internal.protonwrapper.netty.CrlSource +import net.corda.nodeapi.internal.protonwrapper.netty.distributionPoints +import java.net.URI +import java.security.cert.X509CRL +import java.security.cert.X509Certificate +import java.time.Duration +import javax.security.auth.x500.X500Principal + +/** + * [CrlSource] which downloads CRLs from the distribution points in the X509 certificate and caches them. + */ +@Suppress("TooGenericExceptionCaught") +class CertDistPointCrlSource(cacheSize: Long = DEFAULT_CACHE_SIZE, + cacheExpiry: Duration = DEFAULT_CACHE_EXPIRY, + private val connectTimeout: Duration = DEFAULT_CONNECT_TIMEOUT, + private val readTimeout: Duration = DEFAULT_READ_TIMEOUT) : CrlSource { + companion object { + private val logger = contextLogger() + + // The default SSL handshake timeout is 60s (DEFAULT_SSL_HANDSHAKE_TIMEOUT). Considering there are 3 CRLs endpoints to check in a + // node handshake, we want to keep the total timeout within that. + private val DEFAULT_CONNECT_TIMEOUT = 9.seconds + private val DEFAULT_READ_TIMEOUT = 9.seconds + private const val DEFAULT_CACHE_SIZE = 185L // Same default as the JDK (URICertStore) + private val DEFAULT_CACHE_EXPIRY = 5.minutes + + val SINGLETON = CertDistPointCrlSource( + cacheSize = java.lang.Long.getLong("net.corda.dpcrl.cache.size", DEFAULT_CACHE_SIZE), + cacheExpiry = java.lang.Long.getLong("net.corda.dpcrl.cache.expiry")?.let(Duration::ofMillis) ?: DEFAULT_CACHE_EXPIRY, + connectTimeout = java.lang.Long.getLong("net.corda.dpcrl.connect.timeout")?.let(Duration::ofMillis) ?: DEFAULT_CONNECT_TIMEOUT, + readTimeout = java.lang.Long.getLong("net.corda.dpcrl.read.timeout")?.let(Duration::ofMillis) ?: DEFAULT_READ_TIMEOUT + ) + } + + private val cache: LoadingCache = Caffeine.newBuilder() + .maximumSize(cacheSize) + .expireAfterWrite(cacheExpiry) + .build(::retrieveCRL) + + private fun retrieveCRL(uri: URI): X509CRL { + val start = System.currentTimeMillis() + val bytes = try { + val conn = uri.toURL().openConnection() + conn.connectTimeout = connectTimeout.toMillis().toInt() + conn.readTimeout = readTimeout.toMillis().toInt() + // Read all bytes first and then pass them into the CertificateFactory. This may seem unnecessary when generateCRL already takes + // in an InputStream, but the JDK implementation (sun.security.provider.X509Factory.engineGenerateCRL) converts any IOException + // into CRLException and drops the cause chain. + conn.getInputStream().readFully() + } catch (e: Exception) { + if (logger.isDebugEnabled) { + logger.debug("Unable to download CRL from $uri (${System.currentTimeMillis() - start}ms)", e) + } + throw e + } + val duration = System.currentTimeMillis() - start + val crl = try { + X509CertificateFactory().generateCRL(bytes.inputStream()) + } catch (e: Exception) { + if (logger.isDebugEnabled) { + logger.debug("Invalid CRL from $uri (${duration}ms)", e) + } + throw e + } + logger.debug { "CRL from $uri (${duration}ms): ${crl.toSimpleString()}" } + return crl + } + + fun clearCache() { + cache.invalidateAll() + } + + override fun fetch(certificate: X509Certificate): Set { + val approvedCRLs = HashSet() + var exception: Exception? = null + for ((distPointUri, issuerNames) in certificate.distributionPoints()) { + try { + val possibleCRL = getPossibleCRL(distPointUri) + if (verifyCRL(possibleCRL, certificate, issuerNames)) { + approvedCRLs += possibleCRL + } + } catch (e: Exception) { + if (exception == null) { + exception = e + } else { + exception.addSuppressed(e) + } + } + } + // Only throw if no CRLs are retrieved + if (exception != null && approvedCRLs.isEmpty()) { + throw exception + } else { + return approvedCRLs + } + } + + private fun getPossibleCRL(uri: URI): X509CRL { + return cache[uri]!! + } + + // DistributionPointFetcher.verifyCRL + private fun verifyCRL(crl: X509CRL, certificate: X509Certificate, distPointIssuerNames: List?): Boolean { + val crlIssuer = crl.issuerX500Principal + return distPointIssuerNames?.any { it == crlIssuer } ?: (certificate.issuerX500Principal == crlIssuer) + } +} diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/revocation/CordaRevocationChecker.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/revocation/CordaRevocationChecker.kt new file mode 100644 index 0000000000..1e0a3ecf53 --- /dev/null +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/revocation/CordaRevocationChecker.kt @@ -0,0 +1,126 @@ +package net.corda.nodeapi.internal.revocation + +import net.corda.core.utilities.contextLogger +import net.corda.core.utilities.debug +import net.corda.nodeapi.internal.crypto.toSimpleString +import net.corda.nodeapi.internal.protonwrapper.netty.CrlSource +import org.bouncycastle.asn1.x509.Extension +import java.security.cert.CRLReason +import java.security.cert.CertPathValidatorException +import java.security.cert.CertPathValidatorException.BasicReason +import java.security.cert.Certificate +import java.security.cert.CertificateRevokedException +import java.security.cert.PKIXRevocationChecker +import java.security.cert.X509CRL +import java.security.cert.X509Certificate +import java.util.* +import kotlin.collections.ArrayList + +/** + * Custom [PKIXRevocationChecker] which delegates to a plugable [CrlSource] to retrieve the CRLs for certificate revocation checks. + */ +class CordaRevocationChecker(private val crlSource: CrlSource, + private val softFail: Boolean, + private val dateSource: () -> Date = ::Date) : PKIXRevocationChecker() { + companion object { + private val logger = contextLogger() + } + + private val softFailExceptions = ArrayList() + + override fun check(cert: Certificate, unresolvedCritExts: MutableCollection?) { + cert as X509Certificate + checkApprovedCRLs(cert, getCRLs(cert)) + } + + @Suppress("TooGenericExceptionCaught") + private fun getCRLs(cert: X509Certificate): Set { + val crls = try { + crlSource.fetch(cert) + } catch (e: Exception) { + if (softFail) { + addSoftFailException(e) + return emptySet() + } else { + throw undeterminedRevocationException("Unable to retrieve CRLs for cert ${cert.serialNumber}", e) + } + } + if (crls.isNotEmpty() || softFail) { + return crls + } + // Note, the JDK tries to find a valid CRL from a different signing key before giving up (RevocationChecker.verifyWithSeparateSigningKey) + throw undeterminedRevocationException("Could not find any valid CRLs for cert ${cert.serialNumber}", null) + } + + /** + * Borrowed from `RevocationChecker.checkApprovedCRLs()` + */ + @Suppress("NestedBlockDepth") + private fun checkApprovedCRLs(cert: X509Certificate, approvedCRLs: Set) { + // See if the cert is in the set of approved crls. + logger.debug { "Check cert ${cert.serialNumber} against CRLs ${approvedCRLs.map { it.toSimpleString() }}" } + + for (crl in approvedCRLs) { + val entry = crl.getRevokedCertificate(cert) + if (entry != null) { + /* + * Abort CRL validation and throw exception if there are any + * unrecognized critical CRL entry extensions (see section + * 5.3 of RFC 5280). + */ + val unresCritExts = entry.criticalExtensionOIDs + if (unresCritExts != null && unresCritExts.isNotEmpty()) { + /* remove any that we will process */ + unresCritExts.remove(Extension.cRLDistributionPoints.id) + unresCritExts.remove(Extension.certificateIssuer.id) + if (unresCritExts.isNotEmpty()) { + throw CertPathValidatorException("Unrecognized critical extension(s) in revoked CRL entry: $unresCritExts") + } + } + + val reasonCode = entry.revocationReason ?: CRLReason.UNSPECIFIED + val revocationDate = entry.revocationDate + if (revocationDate.before(dateSource())) { + val t = CertificateRevokedException(revocationDate, reasonCode, crl.issuerX500Principal, emptyMap()) + throw CertPathValidatorException(t.message, t, null, -1, BasicReason.REVOKED) + } + } + } + } + + /** + * This is set to false intentionally for security reasons. + * It ensures that certificates are provided in reverse direction (from most-trusted CA to target certificate) + * after the necessary validation checks have already been performed. + * + * If that wasn't the case, we could be reaching out to CRL endpoints for invalid certificates, which would open security holes + * e.g. systems that are not part of a Corda network could force a Corda firewall to initiate outbound requests to systems under their control. + */ + override fun isForwardCheckingSupported(): Boolean { + return false + } + + override fun getSupportedExtensions(): Set? { + return null + } + + override fun init(forward: Boolean) { + if (forward) { + throw CertPathValidatorException("Forward checking not allowed") + } + softFailExceptions.clear() + } + + override fun getSoftFailExceptions(): List { + return Collections.unmodifiableList(softFailExceptions) + } + + private fun addSoftFailException(e: Exception) { + logger.debug("Soft fail exception", e) + softFailExceptions += undeterminedRevocationException(e.message, e) + } + + private fun undeterminedRevocationException(message: String?, cause: Throwable?): CertPathValidatorException { + return CertPathValidatorException(message, cause, null, -1, BasicReason.UNDETERMINED_REVOCATION_STATUS) + } +} diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/CustomSerializationSchemeAdapter.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/CustomSerializationSchemeAdapter.kt new file mode 100644 index 0000000000..f656f81502 --- /dev/null +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/CustomSerializationSchemeAdapter.kt @@ -0,0 +1,47 @@ +package net.corda.nodeapi.internal.serialization + +import net.corda.core.serialization.SerializationSchemeContext +import net.corda.core.serialization.CustomSerializationScheme +import net.corda.core.serialization.SerializationContext +import net.corda.core.serialization.SerializedBytes +import net.corda.core.serialization.internal.CustomSerializationSchemeUtils.Companion.getCustomSerializationMagicFromSchemeId +import net.corda.core.utilities.ByteSequence +import net.corda.serialization.internal.CordaSerializationMagic +import net.corda.serialization.internal.SerializationScheme +import java.io.ByteArrayOutputStream +import java.io.NotSerializableException + +class CustomSerializationSchemeAdapter(private val customScheme: CustomSerializationScheme): SerializationScheme { + + val serializationSchemeMagic = getCustomSerializationMagicFromSchemeId(customScheme.getSchemeId()) + + override fun canDeserializeVersion(magic: CordaSerializationMagic, target: SerializationContext.UseCase): Boolean { + return magic == serializationSchemeMagic + } + + override fun deserialize(byteSequence: ByteSequence, clazz: Class, context: SerializationContext): T { + val readMagic = byteSequence.take(serializationSchemeMagic.size) + if (readMagic != serializationSchemeMagic) { + throw NotSerializableException("Scheme ${customScheme::class.java} is incompatible with blob." + + " Magic from blob = $readMagic (Expected = $serializationSchemeMagic)") + } + return customScheme.deserialize( + byteSequence.subSequence(serializationSchemeMagic.size, byteSequence.size - serializationSchemeMagic.size), + clazz, + SerializationSchemeContextAdapter(context) + ) + } + + override fun serialize(obj: T, context: SerializationContext): SerializedBytes { + val stream = ByteArrayOutputStream() + stream.write(serializationSchemeMagic.bytes) + stream.write(customScheme.serialize(obj, SerializationSchemeContextAdapter(context)).bytes) + return SerializedBytes(stream.toByteArray()) + } + + private class SerializationSchemeContextAdapter(context: SerializationContext) : SerializationSchemeContext { + override val deserializationClassLoader = context.deserializationClassLoader + override val whitelist = context.whitelist + override val properties = context.properties + } +} \ No newline at end of file diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/kryo/Kryo.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/kryo/Kryo.kt index 7866d51e08..b594954ef5 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/kryo/Kryo.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/kryo/Kryo.kt @@ -30,7 +30,7 @@ import java.security.PublicKey import java.security.cert.CertPath import java.security.cert.CertificateFactory import java.security.cert.X509Certificate -import java.util.* +import java.util.Collections import javax.annotation.concurrent.ThreadSafe import kotlin.reflect.KClass import kotlin.reflect.KMutableProperty @@ -509,6 +509,7 @@ class ThrowableSerializer(kryo: Kryo, type: Class) : Serializer @ThreadSafe @SuppressWarnings("ALL") object LazyMappedListSerializer : Serializer>() { - override fun write(kryo: Kryo, output: Output, obj: List<*>) = kryo.writeClassAndObject(output, obj.toList()) - override fun read(kryo: Kryo, input: Input, type: Class>) = kryo.readClassAndObject(input) as List<*> + // Using a MutableList so that Kryo will always write an instance of java.util.ArrayList. + override fun write(kryo: Kryo, output: Output, obj: List<*>) = kryo.writeClassAndObject(output, obj.toMutableList()) + override fun read(kryo: Kryo, input: Input, type: Class>) = kryo.readClassAndObject(input) as? List<*> } diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/kryo/KryoCheckpointSerializer.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/kryo/KryoCheckpointSerializer.kt index 06698d99ad..178682e088 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/kryo/KryoCheckpointSerializer.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/kryo/KryoCheckpointSerializer.kt @@ -90,7 +90,7 @@ object KryoCheckpointSerializer : CheckpointSerializer { */ private fun getInputClassForCustomSerializer(classLoader: ClassLoader, customSerializer: CustomSerializerCheckpointAdaptor<*, *>): Class<*> { val typeNameWithoutGenerics = customSerializer.cordappType.typeName.substringBefore('<') - return classLoader.loadClass(typeNameWithoutGenerics) + return Class.forName(typeNameWithoutGenerics, false, classLoader) } /** diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/notary/NotaryQueries.kt b/node-api/src/main/kotlin/net/corda/nodeapi/notary/NotaryQueries.kt new file mode 100644 index 0000000000..af97ba89c9 --- /dev/null +++ b/node-api/src/main/kotlin/net/corda/nodeapi/notary/NotaryQueries.kt @@ -0,0 +1,30 @@ +package net.corda.nodeapi.notary + +import net.corda.core.contracts.StateRef +import net.corda.core.internal.notary.NotaryService +import net.corda.core.serialization.CordaSerializable +import java.time.Instant + +/** + * Implementations of queries supported by notary services + */ +class SpentStateQuery { + @CordaSerializable + data class Request(val stateRef: StateRef, + val maxResults: Int, + val successOnly: Boolean, + val startTime: Instant?, + val endTime: Instant?, + val lastTxId: String?) : NotaryService.Query.Request + + @CordaSerializable + data class Result(val spendEvents: List, + val moreResults: Boolean): NotaryService.Query.Result + + @CordaSerializable + data class SpendEventDetails(val requestTimestamp: Instant, + val transactionId: String, + val result: String, + val requestingPartyName: String?, + val workerNodeX500Name: String?) +} \ No newline at end of file diff --git a/node-api/src/test/java/net/corda/nodeapi/internal/serialization/DummyCustomSerializationSchemeInJava.java b/node-api/src/test/java/net/corda/nodeapi/internal/serialization/DummyCustomSerializationSchemeInJava.java new file mode 100644 index 0000000000..3be21e0b86 --- /dev/null +++ b/node-api/src/test/java/net/corda/nodeapi/internal/serialization/DummyCustomSerializationSchemeInJava.java @@ -0,0 +1,30 @@ +package net.corda.nodeapi.internal.serialization; + +import net.corda.core.serialization.SerializationSchemeContext; +import net.corda.core.serialization.CustomSerializationScheme; +import net.corda.core.serialization.SerializedBytes; +import net.corda.core.utilities.ByteSequence; + +public class DummyCustomSerializationSchemeInJava implements CustomSerializationScheme { + + public class DummyOutput {} + + static final int testMagic = 7; + + @Override + public int getSchemeId() { + return testMagic; + } + + @Override + @SuppressWarnings("unchecked") + public T deserialize(ByteSequence bytes, Class clazz, SerializationSchemeContext context) { + return (T)new DummyOutput(); + } + + @Override + public SerializedBytes serialize(T obj, SerializationSchemeContext context) { + byte[] myBytes = {0xA, 0xA}; + return new SerializedBytes<>(myBytes); + } +} \ No newline at end of file diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/crypto/TlsDiffAlgorithmsTest.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/crypto/TlsDiffAlgorithmsTest.kt index 21c7bc8a94..e951c587c2 100644 --- a/node-api/src/test/kotlin/net/corda/nodeapi/internal/crypto/TlsDiffAlgorithmsTest.kt +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/crypto/TlsDiffAlgorithmsTest.kt @@ -4,7 +4,8 @@ import net.corda.core.crypto.newSecureRandom import net.corda.core.utilities.Try import net.corda.core.utilities.contextLogger import net.corda.nodeapi.internal.config.CertificateStore -import net.corda.nodeapi.internal.protonwrapper.netty.init +import net.corda.nodeapi.internal.protonwrapper.netty.keyManagerFactory +import net.corda.nodeapi.internal.protonwrapper.netty.trustManagerFactory import org.assertj.core.api.Assertions import org.junit.Rule import org.junit.Test @@ -161,11 +162,9 @@ class TlsDiffAlgorithmsTest(private val serverAlgo: String, private val clientAl private fun createSslContext(keyStore: CertificateStore, trustStore: CertificateStore): SSLContext { return SSLContext.getInstance("TLS").apply { - val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()) - keyManagerFactory.init(keyStore) + val keyManagerFactory = keyManagerFactory(keyStore) val keyManagers = keyManagerFactory.keyManagers - val trustMgrFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()) - trustMgrFactory.init(trustStore) + val trustMgrFactory = trustManagerFactory(trustStore) val trustManagers = trustMgrFactory.trustManagers init(keyManagers, trustManagers, newSecureRandom()) } diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/crypto/TlsDiffProtocolsTest.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/crypto/TlsDiffProtocolsTest.kt index 46b6bf381c..0bb81e5627 100644 --- a/node-api/src/test/kotlin/net/corda/nodeapi/internal/crypto/TlsDiffProtocolsTest.kt +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/crypto/TlsDiffProtocolsTest.kt @@ -4,7 +4,8 @@ import net.corda.core.crypto.newSecureRandom import net.corda.core.utilities.Try import net.corda.core.utilities.contextLogger import net.corda.nodeapi.internal.config.CertificateStore -import net.corda.nodeapi.internal.protonwrapper.netty.init +import net.corda.nodeapi.internal.protonwrapper.netty.keyManagerFactory +import net.corda.nodeapi.internal.protonwrapper.netty.trustManagerFactory import org.assertj.core.api.Assertions import org.junit.Ignore import org.junit.Rule @@ -18,7 +19,6 @@ import java.io.IOException import java.net.InetAddress import java.net.InetSocketAddress import javax.net.ssl.* -import javax.net.ssl.SNIHostName import kotlin.concurrent.thread import kotlin.test.assertEquals import kotlin.test.assertFalse @@ -209,11 +209,9 @@ class TlsDiffProtocolsTest(private val serverAlgo: String, private val clientAlg private fun createSslContext(keyStore: CertificateStore, trustStore: CertificateStore): SSLContext { return SSLContext.getInstance("TLS").apply { - val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()) - keyManagerFactory.init(keyStore) + val keyManagerFactory = keyManagerFactory(keyStore) val keyManagers = keyManagerFactory.keyManagers - val trustMgrFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()) - trustMgrFactory.init(trustStore) + val trustMgrFactory = trustManagerFactory(trustStore) val trustManagers = trustMgrFactory.trustManagers init(keyManagers, trustManagers, newSecureRandom()) } diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/SSLHelperTest.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/SSLHelperTest.kt index 782d8b8abe..12eb6d3e35 100644 --- a/node-api/src/test/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/SSLHelperTest.kt +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/SSLHelperTest.kt @@ -1,5 +1,6 @@ package net.corda.nodeapi.internal.protonwrapper.netty +import io.netty.util.concurrent.ImmediateExecutor import net.corda.core.crypto.SecureHash import net.corda.core.identity.CordaX500Name import net.corda.core.utilities.NetworkHostAndPort @@ -8,10 +9,9 @@ import net.corda.nodeapi.internal.DEV_CA_KEY_STORE_PASS import net.corda.nodeapi.internal.DEV_CA_PRIVATE_KEY_PASS import net.corda.nodeapi.internal.config.CertificateStore import net.corda.nodeapi.internal.crypto.X509Utilities.CORDA_CLIENT_TLS +import net.corda.testing.internal.fixedCrlSource import org.junit.Test -import javax.net.ssl.KeyManagerFactory import javax.net.ssl.SNIHostName -import javax.net.ssl.TrustManagerFactory import kotlin.test.assertEquals class SSLHelperTest { @@ -20,15 +20,21 @@ class SSLHelperTest { val legalName = CordaX500Name("Test", "London", "GB") val sslConfig = configureTestSSL(legalName) - val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()) - val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()) + val keyManagerFactory = keyManagerFactory(sslConfig.keyStore.get()) - val keyStore = sslConfig.keyStore - keyManagerFactory.init(CertificateStore.fromFile(keyStore.path, keyStore.storePassword, keyStore.entryPassword, false)) - val trustStore = sslConfig.trustStore - trustManagerFactory.init(initialiseTrustStoreAndEnableCrlChecking(CertificateStore.fromFile(trustStore.path, trustStore.storePassword, trustStore.entryPassword, false), RevocationConfigImpl(RevocationConfig.Mode.HARD_FAIL))) + val trustManagerFactory = trustManagerFactoryWithRevocation( + sslConfig.trustStore.get(), + RevocationConfigImpl(RevocationConfig.Mode.HARD_FAIL), + fixedCrlSource(emptySet()) + ) - val sslHandler = createClientSslHelper(NetworkHostAndPort("localhost", 1234), setOf(legalName), keyManagerFactory, trustManagerFactory) + val sslHandler = createClientSslHandler( + NetworkHostAndPort("localhost", 1234), + setOf(legalName), + keyManagerFactory, + trustManagerFactory, + ImmediateExecutor.INSTANCE + ) val legalNameHash = SecureHash.sha256(legalName.toString()).toString().take(32).toLowerCase() // These hardcoded values must not be changed, something is broken if you have to change these hardcoded values. diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/revocation/CertDistPointCrlSourceTest.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/revocation/CertDistPointCrlSourceTest.kt new file mode 100644 index 0000000000..66c17e4a39 --- /dev/null +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/revocation/CertDistPointCrlSourceTest.kt @@ -0,0 +1,50 @@ +package net.corda.nodeapi.internal.revocation + +import net.corda.core.crypto.Crypto +import net.corda.core.utilities.NetworkHostAndPort +import net.corda.nodeapi.internal.DEV_INTERMEDIATE_CA +import net.corda.testing.node.internal.network.CrlServer +import org.assertj.core.api.Assertions.assertThat +import org.bouncycastle.jce.provider.BouncyCastleProvider +import org.junit.After +import org.junit.Before +import org.junit.Test + +class CertDistPointCrlSourceTest { + private lateinit var crlServer: CrlServer + + @Before + fun setUp() { + // Do not use Security.addProvider(BouncyCastleProvider()) to avoid EdDSA signature disruption in other tests. + Crypto.findProvider(BouncyCastleProvider.PROVIDER_NAME) + crlServer = CrlServer(NetworkHostAndPort("localhost", 0)) + crlServer.start() + } + + @After + fun tearDown() { + if (::crlServer.isInitialized) { + crlServer.close() + } + } + + @Test(timeout=300_000) + fun `happy path`() { + val crlSource = CertDistPointCrlSource() + + with(crlSource.fetch(crlServer.intermediateCa.certificate)) { + assertThat(size).isEqualTo(1) + assertThat(single().revokedCertificates).isNull() + } + + crlSource.clearCache() + + crlServer.revokedIntermediateCerts += DEV_INTERMEDIATE_CA.certificate + with(crlSource.fetch(crlServer.intermediateCa.certificate)) { + assertThat(size).isEqualTo(1) + val revokedCertificates = single().revokedCertificates + // This also tests clearCache() works. + assertThat(revokedCertificates.map { it.serialNumber }).containsExactly(DEV_INTERMEDIATE_CA.certificate.serialNumber) + } + } +} diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/revocation/ExternalSourceRevocationCheckerTest.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/revocation/CordaRevocationCheckerTest.kt similarity index 64% rename from node-api/src/test/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/revocation/ExternalSourceRevocationCheckerTest.kt rename to node-api/src/test/kotlin/net/corda/nodeapi/internal/revocation/CordaRevocationCheckerTest.kt index 7be350a525..6dbfcd4515 100644 --- a/node-api/src/test/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/revocation/ExternalSourceRevocationCheckerTest.kt +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/revocation/CordaRevocationCheckerTest.kt @@ -1,26 +1,27 @@ -package net.corda.nodeapi.internal.protonwrapper.netty.revocation +package net.corda.nodeapi.internal.revocation import net.corda.core.utilities.Try import net.corda.nodeapi.internal.DEV_CA_KEY_STORE_PASS import net.corda.nodeapi.internal.DEV_CA_PRIVATE_KEY_PASS import net.corda.nodeapi.internal.config.CertificateStore import net.corda.nodeapi.internal.crypto.X509Utilities -import net.corda.nodeapi.internal.protonwrapper.netty.ExternalCrlSource +import net.corda.testing.internal.fixedCrlSource import org.bouncycastle.jcajce.provider.asymmetric.x509.CertificateFactory import org.junit.Test import java.math.BigInteger - import java.security.cert.X509CRL import java.security.cert.X509Certificate -import java.sql.Date +import java.time.LocalDate +import java.time.ZoneOffset +import java.util.* import kotlin.test.assertEquals import kotlin.test.assertTrue -class ExternalSourceRevocationCheckerTest { +class CordaRevocationCheckerTest { @Test(timeout=300_000) fun checkRevoked() { - val checkResult = performCheckOnDate(Date.valueOf("2019-09-27")) + val checkResult = performCheckOnDate(LocalDate.of(2019, 9, 27)) val failedChecks = checkResult.filterNot { it.second.isSuccess } assertEquals(1, failedChecks.size) assertEquals(BigInteger.valueOf(8310484079152632582), failedChecks.first().first.serialNumber) @@ -28,11 +29,11 @@ class ExternalSourceRevocationCheckerTest { @Test(timeout=300_000) fun checkTooEarly() { - val checkResult = performCheckOnDate(Date.valueOf("2019-08-27")) + val checkResult = performCheckOnDate(LocalDate.of(2019, 8, 27)) assertTrue(checkResult.all { it.second.isSuccess }) } - private fun performCheckOnDate(date: Date): List>> { + private fun performCheckOnDate(date: LocalDate): List>> { val certStore = CertificateStore.fromResource( "net/corda/nodeapi/internal/protonwrapper/netty/sslkeystore_Revoked.jks", DEV_CA_KEY_STORE_PASS, DEV_CA_PRIVATE_KEY_PASS) @@ -40,16 +41,15 @@ class ExternalSourceRevocationCheckerTest { val resourceAsStream = javaClass.getResourceAsStream("/net/corda/nodeapi/internal/protonwrapper/netty/doorman.crl") val crl = CertificateFactory().engineGenerateCRL(resourceAsStream) as X509CRL - //val crlHolder = X509CRLHolder(resourceAsStream) - //crlHolder.revokedCertificates as X509CRLEntryHolder - - val instance = ExternalSourceRevocationChecker(object : ExternalCrlSource { - override fun fetch(certificate: X509Certificate): Set = setOf(crl) - }) { date } + val checker = CordaRevocationChecker( + crlSource = fixedCrlSource(setOf(crl)), + softFail = true, + dateSource = { Date.from(date.atStartOfDay().toInstant(ZoneOffset.UTC)) } + ) return certStore.query { getCertificateChain(X509Utilities.CORDA_CLIENT_TLS).map { - Pair(it, Try.on { instance.check(it, mutableListOf()) }) + Pair(it, Try.on { checker.check(it, mutableListOf()) }) } } } diff --git a/node/src/test/kotlin/net/corda/node/internal/artemis/RevocationCheckTest.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/revocation/RevocationTest.kt similarity index 59% rename from node/src/test/kotlin/net/corda/node/internal/artemis/RevocationCheckTest.kt rename to node-api/src/test/kotlin/net/corda/nodeapi/internal/revocation/RevocationTest.kt index 2e984eb3b5..675f222f92 100644 --- a/node/src/test/kotlin/net/corda/node/internal/artemis/RevocationCheckTest.kt +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/revocation/RevocationTest.kt @@ -1,20 +1,16 @@ -package net.corda.node.internal.artemis +package net.corda.nodeapi.internal.revocation import net.corda.core.crypto.Crypto -import net.corda.core.utilities.days -import net.corda.node.internal.artemis.CertificateChainCheckPolicy.RevocationCheck +import net.corda.nodeapi.internal.config.CertificateStore +import net.corda.nodeapi.internal.crypto.CertificateAndKeyPair import net.corda.nodeapi.internal.crypto.CertificateType +import net.corda.nodeapi.internal.crypto.X509KeyStore import net.corda.nodeapi.internal.crypto.X509Utilities import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfig +import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfigImpl +import net.corda.nodeapi.internal.protonwrapper.netty.trustManagerFactoryWithRevocation +import net.corda.testing.core.createCRL import org.bouncycastle.asn1.x500.X500Name -import org.bouncycastle.asn1.x509.CRLReason -import org.bouncycastle.asn1.x509.Extension -import org.bouncycastle.asn1.x509.ExtensionsGenerator -import org.bouncycastle.asn1.x509.GeneralName -import org.bouncycastle.asn1.x509.GeneralNames -import org.bouncycastle.asn1.x509.IssuingDistributionPoint -import org.bouncycastle.cert.jcajce.JcaX509v2CRLBuilder -import org.bouncycastle.operator.jcajce.JcaContentSignerBuilder import org.junit.Before import org.junit.Rule import org.junit.Test @@ -22,15 +18,18 @@ import org.junit.rules.TemporaryFolder import org.junit.runner.RunWith import org.junit.runners.Parameterized import java.io.File +import java.security.KeyPair import java.security.KeyStore import java.security.PrivateKey +import java.security.cert.CertificateException import java.security.cert.X509Certificate import java.util.* +import javax.net.ssl.X509TrustManager import javax.security.auth.x500.X500Principal -import kotlin.test.assertFails +import kotlin.test.assertFailsWith @RunWith(Parameterized::class) -class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) { +class RevocationTest(private val revocationMode: RevocationConfig.Mode) { companion object { @JvmStatic @Parameterized.Parameters(name = "revocationMode = {0}") @@ -45,8 +44,7 @@ class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) { private lateinit var doormanCRL: File private lateinit var tlsCRL: File - private val keyStore = KeyStore.getInstance("JKS") - private val trustStore = KeyStore.getInstance("JKS") + private lateinit var trustManager: X509TrustManager private val rootKeyPair = Crypto.generateKeyPair(Crypto.ECDSA_SECP256R1_SHA256) private val tlsCRLIssuerKeyPair = Crypto.generateKeyPair(Crypto.ECDSA_SECP256R1_SHA256) @@ -61,9 +59,7 @@ class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) { private lateinit var tlsCert: X509Certificate private val chain - get() = listOf(tlsCert, nodeCACert, doormanCert, rootCert).map { - javax.security.cert.X509Certificate.getInstance(it.encoded) - }.toTypedArray() + get() = arrayOf(tlsCert, nodeCACert, doormanCert, rootCert) @Before fun before() { @@ -74,10 +70,18 @@ class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) { rootCert = X509Utilities.createSelfSignedCACertificate(X500Principal("CN=root"), rootKeyPair) tlsCRLIssuerCert = X509Utilities.createSelfSignedCACertificate(X500Principal("CN=issuer"), tlsCRLIssuerKeyPair) + val trustStore = KeyStore.getInstance("JKS") trustStore.load(null, null) trustStore.setCertificateEntry("cordatlscrlsigner", tlsCRLIssuerCert) trustStore.setCertificateEntry("cordarootca", rootCert) + val trustManagerFactory = trustManagerFactoryWithRevocation( + CertificateStore.of(X509KeyStore(trustStore, "pass"), "pass", "pass"), + RevocationConfigImpl(revocationMode), + CertDistPointCrlSource() + ) + trustManager = trustManagerFactory.trustManagers.single() as X509TrustManager + doormanCert = X509Utilities.createCertificate( CertificateType.INTERMEDIATE_CA, rootCert, rootKeyPair, X500Principal("CN=doorman"), doormanKeyPair.public, crlDistPoint = rootCRL.toURI().toString() @@ -91,43 +95,34 @@ class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) { crlDistPoint = tlsCRL.toURI().toString(), crlIssuer = X500Name.getInstance(tlsCRLIssuerCert.issuerX500Principal.encoded) ) - rootCRL.createCRL(rootCert, rootKeyPair.private, false) - doormanCRL.createCRL(doormanCert, doormanKeyPair.private, false) - tlsCRL.createCRL(tlsCRLIssuerCert, tlsCRLIssuerKeyPair.private, true) + rootCRL.writeCRL(rootCert, rootKeyPair.private, false) + doormanCRL.writeCRL(doormanCert, doormanKeyPair.private, false) + tlsCRL.writeCRL(tlsCRLIssuerCert, tlsCRLIssuerKeyPair.private, true) } - private fun File.createCRL(certificate: X509Certificate, privateKey: PrivateKey, indirect: Boolean, vararg revoked: X509Certificate) { - val builder = JcaX509v2CRLBuilder(certificate.subjectX500Principal, Date()) - builder.setNextUpdate(Date.from(Date().toInstant() + 7.days)) - builder.addExtension(Extension.issuingDistributionPoint, true, IssuingDistributionPoint(null, indirect, false)) - revoked.forEach { - val extensionsGenerator = ExtensionsGenerator() - extensionsGenerator.addExtension(Extension.reasonCode, false, CRLReason.lookup(CRLReason.keyCompromise)) - // Certificate issuer is required for indirect CRL - val certificateIssuerName = X500Name.getInstance(it.issuerX500Principal.encoded) - extensionsGenerator.addExtension(Extension.certificateIssuer, true, GeneralNames(GeneralName(certificateIssuerName))) - builder.addCRLEntry(it.serialNumber, Date(), extensionsGenerator.generate()) - } - val holder = builder.build(JcaContentSignerBuilder("SHA256withECDSA").setProvider(Crypto.findProvider("BC")).build(privateKey)) - outputStream().use { it.write(holder.encoded) } + private fun File.writeCRL(certificate: X509Certificate, privateKey: PrivateKey, indirect: Boolean, vararg revoked: X509Certificate) { + val crl = createCRL( + CertificateAndKeyPair(certificate, KeyPair(certificate.publicKey, privateKey)), + revoked.asList(), + indirect = indirect + ) + writeBytes(crl.encoded) } - private fun assertFailsFor(vararg modes: RevocationConfig.Mode, block: () -> Unit) { - if (revocationMode in modes) assertFails(block) else block() + private fun assertFailsFor(vararg modes: RevocationConfig.Mode) { + if (revocationMode in modes) assertFailsWith(CertificateException::class, ::doRevocationCheck) else doRevocationCheck() } @Test(timeout = 300_000) fun `ok with empty CRLs`() { - RevocationCheck(revocationMode).createCheck(keyStore, trustStore).checkCertificateChain(chain) + doRevocationCheck() } @Test(timeout = 300_000) fun `soft fail with revoked TLS certificate`() { - tlsCRL.createCRL(tlsCRLIssuerCert, tlsCRLIssuerKeyPair.private, true, tlsCert) + tlsCRL.writeCRL(tlsCRLIssuerCert, tlsCRLIssuerKeyPair.private, true, tlsCert) - assertFailsFor(RevocationConfig.Mode.SOFT_FAIL, RevocationConfig.Mode.HARD_FAIL) { - RevocationCheck(revocationMode).createCheck(keyStore, trustStore).checkCertificateChain(chain) - } + assertFailsFor(RevocationConfig.Mode.SOFT_FAIL, RevocationConfig.Mode.HARD_FAIL) } @Test(timeout = 300_000) @@ -138,9 +133,7 @@ class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) { crlIssuer = X500Name.getInstance(tlsCRLIssuerCert.issuerX500Principal.encoded) ) - assertFailsFor(RevocationConfig.Mode.HARD_FAIL) { - RevocationCheck(revocationMode).createCheck(keyStore, trustStore).checkCertificateChain(chain) - } + assertFailsFor(RevocationConfig.Mode.HARD_FAIL) } @Test(timeout = 300_000) @@ -150,9 +143,7 @@ class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) { crlDistPoint = tlsCRL.toURI().toString(), crlIssuer = X500Name("CN=unknown") ) - assertFailsFor(RevocationConfig.Mode.HARD_FAIL) { - RevocationCheck(revocationMode).createCheck(keyStore, trustStore).checkCertificateChain(chain) - } + assertFailsFor(RevocationConfig.Mode.HARD_FAIL) } @Test(timeout = 300_000) @@ -162,9 +153,7 @@ class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) { crlDistPoint = tlsCRL.toURI().toString() ) - assertFailsFor(RevocationConfig.Mode.HARD_FAIL) { - RevocationCheck(revocationMode).createCheck(keyStore, trustStore).checkCertificateChain(chain) - } + assertFailsFor(RevocationConfig.Mode.HARD_FAIL) } @Test(timeout = 300_000) @@ -174,18 +163,16 @@ class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) { CertificateType.TLS, nodeCACert, nodeCAKeyPair, X500Principal("CN=other"), otherKeyPair.public, crlDistPoint = tlsCRL.toURI().toString(), crlIssuer = X500Name.getInstance(tlsCRLIssuerCert.issuerX500Principal.encoded) ) - tlsCRL.createCRL(tlsCRLIssuerCert, tlsCRLIssuerKeyPair.private, true, otherCert) + tlsCRL.writeCRL(tlsCRLIssuerCert, tlsCRLIssuerKeyPair.private, true, otherCert) - RevocationCheck(revocationMode).createCheck(keyStore, trustStore).checkCertificateChain(chain) + doRevocationCheck() } @Test(timeout = 300_000) fun `soft fail with revoked node CA certificate`() { - doormanCRL.createCRL(doormanCert, doormanKeyPair.private, false, nodeCACert) + doormanCRL.writeCRL(doormanCert, doormanKeyPair.private, false, nodeCACert) - assertFailsFor(RevocationConfig.Mode.SOFT_FAIL, RevocationConfig.Mode.HARD_FAIL) { - RevocationCheck(revocationMode).createCheck(keyStore, trustStore).checkCertificateChain(chain) - } + assertFailsFor(RevocationConfig.Mode.SOFT_FAIL, RevocationConfig.Mode.HARD_FAIL) } @Test(timeout = 300_000) @@ -195,9 +182,7 @@ class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) { crlDistPoint = "http://unknown-host:10000/certificate-revocation-list/doorman" ) - assertFailsFor(RevocationConfig.Mode.HARD_FAIL) { - RevocationCheck(revocationMode).createCheck(keyStore, trustStore).checkCertificateChain(chain) - } + assertFailsFor(RevocationConfig.Mode.HARD_FAIL) } @Test(timeout = 300_000) @@ -207,8 +192,12 @@ class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) { CertificateType.NODE_CA, doormanCert, doormanKeyPair, X500Principal("CN=other"), otherKeyPair.public, crlDistPoint = doormanCRL.toURI().toString() ) - doormanCRL.createCRL(doormanCert, doormanKeyPair.private, false, otherCert) + doormanCRL.writeCRL(doormanCert, doormanKeyPair.private, false, otherCert) - RevocationCheck(revocationMode).createCheck(keyStore, trustStore).checkCertificateChain(chain) + doRevocationCheck() + } + + private fun doRevocationCheck() { + trustManager.checkClientTrusted(chain, "ECDHE_ECDSA") } } diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/CustomSerializationSchemeAdapterTests.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/CustomSerializationSchemeAdapterTests.kt new file mode 100644 index 0000000000..2d4f751ddf --- /dev/null +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/CustomSerializationSchemeAdapterTests.kt @@ -0,0 +1,93 @@ +package net.corda.nodeapi.internal.serialization + +import net.corda.core.serialization.SerializationSchemeContext +import net.corda.core.serialization.CustomSerializationScheme +import net.corda.core.utilities.ByteSequence +import net.corda.nodeapi.internal.serialization.testutils.serializationContext +import org.junit.Test +import org.junit.jupiter.api.Assertions.assertTrue +import java.io.NotSerializableException +import kotlin.test.assertFailsWith + +class CustomSerializationSchemeAdapterTests { + + companion object { + const val DEFAULT_SCHEME_ID = 7 + } + + class DummyInputClass + class DummyOutputClass + + class SingleInputAndOutputScheme(private val schemeId: Int = DEFAULT_SCHEME_ID): CustomSerializationScheme { + + override fun getSchemeId(): Int { + return schemeId + } + + override fun deserialize(bytes: ByteSequence, clazz: Class, context: SerializationSchemeContext): T { + @Suppress("UNCHECKED_CAST") + return DummyOutputClass() as T + } + + override fun serialize(obj: T, context: SerializationSchemeContext): ByteSequence { + assertTrue(obj is DummyInputClass) + return ByteSequence.of(ByteArray(2) { 0x2 }) + } + } + + class SameBytesInputAndOutputsAndScheme: CustomSerializationScheme { + + private val expectedBytes = "123456789".toByteArray() + + override fun getSchemeId(): Int { + return DEFAULT_SCHEME_ID + } + + override fun deserialize(bytes: ByteSequence, clazz: Class, context: SerializationSchemeContext): T { + bytes.open().use { + val data = ByteArray(expectedBytes.size) { 0 } + it.read(data) + assertTrue(data.contentEquals(expectedBytes)) + } + @Suppress("UNCHECKED_CAST") + return DummyOutputClass() as T + } + + override fun serialize(obj: T, context: SerializationSchemeContext): ByteSequence { + return ByteSequence.of(expectedBytes) + } + } + + @Test(timeout=300_000) + fun `CustomSerializationSchemeAdapter calls the correct methods in CustomSerializationScheme`() { + val scheme = CustomSerializationSchemeAdapter(SingleInputAndOutputScheme()) + val serializedData = scheme.serialize(DummyInputClass(), serializationContext) + val roundTripped = scheme.deserialize(serializedData, Any::class.java, serializationContext) + assertTrue(roundTripped is DummyOutputClass) + } + + @Test(timeout=300_000) + fun `CustomSerializationSchemeAdapter can adapt a Java implementation`() { + val scheme = CustomSerializationSchemeAdapter(DummyCustomSerializationSchemeInJava()) + val serializedData = scheme.serialize(DummyInputClass(), serializationContext) + val roundTripped = scheme.deserialize(serializedData, Any::class.java, serializationContext) + assertTrue(roundTripped is DummyCustomSerializationSchemeInJava.DummyOutput) + } + + @Test(timeout=300_000) + fun `CustomSerializationSchemeAdapter validates the magic`() { + val inScheme = CustomSerializationSchemeAdapter(SingleInputAndOutputScheme()) + val serializedData = inScheme.serialize(DummyInputClass(), serializationContext) + val outScheme = CustomSerializationSchemeAdapter(SingleInputAndOutputScheme(8)) + assertFailsWith { + outScheme.deserialize(serializedData, DummyOutputClass::class.java, serializationContext) + } + } + + @Test(timeout=300_000) + fun `CustomSerializationSchemeAdapter preserves the serialized bytes between deserialize and serialize`() { + val scheme = CustomSerializationSchemeAdapter(SameBytesInputAndOutputsAndScheme()) + val serializedData = scheme.serialize(Any(), serializationContext) + scheme.deserialize(serializedData, Any::class.java, serializationContext) + } +} \ No newline at end of file diff --git a/node/djvm/src/main/kotlin/net/corda/node/djvm/AttachmentBuilder.kt b/node/djvm/src/main/kotlin/net/corda/node/djvm/AttachmentBuilder.kt index f3a205ba38..561b5bcd76 100644 --- a/node/djvm/src/main/kotlin/net/corda/node/djvm/AttachmentBuilder.kt +++ b/node/djvm/src/main/kotlin/net/corda/node/djvm/AttachmentBuilder.kt @@ -3,6 +3,7 @@ package net.corda.node.djvm import net.corda.core.contracts.Attachment import net.corda.core.contracts.BrokenAttachmentException +import net.corda.core.contracts.ContractAttachment import net.corda.core.crypto.SecureHash import net.corda.core.identity.Party import java.io.InputStream @@ -16,6 +17,12 @@ private const val ID_IDX = 2 private const val ATTACHMENT_IDX = 3 private const val STREAMER_IDX = 4 +private const val CONTRACT_IDX = 5 +private const val ADDITIONAL_CONTRACT_IDX = 6 +private const val UPLOADER_IDX = 7 +private const val CONTRACT_SIGNER_KEYS_IDX = 8 +private const val VERSION_IDX = 9 + class AttachmentBuilder : Function?, List?> { private val attachments = mutableListOf() @@ -28,17 +35,30 @@ class AttachmentBuilder : Function?, List?> { } override fun apply(inputs: Array?): List? { + @Suppress("unchecked_cast") return if (inputs == null) { unmodifiable(attachments) } else { - @Suppress("unchecked_cast") - attachments.add(SandboxAttachment( + var attachment: Attachment = SandboxAttachment( signerKeys = inputs[SIGNER_KEYS_IDX] as List, size = inputs[SIZE_IDX] as Int, id = inputs[ID_IDX] as SecureHash, attachment = inputs[ATTACHMENT_IDX], streamer = inputs[STREAMER_IDX] as Function - )) + ) + + if (inputs.size > VERSION_IDX) { + attachment = ContractAttachment.create( + attachment = attachment, + contract = inputs[CONTRACT_IDX] as String, + additionalContracts = (inputs[ADDITIONAL_CONTRACT_IDX] as Array).toSet(), + uploader = inputs[UPLOADER_IDX] as? String, + signerKeys = inputs[CONTRACT_SIGNER_KEYS_IDX] as List, + version = inputs[VERSION_IDX] as Int + ) + } + + attachments.add(attachment) null } } @@ -47,7 +67,7 @@ class AttachmentBuilder : Function?, List?> { /** * This represents an [Attachment] from within the sandbox. */ -class SandboxAttachment( +private class SandboxAttachment( override val signerKeys: List, override val size: Int, override val id: SecureHash, diff --git a/node/djvm/src/main/kotlin/net/corda/node/djvm/CommandBuilder.kt b/node/djvm/src/main/kotlin/net/corda/node/djvm/CommandBuilder.kt index 311f8e69ff..247fef3ec6 100644 --- a/node/djvm/src/main/kotlin/net/corda/node/djvm/CommandBuilder.kt +++ b/node/djvm/src/main/kotlin/net/corda/node/djvm/CommandBuilder.kt @@ -5,35 +5,43 @@ import net.corda.core.contracts.CommandWithParties import net.corda.core.internal.lazyMapped import java.security.PublicKey import java.util.function.Function +import java.util.function.Supplier -class CommandBuilder : Function, List>> { +class CommandBuilder : Function, Supplier>>> { @Suppress("unchecked_cast") - override fun apply(inputs: Array): List> { - val signers = inputs[0] as? List> ?: emptyList() - val commandsData = inputs[1] as? List ?: emptyList() + override fun apply(inputs: Array): Supplier>> { + val signersProvider = inputs[0] as? Supplier>> ?: Supplier(::emptyList) + val commandsDataProvider = inputs[1] as? Supplier> ?: Supplier(::emptyList) val partialMerkleLeafIndices = inputs[2] as? IntArray /** * This logic has been lovingly reproduced from [net.corda.core.internal.deserialiseCommands]. */ - return if (partialMerkleLeafIndices != null) { - check(commandsData.size <= signers.size) { - "Invalid Transaction. Fewer Signers (${signers.size}) than CommandData (${commandsData.size}) objects" - } - if (partialMerkleLeafIndices.isNotEmpty()) { - check(partialMerkleLeafIndices.max()!! < signers.size) { - "Invalid Transaction. A command with no corresponding signer detected" + return Supplier { + val signers = signersProvider.get() + val commandsData = commandsDataProvider.get() + + if (partialMerkleLeafIndices != null) { + check(commandsData.size <= signers.size) { + "Invalid Transaction. Fewer Signers (${signers.size}) than CommandData (${commandsData.size}) objects" + } + if (partialMerkleLeafIndices.isNotEmpty()) { + check(partialMerkleLeafIndices.max()!! < signers.size) { + "Invalid Transaction. A command with no corresponding signer detected" + } + } + commandsData.lazyMapped { commandData, index -> + // Deprecated signingParties property not supported. + CommandWithParties(signers[partialMerkleLeafIndices[index]], emptyList(), commandData) + } + } else { + check(commandsData.size == signers.size) { + "Invalid Transaction. Sizes of CommandData (${commandsData.size}) and Signers (${signers.size}) do not match" + } + commandsData.lazyMapped { commandData, index -> + // Deprecated signingParties property not supported. + CommandWithParties(signers[index], emptyList(), commandData) } - } - commandsData.lazyMapped { commandData, index -> - CommandWithParties(signers[partialMerkleLeafIndices[index]], emptyList(), commandData) - } - } else { - check(commandsData.size == signers.size) { - "Invalid Transaction. Sizes of CommandData (${commandsData.size}) and Signers (${signers.size}) do not match" - } - commandsData.lazyMapped { commandData, index -> - CommandWithParties(signers[index], emptyList(), commandData) } } } diff --git a/node/djvm/src/main/kotlin/net/corda/node/djvm/ComponentBuilder.kt b/node/djvm/src/main/kotlin/net/corda/node/djvm/ComponentBuilder.kt index 78c3efc737..f0e2e476aa 100644 --- a/node/djvm/src/main/kotlin/net/corda/node/djvm/ComponentBuilder.kt +++ b/node/djvm/src/main/kotlin/net/corda/node/djvm/ComponentBuilder.kt @@ -5,19 +5,22 @@ import net.corda.core.internal.TransactionDeserialisationException import net.corda.core.internal.lazyMapped import net.corda.core.utilities.OpaqueBytes import java.util.function.Function +import java.util.function.Supplier -class ComponentBuilder : Function, List<*>> { +class ComponentBuilder : Function, Supplier>> { @Suppress("unchecked_cast", "TooGenericExceptionCaught") - override fun apply(inputs: Array): List<*> { + override fun apply(inputs: Array): Supplier> { val deserializer = inputs[0] as Function val groupType = inputs[1] as ComponentGroupEnum val components = (inputs[2] as Array).map(::OpaqueBytes) - return components.lazyMapped { component, index -> - try { - deserializer.apply(component.bytes) - } catch (e: Exception) { - throw TransactionDeserialisationException(groupType, index, e) + return Supplier { + components.lazyMapped { component, index -> + try { + deserializer.apply(component.bytes) + } catch (e: Exception) { + throw TransactionDeserialisationException(groupType, index, e) + } } } } diff --git a/node/djvm/src/main/kotlin/net/corda/node/djvm/LtxFactory.kt b/node/djvm/src/main/kotlin/net/corda/node/djvm/LtxFactory.kt deleted file mode 100644 index f12f0cb108..0000000000 --- a/node/djvm/src/main/kotlin/net/corda/node/djvm/LtxFactory.kt +++ /dev/null @@ -1,54 +0,0 @@ -@file:JvmName("LtxConstants") -package net.corda.node.djvm - -import net.corda.core.contracts.Attachment -import net.corda.core.contracts.CommandData -import net.corda.core.contracts.CommandWithParties -import net.corda.core.contracts.ContractState -import net.corda.core.contracts.PrivacySalt -import net.corda.core.contracts.StateAndRef -import net.corda.core.contracts.StateRef -import net.corda.core.contracts.TimeWindow -import net.corda.core.contracts.TransactionState -import net.corda.core.crypto.DigestService -import net.corda.core.crypto.SecureHash -import net.corda.core.identity.Party -import net.corda.core.node.NetworkParameters -import net.corda.core.transactions.LedgerTransaction -import java.util.function.Function - -private const val TX_INPUTS = 0 -private const val TX_OUTPUTS = 1 -private const val TX_COMMANDS = 2 -private const val TX_ATTACHMENTS = 3 -private const val TX_ID = 4 -private const val TX_NOTARY = 5 -private const val TX_TIME_WINDOW = 6 -private const val TX_PRIVACY_SALT = 7 -private const val TX_NETWORK_PARAMETERS = 8 -private const val TX_REFERENCES = 9 -private const val TX_DIGEST_SERVICE = 10 - -class LtxFactory : Function, LedgerTransaction> { - - @Suppress("unchecked_cast") - override fun apply(txArgs: Array): LedgerTransaction { - return LedgerTransaction.createForSandbox( - inputs = (txArgs[TX_INPUTS] as Array>).map { it.toStateAndRef() }, - outputs = (txArgs[TX_OUTPUTS] as? List>) ?: emptyList(), - commands = (txArgs[TX_COMMANDS] as? List>) ?: emptyList(), - attachments = (txArgs[TX_ATTACHMENTS] as? List) ?: emptyList(), - id = txArgs[TX_ID] as SecureHash, - notary = txArgs[TX_NOTARY] as? Party, - timeWindow = txArgs[TX_TIME_WINDOW] as? TimeWindow, - privacySalt = txArgs[TX_PRIVACY_SALT] as PrivacySalt, - networkParameters = txArgs[TX_NETWORK_PARAMETERS] as NetworkParameters, - references = (txArgs[TX_REFERENCES] as Array>).map { it.toStateAndRef() }, - digestService = if (txArgs.size > TX_DIGEST_SERVICE) (txArgs[TX_DIGEST_SERVICE] as DigestService) else DigestService.sha2_256 - ) - } - - private fun Array<*>.toStateAndRef(): StateAndRef { - return StateAndRef(this[0] as TransactionState<*>, this[1] as StateRef) - } -} diff --git a/node/djvm/src/main/kotlin/net/corda/node/djvm/LtxSupplierFactory.kt b/node/djvm/src/main/kotlin/net/corda/node/djvm/LtxSupplierFactory.kt new file mode 100644 index 0000000000..fd982d610a --- /dev/null +++ b/node/djvm/src/main/kotlin/net/corda/node/djvm/LtxSupplierFactory.kt @@ -0,0 +1,73 @@ +@file:JvmName("LtxTools") +package net.corda.node.djvm + +import net.corda.core.contracts.Attachment +import net.corda.core.contracts.CommandData +import net.corda.core.contracts.CommandWithParties +import net.corda.core.contracts.ContractState +import net.corda.core.contracts.PrivacySalt +import net.corda.core.contracts.StateAndRef +import net.corda.core.contracts.StateRef +import net.corda.core.contracts.TimeWindow +import net.corda.core.contracts.TransactionState +import net.corda.core.crypto.DigestService +import net.corda.core.crypto.SecureHash +import net.corda.core.identity.Party +import net.corda.core.node.NetworkParameters +import net.corda.core.transactions.LedgerTransaction +import java.util.function.Function +import java.util.function.Supplier + +private const val TX_INPUTS = 0 +private const val TX_OUTPUTS = 1 +private const val TX_COMMANDS = 2 +private const val TX_ATTACHMENTS = 3 +private const val TX_ID = 4 +private const val TX_NOTARY = 5 +private const val TX_TIME_WINDOW = 6 +private const val TX_PRIVACY_SALT = 7 +private const val TX_NETWORK_PARAMETERS = 8 +private const val TX_REFERENCES = 9 +private const val TX_DIGEST_SERVICE = 10 + +class LtxSupplierFactory : Function, Supplier> { + @Suppress("unchecked_cast") + override fun apply(txArgs: Array): Supplier { + val inputProvider = (txArgs[TX_INPUTS] as Function>>) + .andThen(Function(Array>::toContractStatesAndRef)) + .toSupplier() + val outputProvider = txArgs[TX_OUTPUTS] as? Supplier>> ?: Supplier(::emptyList) + val commandsProvider = txArgs[TX_COMMANDS] as Supplier>> + val referencesProvider = (txArgs[TX_REFERENCES] as Function>>) + .andThen(Function(Array>::toContractStatesAndRef)) + .toSupplier() + val networkParameters = (txArgs[TX_NETWORK_PARAMETERS] as? NetworkParameters)?.toImmutable() + return Supplier { + LedgerTransaction.createForContractVerify( + inputs = inputProvider.get(), + outputs = outputProvider.get(), + commands = commandsProvider.get(), + attachments = txArgs[TX_ATTACHMENTS] as? List ?: emptyList(), + id = txArgs[TX_ID] as SecureHash, + notary = txArgs[TX_NOTARY] as? Party, + timeWindow = txArgs[TX_TIME_WINDOW] as? TimeWindow, + privacySalt = txArgs[TX_PRIVACY_SALT] as PrivacySalt, + networkParameters = networkParameters, + references = referencesProvider.get(), + digestService = txArgs[TX_DIGEST_SERVICE] as DigestService + ) + } + } +} + +private fun Function.toSupplier(): Supplier { + return Supplier { apply(null) } +} + +private fun Array>.toContractStatesAndRef(): List> { + return map(Array::toStateAndRef) +} + +private fun Array<*>.toStateAndRef(): StateAndRef { + return StateAndRef(this[0] as TransactionState<*>, this[1] as StateRef) +} diff --git a/node/src/integration-test-slow/kotlin/net/corda/node/services/statemachine/StateMachineErrorHandlingTest.kt b/node/src/integration-test-slow/kotlin/net/corda/node/services/statemachine/StateMachineErrorHandlingTest.kt index 6a43de1ea8..3e0882e62e 100644 --- a/node/src/integration-test-slow/kotlin/net/corda/node/services/statemachine/StateMachineErrorHandlingTest.kt +++ b/node/src/integration-test-slow/kotlin/net/corda/node/services/statemachine/StateMachineErrorHandlingTest.kt @@ -153,13 +153,15 @@ abstract class StateMachineErrorHandlingTest { runnable: Int = 0, failed: Int = 0, completed: Int = 0, - hospitalized: Int = 0 + hospitalized: Int = 0, + killed: Int = 0 ) { val counts = startFlow(StateMachineErrorHandlingTest::GetNumberOfCheckpointsFlow).returnValue.getOrThrow(20.seconds) assertEquals(runnable, counts.runnable, "There should be $runnable runnable checkpoints") assertEquals(failed, counts.failed, "There should be $failed failed checkpoints") assertEquals(completed, counts.completed, "There should be $completed completed checkpoints") assertEquals(hospitalized, counts.hospitalized, "There should be $hospitalized hospitalized checkpoints") + assertEquals(killed, counts.killed, "There should be $killed killed checkpoints") } internal fun CordaRPCOps.assertNumberOfCheckpointsAllZero() = assertNumberOfCheckpoints() @@ -189,6 +191,7 @@ abstract class StateMachineErrorHandlingTest { class ThrowAnErrorFlow : FlowLogic() { @Suspendable override fun call(): String { + sleep(1.seconds) throwException() return "cant get here" } @@ -219,7 +222,8 @@ abstract class StateMachineErrorHandlingTest { runnable = getNumberOfCheckpointsWithStatus(Checkpoint.FlowStatus.RUNNABLE), failed = getNumberOfCheckpointsWithStatus(Checkpoint.FlowStatus.FAILED), completed = getNumberOfCheckpointsWithStatus(Checkpoint.FlowStatus.COMPLETED), - hospitalized = getNumberOfCheckpointsWithStatus(Checkpoint.FlowStatus.HOSPITALIZED) + hospitalized = getNumberOfCheckpointsWithStatus(Checkpoint.FlowStatus.HOSPITALIZED), + killed = getNumberOfCheckpointsWithStatus(Checkpoint.FlowStatus.KILLED) ) private fun getNumberOfCheckpointsWithStatus(status: Checkpoint.FlowStatus): Int { @@ -243,7 +247,8 @@ abstract class StateMachineErrorHandlingTest { val runnable: Int = 0, val failed: Int = 0, val completed: Int = 0, - val hospitalized: Int = 0 + val hospitalized: Int = 0, + val killed: Int = 0 ) // Internal use for testing only!! diff --git a/node/src/integration-test-slow/kotlin/net/corda/node/services/statemachine/StateMachineFlowInitErrorHandlingTest.kt b/node/src/integration-test-slow/kotlin/net/corda/node/services/statemachine/StateMachineFlowInitErrorHandlingTest.kt index 93f92aa81d..24f06c9600 100644 --- a/node/src/integration-test-slow/kotlin/net/corda/node/services/statemachine/StateMachineFlowInitErrorHandlingTest.kt +++ b/node/src/integration-test-slow/kotlin/net/corda/node/services/statemachine/StateMachineFlowInitErrorHandlingTest.kt @@ -1056,4 +1056,158 @@ class StateMachineFlowInitErrorHandlingTest : StateMachineErrorHandlingTest() { charlie.rpc.assertHospitalCounts(discharged = 3) } } + + /** + * Throws an exception when calling [FlowStateMachineImpl.recordDuration] to cause an unexpected error during flow initialisation. + * + * The hospital has the flow's medical history updated with the new failure added to it. As the failure occurred before the original + * checkpoint was persisted, there is no checkpoint to update in the database. + */ + @Test(timeout = 300_000) + fun `unexpected error during flow initialisation that gets caught by default exception handler puts flow into in-memory overnight observation`() { + startDriver { + val (charlie, alice, port) = createNodeAndBytemanNode(CHARLIE_NAME, ALICE_NAME) + val rules = """ + RULE Throw exception + CLASS ${FlowStateMachineImpl::class.java.name} + METHOD openThreadLocalWormhole + AT ENTRY + IF readCounter("counter") < 1 + DO incrementCounter("counter"); traceln("Throwing exception"); throw new java.lang.RuntimeException("die dammit die") + ENDRULE + """.trimIndent() + + submitBytemanRules(rules, port) + + executor.execute { + alice.rpc.startFlow( + ::SendAMessageFlow, + charlie.nodeInfo.singleIdentity() + ) + } + + Thread.sleep(10.seconds.toMillis()) + + val (discharge, observation) = alice.rpc.startFlow(::GetHospitalCountersFlow).returnValue.get() + assertEquals(0, discharge) + assertEquals(1, observation) + assertEquals(1, alice.rpc.stateMachinesSnapshot().size) + // The flow failed during flow initialisation before committing the original checkpoint + // therefore there is no checkpoint to update the status of + alice.rpc.assertNumberOfCheckpoints(hospitalized = 0) + } + } + + /** + * Throws an exception when calling [FlowStateMachineImpl.logFlowError] to cause an unexpected error after the flow has properly + * initialised. + * + * The hospital has the flow's medical history updated with the new failure added to it. The status of the checkpoint is also set to + * [Checkpoint.FlowStatus.HOSPITALIZED] to reflect this information in the database. + */ + @Test(timeout = 300_000) + fun `unexpected error after flow initialisation that gets caught by default exception handler puts flow into overnight observation and reflected in database`() { + startDriver { + val (alice, port) = createBytemanNode(ALICE_NAME) + val rules = """ + RULE Throw exception + CLASS ${FlowStateMachineImpl::class.java.name} + METHOD logFlowError + AT ENTRY + IF readCounter("counter") < 1 + DO incrementCounter("counter"); traceln("Throwing exception"); throw new java.lang.RuntimeException("die dammit die") + ENDRULE + """.trimIndent() + + submitBytemanRules(rules, port) + + assertFailsWith { + alice.rpc.startFlow(::ThrowAnErrorFlow).returnValue.getOrThrow(30.seconds) + } + + val (discharge, observation) = alice.rpc.startFlow(::GetHospitalCountersFlow).returnValue.get() + assertEquals(0, discharge) + assertEquals(1, observation) + assertEquals(1, alice.rpc.stateMachinesSnapshot().size) + alice.rpc.assertNumberOfCheckpoints(hospitalized = 1) + } + } + + /** + * Throws an exception when calling [FlowStateMachineImpl.logFlowError] to cause an unexpected error after the flow has properly + * initialised. When updating the status of the flow to [Checkpoint.FlowStatus.HOSPITALIZED] an error occurs. + * + * The update is rescheduled and tried again. This is done separate from the fiber. + */ + @Test(timeout = 300_000) + fun `unexpected error after flow initialisation that gets caught by default exception handler retries the status update if it fails`() { + startDriver { + val (alice, port) = createBytemanNode(ALICE_NAME) + val rules = """ + RULE Throw exception + CLASS ${FlowStateMachineImpl::class.java.name} + METHOD logFlowError + AT ENTRY + IF readCounter("counter") < 1 + DO incrementCounter("counter"); traceln("Throwing exception"); throw new java.lang.RuntimeException("die dammit die") + ENDRULE + + RULE Throw exception when updating status + INTERFACE ${CheckpointStorage::class.java.name} + METHOD updateStatus + AT ENTRY + IF readCounter("counter") < 2 + DO incrementCounter("counter"); traceln("Throwing exception"); throw new java.lang.RuntimeException("should be a sql exception") + ENDRULE + """.trimIndent() + + submitBytemanRules(rules, port) + + assertFailsWith { + alice.rpc.startFlow(::ThrowAnErrorFlow).returnValue.getOrThrow(50.seconds) + } + + val (discharge, observation) = alice.rpc.startFlow(::GetHospitalCountersFlow).returnValue.get() + assertEquals(0, discharge) + assertEquals(1, observation) + assertEquals(1, alice.rpc.stateMachinesSnapshot().size) + alice.rpc.assertNumberOfCheckpoints(hospitalized = 1) + } + } + + /** + * Throws an exception when calling [FlowStateMachineImpl.recordDuration] to cause an unexpected error after a flow has returned its + * result to the client. + * + * As the flow has already returned its result to the client, then the status of the flow has already been updated correctly and now the + * flow has experienced an unexpected error. There is no need to change the status as the flow has already finished. + */ + @Test(timeout = 300_000) + fun `unexpected error after flow has returned result to client that gets caught by default exception handler does nothing except log`() { + startDriver { + val (charlie, alice, port) = createNodeAndBytemanNode(CHARLIE_NAME, ALICE_NAME) + val rules = """ + RULE Throw exception + CLASS ${FlowStateMachineImpl::class.java.name} + METHOD recordDuration + AT ENTRY + IF readCounter("counter") < 1 + DO incrementCounter("counter"); traceln("Throwing exception"); throw new java.lang.RuntimeException("die dammit die") + ENDRULE + """.trimIndent() + + submitBytemanRules(rules, port) + + alice.rpc.startFlow( + ::SendAMessageFlow, + charlie.nodeInfo.singleIdentity() + ).returnValue.getOrThrow(30.seconds) + + val (discharge, observation) = alice.rpc.startFlow(::GetHospitalCountersFlow).returnValue.get() + assertEquals(0, discharge) + assertEquals(0, observation) + assertEquals(0, alice.rpc.stateMachinesSnapshot().size) + alice.rpc.assertNumberOfCheckpoints() + } + } } \ No newline at end of file diff --git a/node/src/integration-test-slow/kotlin/net/corda/node/services/statemachine/StateMachineGeneralErrorHandlingTest.kt b/node/src/integration-test-slow/kotlin/net/corda/node/services/statemachine/StateMachineGeneralErrorHandlingTest.kt index 894a66692f..282c3d9bb4 100644 --- a/node/src/integration-test-slow/kotlin/net/corda/node/services/statemachine/StateMachineGeneralErrorHandlingTest.kt +++ b/node/src/integration-test-slow/kotlin/net/corda/node/services/statemachine/StateMachineGeneralErrorHandlingTest.kt @@ -697,6 +697,58 @@ class StateMachineGeneralErrorHandlingTest : StateMachineErrorHandlingTest() { } } + /** + * Throws an exception when calling [FlowStateMachineImpl.logFlowError] to cause an unexpected error after the flow has properly + * initialised, placing the flow into a dead state. + * + * On shutdown this flow will still terminate correctly and not prevent the node from shutting down. + */ + @Suppress("TooGenericExceptionCaught") + @Test(timeout = 300_000) + fun `a dead flow can be shutdown`() { + startDriver { + val (alice, port) = createBytemanNode(ALICE_NAME) + val rules = """ + RULE Throw exception + CLASS ${FlowStateMachineImpl::class.java.name} + METHOD logFlowError + AT ENTRY + IF readCounter("counter") < 1 + DO incrementCounter("counter"); traceln("Throwing exception"); throw new java.lang.RuntimeException("die dammit die") + ENDRULE + + RULE Log that state machine has ended + CLASS $stateMachineManagerClassName + METHOD stop + AT EXIT + IF true + DO traceln("State machine shutdown") + ENDRULE + """.trimIndent() + + submitBytemanRules(rules, port) + + assertFailsWith { + alice.rpc.startFlow(::ThrowAnErrorFlow).returnValue.getOrThrow(50.seconds) + } + + val (discharge, observation) = alice.rpc.startFlow(::GetHospitalCountersFlow).returnValue.get() + assertEquals(0, discharge) + assertEquals(1, observation) + assertEquals(1, alice.rpc.stateMachinesSnapshot().size) + alice.rpc.assertNumberOfCheckpoints(hospitalized = 1) + + try { + // This actually shuts down the node + alice.rpc.shutdown() + } catch(e: Exception) { + // Exception gets thrown due to shutdown + } + Thread.sleep(30.seconds.toMillis()) + alice.assertBytemanOutput("State machine shutdown", 1) + } + } + @StartableByRPC class SleepCatchAndRethrowFlow : FlowLogic() { @Suspendable diff --git a/node/src/integration-test-slow/kotlin/net/corda/node/services/statemachine/StateMachineKillFlowErrorHandlingTest.kt b/node/src/integration-test-slow/kotlin/net/corda/node/services/statemachine/StateMachineKillFlowErrorHandlingTest.kt index f39005c476..2a4814fe00 100644 --- a/node/src/integration-test-slow/kotlin/net/corda/node/services/statemachine/StateMachineKillFlowErrorHandlingTest.kt +++ b/node/src/integration-test-slow/kotlin/net/corda/node/services/statemachine/StateMachineKillFlowErrorHandlingTest.kt @@ -5,6 +5,7 @@ import net.corda.core.flows.FlowLogic import net.corda.core.flows.KilledFlowException import net.corda.core.flows.StartableByRPC import net.corda.core.messaging.startFlow +import net.corda.core.messaging.startFlowWithClientId import net.corda.core.messaging.startTrackedFlow import net.corda.core.utilities.ProgressTracker import net.corda.core.utilities.getOrThrow @@ -140,6 +141,168 @@ class StateMachineKillFlowErrorHandlingTest : StateMachineErrorHandlingTest() { } } + /** + * Throws an exception when calling [FlowStateMachineImpl.logFlowError] to cause an unexpected error after the flow has properly + * initialised, placing the flow into a dead state. + * + * The flow is then manually killed which triggers the flow to go through the normal kill flow process. + */ + @Test(timeout = 300_000) + fun `a dead flow can be killed`() { + startDriver { + val (alice, port) = createBytemanNode(ALICE_NAME) + val rules = """ + RULE Throw exception + CLASS ${FlowStateMachineImpl::class.java.name} + METHOD logFlowError + AT ENTRY + IF readCounter("counter") < 1 + DO incrementCounter("counter"); traceln("Throwing exception"); throw new java.lang.RuntimeException("die dammit die") + ENDRULE + """.trimIndent() + + submitBytemanRules(rules, port) + + val handle = alice.rpc.startFlow(::ThrowAnErrorFlow) + val id = handle.id + + assertFailsWith { + handle.returnValue.getOrThrow(20.seconds) + } + + val (discharge, observation) = alice.rpc.startFlow(::GetHospitalCountersFlow).returnValue.get() + assertEquals(0, discharge) + assertEquals(1, observation) + assertEquals(1, alice.rpc.stateMachinesSnapshot().size) + alice.rpc.assertNumberOfCheckpoints(hospitalized = 1) + + val killed = alice.rpc.killFlow(id) + + assertTrue(killed) + + Thread.sleep(20.seconds.toMillis()) + + assertEquals(0, alice.rpc.stateMachinesSnapshot().size) + alice.rpc.assertNumberOfCheckpointsAllZero() + } + } + + /** + * Throws an exception when calling [FlowStateMachineImpl.logFlowError] to cause an unexpected error after the flow has properly + * initialised, placing the flow into a dead state. + * + * The flow is then manually killed which triggers the flow to go through the normal kill flow process. + * + * Since the flow was started with a client id, record of the [KilledFlowException] should exists in the database. + */ + @Test(timeout = 300_000) + fun `a dead flow that was started with a client id can be killed`() { + startDriver { + val (alice, port) = createBytemanNode(ALICE_NAME) + val rules = """ + RULE Throw exception + CLASS ${FlowStateMachineImpl::class.java.name} + METHOD logFlowError + AT ENTRY + IF readCounter("counter") < 1 + DO incrementCounter("counter"); traceln("Throwing exception"); throw new java.lang.RuntimeException("die dammit die") + ENDRULE + """.trimIndent() + + submitBytemanRules(rules, port) + + val handle = alice.rpc.startFlowWithClientId("my id", ::ThrowAnErrorFlow) + val id = handle.id + + assertFailsWith { + handle.returnValue.getOrThrow(20.seconds) + } + + val (discharge, observation) = alice.rpc.startFlow(::GetHospitalCountersFlow).returnValue.get() + assertEquals(0, discharge) + assertEquals(1, observation) + assertEquals(1, alice.rpc.stateMachinesSnapshot().size) + alice.rpc.assertNumberOfCheckpoints(hospitalized = 1) + + val killed = alice.rpc.killFlow(id) + + assertTrue(killed) + + Thread.sleep(20.seconds.toMillis()) + + assertEquals(0, alice.rpc.stateMachinesSnapshot().size) + alice.rpc.assertNumberOfCheckpoints(killed = 1) + // Exception thrown by flow + assertFailsWith { + alice.rpc.reattachFlowWithClientId("my id")?.returnValue?.getOrThrow(20.seconds) + } + } + } + + /** + * Throws an exception when calling [FlowStateMachineImpl.logFlowError] to cause an unexpected error after the flow has properly + * initialised, placing the flow into a dead state. + * + * The flow is then manually killed which triggers the flow to go through the normal kill flow process. + */ + @Test(timeout = 300_000) + fun `a dead flow that is killed and fails again will forcibly kill itself`() { + startDriver { + val (alice, port) = createBytemanNode(ALICE_NAME) + val rules = """ + RULE Throw exception + CLASS ${FlowStateMachineImpl::class.java.name} + METHOD logFlowError + AT ENTRY + IF readCounter("counter") == 0 + DO incrementCounter("counter"); traceln("Throwing exception"); throw new java.lang.RuntimeException("die dammit die") + ENDRULE + + RULE Throw exception 2 + CLASS ${TransitionExecutorImpl::class.java.name} + METHOD executeTransition + AT ENTRY + IF readCounter("counter") == 1 + DO incrementCounter("counter"); traceln("Throwing exception"); throw new java.lang.RuntimeException("die again") + ENDRULE + + RULE Log that removeFlow is called + CLASS $stateMachineManagerClassName + METHOD removeFlow + AT EXIT + IF true + DO traceln("removeFlow called") + ENDRULE + + RULE Log that killFlowForcibly is called + CLASS $stateMachineManagerClassName + METHOD killFlowForcibly + AT EXIT + IF true + DO traceln("killFlowForcibly called") + ENDRULE + """.trimIndent() + + submitBytemanRules(rules, port) + + val handle = alice.rpc.startFlow(::ThrowAnErrorFlow) + val id = handle.id + + assertFailsWith { + handle.returnValue.getOrThrow(20.seconds) + } + + assertTrue(alice.rpc.killFlow(id)) + + Thread.sleep(20.seconds.toMillis()) + + alice.assertBytemanOutput("removeFlow called", 1) + alice.assertBytemanOutput("killFlowForcibly called", 1) + assertEquals(0, alice.rpc.stateMachinesSnapshot().size) + alice.rpc.assertNumberOfCheckpointsAllZero() + } + } + @StartableByRPC class SleepFlow : FlowLogic() { diff --git a/node/src/integration-test/java/net/corda/serialization/reproduction/GenericReturnFailureReproductionIntegrationTest.java b/node/src/integration-test/java/net/corda/serialization/reproduction/GenericReturnFailureReproductionIntegrationTest.java index f76f55bc49..5c7f59a2bc 100644 --- a/node/src/integration-test/java/net/corda/serialization/reproduction/GenericReturnFailureReproductionIntegrationTest.java +++ b/node/src/integration-test/java/net/corda/serialization/reproduction/GenericReturnFailureReproductionIntegrationTest.java @@ -1,6 +1,5 @@ package net.corda.serialization.reproduction; -import com.google.common.io.LineProcessor; import net.corda.client.rpc.CordaRPCClient; import net.corda.core.concurrent.CordaFuture; import net.corda.node.services.Permissions; diff --git a/node/src/integration-test/kotlin/net/corda/contracts/multiple/evil/EvilContract.kt b/node/src/integration-test/kotlin/net/corda/contracts/multiple/evil/EvilContract.kt new file mode 100644 index 0000000000..a5c547bd2d --- /dev/null +++ b/node/src/integration-test/kotlin/net/corda/contracts/multiple/evil/EvilContract.kt @@ -0,0 +1,46 @@ +package net.corda.contracts.multiple.evil + +import net.corda.contracts.multiple.vulnerable.MutableDataObject +import net.corda.contracts.multiple.vulnerable.VulnerablePaymentContract.VulnerablePurchase +import net.corda.contracts.multiple.vulnerable.VulnerablePaymentContract.VulnerableState +import net.corda.core.contracts.CommandData +import net.corda.core.contracts.Contract +import net.corda.core.contracts.ContractState +import net.corda.core.identity.AbstractParty +import net.corda.core.transactions.LedgerTransaction + +@Suppress("unused") +class EvilContract : Contract { + override fun verify(tx: LedgerTransaction) { + val vulnerableStates = tx.outputsOfType(VulnerableState::class.java) + val vulnerablePurchases = tx.commandsOfType(VulnerablePurchase::class.java) + + val addExtras = tx.commandsOfType(AddExtra::class.java) + addExtras.forEach { extra -> + val extraValue = extra.value.payment.value + + // And our extra value to every vulnerable output state. + vulnerableStates.forEach { state -> + state.data?.also { data -> + data.value += extraValue + } + } + + // Add our extra value to every vulnerable command too. + vulnerablePurchases.forEach { purchase -> + purchase.value.payment.value += extraValue + } + } + } + + class EvilState(val owner: AbstractParty) : ContractState { + override val participants: List = listOf(owner) + + @Override + override fun toString(): String { + return "Money For Nothing!" + } + } + + class AddExtra(val payment: MutableDataObject) : CommandData +} diff --git a/node/src/integration-test/kotlin/net/corda/contracts/multiple/vulnerable/MutableDataObject.kt b/node/src/integration-test/kotlin/net/corda/contracts/multiple/vulnerable/MutableDataObject.kt new file mode 100644 index 0000000000..50ffebfa17 --- /dev/null +++ b/node/src/integration-test/kotlin/net/corda/contracts/multiple/vulnerable/MutableDataObject.kt @@ -0,0 +1,14 @@ +package net.corda.contracts.multiple.vulnerable + +import net.corda.core.serialization.CordaSerializable + +@CordaSerializable +data class MutableDataObject(var value: Long) : Comparable { + override fun toString(): String { + return "$value data points" + } + + override fun compareTo(other: MutableDataObject): Int { + return value.compareTo(other.value) + } +} diff --git a/node/src/integration-test/kotlin/net/corda/contracts/multiple/vulnerable/VulnerablePaymentContract.kt b/node/src/integration-test/kotlin/net/corda/contracts/multiple/vulnerable/VulnerablePaymentContract.kt new file mode 100644 index 0000000000..6e90bb5725 --- /dev/null +++ b/node/src/integration-test/kotlin/net/corda/contracts/multiple/vulnerable/VulnerablePaymentContract.kt @@ -0,0 +1,43 @@ +package net.corda.contracts.multiple.vulnerable + +import net.corda.core.contracts.CommandData +import net.corda.core.contracts.Contract +import net.corda.core.contracts.ContractState +import net.corda.core.contracts.requireThat +import net.corda.core.identity.AbstractParty +import net.corda.core.transactions.LedgerTransaction + +@Suppress("unused") +class VulnerablePaymentContract : Contract { + companion object { + const val BASE_PAYMENT = 2000L + } + + override fun verify(tx: LedgerTransaction) { + val states = tx.outputsOfType() + requireThat { + "Requires at least one data state" using states.isNotEmpty() + } + val purchases = tx.commandsOfType() + requireThat { + "Requires at least one purchase" using purchases.isNotEmpty() + } + for (purchase in purchases) { + val payment = purchase.value.payment + requireThat { + "Purchase payment of $payment should be at least $BASE_PAYMENT" using (payment.value >= BASE_PAYMENT) + } + } + } + + class VulnerableState(val owner: AbstractParty, val data: MutableDataObject?) : ContractState { + override val participants: List = listOf(owner) + + @Override + override fun toString(): String { + return data.toString() + } + } + + class VulnerablePurchase(val payment: MutableDataObject) : CommandData +} diff --git a/node/src/integration-test/kotlin/net/corda/contracts/mutator/MutatorContract.kt b/node/src/integration-test/kotlin/net/corda/contracts/mutator/MutatorContract.kt new file mode 100644 index 0000000000..239525c576 --- /dev/null +++ b/node/src/integration-test/kotlin/net/corda/contracts/mutator/MutatorContract.kt @@ -0,0 +1,115 @@ +package net.corda.contracts.mutator + +import net.corda.core.contracts.CommandData +import net.corda.core.contracts.CommandWithParties +import net.corda.core.contracts.Contract +import net.corda.core.contracts.ContractState +import net.corda.core.contracts.TransactionState +import net.corda.core.contracts.requireSingleCommand +import net.corda.core.contracts.requireThat +import net.corda.core.identity.AbstractParty +import net.corda.core.internal.Verifier +import net.corda.core.serialization.SerializationContext +import net.corda.core.transactions.LedgerTransaction + +class MutatorContract : Contract { + override fun verify(tx: LedgerTransaction) { + tx.transform { componentGroups, serializedInputs, serializedReferences -> + requireThat { + "component groups are protected" using componentGroups.isImmutableAnd(isEmpty = true) + "serialized inputs are protected" using serializedInputs.isImmutableAnd(isEmpty = true) + "serialized references are protected" using serializedReferences.isImmutableAnd(isEmpty = true) + } + } + + requireThat { + "Cannot add/remove inputs" using tx.inputs.isImmutable() + "Cannot add/remove outputs" using failToMutateOutputs(tx) + "Cannot add/remove commands" using failToMutateCommands(tx) + "Cannot add/remove references" using tx.references.isImmutable() + "Cannot add/remove attachments" using tx.attachments.isImmutableAnd(isEmpty = false) + "Cannot specialise transaction" using failToSpecialise(tx) + } + + requireNotNull(tx.networkParameters).also { networkParameters -> + requireThat { + "Cannot add/remove notaries" using networkParameters.notaries.isImmutableAnd(isEmpty = false) + "Cannot add/remove package ownerships" using networkParameters.packageOwnership.isImmutable() + "Cannot add/remove whitelisted contracts" using networkParameters.whitelistedContractImplementations.isImmutable() + } + } + } + + private fun List<*>.isImmutableAnd(isEmpty: Boolean): Boolean { + return isImmutable() && (this.isEmpty() == isEmpty) + } + + private fun List<*>.isImmutable(): Boolean { + return try { + @Suppress("platform_class_mapped_to_kotlin") + (this as java.util.List<*>).clear() + false + } catch (e: UnsupportedOperationException) { + true + } + } + + private fun failToMutateOutputs(tx: LedgerTransaction): Boolean { + val output = tx.outputsOfType().single() + val mutableOutputs = tx.outputs as MutableList> + return try { + mutableOutputs += TransactionState(MutateState(output.owner), MutatorContract::class.java.name, tx.notary!!, 0) + false + } catch (e: UnsupportedOperationException) { + true + } + } + + private fun failToMutateCommands(tx: LedgerTransaction): Boolean { + val mutate = tx.commands.requireSingleCommand() + val mutableCommands = tx.commands as MutableList> + return try { + mutableCommands += CommandWithParties(mutate.signers, emptyList(), MutateCommand()) + false + } catch (e: UnsupportedOperationException) { + true + } + } + + private fun Map<*, *>.isImmutable(): Boolean { + return try { + @Suppress("platform_class_mapped_to_kotlin") + (this as java.util.Map<*, *>).clear() + false + } catch (e: UnsupportedOperationException) { + true + } + } + + private fun failToSpecialise(ltx: LedgerTransaction): Boolean { + return try { + ltx.specialise(::ExtraSpecialise) + false + } catch (e: IllegalStateException) { + true + } + } + + private class ExtraSpecialise(private val ltx: LedgerTransaction, private val ctx: SerializationContext) : Verifier { + override fun verify() { + ltx.inputStates.forEach(::println) + println(ctx.deserializationClassLoader) + } + } + + class MutateState(val owner: AbstractParty) : ContractState { + override val participants: List = listOf(owner) + + @Override + override fun toString(): String { + return "All change!" + } + } + + class MutateCommand : CommandData +} diff --git a/node/src/integration-test/kotlin/net/corda/flows/multiple/evil/EvilFlow.kt b/node/src/integration-test/kotlin/net/corda/flows/multiple/evil/EvilFlow.kt new file mode 100644 index 0000000000..0c627c7e46 --- /dev/null +++ b/node/src/integration-test/kotlin/net/corda/flows/multiple/evil/EvilFlow.kt @@ -0,0 +1,39 @@ +package net.corda.flows.multiple.evil + +import co.paralleluniverse.fibers.Suspendable +import net.corda.contracts.multiple.evil.EvilContract.EvilState +import net.corda.contracts.multiple.evil.EvilContract.AddExtra +import net.corda.contracts.multiple.vulnerable.MutableDataObject +import net.corda.contracts.multiple.vulnerable.VulnerablePaymentContract.VulnerablePurchase +import net.corda.contracts.multiple.vulnerable.VulnerablePaymentContract.VulnerableState +import net.corda.core.contracts.Command +import net.corda.core.crypto.SecureHash +import net.corda.core.flows.FlowLogic +import net.corda.core.flows.StartableByRPC +import net.corda.core.transactions.TransactionBuilder + +@StartableByRPC +class EvilFlow( + private val purchase: MutableDataObject +) : FlowLogic() { + private companion object { + private val NOTHING = MutableDataObject(0) + } + + @Suspendable + override fun call(): SecureHash { + val notary = serviceHub.networkMapCache.notaryIdentities[0] + val stx = serviceHub.signInitialTransaction( + TransactionBuilder(notary) + // Add Evil objects first, so that Corda will verify EvilContract first. + .addCommand(Command(AddExtra(purchase), ourIdentity.owningKey)) + .addOutputState(EvilState(ourIdentity)) + + // Now add the VulnerablePaymentContract objects with NO PAYMENT! + .addCommand(Command(VulnerablePurchase(NOTHING), ourIdentity.owningKey)) + .addOutputState(VulnerableState(ourIdentity, NOTHING)) + ) + stx.verify(serviceHub, checkSufficientSignatures = false) + return stx.id + } +} diff --git a/node/src/integration-test/kotlin/net/corda/flows/mutator/MutatorFlow.kt b/node/src/integration-test/kotlin/net/corda/flows/mutator/MutatorFlow.kt new file mode 100644 index 0000000000..c131dfa433 --- /dev/null +++ b/node/src/integration-test/kotlin/net/corda/flows/mutator/MutatorFlow.kt @@ -0,0 +1,26 @@ +package net.corda.flows.mutator + +import co.paralleluniverse.fibers.Suspendable +import net.corda.contracts.mutator.MutatorContract.MutateCommand +import net.corda.contracts.mutator.MutatorContract.MutateState +import net.corda.core.contracts.Command +import net.corda.core.crypto.SecureHash +import net.corda.core.flows.FlowLogic +import net.corda.core.flows.StartableByRPC +import net.corda.core.transactions.TransactionBuilder + +@StartableByRPC +class MutatorFlow : FlowLogic() { + @Suspendable + override fun call(): SecureHash { + val notary = serviceHub.networkMapCache.notaryIdentities[0] + val stx = serviceHub.signInitialTransaction( + TransactionBuilder(notary) + // Create some content for the LedgerTransaction. + .addOutputState(MutateState(ourIdentity)) + .addCommand(Command(MutateCommand(), ourIdentity.owningKey)) + ) + stx.verify(serviceHub, checkSufficientSignatures = false) + return stx.id + } +} diff --git a/node/src/integration-test/kotlin/net/corda/node/CashIssueAndPaymentTest.kt b/node/src/integration-test/kotlin/net/corda/node/CashIssueAndPaymentTest.kt new file mode 100644 index 0000000000..2da38e1509 --- /dev/null +++ b/node/src/integration-test/kotlin/net/corda/node/CashIssueAndPaymentTest.kt @@ -0,0 +1,68 @@ +package net.corda.node + +import net.corda.core.messaging.startFlow +import net.corda.core.utilities.OpaqueBytes +import net.corda.core.utilities.getOrThrow +import net.corda.core.utilities.loggerFor +import net.corda.finance.DOLLARS +import net.corda.finance.flows.CashIssueAndPaymentFlow +import net.corda.node.services.config.NodeConfiguration +import net.corda.testing.core.ALICE_NAME +import net.corda.testing.core.DUMMY_NOTARY_NAME +import net.corda.testing.core.singleIdentity +import net.corda.testing.driver.DriverParameters +import net.corda.testing.driver.driver +import net.corda.testing.driver.internal.incrementalPortAllocation +import net.corda.testing.node.NotarySpec +import net.corda.testing.node.internal.findCordapp +import org.junit.Test +import org.junit.jupiter.api.assertDoesNotThrow + +/** + * Execute a flow with sub-flows, including the finality flow. + * This operation should checkpoint, and have its checkpoint restored. + */ +@Suppress("FunctionName") +class CashIssueAndPaymentTest { + companion object { + private val logger = loggerFor() + + private val configOverrides = mapOf(NodeConfiguration::reloadCheckpointAfterSuspend.name to true) + private val CASH_AMOUNT = 500.DOLLARS + + fun parametersFor(runInProcess: Boolean = false): DriverParameters { + return DriverParameters( + systemProperties = mapOf("co.paralleluniverse.fibers.verifyInstrumentation" to "false"), + portAllocation = incrementalPortAllocation(), + startNodesInProcess = runInProcess, + notarySpecs = listOf(NotarySpec(DUMMY_NOTARY_NAME, startInProcess = runInProcess, validating = true)), + notaryCustomOverrides = configOverrides, + cordappsForAllNodes = listOf( + findCordapp("net.corda.finance.contracts"), + findCordapp("net.corda.finance.workflows") + ) + ) + } + } + + @Test(timeout = 300_000) + fun `test can issue cash`() { + driver(parametersFor()) { + val alice = startNode(providedName = ALICE_NAME, customOverrides = configOverrides).getOrThrow() + val aliceParty = alice.nodeInfo.singleIdentity() + val notaryParty = notaryHandles.single().identity + val result = assertDoesNotThrow { + alice.rpc.startFlow(::CashIssueAndPaymentFlow, + CASH_AMOUNT, + OpaqueBytes.of(0x01), + aliceParty, + false, + notaryParty + ).use { flowHandle -> + flowHandle.returnValue.getOrThrow() + } + } + logger.info("TXN={}, recipient={}", result.stx, result.recipient) + } + } +} diff --git a/node/src/integration-test/kotlin/net/corda/node/ContractCannotMutateTransactionTest.kt b/node/src/integration-test/kotlin/net/corda/node/ContractCannotMutateTransactionTest.kt new file mode 100644 index 0000000000..62a92dc14f --- /dev/null +++ b/node/src/integration-test/kotlin/net/corda/node/ContractCannotMutateTransactionTest.kt @@ -0,0 +1,48 @@ +package net.corda.node + +import net.corda.client.rpc.CordaRPCClient +import net.corda.core.messaging.startFlow +import net.corda.core.utilities.getOrThrow +import net.corda.core.utilities.loggerFor +import net.corda.flows.mutator.MutatorFlow +import net.corda.node.services.Permissions +import net.corda.testing.core.ALICE_NAME +import net.corda.testing.core.DUMMY_NOTARY_NAME +import net.corda.testing.driver.DriverParameters +import net.corda.testing.driver.driver +import net.corda.testing.driver.internal.incrementalPortAllocation +import net.corda.testing.node.NotarySpec +import net.corda.testing.node.User +import net.corda.testing.node.internal.cordappWithPackages +import org.junit.Test + +class ContractCannotMutateTransactionTest { + companion object { + private val logger = loggerFor() + private val user = User("u", "p", setOf(Permissions.all())) + private val mutatorFlowCorDapp = cordappWithPackages("net.corda.flows.mutator").signed() + private val mutatorContractCorDapp = cordappWithPackages("net.corda.contracts.mutator").signed() + + fun driverParameters(runInProcess: Boolean = false): DriverParameters { + return DriverParameters( + portAllocation = incrementalPortAllocation(), + startNodesInProcess = runInProcess, + notarySpecs = listOf(NotarySpec(DUMMY_NOTARY_NAME, startInProcess = runInProcess, validating = true)), + cordappsForAllNodes = listOf(mutatorContractCorDapp, mutatorFlowCorDapp) + ) + } + } + + @Test(timeout = 300_000) + fun testContractCannotModifyTransaction() { + driver(driverParameters()) { + val alice = startNode(providedName = ALICE_NAME, rpcUsers = listOf(user)).getOrThrow() + val txID = CordaRPCClient(hostAndPort = alice.rpcAddress) + .start(user.username, user.password) + .use { client -> + client.proxy.startFlow(::MutatorFlow).returnValue.getOrThrow() + } + logger.info("TX-ID: {}", txID) + } + } +} diff --git a/node/src/integration-test/kotlin/net/corda/node/ContractWithCordappFixupTest.kt b/node/src/integration-test/kotlin/net/corda/node/ContractWithCordappFixupTest.kt index 50e5b1b1bd..77f267aed7 100644 --- a/node/src/integration-test/kotlin/net/corda/node/ContractWithCordappFixupTest.kt +++ b/node/src/integration-test/kotlin/net/corda/node/ContractWithCordappFixupTest.kt @@ -35,11 +35,11 @@ class ContractWithCordappFixupTest { val dependentContractCorDapp = cordappWithPackages("net.corda.contracts.fixup.dependent").signed() val standaloneContractCorDapp = cordappWithPackages("net.corda.contracts.fixup.standalone").signed() - fun driverParameters(cordapps: List): DriverParameters { + fun driverParameters(cordapps: List, runInProcess: Boolean = false): DriverParameters { return DriverParameters( portAllocation = incrementalPortAllocation(), - startNodesInProcess = false, - notarySpecs = listOf(NotarySpec(DUMMY_NOTARY_NAME, validating = true)), + startNodesInProcess = runInProcess, + notarySpecs = listOf(NotarySpec(DUMMY_NOTARY_NAME, startInProcess = runInProcess, validating = true)), cordappsForAllNodes = cordapps, systemProperties = mapOf("net.corda.transactionbuilder.missingclass.disabled" to true.toString()) ) diff --git a/node/src/integration-test/kotlin/net/corda/node/ContractWithCustomSerializerTest.kt b/node/src/integration-test/kotlin/net/corda/node/ContractWithCustomSerializerTest.kt index ffb2d297b1..442214e13e 100644 --- a/node/src/integration-test/kotlin/net/corda/node/ContractWithCustomSerializerTest.kt +++ b/node/src/integration-test/kotlin/net/corda/node/ContractWithCustomSerializerTest.kt @@ -46,7 +46,7 @@ class ContractWithCustomSerializerTest(private val runInProcess: Boolean) { driver(DriverParameters( portAllocation = incrementalPortAllocation(), startNodesInProcess = runInProcess, - notarySpecs = listOf(NotarySpec(DUMMY_NOTARY_NAME, validating = true)), + notarySpecs = listOf(NotarySpec(DUMMY_NOTARY_NAME, startInProcess = runInProcess, validating = true)), cordappsForAllNodes = listOf( cordappWithPackages("net.corda.flows.serialization.custom").signed(), cordappWithPackages("net.corda.contracts.serialization.custom").signed() diff --git a/node/src/integration-test/kotlin/net/corda/node/ContractWithGenericTypeTest.kt b/node/src/integration-test/kotlin/net/corda/node/ContractWithGenericTypeTest.kt index d23c137dda..4dfb1f17e6 100644 --- a/node/src/integration-test/kotlin/net/corda/node/ContractWithGenericTypeTest.kt +++ b/node/src/integration-test/kotlin/net/corda/node/ContractWithGenericTypeTest.kt @@ -31,11 +31,11 @@ class ContractWithGenericTypeTest { @JvmField val user = User("u", "p", setOf(Permissions.all())) - fun parameters(): DriverParameters { + fun parameters(runInProcess: Boolean = false): DriverParameters { return DriverParameters( portAllocation = incrementalPortAllocation(), - startNodesInProcess = false, - notarySpecs = listOf(NotarySpec(DUMMY_NOTARY_NAME, validating = true)), + startNodesInProcess = runInProcess, + notarySpecs = listOf(NotarySpec(DUMMY_NOTARY_NAME, startInProcess = runInProcess, validating = true)), cordappsForAllNodes = listOf( cordappWithPackages("net.corda.flows.serialization.generics").signed(), cordappWithPackages("net.corda.contracts.serialization.generics").signed() diff --git a/node/src/integration-test/kotlin/net/corda/node/ContractWithMissingCustomSerializerTest.kt b/node/src/integration-test/kotlin/net/corda/node/ContractWithMissingCustomSerializerTest.kt index 2110ff3cfe..78ee896844 100644 --- a/node/src/integration-test/kotlin/net/corda/node/ContractWithMissingCustomSerializerTest.kt +++ b/node/src/integration-test/kotlin/net/corda/node/ContractWithMissingCustomSerializerTest.kt @@ -45,7 +45,7 @@ class ContractWithMissingCustomSerializerTest(private val runInProcess: Boolean) return DriverParameters( portAllocation = incrementalPortAllocation(), startNodesInProcess = runInProcess, - notarySpecs = listOf(NotarySpec(DUMMY_NOTARY_NAME, validating = true)), + notarySpecs = listOf(NotarySpec(DUMMY_NOTARY_NAME, startInProcess = runInProcess, validating = true)), cordappsForAllNodes = cordapps ) } diff --git a/node/src/integration-test/kotlin/net/corda/node/ContractWithSerializationWhitelistTest.kt b/node/src/integration-test/kotlin/net/corda/node/ContractWithSerializationWhitelistTest.kt index 2a9ae80195..9c6d809d77 100644 --- a/node/src/integration-test/kotlin/net/corda/node/ContractWithSerializationWhitelistTest.kt +++ b/node/src/integration-test/kotlin/net/corda/node/ContractWithSerializationWhitelistTest.kt @@ -43,7 +43,7 @@ class ContractWithSerializationWhitelistTest(private val runInProcess: Boolean) return DriverParameters( portAllocation = incrementalPortAllocation(), startNodesInProcess = runInProcess, - notarySpecs = listOf(NotarySpec(DUMMY_NOTARY_NAME, validating = true)), + notarySpecs = listOf(NotarySpec(DUMMY_NOTARY_NAME, startInProcess = runInProcess, validating = true)), cordappsForAllNodes = listOf(contractCordapp, workflowCordapp) ) } diff --git a/node/src/integration-test/kotlin/net/corda/node/CustomSerializationSchemeDriverTest.kt b/node/src/integration-test/kotlin/net/corda/node/CustomSerializationSchemeDriverTest.kt new file mode 100644 index 0000000000..26a2039406 --- /dev/null +++ b/node/src/integration-test/kotlin/net/corda/node/CustomSerializationSchemeDriverTest.kt @@ -0,0 +1,338 @@ +package net.corda.node + +import co.paralleluniverse.fibers.Suspendable +import com.esotericsoftware.kryo.Kryo +import com.esotericsoftware.kryo.io.Input +import com.esotericsoftware.kryo.io.Output +import de.javakaffee.kryoserializers.ArraysAsListSerializer +import net.corda.core.contracts.AlwaysAcceptAttachmentConstraint +import net.corda.core.contracts.BelongsToContract +import net.corda.core.contracts.Contract +import net.corda.core.contracts.ContractState +import net.corda.core.contracts.TransactionState +import net.corda.core.contracts.TypeOnlyCommandData +import net.corda.core.crypto.Crypto +import net.corda.core.crypto.SecureHash +import net.corda.core.crypto.SignableData +import net.corda.core.crypto.SignatureMetadata +import net.corda.core.flows.CollectSignaturesFlow +import net.corda.core.flows.FinalityFlow +import net.corda.core.flows.FlowLogic +import net.corda.core.flows.FlowSession +import net.corda.core.flows.InitiatedBy +import net.corda.core.flows.InitiatingFlow +import net.corda.core.flows.ReceiveFinalityFlow +import net.corda.core.flows.SignTransactionFlow +import net.corda.core.flows.StartableByRPC +import net.corda.core.identity.AbstractParty +import net.corda.core.identity.Party +import net.corda.core.internal.concurrent.transpose +import net.corda.core.internal.copyBytes +import net.corda.core.messaging.startFlow +import net.corda.core.node.ServiceHub +import net.corda.core.serialization.CustomSerializationScheme +import net.corda.core.serialization.SerializationSchemeContext +import net.corda.core.serialization.internal.CustomSerializationSchemeUtils.Companion.getSchemeIdIfCustomSerializationMagic +import net.corda.core.transactions.LedgerTransaction +import net.corda.core.transactions.SignedTransaction +import net.corda.core.transactions.TransactionBuilder +import net.corda.core.transactions.WireTransaction +import net.corda.core.utilities.ByteSequence +import net.corda.core.utilities.getOrThrow +import net.corda.core.utilities.unwrap +import net.corda.serialization.internal.CordaSerializationMagic +import net.corda.serialization.internal.SerializationFactoryImpl +import net.corda.testing.core.ALICE_NAME +import net.corda.testing.core.BOB_NAME +import net.corda.testing.core.DUMMY_NOTARY_NAME +import net.corda.testing.core.TestIdentity +import net.corda.testing.driver.DriverParameters +import net.corda.testing.driver.NodeParameters +import net.corda.testing.driver.driver +import net.corda.testing.node.internal.enclosedCordapp +import org.junit.Test +import org.objenesis.instantiator.ObjectInstantiator +import org.objenesis.strategy.InstantiatorStrategy +import org.objenesis.strategy.StdInstantiatorStrategy +import java.io.ByteArrayOutputStream +import java.lang.reflect.Modifier +import java.security.PublicKey +import java.util.* +import kotlin.test.assertEquals +import kotlin.test.assertTrue + +class CustomSerializationSchemeDriverTest { + + companion object { + private fun createWireTx(serviceHub: ServiceHub, notary: Party, key: PublicKey, schemeId: Int): WireTransaction { + val outputState = TransactionState( + data = DummyContract.DummyState(), + contract = DummyContract::class.java.name, + notary = notary, + constraint = AlwaysAcceptAttachmentConstraint + ) + val builder = TransactionBuilder() + .addOutputState(outputState) + .addCommand(DummyCommandData, key) + return builder.toWireTransaction(serviceHub, schemeId) + } + } + + @Test(timeout = 300_000) + fun `flow can send wire transaction serialized with custom kryo serializer`() { + driver(DriverParameters(notarySpecs = emptyList(), startNodesInProcess = true, cordappsForAllNodes = listOf(enclosedCordapp()))) { + val (alice, bob) = listOf( + startNode(NodeParameters(providedName = ALICE_NAME)), + startNode(NodeParameters(providedName = BOB_NAME)) + ).transpose().getOrThrow() + + val flow = alice.rpc.startFlow(::SendFlow, bob.nodeInfo.legalIdentities.single()) + assertTrue { flow.returnValue.getOrThrow() } + } + } + + @Test(timeout = 300_000) + fun `flow can write a wire transaction serialized with custom kryo serializer to the ledger`() { + driver(DriverParameters(startNodesInProcess = true, cordappsForAllNodes = listOf(enclosedCordapp()))) { + val (alice, bob) = listOf( + startNode(NodeParameters(providedName = ALICE_NAME)), + startNode(NodeParameters(providedName = BOB_NAME)) + ).transpose().getOrThrow() + + val flow = alice.rpc.startFlow(::WriteTxToLedgerFlow, bob.nodeInfo.legalIdentities.single(), defaultNotaryIdentity) + val txId = flow.returnValue.getOrThrow() + val transaction = alice.rpc.startFlow(::GetTxFromDBFlow, txId).returnValue.getOrThrow() + + for(group in transaction!!.tx.componentGroups) { + for (item in group.components) { + val magic = CordaSerializationMagic(item.slice(end = SerializationFactoryImpl.magicSize).copyBytes()) + assertEquals( KryoScheme.SCHEME_ID, getSchemeIdIfCustomSerializationMagic(magic)) + } + } + } + } + + @Test(timeout = 300_000) + fun `Component groups are lazily serialized by the CustomSerializationScheme`() { + driver(DriverParameters(notarySpecs = emptyList(), startNodesInProcess = true, cordappsForAllNodes = listOf(enclosedCordapp()))) { + val alice = startNode(NodeParameters(providedName = ALICE_NAME)).getOrThrow() + //We don't need a real notary as we don't verify the transaction in this test. + val dummyNotary = TestIdentity(DUMMY_NOTARY_NAME, 20) + assertTrue { alice.rpc.startFlow(::CheckComponentGroupsFlow, dummyNotary.party).returnValue.getOrThrow() } + } + } + + @Test(timeout = 300_000) + fun `Map in the serialization context can be used by lazily component group serialization`() { + driver(DriverParameters(notarySpecs = emptyList(), startNodesInProcess = true, cordappsForAllNodes = listOf(enclosedCordapp()))) { + val alice = startNode(NodeParameters(providedName = ALICE_NAME)).getOrThrow() + //We don't need a real notary as we don't verify the transaction in this test. + val dummyNotary = TestIdentity(DUMMY_NOTARY_NAME, 20) + assertTrue { alice.rpc.startFlow(::CheckComponentGroupsWithMapFlow, dummyNotary.party).returnValue.getOrThrow() } + } + } + + @StartableByRPC + @InitiatingFlow + class WriteTxToLedgerFlow(val counterparty: Party, val notary: Party) : FlowLogic() { + @Suspendable + override fun call(): SecureHash { + val wireTx = createWireTx(serviceHub, notary, counterparty.owningKey, KryoScheme.SCHEME_ID) + val partSignedTx = signWireTx(wireTx) + val session = initiateFlow(counterparty) + val fullySignedTx = subFlow(CollectSignaturesFlow(partSignedTx, setOf(session))) + subFlow(FinalityFlow(fullySignedTx, setOf(session))) + return fullySignedTx.id + } + + fun signWireTx(wireTx: WireTransaction) : SignedTransaction { + val signatureMetadata = SignatureMetadata( + serviceHub.myInfo.platformVersion, + Crypto.findSignatureScheme(serviceHub.myInfo.legalIdentitiesAndCerts.first().owningKey).schemeNumberID + ) + val signableData = SignableData(wireTx.id, signatureMetadata) + val sig = serviceHub.keyManagementService.sign(signableData, serviceHub.myInfo.legalIdentitiesAndCerts.first().owningKey) + return SignedTransaction(wireTx, listOf(sig)) + } + } + + @InitiatedBy(WriteTxToLedgerFlow::class) + class SignWireTxFlow(private val session: FlowSession): FlowLogic() { + @Suspendable + override fun call(): SignedTransaction { + val signTransactionFlow = object : SignTransactionFlow(session) { + override fun checkTransaction(stx: SignedTransaction) { + return + } + } + val txId = subFlow(signTransactionFlow).id + return subFlow(ReceiveFinalityFlow(session, expectedTxId = txId)) + } + } + + @StartableByRPC + class GetTxFromDBFlow(private val txId: SecureHash): FlowLogic() { + override fun call(): SignedTransaction? { + return serviceHub.validatedTransactions.getTransaction(txId) + } + } + + @StartableByRPC + @InitiatingFlow + class CheckComponentGroupsFlow(val notary: Party) : FlowLogic() { + @Suspendable + override fun call(): Boolean { + val wtx = createWireTx(serviceHub, notary, notary.owningKey, KryoScheme.SCHEME_ID) + var success = true + for (group in wtx.componentGroups) { + //Component groups are lazily serialized as we iterate through. + for (item in group.components) { + val magic = CordaSerializationMagic(item.slice(end = SerializationFactoryImpl.magicSize).copyBytes()) + success = success && (getSchemeIdIfCustomSerializationMagic(magic) == KryoScheme.SCHEME_ID) + } + } + return success + } + } + + @StartableByRPC + @InitiatingFlow + class CheckComponentGroupsWithMapFlow(val notary: Party) : FlowLogic() { + @Suspendable + override fun call(): Boolean { + val outputState = TransactionState( + data = DummyContract.DummyState(), + contract = DummyContract::class.java.name, + notary = notary, + constraint = AlwaysAcceptAttachmentConstraint + ) + val builder = TransactionBuilder() + .addOutputState(outputState) + .addCommand(DummyCommandData, notary.owningKey) + val mapToCheckWhenSerializing = mapOf(Pair(KryoSchemeWithMap.KEY, KryoSchemeWithMap.VALUE)) + val wtx = builder.toWireTransaction(serviceHub, KryoSchemeWithMap.SCHEME_ID, mapToCheckWhenSerializing) + var success = true + for (group in wtx.componentGroups) { + //Component groups are lazily serialized as we iterate through. + for (item in group.components) { + val magic = CordaSerializationMagic(item.slice(end = SerializationFactoryImpl.magicSize).copyBytes()) + success = success && (getSchemeIdIfCustomSerializationMagic(magic) == KryoSchemeWithMap.SCHEME_ID) + } + } + return success + } + } + + @StartableByRPC + @InitiatingFlow + class SendFlow(val counterparty: Party) : FlowLogic() { + @Suspendable + override fun call(): Boolean { + val wtx = createWireTx(serviceHub, counterparty, counterparty.owningKey, KryoScheme.SCHEME_ID) + val session = initiateFlow(counterparty) + session.send(wtx) + return session.receive().unwrap {it} + } + } + + @StartableByRPC + class CreateWireTxFlow(val counterparty: Party) : FlowLogic() { + @Suspendable + override fun call(): WireTransaction { + return createWireTx(serviceHub, counterparty, counterparty.owningKey, KryoScheme.SCHEME_ID) + } + } + + @InitiatedBy(SendFlow::class) + class ReceiveFlow(private val session: FlowSession): FlowLogic() { + @Suspendable + override fun call() { + val message = session.receive().unwrap {it} + message.toLedgerTransaction(serviceHub) + session.send(true) + } + } + + class DummyContract: Contract { + @BelongsToContract(DummyContract::class) + class DummyState(override val participants: List = listOf()) : ContractState + override fun verify(tx: LedgerTransaction) { + return + } + } + + object DummyCommandData : TypeOnlyCommandData() + + open class KryoScheme : CustomSerializationScheme { + + companion object { + const val SCHEME_ID = 7 + } + + override fun getSchemeId(): Int { + return SCHEME_ID + } + + override fun deserialize(bytes: ByteSequence, clazz: Class, context: SerializationSchemeContext): T { + val kryo = Kryo() + customiseKryo(kryo, context.deserializationClassLoader) + + val obj = Input(bytes.open()).use { + kryo.readClassAndObject(it) + } + @Suppress("UNCHECKED_CAST") + return obj as T + } + + override fun serialize(obj: T, context: SerializationSchemeContext): ByteSequence { + val kryo = Kryo() + customiseKryo(kryo, context.deserializationClassLoader) + + val outputStream = ByteArrayOutputStream() + Output(outputStream).use { + kryo.writeClassAndObject(it, obj) + } + return ByteSequence.of(outputStream.toByteArray()) + } + + private fun customiseKryo(kryo: Kryo, classLoader: ClassLoader) { + kryo.instantiatorStrategy = CustomInstantiatorStrategy() + kryo.classLoader = classLoader + kryo.register(Arrays.asList("").javaClass, ArraysAsListSerializer()) + } + + //Stolen from DefaultKryoCustomizer.kt + private class CustomInstantiatorStrategy : InstantiatorStrategy { + private val fallbackStrategy = StdInstantiatorStrategy() + + // Use this to allow construction of objects using a JVM backdoor that skips invoking the constructors, if there + // is no no-arg constructor available. + private val defaultStrategy = Kryo.DefaultInstantiatorStrategy(fallbackStrategy) + + override fun newInstantiatorOf(type: Class): ObjectInstantiator { + // However this doesn't work for non-public classes in the java. namespace + val strat = if (type.name.startsWith("java.") && !Modifier.isPublic(type.modifiers)) fallbackStrategy else defaultStrategy + return strat.newInstantiatorOf(type) + } + } + } + + class KryoSchemeWithMap : KryoScheme() { + + companion object { + const val SCHEME_ID = 8 + const val KEY = "Key" + const val VALUE = "Value" + } + + override fun getSchemeId(): Int { + return SCHEME_ID + } + + override fun serialize(obj: T, context: SerializationSchemeContext): ByteSequence { + assertEquals(VALUE, context.properties[KEY]) + return super.serialize(obj, context) + } + + } +} diff --git a/node/src/integration-test/kotlin/net/corda/node/CustomSerializationSchemeMockNetworkTest.kt b/node/src/integration-test/kotlin/net/corda/node/CustomSerializationSchemeMockNetworkTest.kt new file mode 100644 index 0000000000..6da9350ec7 --- /dev/null +++ b/node/src/integration-test/kotlin/net/corda/node/CustomSerializationSchemeMockNetworkTest.kt @@ -0,0 +1,62 @@ +package net.corda.node + +import net.corda.core.crypto.SecureHash +import net.corda.core.serialization.SerializedBytes +import net.corda.core.serialization.deserialize +import net.corda.node.CustomSerializationSchemeDriverTest.CreateWireTxFlow +import net.corda.node.CustomSerializationSchemeDriverTest.WriteTxToLedgerFlow +import net.corda.testing.core.ALICE_NAME +import net.corda.testing.core.BOB_NAME +import net.corda.testing.node.internal.CustomCordapp +import net.corda.testing.node.internal.InternalMockNetwork +import net.corda.testing.node.internal.InternalMockNodeParameters +import net.corda.testing.node.internal.enclosedCordapp +import net.corda.testing.node.internal.startFlow +import org.junit.After +import org.junit.Before +import org.junit.Test +import kotlin.test.assertEquals +import kotlin.test.assertNotNull + +class CustomSerializationSchemeMockNetworkTest { + + private lateinit var mockNetwork : InternalMockNetwork + + val customSchemeCordapp: CustomCordapp = CustomSerializationSchemeDriverTest().enclosedCordapp() + + @Before + fun setup() { + mockNetwork = InternalMockNetwork(cordappsForAllNodes = listOf(customSchemeCordapp)) + } + + @After + fun shutdown() { + mockNetwork.stopNodes() + } + + @Test(timeout = 300_000) + fun `transactions network parameter hash is correct`() { + val alice = mockNetwork.createNode(InternalMockNodeParameters(legalName = ALICE_NAME)) + val bob = mockNetwork.createNode(InternalMockNodeParameters(legalName = BOB_NAME)) + val flow = alice.services.startFlow (CreateWireTxFlow(bob.info.legalIdentities.single())) + mockNetwork.runNetwork() + val wireTx = flow.resultFuture.get() + /** The NetworkParmeters is the last component in the list of component groups. If we ever change this this + * in [net.corda.core.internal.createComponentGroups] this test will need to be updated.*/ + val serializedHash = SerializedBytes(wireTx.componentGroups.last().components.single().bytes) + assertEquals(alice.internals.networkParametersStorage.defaultHash, serializedHash.deserialize()) + } + + @Test(timeout = 300_000) + fun `transaction can be written to the ledger`() { + val alice = mockNetwork.createNode(InternalMockNodeParameters(legalName = ALICE_NAME)) + val bob = mockNetwork.createNode(InternalMockNodeParameters(legalName = BOB_NAME)) + val flow = alice.services.startFlow (WriteTxToLedgerFlow(bob.info.legalIdentities.single(), + mockNetwork.notaryNodes.single().info.legalIdentities.single())) + mockNetwork.runNetwork() + val txId = flow.resultFuture.get() + val getTxFlow = bob.services.startFlow(CustomSerializationSchemeDriverTest.GetTxFromDBFlow(txId)) + mockNetwork.runNetwork() + assertNotNull(getTxFlow.resultFuture.get()) + } +} diff --git a/node/src/integration-test/kotlin/net/corda/node/EvilContractCannotModifyStatesTest.kt b/node/src/integration-test/kotlin/net/corda/node/EvilContractCannotModifyStatesTest.kt new file mode 100644 index 0000000000..c365552553 --- /dev/null +++ b/node/src/integration-test/kotlin/net/corda/node/EvilContractCannotModifyStatesTest.kt @@ -0,0 +1,60 @@ +package net.corda.node + +import net.corda.client.rpc.CordaRPCClient +import net.corda.contracts.multiple.vulnerable.MutableDataObject +import net.corda.core.contracts.TransactionVerificationException.ContractRejection +import net.corda.core.messaging.startFlow +import net.corda.core.utilities.getOrThrow +import net.corda.flows.multiple.evil.EvilFlow +import net.corda.node.services.Permissions +import net.corda.testing.core.ALICE_NAME +import net.corda.testing.core.DUMMY_NOTARY_NAME +import net.corda.testing.driver.DriverParameters +import net.corda.testing.driver.driver +import net.corda.testing.driver.internal.incrementalPortAllocation +import net.corda.testing.node.NotarySpec +import net.corda.testing.node.User +import net.corda.testing.node.internal.cordappWithPackages +import org.assertj.core.api.Assertions.assertThat +import org.junit.Test +import kotlin.test.assertFailsWith + +class EvilContractCannotModifyStatesTest { + companion object { + private val user = User("u", "p", setOf(Permissions.all())) + private val evilFlowCorDapp = cordappWithPackages("net.corda.flows.multiple.evil").signed() + private val evilContractCorDapp = cordappWithPackages("net.corda.contracts.multiple.evil").signed() + private val vulnerableContractCorDapp = cordappWithPackages("net.corda.contracts.multiple.vulnerable").signed() + + private val NOTHING = MutableDataObject(0) + + fun driverParameters(runInProcess: Boolean): DriverParameters { + return DriverParameters( + portAllocation = incrementalPortAllocation(), + startNodesInProcess = runInProcess, + notarySpecs = listOf(NotarySpec(DUMMY_NOTARY_NAME, startInProcess = runInProcess, validating = true)), + cordappsForAllNodes = listOf( + vulnerableContractCorDapp, + evilContractCorDapp, + evilFlowCorDapp + ) + ) + } + } + + @Test(timeout = 300_000) + fun testContractThatTriesToModifyStates() { + val evilData = MutableDataObject(5000) + driver(driverParameters(runInProcess = false)) { + val alice = startNode(providedName = ALICE_NAME, rpcUsers = listOf(user)).getOrThrow() + val ex = assertFailsWith { + CordaRPCClient(hostAndPort = alice.rpcAddress) + .start(user.username, user.password) + .use { client -> + client.proxy.startFlow(::EvilFlow, evilData).returnValue.getOrThrow() + } + } + assertThat(ex).hasMessageContaining("Purchase payment of $NOTHING should be at least ") + } + } +} diff --git a/node/src/integration-test/kotlin/net/corda/node/amqp/AMQPBridgeTest.kt b/node/src/integration-test/kotlin/net/corda/node/amqp/AMQPBridgeTest.kt index 72d2c1bd47..a6d03bc337 100644 --- a/node/src/integration-test/kotlin/net/corda/node/amqp/AMQPBridgeTest.kt +++ b/node/src/integration-test/kotlin/net/corda/node/amqp/AMQPBridgeTest.kt @@ -204,7 +204,7 @@ class AMQPBridgeTest { doReturn(null).whenever(it).jmxMonitoringHttpPort } artemisConfig.configureWithDevSSLCertificate() - val artemisServer = ArtemisMessagingServer(artemisConfig, artemisAddress.copy(host = "0.0.0.0"), MAX_MESSAGE_SIZE, null) + val artemisServer = ArtemisMessagingServer(artemisConfig, artemisAddress.copy(host = "0.0.0.0"), MAX_MESSAGE_SIZE) val artemisClient = ArtemisMessagingClient(artemisConfig.p2pSslOptions, artemisAddress, MAX_MESSAGE_SIZE) artemisServer.start() artemisClient.start() diff --git a/node/src/integration-test/kotlin/net/corda/node/amqp/AMQPClientSslErrorsTest.kt b/node/src/integration-test/kotlin/net/corda/node/amqp/AMQPClientSslErrorsTest.kt index b1bf4b99f8..050142097e 100644 --- a/node/src/integration-test/kotlin/net/corda/node/amqp/AMQPClientSslErrorsTest.kt +++ b/node/src/integration-test/kotlin/net/corda/node/amqp/AMQPClientSslErrorsTest.kt @@ -14,12 +14,15 @@ import net.corda.node.services.config.NodeConfiguration import net.corda.node.services.config.configureWithDevSSLCertificate import net.corda.nodeapi.internal.protonwrapper.netty.AMQPClient import net.corda.nodeapi.internal.protonwrapper.netty.AMQPConfiguration -import net.corda.nodeapi.internal.protonwrapper.netty.init -import net.corda.nodeapi.internal.protonwrapper.netty.initialiseTrustStoreAndEnableCrlChecking +import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfig +import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfigImpl +import net.corda.nodeapi.internal.protonwrapper.netty.keyManagerFactory import net.corda.nodeapi.internal.protonwrapper.netty.toRevocationConfig +import net.corda.nodeapi.internal.protonwrapper.netty.trustManagerFactoryWithRevocation import net.corda.testing.core.ALICE_NAME import net.corda.testing.core.BOB_NAME import net.corda.testing.driver.internal.incrementalPortAllocation +import net.corda.testing.internal.fixedCrlSource import org.junit.Assume.assumeFalse import org.junit.Before import org.junit.Rule @@ -27,6 +30,7 @@ import org.junit.Test import org.junit.rules.TemporaryFolder import org.junit.runner.RunWith import org.junit.runners.Parameterized +import java.time.Duration import javax.net.ssl.KeyManagerFactory import javax.net.ssl.TrustManagerFactory import kotlin.test.assertFalse @@ -95,11 +99,13 @@ class AMQPClientSslErrorsTest(@Suppress("unused") private val iteration: Int) { override val maxMessageSize: Int = MAX_MESSAGE_SIZE } - serverKeyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()) - serverTrustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()) + serverKeyManagerFactory = keyManagerFactory(keyStore) - serverKeyManagerFactory.init(keyStore) - serverTrustManagerFactory.init(initialiseTrustStoreAndEnableCrlChecking(serverAmqpConfig.trustStore, serverAmqpConfig.revocationConfig)) + serverTrustManagerFactory = trustManagerFactoryWithRevocation( + serverAmqpConfig.trustStore, + RevocationConfigImpl(RevocationConfig.Mode.SOFT_FAIL), + fixedCrlSource(emptySet()) + ) } private fun setupClientCertificates() { @@ -123,14 +129,16 @@ class AMQPClientSslErrorsTest(@Suppress("unused") private val iteration: Int) { override val keyStore = keyStore override val trustStore = clientConfig.p2pSslOptions.trustStore.get() override val maxMessageSize: Int = MAX_MESSAGE_SIZE - override val sslHandshakeTimeout: Long = 3000 + override val sslHandshakeTimeout: Duration = 3.seconds } - clientKeyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()) - clientTrustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()) + clientKeyManagerFactory = keyManagerFactory(keyStore) - clientKeyManagerFactory.init(keyStore) - clientTrustManagerFactory.init(initialiseTrustStoreAndEnableCrlChecking(clientAmqpConfig.trustStore, clientAmqpConfig.revocationConfig)) + clientTrustManagerFactory = trustManagerFactoryWithRevocation( + clientAmqpConfig.trustStore, + RevocationConfigImpl(RevocationConfig.Mode.SOFT_FAIL), + fixedCrlSource(emptySet()) + ) } @Test(timeout = 300_000) diff --git a/node/src/integration-test/kotlin/net/corda/node/amqp/CertificateRevocationListNodeTests.kt b/node/src/integration-test/kotlin/net/corda/node/amqp/CertificateRevocationListNodeTests.kt index e941a78aea..ddffb79506 100644 --- a/node/src/integration-test/kotlin/net/corda/node/amqp/CertificateRevocationListNodeTests.kt +++ b/node/src/integration-test/kotlin/net/corda/node/amqp/CertificateRevocationListNodeTests.kt @@ -1,3 +1,5 @@ +@file:Suppress("LongParameterList") + package net.corda.node.amqp import com.nhaarman.mockito_kotlin.doReturn @@ -5,112 +7,80 @@ import com.nhaarman.mockito_kotlin.whenever import net.corda.core.crypto.Crypto import net.corda.core.identity.CordaX500Name import net.corda.core.internal.div -import net.corda.core.toFuture +import net.corda.core.internal.times import net.corda.core.utilities.NetworkHostAndPort -import net.corda.core.utilities.days import net.corda.core.utilities.minutes import net.corda.core.utilities.seconds +import net.corda.coretesting.internal.rigorousMock +import net.corda.coretesting.internal.stubs.CertificateStoreStubs import net.corda.node.services.config.NodeConfiguration import net.corda.node.services.config.configureWithDevSSLCertificate +import net.corda.node.services.messaging.ArtemisMessagingServer +import net.corda.nodeapi.internal.ArtemisMessagingClient import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.P2P_PREFIX import net.corda.nodeapi.internal.config.CertificateStoreSupplier import net.corda.nodeapi.internal.config.MutualSslConfiguration -import net.corda.nodeapi.internal.crypto.* +import net.corda.nodeapi.internal.crypto.X509Utilities.CORDA_CLIENT_CA +import net.corda.nodeapi.internal.crypto.X509Utilities.CORDA_CLIENT_TLS import net.corda.nodeapi.internal.protonwrapper.messages.MessageStatus import net.corda.nodeapi.internal.protonwrapper.netty.AMQPClient import net.corda.nodeapi.internal.protonwrapper.netty.AMQPConfiguration import net.corda.nodeapi.internal.protonwrapper.netty.AMQPServer +import net.corda.nodeapi.internal.protonwrapper.netty.ConnectionChange +import net.corda.nodeapi.internal.protonwrapper.netty.toRevocationConfig +import net.corda.nodeapi.internal.revocation.CertDistPointCrlSource import net.corda.testing.core.ALICE_NAME -import net.corda.testing.core.BOB_NAME import net.corda.testing.core.CHARLIE_NAME import net.corda.testing.core.MAX_MESSAGE_SIZE import net.corda.testing.driver.internal.incrementalPortAllocation -import net.corda.coretesting.internal.DEV_INTERMEDIATE_CA -import net.corda.coretesting.internal.DEV_ROOT_CA -import net.corda.coretesting.internal.rigorousMock -import net.corda.coretesting.internal.stubs.CertificateStoreStubs -import net.corda.node.services.messaging.ArtemisMessagingServer -import net.corda.nodeapi.internal.ArtemisMessagingClient -import net.corda.nodeapi.internal.protonwrapper.netty.toRevocationConfig +import net.corda.testing.node.internal.network.CrlServer +import net.corda.testing.node.internal.network.CrlServer.Companion.EMPTY_CRL +import net.corda.testing.node.internal.network.CrlServer.Companion.NODE_CRL +import net.corda.testing.node.internal.network.CrlServer.Companion.withCrlDistPoint import org.apache.activemq.artemis.api.core.RoutingType -import org.assertj.core.api.Assertions.assertThatIllegalArgumentException -import org.bouncycastle.asn1.x500.X500Name -import org.bouncycastle.asn1.x509.* -import org.bouncycastle.cert.jcajce.JcaX509CRLConverter -import org.bouncycastle.cert.jcajce.JcaX509ExtensionUtils -import org.bouncycastle.cert.jcajce.JcaX509v2CRLBuilder +import org.assertj.core.api.Assertions.assertThat import org.bouncycastle.jce.provider.BouncyCastleProvider -import org.bouncycastle.operator.jcajce.JcaContentSignerBuilder -import org.eclipse.jetty.server.Server -import org.eclipse.jetty.server.ServerConnector -import org.eclipse.jetty.server.handler.HandlerCollection -import org.eclipse.jetty.servlet.ServletContextHandler -import org.eclipse.jetty.servlet.ServletHolder -import org.glassfish.jersey.server.ResourceConfig -import org.glassfish.jersey.servlet.ServletContainer import org.junit.After import org.junit.Before import org.junit.Rule import org.junit.Test import org.junit.rules.TemporaryFolder import java.io.Closeable -import java.math.BigInteger -import java.net.InetSocketAddress -import java.security.KeyPair -import java.security.PrivateKey -import java.security.cert.X509CRL import java.security.cert.X509Certificate -import java.util.* -import javax.ws.rs.GET -import javax.ws.rs.Path -import javax.ws.rs.Produces -import javax.ws.rs.core.Response -import kotlin.test.assertEquals +import java.time.Duration +import java.util.concurrent.BlockingQueue +import java.util.concurrent.LinkedBlockingQueue +import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.AtomicInteger +import java.util.stream.IntStream -class CertificateRevocationListNodeTests { +abstract class AbstractServerRevocationTest { @Rule @JvmField val temporaryFolder = TemporaryFolder() - private val ROOT_CA = DEV_ROOT_CA - private lateinit var INTERMEDIATE_CA: CertificateAndKeyPair - private val portAllocation = incrementalPortAllocation() - private val serverPort = portAllocation.nextPort() + protected val serverPort = portAllocation.nextPort() - private lateinit var server: CrlServer + protected lateinit var crlServer: CrlServer + private val amqpClients = ArrayList() - private val revokedNodeCerts: MutableList = mutableListOf() - private val revokedIntermediateCerts: MutableList = mutableListOf() + protected lateinit var defaultCrlDistPoints: CrlDistPoints - private abstract class AbstractNodeConfiguration : NodeConfiguration + protected abstract class AbstractNodeConfiguration : NodeConfiguration companion object { + private val unreachableIpCounter = AtomicInteger(1) - const val FORBIDDEN_CRL = "forbidden.crl" + val crlConnectTimeout = 2.seconds - fun createRevocationList(clrServer: CrlServer, signatureAlgorithm: String, caCertificate: X509Certificate, - caPrivateKey: PrivateKey, - endpoint: String, - indirect: Boolean, - vararg serialNumbers: BigInteger): X509CRL { - println("Generating CRL for $endpoint") - val builder = JcaX509v2CRLBuilder(caCertificate.subjectX500Principal, Date(System.currentTimeMillis() - 1.minutes.toMillis())) - val extensionUtils = JcaX509ExtensionUtils() - builder.addExtension(Extension.authorityKeyIdentifier, - false, extensionUtils.createAuthorityKeyIdentifier(caCertificate)) - val issuingDistPointName = GeneralName( - GeneralName.uniformResourceIdentifier, - "http://${clrServer.hostAndPort.host}:${clrServer.hostAndPort.port}/crl/$endpoint") - // This is required and needs to match the certificate settings with respect to being indirect - val issuingDistPoint = IssuingDistributionPoint(DistributionPointName(GeneralNames(issuingDistPointName)), indirect, false) - builder.addExtension(Extension.issuingDistributionPoint, true, issuingDistPoint) - builder.setNextUpdate(Date(System.currentTimeMillis() + 1.seconds.toMillis())) - serialNumbers.forEach { - builder.addCRLEntry(it, Date(System.currentTimeMillis() - 10.minutes.toMillis()), ReasonFlags.certificateHold) - } - val signer = JcaContentSignerBuilder(signatureAlgorithm).setProvider(Crypto.findProvider("BC")).build(caPrivateKey) - return JcaX509CRLConverter().setProvider(Crypto.findProvider("BC")).getCRL(builder.build(signer)) + /** + * Use this method to get a unqiue unreachable IP address. Subsequent uses of the same IP for connection timeout testing purposes + * may not work as the OS process may cache the timeout result. + */ + private fun newUnreachableIpAddress(): NetworkHostAndPort { + check(unreachableIpCounter.get() != 255) + return NetworkHostAndPort("10.255.255", unreachableIpCounter.getAndIncrement()) } } @@ -118,602 +88,478 @@ class CertificateRevocationListNodeTests { fun setUp() { // Do not use Security.addProvider(BouncyCastleProvider()) to avoid EdDSA signature disruption in other tests. Crypto.findProvider(BouncyCastleProvider.PROVIDER_NAME) - revokedNodeCerts.clear() - server = CrlServer(NetworkHostAndPort("localhost", 0)) - server.start() - INTERMEDIATE_CA = CertificateAndKeyPair(replaceCrlDistPointCaCertificate( - DEV_INTERMEDIATE_CA.certificate, - CertificateType.INTERMEDIATE_CA, - ROOT_CA.keyPair, - "http://${server.hostAndPort}/crl/intermediate.crl"), DEV_INTERMEDIATE_CA.keyPair) + crlServer = CrlServer(NetworkHostAndPort("localhost", 0)) + crlServer.start() + defaultCrlDistPoints = CrlDistPoints(crlServer.hostAndPort) } @After fun tearDown() { - server.close() - revokedNodeCerts.clear() - } - - @Test(timeout=300_000) - fun `Simple AMPQ Client to Server connection works and soft fail is enabled`() { - val crlCheckSoftFail = true - val (amqpServer, _) = createServer(serverPort, crlCheckSoftFail = crlCheckSoftFail) - amqpServer.use { - amqpServer.start() - val receiveSubs = amqpServer.onReceive.subscribe { - assertEquals(BOB_NAME.toString(), it.sourceLegalName) - assertEquals(P2P_PREFIX + "Test", it.topic) - assertEquals("Test", String(it.payload)) - it.complete(true) - } - val (amqpClient, _) = createClient(serverPort, crlCheckSoftFail) - amqpClient.use { - val serverConnected = amqpServer.onConnection.toFuture() - val clientConnected = amqpClient.onConnection.toFuture() - amqpClient.start() - val serverConnect = serverConnected.get() - assertEquals(true, serverConnect.connected) - val clientConnect = clientConnected.get() - assertEquals(true, clientConnect.connected) - val msg = amqpClient.createMessage("Test".toByteArray(), - P2P_PREFIX + "Test", - ALICE_NAME.toString(), - emptyMap()) - amqpClient.write(msg) - assertEquals(MessageStatus.Acknowledged, msg.onComplete.get()) - receiveSubs.unsubscribe() - } + amqpClients.parallelStream().forEach(AMQPClient::close) + if (::crlServer.isInitialized) { + crlServer.close() } } @Test(timeout=300_000) - fun `Simple AMPQ Client to Server connection works and soft fail is disabled`() { - val crlCheckSoftFail = false - val (amqpServer, _) = createServer(serverPort, crlCheckSoftFail = crlCheckSoftFail) - amqpServer.use { - amqpServer.start() - val receiveSubs = amqpServer.onReceive.subscribe { - assertEquals(BOB_NAME.toString(), it.sourceLegalName) - assertEquals(P2P_PREFIX + "Test", it.topic) - assertEquals("Test", String(it.payload)) - it.complete(true) - } - val (amqpClient, _) = createClient(serverPort, crlCheckSoftFail) - amqpClient.use { - val serverConnected = amqpServer.onConnection.toFuture() - val clientConnected = amqpClient.onConnection.toFuture() - amqpClient.start() - val serverConnect = serverConnected.get() - assertEquals(true, serverConnect.connected) - val clientConnect = clientConnected.get() - assertEquals(true, clientConnect.connected) - val msg = amqpClient.createMessage("Test".toByteArray(), - P2P_PREFIX + "Test", - ALICE_NAME.toString(), - emptyMap()) - amqpClient.write(msg) - assertEquals(MessageStatus.Acknowledged, msg.onComplete.get()) - receiveSubs.unsubscribe() - } - } + fun `connection succeeds when soft fail is enabled`() { + verifyConnection( + crlCheckSoftFail = true, + expectedConnectedStatus = true + ) } @Test(timeout=300_000) - fun `AMPQ Client to Server connection fails when client's certificate is revoked and soft fail is enabled`() { - val crlCheckSoftFail = true - val (amqpServer, _) = createServer(serverPort, crlCheckSoftFail = crlCheckSoftFail) - amqpServer.use { - amqpServer.start() - amqpServer.onReceive.subscribe { - it.complete(true) - } - val (amqpClient, clientCert) = createClient(serverPort, crlCheckSoftFail) - revokedNodeCerts.add(clientCert.serialNumber) - amqpClient.use { - val serverConnected = amqpServer.onConnection.toFuture() - amqpClient.onConnection.toFuture() - amqpClient.start() - val serverConnect = serverConnected.get() - assertEquals(false, serverConnect.connected) - } - } + fun `connection succeeds when soft fail is disabled`() { + verifyConnection( + crlCheckSoftFail = false, + expectedConnectedStatus = true + ) } @Test(timeout=300_000) - fun `AMPQ Client to Server connection fails when client's certificate is revoked and soft fail is disabled`() { - val crlCheckSoftFail = false - val (amqpServer, _) = createServer(serverPort, crlCheckSoftFail = crlCheckSoftFail) - amqpServer.use { - amqpServer.start() - amqpServer.onReceive.subscribe { - it.complete(true) - } - val (amqpClient, clientCert) = createClient(serverPort, crlCheckSoftFail) - revokedNodeCerts.add(clientCert.serialNumber) - amqpClient.use { - val serverConnected = amqpServer.onConnection.toFuture() - amqpClient.onConnection.toFuture() - amqpClient.start() - val serverConnect = serverConnected.get() - assertEquals(false, serverConnect.connected) - } - } + fun `connection fails when client's certificate is revoked and soft fail is enabled`() { + verifyConnection( + crlCheckSoftFail = true, + revokeClientCert = true, + expectedConnectedStatus = false + ) } @Test(timeout=300_000) - fun `AMPQ Client to Server connection fails when servers's certificate is revoked`() { - val crlCheckSoftFail = true - val (amqpServer, serverCert) = createServer(serverPort, crlCheckSoftFail = crlCheckSoftFail) - revokedNodeCerts.add(serverCert.serialNumber) - amqpServer.use { - amqpServer.start() - amqpServer.onReceive.subscribe { - it.complete(true) - } - val (amqpClient, _) = createClient(serverPort, crlCheckSoftFail) - amqpClient.use { - val serverConnected = amqpServer.onConnection.toFuture() - amqpClient.onConnection.toFuture() - amqpClient.start() - val serverConnect = serverConnected.get() - assertEquals(false, serverConnect.connected) - } - } + fun `connection fails when client's certificate is revoked and soft fail is disabled`() { + verifyConnection( + crlCheckSoftFail = false, + revokeClientCert = true, + expectedConnectedStatus = false + ) } @Test(timeout=300_000) - fun `AMPQ Client to Server connection fails when servers's certificate is revoked and soft fail is enabled`() { - val crlCheckSoftFail = true - val (amqpServer, serverCert) = createServer(serverPort, crlCheckSoftFail = crlCheckSoftFail) - revokedNodeCerts.add(serverCert.serialNumber) - amqpServer.use { - amqpServer.start() - amqpServer.onReceive.subscribe { - it.complete(true) - } - val (amqpClient, _) = createClient(serverPort, crlCheckSoftFail) - amqpClient.use { - val serverConnected = amqpServer.onConnection.toFuture() - amqpClient.onConnection.toFuture() - amqpClient.start() - val serverConnect = serverConnected.get() - assertEquals(false, serverConnect.connected) - } - } + fun `connection fails when server's certificate is revoked and soft fail is enabled`() { + verifyConnection( + crlCheckSoftFail = true, + revokeServerCert = true, + expectedConnectedStatus = false + ) } @Test(timeout=300_000) - fun `AMPQ Client to Server connection succeeds when CRL cannot be obtained and soft fail is enabled`() { - val crlCheckSoftFail = true - val (amqpServer, _) = createServer( - serverPort, - crlCheckSoftFail = crlCheckSoftFail, - nodeCrlDistPoint = "http://${server.hostAndPort}/crl/invalid.crl") - amqpServer.use { - amqpServer.start() - amqpServer.onReceive.subscribe { - it.complete(true) - } - val (amqpClient, _) = createClient( + fun `connection fails when server's certificate is revoked and soft fail is disabled`() { + verifyConnection( + crlCheckSoftFail = false, + revokeServerCert = true, + expectedConnectedStatus = false + ) + } + + @Test(timeout=300_000) + fun `connection succeeds when CRL cannot be obtained and soft fail is enabled`() { + verifyConnection( + crlCheckSoftFail = true, + clientCrlDistPoints = defaultCrlDistPoints.copy(nodeCa = "non-existent.crl"), + expectedConnectedStatus = true + ) + } + + @Test(timeout=300_000) + fun `connection fails when CRL cannot be obtained and soft fail is disabled`() { + verifyConnection( + crlCheckSoftFail = false, + clientCrlDistPoints = defaultCrlDistPoints.copy(nodeCa = "non-existent.crl"), + expectedConnectedStatus = false + ) + } + + @Test(timeout=300_000) + fun `connection succeeds when CRL is not defined for node CA cert and soft fail is enabled`() { + verifyConnection( + crlCheckSoftFail = true, + clientCrlDistPoints = defaultCrlDistPoints.copy(nodeCa = null), + expectedConnectedStatus = true + ) + } + + @Test(timeout=300_000) + fun `connection fails when CRL is not defined for node CA cert and soft fail is disabled`() { + verifyConnection( + crlCheckSoftFail = false, + clientCrlDistPoints = defaultCrlDistPoints.copy(nodeCa = null), + expectedConnectedStatus = false + ) + } + + @Test(timeout=300_000) + fun `connection succeeds when CRL is not defined for TLS cert and soft fail is enabled`() { + verifyConnection( + crlCheckSoftFail = true, + clientCrlDistPoints = defaultCrlDistPoints.copy(tls = null), + expectedConnectedStatus = true + ) + } + + @Test(timeout=300_000) + fun `connection fails when CRL is not defined for TLS cert and soft fail is disabled`() { + verifyConnection( + crlCheckSoftFail = false, + clientCrlDistPoints = defaultCrlDistPoints.copy(tls = null), + expectedConnectedStatus = false + ) + } + + @Test(timeout=300_000) + fun `connection succeeds when CRL endpoint is unreachable, soft fail is enabled and CRL timeouts are within SSL handshake timeout`() { + verifyConnection( + crlCheckSoftFail = true, + sslHandshakeTimeout = crlConnectTimeout * 4, + clientCrlDistPoints = defaultCrlDistPoints.copy(crlServerAddress = newUnreachableIpAddress()), + expectedConnectedStatus = true + ) + } + + @Test(timeout=300_000) + fun `connection fails when CRL endpoint is unreachable, despite soft fail enabled, when CRL timeouts are not within SSL handshake timeout`() { + verifyConnection( + crlCheckSoftFail = true, + sslHandshakeTimeout = crlConnectTimeout / 2, + clientCrlDistPoints = defaultCrlDistPoints.copy(crlServerAddress = newUnreachableIpAddress()), + expectedConnectedStatus = false + ) + } + + @Test(timeout = 300_000) + fun `influx of new clients during CRL endpoint downtime does not cause existing connections to drop`() { + val serverCrlSource = CertDistPointCrlSource() + // Start the server and verify the first client has connected + val firstClientConnectionChangeStatus = verifyConnection( + crlCheckSoftFail = true, + crlSource = serverCrlSource, + // In general, N remoting threads will naturally support N-1 new handshaking clients plus one thread for heartbeating with + // existing clients. The trick is to make sure at least N new clients are also supported. + remotingThreads = 2, + expectedConnectedStatus = true + ) + + // Now simulate the CRL endpoint becoming very slow/unreachable + crlServer.delay = 10.minutes + // And pretend enough time has elapsed that the cached CRLs have expired and need downloading again + serverCrlSource.clearCache() + + // Now a bunch of new clients have arrived and want to handshake with the server, which will potentially cause the server's Netty + // threads to be tied up in trying to download the CRLs. + IntStream.range(0, 2).parallel().forEach { clientIndex -> + val (newClient, _) = createAMQPClient( serverPort, - crlCheckSoftFail, - nodeCrlDistPoint = "http://${server.hostAndPort}/crl/invalid.crl") - amqpClient.use { - val serverConnected = amqpServer.onConnection.toFuture() - amqpClient.onConnection.toFuture() - amqpClient.start() - val serverConnect = serverConnected.get() - assertEquals(true, serverConnect.connected) - } + crlCheckSoftFail = true, + legalName = CordaX500Name("NewClient$clientIndex", "London", "GB"), + crlDistPoints = defaultCrlDistPoints + ) + newClient.start() } + + // Make sure there are no further connection change updates, i.e. the first client stays connected throughout this whole saga + assertThat(firstClientConnectionChangeStatus.poll(30, TimeUnit.SECONDS)).isNull() } - @Test(timeout=300_000) - fun `Revocation status check fails when the CRL distribution point is not set and soft fail is disabled`() { - val crlCheckSoftFail = false - val (amqpServer, _) = createServer( - serverPort, - crlCheckSoftFail = crlCheckSoftFail, - tlsCrlDistPoint = null) - amqpServer.use { - amqpServer.start() - amqpServer.onReceive.subscribe { - it.complete(true) - } - val (amqpClient, _) = createClient( - serverPort, - crlCheckSoftFail, - tlsCrlDistPoint = null) - amqpClient.use { - val serverConnected = amqpServer.onConnection.toFuture() - amqpClient.onConnection.toFuture() - amqpClient.start() - val serverConnect = serverConnected.get() - assertEquals(false, serverConnect.connected) - } - } - } + protected abstract fun verifyConnection(crlCheckSoftFail: Boolean, + crlSource: CertDistPointCrlSource = CertDistPointCrlSource(connectTimeout = crlConnectTimeout), + sslHandshakeTimeout: Duration? = null, + remotingThreads: Int? = null, + clientCrlDistPoints: CrlDistPoints = defaultCrlDistPoints, + revokeClientCert: Boolean = false, + revokeServerCert: Boolean = false, + expectedConnectedStatus: Boolean): BlockingQueue - @Test(timeout=300_000) - fun `Revocation status chceck succeds when the CRL distribution point is not set and soft fail is enabled`() { - val crlCheckSoftFail = true - val (amqpServer, _) = createServer( - serverPort, - crlCheckSoftFail = crlCheckSoftFail, - tlsCrlDistPoint = null) - amqpServer.use { - amqpServer.start() - amqpServer.onReceive.subscribe { - it.complete(true) - } - val (amqpClient, _) = createClient( - serverPort, - crlCheckSoftFail, - tlsCrlDistPoint = null) - amqpClient.use { - val serverConnected = amqpServer.onConnection.toFuture() - amqpClient.onConnection.toFuture() - amqpClient.start() - val serverConnect = serverConnected.get() - assertEquals(true, serverConnect.connected) - } - } - } - - private fun createClient(targetPort: Int, - crlCheckSoftFail: Boolean, - nodeCrlDistPoint: String = "http://${server.hostAndPort}/crl/node.crl", - tlsCrlDistPoint: String? = "http://${server.hostAndPort}/crl/empty.crl", - maxMessageSize: Int = MAX_MESSAGE_SIZE): Pair { - val baseDirectory = temporaryFolder.root.toPath() / "client" + protected fun createAMQPClient(targetPort: Int, + crlCheckSoftFail: Boolean, + legalName: CordaX500Name, + crlDistPoints: CrlDistPoints): Pair { + val baseDirectory = temporaryFolder.root.toPath() / legalName.organisation val certificatesDirectory = baseDirectory / "certificates" val p2pSslConfiguration = CertificateStoreStubs.P2P.withCertificatesDirectory(certificatesDirectory) val signingCertificateStore = CertificateStoreStubs.Signing.withCertificatesDirectory(certificatesDirectory) val clientConfig = rigorousMock().also { doReturn(baseDirectory).whenever(it).baseDirectory doReturn(certificatesDirectory).whenever(it).certificatesDirectory - doReturn(BOB_NAME).whenever(it).myLegalName + doReturn(legalName).whenever(it).myLegalName doReturn(p2pSslConfiguration).whenever(it).p2pSslOptions doReturn(signingCertificateStore).whenever(it).signingCertificateStore doReturn(crlCheckSoftFail).whenever(it).crlCheckSoftFail } clientConfig.configureWithDevSSLCertificate() - val nodeCert = (signingCertificateStore to p2pSslConfiguration).recreateNodeCaAndTlsCertificates(nodeCrlDistPoint, tlsCrlDistPoint) + val nodeCert = crlDistPoints.recreateNodeCaAndTlsCertificates(signingCertificateStore, p2pSslConfiguration, crlServer) val keyStore = clientConfig.p2pSslOptions.keyStore.get() val amqpConfig = object : AMQPConfiguration { override val keyStore = keyStore override val trustStore = clientConfig.p2pSslOptions.trustStore.get() - override val maxMessageSize: Int = maxMessageSize + override val maxMessageSize: Int = MAX_MESSAGE_SIZE + override val trace: Boolean = true } - return Pair(AMQPClient( + val amqpClient = AMQPClient( listOf(NetworkHostAndPort("localhost", targetPort)), - setOf(ALICE_NAME, CHARLIE_NAME), - amqpConfig), nodeCert) + setOf(CHARLIE_NAME), + amqpConfig, + nettyThreading = AMQPClient.NettyThreading.NonShared(legalName.organisation), + distPointCrlSource = CertDistPointCrlSource(connectTimeout = crlConnectTimeout) + ) + amqpClients += amqpClient + return Pair(amqpClient, nodeCert) } - private fun createServer(port: Int, name: CordaX500Name = ALICE_NAME, - crlCheckSoftFail: Boolean, - nodeCrlDistPoint: String = "http://${server.hostAndPort}/crl/node.crl", - tlsCrlDistPoint: String? = "http://${server.hostAndPort}/crl/empty.crl", - maxMessageSize: Int = MAX_MESSAGE_SIZE): Pair { - val baseDirectory = temporaryFolder.root.toPath() / "server" + protected fun AMQPClient.waitForInitialConnectionAndCaptureChanges(expectedConnectedStatus: Boolean): BlockingQueue { + val connectionChangeStatus = LinkedBlockingQueue() + onConnection.subscribe { connectionChangeStatus.add(it) } + start() + assertThat(connectionChangeStatus.take().connected).isEqualTo(expectedConnectedStatus) + return connectionChangeStatus + } + + protected data class CrlDistPoints(val crlServerAddress: NetworkHostAndPort, + val nodeCa: String? = NODE_CRL, + val tls: String? = EMPTY_CRL) { + private val nodeCaCertCrlDistPoint: String? get() = nodeCa?.let { "http://$crlServerAddress/crl/$it" } + private val tlsCertCrlDistPoint: String? get() = tls?.let { "http://$crlServerAddress/crl/$it" } + + fun recreateNodeCaAndTlsCertificates(signingCertificateStore: CertificateStoreSupplier, + p2pSslConfiguration: MutualSslConfiguration, + crlServer: CrlServer): X509Certificate { + val nodeKeyStore = signingCertificateStore.get() + val (nodeCert, nodeKeys) = nodeKeyStore.query { getCertificateAndKeyPair(CORDA_CLIENT_CA, nodeKeyStore.entryPassword) } + val newNodeCert = crlServer.replaceNodeCertDistPoint(nodeCert, nodeCaCertCrlDistPoint) + val nodeCertChain = listOf(newNodeCert, crlServer.intermediateCa.certificate) + + nodeKeyStore.query { getCertificateChain(CORDA_CLIENT_CA) }.drop(2) + + nodeKeyStore.update { + internal.deleteEntry(CORDA_CLIENT_CA) + } + nodeKeyStore.update { + setPrivateKey(CORDA_CLIENT_CA, nodeKeys.private, nodeCertChain, nodeKeyStore.entryPassword) + } + + val sslKeyStore = p2pSslConfiguration.keyStore.get() + val (tlsCert, tlsKeys) = sslKeyStore.query { getCertificateAndKeyPair(CORDA_CLIENT_TLS, sslKeyStore.entryPassword) } + val newTlsCert = tlsCert.withCrlDistPoint(nodeKeys, tlsCertCrlDistPoint, crlServer.rootCa.certificate.subjectX500Principal) + val sslCertChain = listOf(newTlsCert, newNodeCert, crlServer.intermediateCa.certificate) + + sslKeyStore.query { getCertificateChain(CORDA_CLIENT_TLS) }.drop(3) + + sslKeyStore.update { + internal.deleteEntry(CORDA_CLIENT_TLS) + } + sslKeyStore.update { + setPrivateKey(CORDA_CLIENT_TLS, tlsKeys.private, sslCertChain, sslKeyStore.entryPassword) + } + return newNodeCert + } + } +} + + +class AMQPServerRevocationTest : AbstractServerRevocationTest() { + private lateinit var amqpServer: AMQPServer + + @After + fun shutDown() { + if (::amqpServer.isInitialized) { + amqpServer.close() + } + } + + override fun verifyConnection(crlCheckSoftFail: Boolean, + crlSource: CertDistPointCrlSource, + sslHandshakeTimeout: Duration?, + remotingThreads: Int?, + clientCrlDistPoints: CrlDistPoints, + revokeClientCert: Boolean, + revokeServerCert: Boolean, + expectedConnectedStatus: Boolean): BlockingQueue { + val serverCert = createAMQPServer( + serverPort, + CHARLIE_NAME, + crlCheckSoftFail, + defaultCrlDistPoints, + crlSource, + sslHandshakeTimeout, + remotingThreads + ) + if (revokeServerCert) { + crlServer.revokedNodeCerts.add(serverCert) + } + amqpServer.start() + amqpServer.onReceive.subscribe { + it.complete(true) + } + val (client, clientCert) = createAMQPClient( + serverPort, + crlCheckSoftFail = crlCheckSoftFail, + legalName = ALICE_NAME, + crlDistPoints = clientCrlDistPoints + ) + if (revokeClientCert) { + crlServer.revokedNodeCerts.add(clientCert) + } + + return client.waitForInitialConnectionAndCaptureChanges(expectedConnectedStatus) + } + + private fun createAMQPServer(port: Int, + legalName: CordaX500Name, + crlCheckSoftFail: Boolean, + crlDistPoints: CrlDistPoints, + distPointCrlSource: CertDistPointCrlSource, + sslHandshakeTimeout: Duration?, + remotingThreads: Int?): X509Certificate { + check(!::amqpServer.isInitialized) + val baseDirectory = temporaryFolder.root.toPath() / legalName.organisation val certificatesDirectory = baseDirectory / "certificates" val p2pSslConfiguration = CertificateStoreStubs.P2P.withCertificatesDirectory(certificatesDirectory) val signingCertificateStore = CertificateStoreStubs.Signing.withCertificatesDirectory(certificatesDirectory) val serverConfig = rigorousMock().also { doReturn(baseDirectory).whenever(it).baseDirectory doReturn(certificatesDirectory).whenever(it).certificatesDirectory - doReturn(name).whenever(it).myLegalName + doReturn(legalName).whenever(it).myLegalName doReturn(p2pSslConfiguration).whenever(it).p2pSslOptions doReturn(signingCertificateStore).whenever(it).signingCertificateStore } serverConfig.configureWithDevSSLCertificate() - val nodeCert = (signingCertificateStore to p2pSslConfiguration).recreateNodeCaAndTlsCertificates(nodeCrlDistPoint, tlsCrlDistPoint) + val serverCert = crlDistPoints.recreateNodeCaAndTlsCertificates(signingCertificateStore, p2pSslConfiguration, crlServer) val keyStore = serverConfig.p2pSslOptions.keyStore.get() val amqpConfig = object : AMQPConfiguration { override val keyStore = keyStore override val trustStore = serverConfig.p2pSslOptions.trustStore.get() override val revocationConfig = crlCheckSoftFail.toRevocationConfig() - override val maxMessageSize: Int = maxMessageSize + override val maxMessageSize: Int = MAX_MESSAGE_SIZE + override val sslHandshakeTimeout: Duration = sslHandshakeTimeout ?: super.sslHandshakeTimeout } - return Pair(AMQPServer( + amqpServer = AMQPServer( "0.0.0.0", port, - amqpConfig), nodeCert) - } - - private fun Pair.recreateNodeCaAndTlsCertificates(nodeCaCrlDistPoint: String, tlsCrlDistPoint: String?): X509Certificate { - - val signingCertificateStore = first - val p2pSslConfiguration = second - val nodeKeyStore = signingCertificateStore.get() - val (nodeCert, nodeKeys) = nodeKeyStore.query { getCertificateAndKeyPair(X509Utilities.CORDA_CLIENT_CA, nodeKeyStore.entryPassword) } - val newNodeCert = replaceCrlDistPointCaCertificate(nodeCert, CertificateType.NODE_CA, INTERMEDIATE_CA.keyPair, nodeCaCrlDistPoint) - val nodeCertChain = listOf(newNodeCert, INTERMEDIATE_CA.certificate, *nodeKeyStore.query { getCertificateChain(X509Utilities.CORDA_CLIENT_CA) }.drop(2).toTypedArray()) - nodeKeyStore.update { - internal.deleteEntry(X509Utilities.CORDA_CLIENT_CA) - } - nodeKeyStore.update { - setPrivateKey(X509Utilities.CORDA_CLIENT_CA, nodeKeys.private, nodeCertChain, nodeKeyStore.entryPassword) - } - val sslKeyStore = p2pSslConfiguration.keyStore.get() - val (tlsCert, tlsKeys) = sslKeyStore.query { getCertificateAndKeyPair(X509Utilities.CORDA_CLIENT_TLS, sslKeyStore.entryPassword) } - val newTlsCert = replaceCrlDistPointCaCertificate(tlsCert, CertificateType.TLS, nodeKeys, tlsCrlDistPoint, X500Name.getInstance(ROOT_CA.certificate.subjectX500Principal.encoded)) - val sslCertChain = listOf(newTlsCert, newNodeCert, INTERMEDIATE_CA.certificate, *sslKeyStore.query { getCertificateChain(X509Utilities.CORDA_CLIENT_TLS) }.drop(3).toTypedArray()) - - sslKeyStore.update { - internal.deleteEntry(X509Utilities.CORDA_CLIENT_TLS) - } - sslKeyStore.update { - setPrivateKey(X509Utilities.CORDA_CLIENT_TLS, tlsKeys.private, sslCertChain, sslKeyStore.entryPassword) - } - return newNodeCert - } - - private fun replaceCrlDistPointCaCertificate(currentCaCert: X509Certificate, certType: CertificateType, issuerKeyPair: KeyPair, crlDistPoint: String?, crlIssuer: X500Name? = null): X509Certificate { - val signatureScheme = Crypto.findSignatureScheme(issuerKeyPair.private) - val provider = Crypto.findProvider(signatureScheme.providerName) - val issuerSigner = ContentSignerBuilder.build(signatureScheme, issuerKeyPair.private, provider) - val builder = X509Utilities.createPartialCertificate( - certType, - currentCaCert.issuerX500Principal, - issuerKeyPair.public, - currentCaCert.subjectX500Principal, - currentCaCert.publicKey, - Pair(Date(System.currentTimeMillis() - 5.minutes.toMillis()), Date(System.currentTimeMillis() + 10.days.toMillis())), - null + amqpConfig, + threadPoolName = legalName.organisation, + distPointCrlSource = distPointCrlSource, + remotingThreads = remotingThreads ) - crlDistPoint?.let { - val distPointName = DistributionPointName(GeneralNames(GeneralName(GeneralName.uniformResourceIdentifier, it))) - val crlIssuerGeneralNames = crlIssuer?.let { - GeneralNames(GeneralName(crlIssuer)) - } - val distPoint = DistributionPoint(distPointName, null, crlIssuerGeneralNames) - builder.addExtension(Extension.cRLDistributionPoints, false, CRLDistPoint(arrayOf(distPoint))) - } - return builder.build(issuerSigner).toJca() + return serverCert } +} - @Path("crl") - inner class CrlServlet(private val server: CrlServer) { - private val SIGNATURE_ALGORITHM = "SHA256withECDSA" - private val NODE_CRL = "node.crl" - private val INTEMEDIATE_CRL = "intermediate.crl" - private val EMPTY_CRL = "empty.crl" +class ArtemisServerRevocationTest : AbstractServerRevocationTest() { + private lateinit var artemisNode: ArtemisNode + private var crlCheckArtemisServer = true - @GET - @Path("node.crl") - @Produces("application/pkcs7-crl") - fun getNodeCRL(): Response { - return Response.ok(CertificateRevocationListNodeTests.createRevocationList( - server, - SIGNATURE_ALGORITHM, - INTERMEDIATE_CA.certificate, - INTERMEDIATE_CA.keyPair.private, - NODE_CRL, - false, - *revokedNodeCerts.toTypedArray()).encoded) - .build() - } - - @GET - @Path(FORBIDDEN_CRL) - @Produces("application/pkcs7-crl") - fun getNodeSlowCRL(): Response { - return Response.status(Response.Status.FORBIDDEN).build() - } - - @GET - @Path("intermediate.crl") - @Produces("application/pkcs7-crl") - fun getIntermediateCRL(): Response { - return Response.ok(createRevocationList( - server, - SIGNATURE_ALGORITHM, - ROOT_CA.certificate, - ROOT_CA.keyPair.private, - INTEMEDIATE_CRL, - false, - *revokedIntermediateCerts.toTypedArray()).encoded) - .build() - } - - @GET - @Path("empty.crl") - @Produces("application/pkcs7-crl") - fun getEmptyCRL(): Response { - return Response.ok(createRevocationList( - server, - SIGNATURE_ALGORITHM, - ROOT_CA.certificate, - ROOT_CA.keyPair.private, - EMPTY_CRL, - true).encoded) - .build() + @After + fun shutDown() { + if (::artemisNode.isInitialized) { + artemisNode.close() } } - inner class CrlServer(hostAndPort: NetworkHostAndPort) : Closeable { - - private val server: Server = Server(InetSocketAddress(hostAndPort.host, hostAndPort.port)).apply { - handler = HandlerCollection().apply { - addHandler(buildServletContextHandler()) - } - } - - val hostAndPort: NetworkHostAndPort - get() = server.connectors.mapNotNull { it as? ServerConnector } - .map { NetworkHostAndPort(it.host, it.localPort) } - .first() - - override fun close() { - println("Shutting down network management web services...") - server.stop() - server.join() - } - - fun start() { - server.start() - println("Network management web services started on $hostAndPort") - } - - private fun buildServletContextHandler(): ServletContextHandler { - val crlServer = this - return ServletContextHandler().apply { - contextPath = "/" - val resourceConfig = ResourceConfig().apply { - register(CrlServlet(crlServer)) - } - val jerseyServlet = ServletHolder(ServletContainer(resourceConfig)).apply { initOrder = 0 } - addServlet(jerseyServlet, "/*") - } - } + @Test(timeout = 300_000) + fun `connection succeeds with disabled CRL check on revoked node certificate`() { + crlCheckArtemisServer = false + verifyConnection( + crlCheckSoftFail = false, + revokeClientCert = true, + expectedConnectedStatus = true + ) } - @Test(timeout=300_000) - fun `verify CRL algorithms`() { - val ECDSA_ALGORITHM = "SHA256withECDSA" - val EC_ALGORITHM = "EC" - val EMPTY_CRL = "empty.crl" - - val crl = createRevocationList( - server, - ECDSA_ALGORITHM, - ROOT_CA.certificate, - ROOT_CA.keyPair.private, - EMPTY_CRL, - true) - // This should pass. - crl.verify(ROOT_CA.keyPair.public) - - // Try changing the algorithm to EC will fail. - assertThatIllegalArgumentException().isThrownBy { - createRevocationList( - server, - EC_ALGORITHM, - ROOT_CA.certificate, - ROOT_CA.keyPair.private, - EMPTY_CRL, - true - ) - }.withMessage("Unknown signature type requested: EC") - } - - @Test(timeout=300_000) - fun `AMPQ Client to Server connection succeeds when CRL retrieval is forbidden and soft fail is enabled`() { - val crlCheckSoftFail = true - val forbiddenUrl = "http://${server.hostAndPort}/crl/$FORBIDDEN_CRL" - val (amqpServer, _) = createServer( + override fun verifyConnection(crlCheckSoftFail: Boolean, + crlSource: CertDistPointCrlSource, + sslHandshakeTimeout: Duration?, + remotingThreads: Int?, + clientCrlDistPoints: CrlDistPoints, + revokeClientCert: Boolean, + revokeServerCert: Boolean, + expectedConnectedStatus: Boolean): BlockingQueue { + val (client, clientCert) = createAMQPClient( serverPort, - crlCheckSoftFail = crlCheckSoftFail, - nodeCrlDistPoint = forbiddenUrl, - tlsCrlDistPoint = forbiddenUrl) - amqpServer.use { - amqpServer.start() - amqpServer.onReceive.subscribe { - it.complete(true) - } - val (amqpClient, _) = createClient( - serverPort, - crlCheckSoftFail, - nodeCrlDistPoint = forbiddenUrl, - tlsCrlDistPoint = forbiddenUrl) - amqpClient.use { - val serverConnected = amqpServer.onConnection.toFuture() - amqpClient.onConnection.toFuture() - amqpClient.start() - val serverConnect = serverConnected.get() - assertEquals(true, serverConnect.connected) - } + crlCheckSoftFail = true, + legalName = ALICE_NAME, + crlDistPoints = clientCrlDistPoints + ) + if (revokeClientCert) { + crlServer.revokedNodeCerts.add(clientCert) } + + val nodeCert = startArtemisNode( + CHARLIE_NAME, + crlCheckSoftFail, + defaultCrlDistPoints, + crlSource, + sslHandshakeTimeout, + remotingThreads + ) + if (revokeServerCert) { + crlServer.revokedNodeCerts.add(nodeCert) + } + + val queueName = "${P2P_PREFIX}Test" + artemisNode.client.started!!.session.createQueue(queueName, RoutingType.ANYCAST, queueName, true) + + val clientConnectionChangeStatus = client.waitForInitialConnectionAndCaptureChanges(expectedConnectedStatus) + + if (expectedConnectedStatus) { + val msg = client.createMessage("Test".toByteArray(), queueName, CHARLIE_NAME.toString(), emptyMap()) + client.write(msg) + assertThat(msg.onComplete.get()).isEqualTo(MessageStatus.Acknowledged) + } + + return clientConnectionChangeStatus } - private fun createArtemisServerAndClient(port: Int, crlCheckSoftFail: Boolean, crlCheckArtemisServer: Boolean): - Pair { - val baseDirectory = temporaryFolder.root.toPath() / "artemis" + private fun startArtemisNode(legalName: CordaX500Name, + crlCheckSoftFail: Boolean, + crlDistPoints: CrlDistPoints, + distPointCrlSource: CertDistPointCrlSource, + sslHandshakeTimeout: Duration?, + remotingThreads: Int?): X509Certificate { + check(!::artemisNode.isInitialized) + val baseDirectory = temporaryFolder.root.toPath() / legalName.organisation val certificatesDirectory = baseDirectory / "certificates" val signingCertificateStore = CertificateStoreStubs.Signing.withCertificatesDirectory(certificatesDirectory) - val p2pSslConfiguration = CertificateStoreStubs.P2P.withCertificatesDirectory(certificatesDirectory) + val p2pSslConfiguration = CertificateStoreStubs.P2P.withCertificatesDirectory(certificatesDirectory, sslHandshakeTimeout = sslHandshakeTimeout) val artemisConfig = rigorousMock().also { doReturn(baseDirectory).whenever(it).baseDirectory doReturn(certificatesDirectory).whenever(it).certificatesDirectory - doReturn(CHARLIE_NAME).whenever(it).myLegalName + doReturn(legalName).whenever(it).myLegalName doReturn(signingCertificateStore).whenever(it).signingCertificateStore doReturn(p2pSslConfiguration).whenever(it).p2pSslOptions - doReturn(NetworkHostAndPort("0.0.0.0", port)).whenever(it).p2pAddress + doReturn(NetworkHostAndPort("0.0.0.0", serverPort)).whenever(it).p2pAddress doReturn(null).whenever(it).jmxMonitoringHttpPort doReturn(crlCheckSoftFail).whenever(it).crlCheckSoftFail doReturn(crlCheckArtemisServer).whenever(it).crlCheckArtemisServer } artemisConfig.configureWithDevSSLCertificate() - val server = ArtemisMessagingServer(artemisConfig, artemisConfig.p2pAddress, MAX_MESSAGE_SIZE, null) - val client = ArtemisMessagingClient(artemisConfig.p2pSslOptions, artemisConfig.p2pAddress, MAX_MESSAGE_SIZE) + val nodeCert = crlDistPoints.recreateNodeCaAndTlsCertificates(signingCertificateStore, p2pSslConfiguration, crlServer) + + val server = ArtemisMessagingServer( + artemisConfig, + artemisConfig.p2pAddress, + MAX_MESSAGE_SIZE, + threadPoolName = "${legalName.organisation}-server", + trace = true, + distPointCrlSource = distPointCrlSource, + remotingThreads = remotingThreads + ) + val client = ArtemisMessagingClient( + artemisConfig.p2pSslOptions, + artemisConfig.p2pAddress, + MAX_MESSAGE_SIZE, + threadPoolName = "${legalName.organisation}-client" + ) server.start() client.start() - return server to client + val artemisNode = ArtemisNode(server, client) + this.artemisNode = artemisNode + return nodeCert } - private fun verifyMessageToArtemis(crlCheckSoftFail: Boolean, - crlCheckArtemisServer: Boolean, - expectedStatus: MessageStatus, - revokedNodeCert: Boolean = false, - nodeCrlDistPoint: String = "http://${server.hostAndPort}/crl/node.crl") { - val queueName = P2P_PREFIX + "Test" - val (artemisServer, artemisClient) = createArtemisServerAndClient(serverPort, crlCheckSoftFail, crlCheckArtemisServer) - artemisServer.use { - artemisClient.started!!.session.createQueue(queueName, RoutingType.ANYCAST, queueName, true) - - val (amqpClient, nodeCert) = createClient(serverPort, true, nodeCrlDistPoint) - if (revokedNodeCert) { - revokedNodeCerts.add(nodeCert.serialNumber) - } - amqpClient.use { - val clientConnected = amqpClient.onConnection.toFuture() - amqpClient.start() - val clientConnect = clientConnected.get() - assertEquals(true, clientConnect.connected) - - val msg = amqpClient.createMessage("Test".toByteArray(), queueName, CHARLIE_NAME.toString(), emptyMap()) - amqpClient.write(msg) - assertEquals(expectedStatus, msg.onComplete.get()) - } - artemisClient.stop() + private class ArtemisNode(val server: ArtemisMessagingServer, val client: ArtemisMessagingClient) : Closeable { + override fun close() { + client.stop() + server.close() } } - - @Test(timeout = 300_000) - fun `Artemis server connection succeeds with soft fail CRL check`() { - verifyMessageToArtemis(crlCheckSoftFail = true, crlCheckArtemisServer = true, expectedStatus = MessageStatus.Acknowledged) - } - - @Test(timeout = 300_000) - fun `Artemis server connection succeeds with hard fail CRL check`() { - verifyMessageToArtemis(crlCheckSoftFail = false, crlCheckArtemisServer = true, expectedStatus = MessageStatus.Acknowledged) - } - - @Test(timeout = 300_000) - fun `Artemis server connection succeeds with soft fail CRL check on unavailable URL`() { - verifyMessageToArtemis(crlCheckSoftFail = true, crlCheckArtemisServer = true, expectedStatus = MessageStatus.Acknowledged, - nodeCrlDistPoint = "http://${server.hostAndPort}/crl/$FORBIDDEN_CRL") - } - - @Test(timeout = 300_000) - fun `Artemis server connection fails with hard fail CRL check on unavailable URL`() { - verifyMessageToArtemis(crlCheckSoftFail = false, crlCheckArtemisServer = true, expectedStatus = MessageStatus.Rejected, - nodeCrlDistPoint = "http://${server.hostAndPort}/crl/$FORBIDDEN_CRL") - } - - @Test(timeout = 300_000) - fun `Artemis server connection fails with soft fail CRL check on revoked node certificate`() { - verifyMessageToArtemis(crlCheckSoftFail = true, crlCheckArtemisServer = true, expectedStatus = MessageStatus.Rejected, - revokedNodeCert = true) - } - - @Test(timeout = 300_000) - fun `Artemis server connection succeeds with disabled CRL check on revoked node certificate`() { - verifyMessageToArtemis(crlCheckSoftFail = false, crlCheckArtemisServer = false, expectedStatus = MessageStatus.Acknowledged, - revokedNodeCert = true) - } } diff --git a/node/src/integration-test/kotlin/net/corda/node/amqp/ProtonWrapperTests.kt b/node/src/integration-test/kotlin/net/corda/node/amqp/ProtonWrapperTests.kt index 1fd59b9704..f6e4e1d4ed 100644 --- a/node/src/integration-test/kotlin/net/corda/node/amqp/ProtonWrapperTests.kt +++ b/node/src/integration-test/kotlin/net/corda/node/amqp/ProtonWrapperTests.kt @@ -4,12 +4,15 @@ import com.nhaarman.mockito_kotlin.doReturn import com.nhaarman.mockito_kotlin.whenever import io.netty.channel.EventLoopGroup import io.netty.channel.nio.NioEventLoopGroup +import io.netty.util.concurrent.DefaultThreadFactory import net.corda.core.crypto.newSecureRandom import net.corda.core.identity.CordaX500Name import net.corda.core.internal.div import net.corda.core.toFuture import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.contextLogger +import net.corda.coretesting.internal.rigorousMock +import net.corda.coretesting.internal.stubs.CertificateStoreStubs import net.corda.node.services.config.NodeConfiguration import net.corda.node.services.config.configureWithDevSSLCertificate import net.corda.node.services.messaging.ArtemisMessagingServer @@ -23,7 +26,9 @@ import net.corda.nodeapi.internal.protonwrapper.messages.MessageStatus import net.corda.nodeapi.internal.protonwrapper.netty.AMQPClient import net.corda.nodeapi.internal.protonwrapper.netty.AMQPConfiguration import net.corda.nodeapi.internal.protonwrapper.netty.AMQPServer -import net.corda.nodeapi.internal.protonwrapper.netty.init +import net.corda.nodeapi.internal.protonwrapper.netty.keyManagerFactory +import net.corda.nodeapi.internal.protonwrapper.netty.toRevocationConfig +import net.corda.nodeapi.internal.protonwrapper.netty.trustManagerFactory import net.corda.nodeapi.internal.registerDevP2pCertificates import net.corda.testing.core.ALICE_NAME import net.corda.testing.core.BOB_NAME @@ -31,9 +36,6 @@ import net.corda.testing.core.CHARLIE_NAME import net.corda.testing.core.MAX_MESSAGE_SIZE import net.corda.testing.driver.internal.incrementalPortAllocation import net.corda.testing.internal.createDevIntermediateCaCertPath -import net.corda.coretesting.internal.rigorousMock -import net.corda.coretesting.internal.stubs.CertificateStoreStubs -import net.corda.nodeapi.internal.protonwrapper.netty.toRevocationConfig import org.apache.activemq.artemis.api.core.RoutingType import org.assertj.core.api.Assertions.assertThatThrownBy import org.junit.Assert.assertArrayEquals @@ -42,7 +44,11 @@ import org.junit.Test import org.junit.rules.TemporaryFolder import java.security.cert.X509Certificate import java.util.concurrent.TimeUnit -import javax.net.ssl.* +import javax.net.ssl.SSLContext +import javax.net.ssl.SSLHandshakeException +import javax.net.ssl.SSLParameters +import javax.net.ssl.SSLServerSocket +import javax.net.ssl.SSLSocket import kotlin.concurrent.thread import kotlin.test.assertEquals import kotlin.test.assertTrue @@ -144,15 +150,10 @@ class ProtonWrapperTests { sslConfig.keyStore.get(true).also { it.registerDevP2pCertificates(ALICE_NAME, rootCa.certificate, intermediateCa) } sslConfig.createTrustStore(rootCa.certificate) - val keyStore = sslConfig.keyStore.get() - val trustStore = sslConfig.trustStore.get() - val context = SSLContext.getInstance("TLS") - val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()) - keyManagerFactory.init(keyStore) + val keyManagerFactory = keyManagerFactory(sslConfig.keyStore.get()) val keyManagers = keyManagerFactory.keyManagers - val trustMgrFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()) - trustMgrFactory.init(trustStore) + val trustMgrFactory = trustManagerFactory(sslConfig.trustStore.get()) val trustManagers = trustMgrFactory.trustManagers context.init(keyManagers, trustManagers, newSecureRandom()) @@ -339,7 +340,7 @@ class ProtonWrapperTests { amqpServer.use { val connectionEvents = amqpServer.onConnection.toBlocking().iterator amqpServer.start() - val sharedThreads = NioEventLoopGroup() + val sharedThreads = NioEventLoopGroup(DefaultThreadFactory("sharedThreads")) val amqpClient1 = createSharedThreadsClient(sharedThreads, 0) val amqpClient2 = createSharedThreadsClient(sharedThreads, 1) amqpClient1.start() @@ -437,7 +438,7 @@ class ProtonWrapperTests { } artemisConfig.configureWithDevSSLCertificate() - val server = ArtemisMessagingServer(artemisConfig, NetworkHostAndPort("0.0.0.0", artemisPort), maxMessageSize, null) + val server = ArtemisMessagingServer(artemisConfig, NetworkHostAndPort("0.0.0.0", artemisPort), maxMessageSize) val client = ArtemisMessagingClient(artemisConfig.p2pSslOptions, NetworkHostAndPort("localhost", artemisPort), maxMessageSize) server.start() client.start() @@ -502,7 +503,7 @@ class ProtonWrapperTests { listOf(NetworkHostAndPort("localhost", serverPort)), setOf(ALICE_NAME), amqpConfig, - sharedThreadPool = sharedEventGroup) + nettyThreading = AMQPClient.NettyThreading.Shared(sharedEventGroup)) } private fun createServer(port: Int, diff --git a/node/src/integration-test/kotlin/net/corda/node/flows/FlowWithClientIdTest.kt b/node/src/integration-test/kotlin/net/corda/node/flows/FlowWithClientIdTest.kt index 129b54310c..89a4f99f95 100644 --- a/node/src/integration-test/kotlin/net/corda/node/flows/FlowWithClientIdTest.kt +++ b/node/src/integration-test/kotlin/net/corda/node/flows/FlowWithClientIdTest.kt @@ -1,6 +1,10 @@ package net.corda.node.flows import co.paralleluniverse.fibers.Suspendable +import com.esotericsoftware.kryo.Kryo +import com.esotericsoftware.kryo.KryoSerializable +import com.esotericsoftware.kryo.io.Input +import com.esotericsoftware.kryo.io.Output import net.corda.client.rpc.CordaRPCClient import net.corda.client.rpc.PermissionException import net.corda.core.CordaRuntimeException @@ -11,10 +15,14 @@ import net.corda.core.flows.ResultSerializationException import net.corda.core.flows.StartableByRPC import net.corda.core.flows.StateMachineRunId import net.corda.core.internal.concurrent.OpenFuture +import net.corda.core.internal.concurrent.doOnError import net.corda.core.internal.concurrent.openFuture import net.corda.core.messaging.FlowHandleWithClientId import net.corda.core.messaging.startFlow import net.corda.core.messaging.startFlowWithClientId +import net.corda.core.serialization.CordaSerializable +import net.corda.core.serialization.SerializationDefaults +import net.corda.core.serialization.deserialize import net.corda.core.utilities.getOrThrow import net.corda.core.utilities.seconds import net.corda.node.services.Permissions @@ -23,9 +31,12 @@ import net.corda.nodeapi.exceptions.RejectedCommandException import net.corda.testing.core.ALICE_NAME import net.corda.testing.driver.DriverParameters import net.corda.testing.driver.NodeHandle +import net.corda.testing.driver.NodeParameters import net.corda.testing.driver.driver import net.corda.testing.node.User +import net.corda.testing.node.internal.enclosedCordapp import org.assertj.core.api.Assertions +import org.assertj.core.api.Assertions.assertThatExceptionOfType import org.junit.Before import org.junit.Test import rx.Observable @@ -33,6 +44,7 @@ import java.time.Duration import java.time.Instant import java.util.UUID import java.util.concurrent.CountDownLatch +import java.util.concurrent.TimeUnit import java.util.concurrent.TimeoutException import kotlin.reflect.KClass import kotlin.test.assertEquals @@ -498,6 +510,124 @@ class FlowWithClientIdTest { } } + // This test is not very realistic because the scenario it happens under is also not very realistic. + @Test(timeout = 300_000) + fun `flow started with client id that fails before its first checkpoint that contains an unserializable argument will be persited as FAILED`() { + val clientId = UUID.randomUUID().toString() + driver(DriverParameters(startNodesInProcess = true, cordappsForAllNodes = listOf(enclosedCordapp()))) { + val nodeA = startNode(NodeParameters(ALICE_NAME)).getOrThrow() + val flowHandle = nodeA.rpc.startFlowWithClientId(clientId, ::QuickFailingFlow, LazyUnserializableObject()) + val reattachedFlowHandle = nodeA.rpc.reattachFlowWithClientId(clientId) + + assertThatExceptionOfType(CordaRuntimeException::class.java).isThrownBy { + flowHandle.returnValue.getOrThrow(20.seconds) + }.withMessage("I have failed quickly") + + assertThatExceptionOfType(CordaRuntimeException::class.java).isThrownBy { + reattachedFlowHandle?.returnValue?.getOrThrow() + }.withMessage("I have failed quickly") + + assertTrue(nodeA.hasStatus(flowHandle.id, Checkpoint.FlowStatus.FAILED)) + val arguments = nodeA.rpc.startFlow(::GetFlowInitialArgumentsFromMetadata, flowHandle.id).returnValue.getOrThrow(20.seconds) + assertEquals(arguments.size, 1) + assertTrue(arguments.single() is LazyUnserializableObject) + } + } + + // This test has been added to replicate the exact scenario a user experienced. + @Test(timeout = 300_000) + fun `flow started with client id that fails before its first checkpoint with subflow'd flow will be persited as FAILED`() { + val clientId = UUID.randomUUID().toString() + driver(DriverParameters(startNodesInProcess = true, cordappsForAllNodes = listOf(enclosedCordapp()))) { + val nodeA = startNode(NodeParameters(ALICE_NAME)).getOrThrow() + val flowHandle = nodeA.rpc.startFlowWithClientId(clientId, ::PassedInFailingFlow, SuperQuickFailingFlow()) + val reattachedFlowHandle = nodeA.rpc.reattachFlowWithClientId(clientId) + + assertThatExceptionOfType(CordaRuntimeException::class.java).isThrownBy { + flowHandle.returnValue.getOrThrow(20.seconds) + }.withMessage("I have failed quickly") + + assertThatExceptionOfType(CordaRuntimeException::class.java).isThrownBy { + reattachedFlowHandle?.returnValue?.getOrThrow() + }.withMessage("I have failed quickly") + + assertTrue(nodeA.hasStatus(flowHandle.id, Checkpoint.FlowStatus.FAILED)) + val arguments = nodeA.rpc.startFlow(::GetFlowInitialArgumentsFromMetadata, flowHandle.id).returnValue.getOrThrow(20.seconds) + assertEquals(arguments.size, 1) + assertTrue(arguments.single() is SuperQuickFailingFlow) + + } + } + + @Test(timeout = 300_000) + fun `flow started with client id that fails can use doOnError to process the exception`() { + val clientId = UUID.randomUUID().toString() + driver(DriverParameters(startNodesInProcess = true, cordappsForAllNodes = listOf(enclosedCordapp()))) { + val nodeA = startNode(NodeParameters(ALICE_NAME)).getOrThrow() + val flowHandle = nodeA.rpc.startFlowWithClientId(clientId, ::SuperQuickFailingFlow) + val reattachedFlowHandle = nodeA.rpc.reattachFlowWithClientId(clientId) + + val lock = CountDownLatch(1) + val reattachedLock = CountDownLatch(1) + + flowHandle.returnValue.doOnError { + lock.countDown() + } + + reattachedFlowHandle?.returnValue?.doOnError { + reattachedLock.countDown() + } + + assertTrue(lock.await(20, TimeUnit.SECONDS)) + assertTrue(reattachedLock.await(20, TimeUnit.SECONDS)) + assertTrue(flowHandle.returnValue.isDone) + assertTrue(reattachedFlowHandle!!.returnValue.isDone) + } + } + + @CordaSerializable + @StartableByRPC + internal class QuickFailingFlow(private val lazyUnserializableObject: LazyUnserializableObject) : FlowLogic() { + + @Suspendable + override fun call(): Int { + lazyUnserializableObject.prop = UnserializableObject() + throw CordaRuntimeException("I have failed quickly") + } + } + + @CordaSerializable + class LazyUnserializableObject(var prop: UnserializableObject? = null) + + @CordaSerializable + class UnserializableObject : KryoSerializable { + override fun write(kryo: Kryo?, output: Output?) { + throw IllegalStateException("Cannot be serialized") + } + + override fun read(kryo: Kryo?, input: Input?) { + throw IllegalStateException("Cannot be read") + } + } + + @StartableByRPC + internal class PassedInFailingFlow(private val flow: SuperQuickFailingFlow) : FlowLogic() { + @Suspendable + override fun call(): Int { + return subFlow(flow) + } + } + + @CordaSerializable + @StartableByRPC + internal class SuperQuickFailingFlow : FlowLogic() { + + @Suspendable + override fun call(): Int { + throw CordaRuntimeException("I have failed quickly") + } + } + @StartableByRPC internal class ResultFlow(private val result: A) : FlowLogic() { companion object { @@ -568,6 +698,24 @@ class FlowWithClientIdTest { } } + @StartableByRPC + internal class GetFlowInitialArgumentsFromMetadata(private val id: StateMachineRunId) : FlowLogic>() { + @Suspendable + override fun call(): List { + val argumentBytes = serviceHub.jdbcSession().prepareStatement("select flow_parameters from node_flow_metadata where flow_id = ?") + .apply { + setString(1, id.uuid.toString()) + } + .use { ps -> + ps.executeQuery().use { rs -> + rs.next() + rs.getBytes(1) + } + } + return argumentBytes.deserialize(context = SerializationDefaults.STORAGE_CONTEXT) + } + } + internal class UnserializableException( val unserializableObject: BrokenMap = BrokenMap() ) : CordaRuntimeException("123") diff --git a/node/src/integration-test/kotlin/net/corda/node/services/DeterministicCashIssueAndPaymentTest.kt b/node/src/integration-test/kotlin/net/corda/node/services/DeterministicCashIssueAndPaymentTest.kt index 0de1960375..9b3bdf77ef 100644 --- a/node/src/integration-test/kotlin/net/corda/node/services/DeterministicCashIssueAndPaymentTest.kt +++ b/node/src/integration-test/kotlin/net/corda/node/services/DeterministicCashIssueAndPaymentTest.kt @@ -7,6 +7,7 @@ import net.corda.core.utilities.loggerFor import net.corda.finance.DOLLARS import net.corda.finance.flows.CashIssueAndPaymentFlow import net.corda.node.DeterministicSourcesRule +import net.corda.node.services.config.NodeConfiguration import net.corda.testing.core.ALICE_NAME import net.corda.testing.core.DUMMY_NOTARY_NAME import net.corda.testing.core.singleIdentity @@ -22,20 +23,21 @@ import org.junit.jupiter.api.assertDoesNotThrow @Suppress("FunctionName") class DeterministicCashIssueAndPaymentTest { companion object { - val logger = loggerFor() + private val logger = loggerFor() + + private val configOverrides = mapOf(NodeConfiguration::reloadCheckpointAfterSuspend.name to true) + private val CASH_AMOUNT = 500.DOLLARS @ClassRule @JvmField val djvmSources = DeterministicSourcesRule() - @JvmField - val CASH_AMOUNT = 500.DOLLARS - - fun parametersFor(djvmSources: DeterministicSourcesRule): DriverParameters { + fun parametersFor(djvmSources: DeterministicSourcesRule, runInProcess: Boolean = false): DriverParameters { return DriverParameters( portAllocation = incrementalPortAllocation(), - startNodesInProcess = false, - notarySpecs = listOf(NotarySpec(DUMMY_NOTARY_NAME, validating = true)), + startNodesInProcess = runInProcess, + notarySpecs = listOf(NotarySpec(DUMMY_NOTARY_NAME, startInProcess = runInProcess, validating = true)), + notaryCustomOverrides = configOverrides, cordappsForAllNodes = listOf( findCordapp("net.corda.finance.contracts"), findCordapp("net.corda.finance.workflows") @@ -50,7 +52,7 @@ class DeterministicCashIssueAndPaymentTest { fun `test DJVM can issue cash`() { val reference = OpaqueBytes.of(0x01) driver(parametersFor(djvmSources)) { - val alice = startNode(providedName = ALICE_NAME).getOrThrow() + val alice = startNode(providedName = ALICE_NAME, customOverrides = configOverrides).getOrThrow() val aliceParty = alice.nodeInfo.singleIdentity() val notaryParty = notaryHandles.single().identity val txId = assertDoesNotThrow { @@ -60,7 +62,9 @@ class DeterministicCashIssueAndPaymentTest { aliceParty, false, notaryParty - ).returnValue.getOrThrow() + ).use { flowHandle -> + flowHandle.returnValue.getOrThrow() + } } logger.info("TX-ID: {}", txId) } diff --git a/node/src/integration-test/kotlin/net/corda/node/services/DeterministicContractCannotMutateTransactionTest.kt b/node/src/integration-test/kotlin/net/corda/node/services/DeterministicContractCannotMutateTransactionTest.kt new file mode 100644 index 0000000000..41c80ea7d9 --- /dev/null +++ b/node/src/integration-test/kotlin/net/corda/node/services/DeterministicContractCannotMutateTransactionTest.kt @@ -0,0 +1,55 @@ +package net.corda.node.services + +import net.corda.client.rpc.CordaRPCClient +import net.corda.core.messaging.startFlow +import net.corda.core.utilities.getOrThrow +import net.corda.core.utilities.loggerFor +import net.corda.flows.mutator.MutatorFlow +import net.corda.node.DeterministicSourcesRule +import net.corda.testing.core.ALICE_NAME +import net.corda.testing.core.DUMMY_NOTARY_NAME +import net.corda.testing.driver.DriverParameters +import net.corda.testing.driver.driver +import net.corda.testing.driver.internal.incrementalPortAllocation +import net.corda.testing.node.NotarySpec +import net.corda.testing.node.User +import net.corda.testing.node.internal.cordappWithPackages +import org.junit.ClassRule +import org.junit.Test + +class DeterministicContractCannotMutateTransactionTest { + companion object { + private val logger = loggerFor() + private val user = User("u", "p", setOf(Permissions.all())) + private val mutatorFlowCorDapp = cordappWithPackages("net.corda.flows.mutator").signed() + private val mutatorContractCorDapp = cordappWithPackages("net.corda.contracts.mutator").signed() + + @ClassRule + @JvmField + val djvmSources = DeterministicSourcesRule() + + fun driverParameters(runInProcess: Boolean = false): DriverParameters { + return DriverParameters( + portAllocation = incrementalPortAllocation(), + startNodesInProcess = runInProcess, + notarySpecs = listOf(NotarySpec(DUMMY_NOTARY_NAME, startInProcess = runInProcess, validating = true)), + cordappsForAllNodes = listOf(mutatorContractCorDapp, mutatorFlowCorDapp), + djvmBootstrapSource = djvmSources.bootstrap, + djvmCordaSource = djvmSources.corda + ) + } + } + + @Test(timeout = 300_000) + fun testContractCannotModifyTransaction() { + driver(driverParameters()) { + val alice = startNode(providedName = ALICE_NAME, rpcUsers = listOf(user)).getOrThrow() + val txID = CordaRPCClient(hostAndPort = alice.rpcAddress) + .start(user.username, user.password) + .use { client -> + client.proxy.startFlow(::MutatorFlow).returnValue.getOrThrow() + } + logger.info("TX-ID: {}", txID) + } + } +} diff --git a/node/src/integration-test/kotlin/net/corda/node/services/DeterministicContractCryptoTest.kt b/node/src/integration-test/kotlin/net/corda/node/services/DeterministicContractCryptoTest.kt index 275693f4d7..5d28ae41b8 100644 --- a/node/src/integration-test/kotlin/net/corda/node/services/DeterministicContractCryptoTest.kt +++ b/node/src/integration-test/kotlin/net/corda/node/services/DeterministicContractCryptoTest.kt @@ -32,11 +32,11 @@ class DeterministicContractCryptoTest { @JvmField val djvmSources = DeterministicSourcesRule() - fun parametersFor(djvmSources: DeterministicSourcesRule): DriverParameters { + fun parametersFor(djvmSources: DeterministicSourcesRule, runInProcess: Boolean = false): DriverParameters { return DriverParameters( portAllocation = incrementalPortAllocation(), - startNodesInProcess = false, - notarySpecs = listOf(NotarySpec(DUMMY_NOTARY_NAME, validating = true)), + startNodesInProcess = runInProcess, + notarySpecs = listOf(NotarySpec(DUMMY_NOTARY_NAME, startInProcess = runInProcess, validating = true)), cordappsForAllNodes = listOf( cordappWithPackages("net.corda.flows.djvm.crypto"), CustomCordapp( diff --git a/node/src/integration-test/kotlin/net/corda/node/services/DeterministicContractWithCustomSerializerTest.kt b/node/src/integration-test/kotlin/net/corda/node/services/DeterministicContractWithCustomSerializerTest.kt index 3630fbcf3c..447cd2d6a6 100644 --- a/node/src/integration-test/kotlin/net/corda/node/services/DeterministicContractWithCustomSerializerTest.kt +++ b/node/src/integration-test/kotlin/net/corda/node/services/DeterministicContractWithCustomSerializerTest.kt @@ -41,11 +41,11 @@ class DeterministicContractWithCustomSerializerTest { @JvmField val contractCordapp = cordappWithPackages("net.corda.contracts.serialization.custom").signed() - fun parametersFor(djvmSources: DeterministicSourcesRule, vararg cordapps: TestCordapp): DriverParameters { + fun parametersFor(djvmSources: DeterministicSourcesRule, cordapps: List, runInProcess: Boolean = false): DriverParameters { return DriverParameters( portAllocation = incrementalPortAllocation(), - startNodesInProcess = false, - notarySpecs = listOf(NotarySpec(DUMMY_NOTARY_NAME, validating = true)), + startNodesInProcess = runInProcess, + notarySpecs = listOf(NotarySpec(DUMMY_NOTARY_NAME, startInProcess = runInProcess, validating = true)), cordappsForAllNodes = cordapps.toList(), djvmBootstrapSource = djvmSources.bootstrap, djvmCordaSource = djvmSources.corda @@ -61,7 +61,7 @@ class DeterministicContractWithCustomSerializerTest { @Test(timeout=300_000) fun `test DJVM can verify using custom serializer`() { - driver(parametersFor(djvmSources, flowCordapp, contractCordapp)) { + driver(parametersFor(djvmSources, listOf(flowCordapp, contractCordapp))) { val alice = startNode(providedName = ALICE_NAME).getOrThrow() val txId = assertDoesNotThrow { alice.rpc.startFlow(::CustomSerializerFlow, Currantsy(GOOD_CURRANTS)) @@ -73,7 +73,7 @@ class DeterministicContractWithCustomSerializerTest { @Test(timeout=300_000) fun `test DJVM can fail verify using custom serializer`() { - driver(parametersFor(djvmSources, flowCordapp, contractCordapp)) { + driver(parametersFor(djvmSources, listOf(flowCordapp, contractCordapp))) { val alice = startNode(providedName = ALICE_NAME).getOrThrow() val currantsy = Currantsy(BAD_CURRANTS) val ex = assertThrows { diff --git a/node/src/integration-test/kotlin/net/corda/node/services/DeterministicContractWithGenericTypeTest.kt b/node/src/integration-test/kotlin/net/corda/node/services/DeterministicContractWithGenericTypeTest.kt index c3c440eaf6..d2cae60136 100644 --- a/node/src/integration-test/kotlin/net/corda/node/services/DeterministicContractWithGenericTypeTest.kt +++ b/node/src/integration-test/kotlin/net/corda/node/services/DeterministicContractWithGenericTypeTest.kt @@ -36,11 +36,11 @@ class DeterministicContractWithGenericTypeTest { @JvmField val djvmSources = DeterministicSourcesRule() - fun parameters(): DriverParameters { + fun parameters(runInProcess: Boolean = false): DriverParameters { return DriverParameters( portAllocation = incrementalPortAllocation(), - startNodesInProcess = false, - notarySpecs = listOf(NotarySpec(DUMMY_NOTARY_NAME, validating = true)), + startNodesInProcess = runInProcess, + notarySpecs = listOf(NotarySpec(DUMMY_NOTARY_NAME, startInProcess = runInProcess, validating = true)), cordappsForAllNodes = listOf( cordappWithPackages("net.corda.flows.serialization.generics").signed(), cordappWithPackages("net.corda.contracts.serialization.generics").signed() diff --git a/node/src/integration-test/kotlin/net/corda/node/services/DeterministicContractWithSerializationWhitelistTest.kt b/node/src/integration-test/kotlin/net/corda/node/services/DeterministicContractWithSerializationWhitelistTest.kt index 97ecbf014a..9b0e057453 100644 --- a/node/src/integration-test/kotlin/net/corda/node/services/DeterministicContractWithSerializationWhitelistTest.kt +++ b/node/src/integration-test/kotlin/net/corda/node/services/DeterministicContractWithSerializationWhitelistTest.kt @@ -41,11 +41,11 @@ class DeterministicContractWithSerializationWhitelistTest { @JvmField val contractCordapp = cordappWithPackages("net.corda.contracts.djvm.whitelist").signed() - fun parametersFor(djvmSources: DeterministicSourcesRule, vararg cordapps: TestCordapp): DriverParameters { + fun parametersFor(djvmSources: DeterministicSourcesRule, cordapps: List, runInProcess: Boolean = false): DriverParameters { return DriverParameters( portAllocation = incrementalPortAllocation(), - startNodesInProcess = false, - notarySpecs = listOf(NotarySpec(DUMMY_NOTARY_NAME, validating = true)), + startNodesInProcess = runInProcess, + notarySpecs = listOf(NotarySpec(DUMMY_NOTARY_NAME, startInProcess = runInProcess, validating = true)), cordappsForAllNodes = cordapps.toList(), djvmBootstrapSource = djvmSources.bootstrap, djvmCordaSource = djvmSources.corda @@ -61,7 +61,7 @@ class DeterministicContractWithSerializationWhitelistTest { @Test(timeout=300_000) fun `test DJVM can verify using whitelist`() { - driver(parametersFor(djvmSources, flowCordapp, contractCordapp)) { + driver(parametersFor(djvmSources, listOf(flowCordapp, contractCordapp))) { val alice = startNode(providedName = ALICE_NAME).getOrThrow() val txId = assertDoesNotThrow { alice.rpc.startFlow(::DeterministicWhitelistFlow, WhitelistData(GOOD_VALUE)) @@ -73,7 +73,7 @@ class DeterministicContractWithSerializationWhitelistTest { @Test(timeout=300_000) fun `test DJVM can fail verify using whitelist`() { - driver(parametersFor(djvmSources, flowCordapp, contractCordapp)) { + driver(parametersFor(djvmSources, listOf(flowCordapp, contractCordapp))) { val alice = startNode(providedName = ALICE_NAME).getOrThrow() val badData = WhitelistData(BAD_VALUE) val ex = assertThrows { diff --git a/node/src/integration-test/kotlin/net/corda/node/services/DeterministicEvilContractCannotModifyStatesTest.kt b/node/src/integration-test/kotlin/net/corda/node/services/DeterministicEvilContractCannotModifyStatesTest.kt new file mode 100644 index 0000000000..5188f0b4fd --- /dev/null +++ b/node/src/integration-test/kotlin/net/corda/node/services/DeterministicEvilContractCannotModifyStatesTest.kt @@ -0,0 +1,71 @@ +package net.corda.node.services + +import net.corda.client.rpc.CordaRPCClient +import net.corda.contracts.multiple.vulnerable.MutableDataObject +import net.corda.contracts.multiple.vulnerable.VulnerablePaymentContract +import net.corda.core.messaging.startFlow +import net.corda.core.utilities.getOrThrow +import net.corda.flows.multiple.evil.EvilFlow +import net.corda.node.DeterministicSourcesRule +import net.corda.node.internal.djvm.DeterministicVerificationException +import net.corda.testing.core.ALICE_NAME +import net.corda.testing.core.DUMMY_NOTARY_NAME +import net.corda.testing.driver.DriverParameters +import net.corda.testing.driver.driver +import net.corda.testing.driver.internal.incrementalPortAllocation +import net.corda.testing.node.NotarySpec +import net.corda.testing.node.User +import net.corda.testing.node.internal.cordappWithPackages +import org.assertj.core.api.Assertions.assertThat +import org.junit.ClassRule +import org.junit.Test +import kotlin.test.assertFailsWith + +class DeterministicEvilContractCannotModifyStatesTest { + companion object { + private val user = User("u", "p", setOf(Permissions.all())) + private val evilFlowCorDapp = cordappWithPackages("net.corda.flows.multiple.evil").signed() + private val evilContractCorDapp = cordappWithPackages("net.corda.contracts.multiple.evil").signed() + private val vulnerableContractCorDapp = cordappWithPackages("net.corda.contracts.multiple.vulnerable").signed() + + private val NOTHING = MutableDataObject(0) + + @ClassRule + @JvmField + val djvmSources = DeterministicSourcesRule() + + fun driverParameters(runInProcess: Boolean = false): DriverParameters { + return DriverParameters( + portAllocation = incrementalPortAllocation(), + startNodesInProcess = runInProcess, + notarySpecs = listOf(NotarySpec(DUMMY_NOTARY_NAME, startInProcess = runInProcess, validating = true)), + cordappsForAllNodes = listOf( + vulnerableContractCorDapp, + evilContractCorDapp, + evilFlowCorDapp + ), + djvmBootstrapSource = djvmSources.bootstrap, + djvmCordaSource = djvmSources.corda + ) + } + } + + @Test(timeout = 300_000) + fun testContractThatTriesToModifyStates() { + val evilData = MutableDataObject(5000) + driver(driverParameters()) { + val alice = startNode(providedName = ALICE_NAME, rpcUsers = listOf(user)).getOrThrow() + val ex = assertFailsWith { + CordaRPCClient(hostAndPort = alice.rpcAddress) + .start(user.username, user.password) + .use { client -> + client.proxy.startFlow(::EvilFlow, evilData).returnValue.getOrThrow() + } + } + assertThat(ex) + .hasMessageStartingWith("sandbox.net.corda.core.contracts.TransactionVerificationException\$ContractRejection -> ") + .hasMessageContaining(" Contract verification failed: Failed requirement: Purchase payment of $NOTHING should be at least ") + .hasMessageContaining(", contract: sandbox.${VulnerablePaymentContract::class.java.name}, ") + } + } +} diff --git a/node/src/integration-test/kotlin/net/corda/node/services/NonDeterministicContractVerifyTest.kt b/node/src/integration-test/kotlin/net/corda/node/services/NonDeterministicContractVerifyTest.kt index 264502e448..a200d5db41 100644 --- a/node/src/integration-test/kotlin/net/corda/node/services/NonDeterministicContractVerifyTest.kt +++ b/node/src/integration-test/kotlin/net/corda/node/services/NonDeterministicContractVerifyTest.kt @@ -35,11 +35,11 @@ class NonDeterministicContractVerifyTest { @JvmField val djvmSources = DeterministicSourcesRule() - fun parametersFor(djvmSources: DeterministicSourcesRule): DriverParameters { + fun parametersFor(djvmSources: DeterministicSourcesRule, runInProcess: Boolean = false): DriverParameters { return DriverParameters( portAllocation = incrementalPortAllocation(), - startNodesInProcess = false, - notarySpecs = listOf(NotarySpec(DUMMY_NOTARY_NAME, validating = true)), + startNodesInProcess = runInProcess, + notarySpecs = listOf(NotarySpec(DUMMY_NOTARY_NAME, startInProcess = runInProcess, validating = true)), cordappsForAllNodes = listOf( cordappWithPackages("net.corda.flows.djvm.broken"), CustomCordapp( diff --git a/node/src/integration-test/kotlin/net/corda/node/services/SandboxAttachmentsTest.kt b/node/src/integration-test/kotlin/net/corda/node/services/SandboxAttachmentsTest.kt index 85dac332a9..e868566f58 100644 --- a/node/src/integration-test/kotlin/net/corda/node/services/SandboxAttachmentsTest.kt +++ b/node/src/integration-test/kotlin/net/corda/node/services/SandboxAttachmentsTest.kt @@ -5,7 +5,6 @@ import net.corda.contracts.djvm.attachment.SandboxAttachmentContract.ExtractFile import net.corda.core.messaging.startFlow import net.corda.core.utilities.getOrThrow import net.corda.core.utilities.loggerFor -import net.corda.djvm.code.asResourcePath import net.corda.flows.djvm.attachment.SandboxAttachmentFlow import net.corda.node.DeterministicSourcesRule import net.corda.node.internal.djvm.DeterministicVerificationException @@ -32,11 +31,11 @@ class SandboxAttachmentsTest { @JvmField val djvmSources = DeterministicSourcesRule() - fun parametersFor(djvmSources: DeterministicSourcesRule): DriverParameters { + fun parametersFor(djvmSources: DeterministicSourcesRule, runInProcess: Boolean = false): DriverParameters { return DriverParameters( portAllocation = incrementalPortAllocation(), - startNodesInProcess = false, - notarySpecs = listOf(NotarySpec(DUMMY_NOTARY_NAME, validating = true)), + startNodesInProcess = runInProcess, + notarySpecs = listOf(NotarySpec(DUMMY_NOTARY_NAME, startInProcess = runInProcess, validating = true)), cordappsForAllNodes = listOf( cordappWithPackages("net.corda.flows.djvm.attachment"), CustomCordapp( @@ -52,7 +51,7 @@ class SandboxAttachmentsTest { @Test(timeout=300_000) fun `test attachment accessible within sandbox`() { - val extractFile = ExtractFile(SandboxAttachmentContract::class.java.name.asResourcePath + ".class") + val extractFile = ExtractFile(SandboxAttachmentContract::class.java.name.replace('.', '/') + ".class") driver(parametersFor(djvmSources)) { val alice = startNode(providedName = ALICE_NAME).getOrThrow() val txId = assertDoesNotThrow { diff --git a/node/src/integration-test/kotlin/net/corda/node/services/events/ScheduledFlowIntegrationTests.kt b/node/src/integration-test/kotlin/net/corda/node/services/events/ScheduledFlowIntegrationTests.kt index 09b2a5999a..dc6166e4ad 100644 --- a/node/src/integration-test/kotlin/net/corda/node/services/events/ScheduledFlowIntegrationTests.kt +++ b/node/src/integration-test/kotlin/net/corda/node/services/events/ScheduledFlowIntegrationTests.kt @@ -28,6 +28,7 @@ import net.corda.testing.node.User import net.corda.testing.node.internal.DUMMY_CONTRACTS_CORDAPP import net.corda.testing.node.internal.cordappWithPackages import net.corda.testing.node.internal.enclosedCordapp +import org.junit.Ignore import org.junit.Assume import org.junit.Test import java.time.Instant @@ -100,6 +101,7 @@ class ScheduledFlowIntegrationTests { } } + @Ignore("ENT-5891: Unstable test we're not addressing in Corda 4.x") @Test(timeout=300_000) fun `test that when states are being spent at the same time that schedules trigger everything is processed`() { Assume.assumeFalse(IS_S390X) diff --git a/node/src/integration-test/kotlin/net/corda/node/services/messaging/ArtemisMessagingTest.kt b/node/src/integration-test/kotlin/net/corda/node/services/messaging/ArtemisMessagingTest.kt index 5d650942a8..56af87c83f 100644 --- a/node/src/integration-test/kotlin/net/corda/node/services/messaging/ArtemisMessagingTest.kt +++ b/node/src/integration-test/kotlin/net/corda/node/services/messaging/ArtemisMessagingTest.kt @@ -7,6 +7,8 @@ import net.corda.core.crypto.generateKeyPair import net.corda.core.internal.div import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.seconds +import net.corda.coretesting.internal.rigorousMock +import net.corda.coretesting.internal.stubs.CertificateStoreStubs import net.corda.node.services.config.FlowTimeoutConfiguration import net.corda.node.services.config.NodeConfiguration import net.corda.node.services.config.configureWithDevSSLCertificate @@ -22,8 +24,6 @@ import net.corda.testing.driver.internal.incrementalPortAllocation import net.corda.testing.internal.LogHelper import net.corda.testing.internal.TestingNamedCacheFactory import net.corda.testing.internal.configureDatabase -import net.corda.coretesting.internal.rigorousMock -import net.corda.coretesting.internal.stubs.CertificateStoreStubs import net.corda.testing.node.MockServices.Companion.makeTestDataSourceProperties import net.corda.testing.node.internal.MOCK_VERSION_INFO import org.apache.activemq.artemis.api.core.ActiveMQConnectionTimedOutException @@ -57,7 +57,6 @@ class ArtemisMessagingTest { @JvmField val temporaryFolder = TemporaryFolder() - // THe private val portAllocation = incrementalPortAllocation() private val serverPort = portAllocation.nextPort() private val identity = generateKeyPair() @@ -200,7 +199,9 @@ class ArtemisMessagingTest { messagingClient!!.start(identity.public, null, maxMessageSize) } - private fun createAndStartClientAndServer(platformVersion: Int = 1, serverMaxMessageSize: Int = MAX_MESSAGE_SIZE, clientMaxMessageSize: Int = MAX_MESSAGE_SIZE): Pair> { + private fun createAndStartClientAndServer(platformVersion: Int = 1, + serverMaxMessageSize: Int = MAX_MESSAGE_SIZE, + clientMaxMessageSize: Int = MAX_MESSAGE_SIZE): Pair> { val receivedMessages = LinkedBlockingQueue() createMessagingServer(maxMessageSize = serverMaxMessageSize).start() @@ -239,7 +240,7 @@ class ArtemisMessagingTest { } private fun createMessagingServer(local: Int = serverPort, maxMessageSize: Int = MAX_MESSAGE_SIZE): ArtemisMessagingServer { - return ArtemisMessagingServer(config, NetworkHostAndPort("0.0.0.0", local), maxMessageSize, null).apply { + return ArtemisMessagingServer(config, NetworkHostAndPort("0.0.0.0", local), maxMessageSize, trace = true).apply { config.configureWithDevSSLCertificate() messagingServer = this } diff --git a/node/src/integration-test/kotlin/net/corda/node/services/rpc/DumpCheckpointsTest.kt b/node/src/integration-test/kotlin/net/corda/node/services/rpc/DumpCheckpointsTest.kt index 21ac40a96d..a4582f6740 100644 --- a/node/src/integration-test/kotlin/net/corda/node/services/rpc/DumpCheckpointsTest.kt +++ b/node/src/integration-test/kotlin/net/corda/node/services/rpc/DumpCheckpointsTest.kt @@ -14,8 +14,11 @@ import net.corda.core.internal.list import net.corda.core.internal.readFully import net.corda.core.messaging.startFlow import net.corda.core.utilities.getOrThrow +import net.corda.core.utilities.minutes +import net.corda.core.utilities.seconds import net.corda.node.internal.NodeStartup import net.corda.node.services.Permissions +import net.corda.node.services.statemachine.Checkpoint import net.corda.node.services.statemachine.CountUpDownLatch import net.corda.testing.core.ALICE_NAME import net.corda.testing.driver.DriverParameters @@ -36,8 +39,8 @@ class DumpCheckpointsTest { private val flowProceedLatch = CountUpDownLatch(1) } - @Test(timeout=300_000) - fun `verify checkpoint dump via RPC`() { + @Test(timeout = 300_000) + fun `verify checkpoint dump via RPC`() { val user = User("mark", "dadada", setOf(Permissions.all())) driver(DriverParameters(notarySpecs = emptyList(), startNodesInProcess = true, cordappsForAllNodes = listOf(enclosedCordapp()))) { @@ -55,20 +58,44 @@ class DumpCheckpointsTest { flowProceedLatch.countDown() assertEquals(1, checkPointCountFuture.get()) - checkDumpFile(logDirPath) + checkDumpFile(logDirPath, GetNumberOfCheckpointsFlow::class.java, Checkpoint.FlowStatus.RUNNABLE) } } } - private fun checkDumpFile(dir: Path) { + @Test(timeout = 300_000) + fun `paused flows included in checkpoint dump output`() { + val user = User("mark", "dadada", setOf(Permissions.all())) + driver(DriverParameters(notarySpecs = emptyList(), startNodesInProcess = true, cordappsForAllNodes = listOf(enclosedCordapp()))) { + + val nodeAHandle = startNode(providedName = ALICE_NAME, rpcUsers = listOf(user)).getOrThrow() + + CordaRPCClient(nodeAHandle.rpcAddress).start(user.username, user.password).use { + + it.proxy.startFlow(::EasyFlow) + + // Hack to get the flow to show as paused + it.proxy.startFlow(::SetAllFlowsToPausedFlow).returnValue.getOrThrow(10.seconds) + + val logDirPath = nodeAHandle.baseDirectory / NodeStartup.LOGS_DIRECTORY_NAME + logDirPath.createDirectories() + nodeAHandle.checkpointsRpc.use { checkpointRPCOps -> checkpointRPCOps.dumpCheckpoints() } + + checkDumpFile(logDirPath, EasyFlow::class.java, Checkpoint.FlowStatus.PAUSED) + } + } + } + + private fun checkDumpFile(dir: Path, containsClass: Class>, flowStatus: Checkpoint.FlowStatus) { // The directory supposed to contain a single ZIP file val file = dir.list().single { it.isRegularFile() } ZipInputStream(file.inputStream()).use { zip -> val entry = zip.nextEntry assertThat(entry.name, containsSubstring("json")) - val content = zip.readFully() - assertThat(String(content), containsSubstring(GetNumberOfCheckpointsFlow::class.java.name)) + val content = String(zip.readFully()) + assertThat(content, containsSubstring(containsClass.name)) + assertThat(content, containsSubstring(flowStatus.name)) } } @@ -94,4 +121,24 @@ class DumpCheckpointsTest { flowProceedLatch.await() } } + + @StartableByRPC + class EasyFlow : FlowLogic() { + @Suspendable + override fun call(): Int { + sleep(2.minutes) + return 1 + } + } + + @StartableByRPC + class SetAllFlowsToPausedFlow : FlowLogic() { + @Suspendable + override fun call(): Int { + return serviceHub + .jdbcSession() + .prepareStatement("UPDATE node_checkpoints SET status = '${Checkpoint.FlowStatus.PAUSED.ordinal}'") + .use { ps -> ps.executeUpdate() } + } + } } \ No newline at end of file diff --git a/node/src/integration-test/kotlin/net/corda/services/messaging/SimpleAMQPClient.kt b/node/src/integration-test/kotlin/net/corda/services/messaging/SimpleAMQPClient.kt index bb3c86e9de..da3e831bda 100644 --- a/node/src/integration-test/kotlin/net/corda/services/messaging/SimpleAMQPClient.kt +++ b/node/src/integration-test/kotlin/net/corda/services/messaging/SimpleAMQPClient.kt @@ -3,6 +3,8 @@ package net.corda.services.messaging import net.corda.core.internal.concurrent.openFuture import net.corda.core.utilities.NetworkHostAndPort import net.corda.nodeapi.internal.config.MutualSslConfiguration +import net.corda.nodeapi.internal.protonwrapper.netty.keyManagerFactory +import net.corda.nodeapi.internal.protonwrapper.netty.trustManagerFactory import org.apache.qpid.jms.JmsConnectionFactory import org.apache.qpid.jms.meta.JmsConnectionInfo import org.apache.qpid.jms.provider.Provider @@ -24,9 +26,7 @@ import javax.jms.Connection import javax.jms.Message import javax.jms.MessageProducer import javax.jms.Session -import javax.net.ssl.KeyManagerFactory import javax.net.ssl.SSLContext -import javax.net.ssl.TrustManagerFactory /** * Simple AMQP client connecting to broker using JMS. @@ -59,12 +59,8 @@ class SimpleAMQPClient(private val target: NetworkHostAndPort, private val confi private lateinit var connection: Connection private fun sslContext(): SSLContext { - val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()).apply { - init(config.keyStore.get().value.internal, config.keyStore.entryPassword.toCharArray()) - } - val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()).apply { - init(config.trustStore.get().value.internal) - } + val keyManagerFactory = keyManagerFactory(config.keyStore.get()) + val trustManagerFactory = trustManagerFactory(config.trustStore.get()) val sslContext = SSLContext.getInstance("TLS") val keyManagers = keyManagerFactory.keyManagers val trustManagers = trustManagerFactory.trustManagers diff --git a/node/src/integration-test/kotlin/net/corda/services/messaging/SimpleMQClient.kt b/node/src/integration-test/kotlin/net/corda/services/messaging/SimpleMQClient.kt index fa5fc09d53..b52422fff0 100644 --- a/node/src/integration-test/kotlin/net/corda/services/messaging/SimpleMQClient.kt +++ b/node/src/integration-test/kotlin/net/corda/services/messaging/SimpleMQClient.kt @@ -22,7 +22,7 @@ class SimpleMQClient(val target: NetworkHostAndPort, lateinit var producer: ClientProducer fun start(username: String? = null, password: String? = null, enableSSL: Boolean = true) { - val tcpTransport = p2pConnectorTcpTransport(target, config, enableSSL = enableSSL) + val tcpTransport = p2pConnectorTcpTransport(target, config, enableSSL = enableSSL, threadPoolName = "SimpleMQClient") val locator = ActiveMQClient.createServerLocatorWithoutHA(tcpTransport).apply { isBlockOnNonDurableSend = true threadPoolMaxSize = 1 diff --git a/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt b/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt index ef8fd825a6..ff47272d64 100644 --- a/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt +++ b/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt @@ -4,8 +4,8 @@ import co.paralleluniverse.fibers.instrument.Retransform import com.codahale.metrics.MetricRegistry import com.google.common.collect.MutableClassToInstanceMap import com.google.common.util.concurrent.MoreExecutors -import com.google.common.util.concurrent.ThreadFactoryBuilder import com.zaxxer.hikari.pool.HikariPool +import io.netty.util.concurrent.DefaultThreadFactory import net.corda.common.logging.errorReporting.NodeDatabaseErrors import net.corda.confidential.SwapIdentitiesFlow import net.corda.core.CordaException @@ -67,6 +67,7 @@ import net.corda.core.toFuture import net.corda.core.transactions.LedgerTransaction import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.days +import net.corda.core.utilities.millis import net.corda.core.utilities.minutes import net.corda.djvm.source.ApiSource import net.corda.djvm.source.EmptyApi @@ -166,6 +167,7 @@ import net.corda.nodeapi.internal.persistence.RestrictedEntityManager import net.corda.nodeapi.internal.persistence.SchemaMigration import net.corda.nodeapi.internal.persistence.contextDatabase import net.corda.nodeapi.internal.persistence.withoutDatabaseAccess +import net.corda.nodeapi.internal.namedThreadPoolExecutor import net.corda.tools.shell.InteractiveShell import org.apache.activemq.artemis.utils.ReusableLatch import org.jolokia.jvmagent.JolokiaServer @@ -178,18 +180,14 @@ import java.sql.Savepoint import java.time.Clock import java.time.Duration import java.time.format.DateTimeParseException -import java.util.* +import java.util.Properties import java.util.concurrent.ExecutorService import java.util.concurrent.Executors -import java.util.concurrent.LinkedBlockingQueue -import java.util.concurrent.ThreadPoolExecutor -import java.util.concurrent.TimeUnit import java.util.concurrent.TimeUnit.MINUTES import java.util.concurrent.TimeUnit.SECONDS import java.util.function.Consumer import javax.persistence.EntityManager import javax.sql.DataSource -import kotlin.collections.ArrayList /** * A base node implementation that can be customised either for production (with real implementations that do real @@ -337,7 +335,7 @@ abstract class AbstractNode(val configuration: NodeConfiguration, private val schedulerService = makeNodeSchedulerService() private val cordappServices = MutableClassToInstanceMap.create() - private val shutdownExecutor = Executors.newSingleThreadExecutor() + private val shutdownExecutor = Executors.newSingleThreadExecutor(DefaultThreadFactory("Shutdown")) protected abstract val transactionVerifierWorkerCount: Int /** @@ -773,7 +771,7 @@ abstract class AbstractNode(val configuration: NodeConfiguration, } else { 1.days } - val executor = Executors.newSingleThreadScheduledExecutor(NamedThreadFactory("Network Map Updater")) + val executor = Executors.newSingleThreadScheduledExecutor(NamedThreadFactory("NetworkMapPublisher")) executor.submit(object : Runnable { override fun run() { val republishInterval = try { @@ -882,13 +880,12 @@ abstract class AbstractNode(val configuration: NodeConfiguration, } // Start with 1 thread and scale up to the configured thread pool size if needed // Parameters of [ThreadPoolExecutor] based on [Executors.newFixedThreadPool] - return ThreadPoolExecutor( - 1, - numberOfThreads, - 0L, - TimeUnit.MILLISECONDS, - LinkedBlockingQueue(), - ThreadFactoryBuilder().setNameFormat("flow-external-operation-thread").setDaemon(true).build() + return namedThreadPoolExecutor( + corePoolSize = 1, + maxPoolSize = numberOfThreads, + idleKeepAlive = 0.millis, + poolName = "flow-external-operation-thread", + daemonThreads = true ) } @@ -1026,13 +1023,23 @@ abstract class AbstractNode(val configuration: NodeConfiguration, service.run { tokenize() runOnStop += ::stop - flowManager.registerInitiatedCoreFlowFactory(NotaryFlow.Client::class, ::createServiceFlow) + registerInitiatingFlows() start() } return service } } + private fun NotaryService.registerInitiatingFlows() { + if (configuration.notary?.enableOverridableFlows == true) { + initiatingFlows.forEach { (flow, factory) -> + flowManager.registerInitiatedCoreFlowFactory(flow, factory) + } + } else { + flowManager.registerInitiatedCoreFlowFactory(NotaryFlow.Client::class, ::createServiceFlow) + } + } + protected open fun makeKeyManagementService(identityService: PersistentIdentityService): KeyManagementServiceInternal { // Place the long term identity key in the KMS. Eventually, this is likely going to be separated again because // the KMS is meant for derived temporary keys used in transactions, and we're not supposed to sign things with @@ -1071,7 +1078,7 @@ abstract class AbstractNode(val configuration: NodeConfiguration, networkParameters: NetworkParameters) protected open fun makeVaultService(keyManagementService: KeyManagementService, - services: ServicesForResolution, + services: NodeServicesForResolution, database: CordaPersistence, cordappLoader: CordappLoader): VaultServiceInternal { return NodeVaultService(platformClock, keyManagementService, services, database, schemaService, cordappLoader.appClassLoader) diff --git a/node/src/main/kotlin/net/corda/node/internal/Node.kt b/node/src/main/kotlin/net/corda/node/internal/Node.kt index 504627dad2..bd3b6cb744 100644 --- a/node/src/main/kotlin/net/corda/node/internal/Node.kt +++ b/node/src/main/kotlin/net/corda/node/internal/Node.kt @@ -73,6 +73,7 @@ import net.corda.node.utilities.DemoClock import net.corda.node.utilities.errorAndTerminate import net.corda.nodeapi.internal.ArtemisMessagingClient import net.corda.common.logging.errorReporting.NodeDatabaseErrors +import net.corda.node.internal.classloading.scanForCustomSerializationScheme import net.corda.nodeapi.internal.ShutdownHook import net.corda.nodeapi.internal.addShutdownHook import net.corda.nodeapi.internal.bridging.BridgeControlListener @@ -414,12 +415,13 @@ open class Node(configuration: NodeConfiguration, } private fun makeBridgeControlListener(serverAddress: NetworkHostAndPort, networkParameters: NetworkParameters): BridgeControlListener { - val artemisMessagingClientFactory = { + val artemisMessagingClientFactory = { threadPoolName: String -> ArtemisMessagingClient( configuration.p2pSslOptions, serverAddress, networkParameters.maxMessageSize, - failoverCallback = { errorAndTerminate("ArtemisMessagingClient failed. Shutting down.", null) } + failoverCallback = { errorAndTerminate("ArtemisMessagingClient failed. Shutting down.", null) }, + threadPoolName = threadPoolName ) } return BridgeControlListener( @@ -430,7 +432,8 @@ open class Node(configuration: NodeConfiguration, networkParameters.maxMessageSize, configuration.crlCheckSoftFail.toRevocationConfig(), false, - artemisMessagingClientFactory) + artemisMessagingClientFactory + ) } private fun startLocalRpcBroker(securityManager: RPCSecurityManager): BrokerAddresses? { @@ -647,10 +650,14 @@ open class Node(configuration: NodeConfiguration, private fun initialiseSerialization() { if (!initialiseSerialization) return val classloader = cordappLoader.appClassLoader + val customScheme = System.getProperty("experimental.corda.customSerializationScheme")?.let { + scanForCustomSerializationScheme(it, classloader) + } nodeSerializationEnv = SerializationEnvironment.with( SerializationFactoryImpl().apply { registerScheme(AMQPServerSerializationScheme(cordappLoader.cordapps, Caffeine.newBuilder().maximumSize(128).build().asMap())) registerScheme(AMQPClientSerializationScheme(cordappLoader.cordapps, Caffeine.newBuilder().maximumSize(128).build().asMap())) + customScheme?.let{ registerScheme(it) } }, p2pContext = AMQP_P2P_CONTEXT.withClassLoader(classloader), rpcServerContext = AMQP_RPC_SERVER_CONTEXT.withClassLoader(classloader), diff --git a/node/src/main/kotlin/net/corda/node/internal/NodeServicesForResolution.kt b/node/src/main/kotlin/net/corda/node/internal/NodeServicesForResolution.kt new file mode 100644 index 0000000000..5baa528297 --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/internal/NodeServicesForResolution.kt @@ -0,0 +1,15 @@ +package net.corda.node.internal + +import net.corda.core.contracts.ContractState +import net.corda.core.contracts.StateAndRef +import net.corda.core.contracts.StateRef +import net.corda.core.contracts.TransactionResolutionException +import net.corda.core.node.ServicesForResolution +import java.util.LinkedHashSet + +interface NodeServicesForResolution : ServicesForResolution { + @Throws(TransactionResolutionException::class) + override fun loadStates(stateRefs: Set): Set> = loadStates(stateRefs, LinkedHashSet()) + + fun >> loadStates(input: Iterable, output: C): C +} diff --git a/node/src/main/kotlin/net/corda/node/internal/ServicesForResolutionImpl.kt b/node/src/main/kotlin/net/corda/node/internal/ServicesForResolutionImpl.kt index f5836c0cc5..ffb21894c1 100644 --- a/node/src/main/kotlin/net/corda/node/internal/ServicesForResolutionImpl.kt +++ b/node/src/main/kotlin/net/corda/node/internal/ServicesForResolutionImpl.kt @@ -1,16 +1,26 @@ package net.corda.node.internal -import net.corda.core.contracts.* +import net.corda.core.contracts.Attachment +import net.corda.core.contracts.AttachmentResolutionException +import net.corda.core.contracts.ContractAttachment +import net.corda.core.contracts.ContractState +import net.corda.core.contracts.StateAndRef +import net.corda.core.contracts.StateRef +import net.corda.core.contracts.TransactionResolutionException +import net.corda.core.contracts.TransactionState import net.corda.core.cordapp.CordappProvider +import net.corda.core.crypto.SecureHash import net.corda.core.internal.SerializedStateAndRef +import net.corda.core.internal.uncheckedCast import net.corda.core.node.NetworkParameters -import net.corda.core.node.ServicesForResolution import net.corda.core.node.services.AttachmentStorage import net.corda.core.node.services.IdentityService import net.corda.core.node.services.NetworkParametersService import net.corda.core.node.services.TransactionStorage +import net.corda.core.transactions.BaseTransaction import net.corda.core.transactions.ContractUpgradeWireTransaction import net.corda.core.transactions.NotaryChangeWireTransaction +import net.corda.core.transactions.SignedTransaction import net.corda.core.transactions.WireTransaction import net.corda.core.transactions.WireTransaction.Companion.resolveStateRefBinaryComponent @@ -20,31 +30,28 @@ data class ServicesForResolutionImpl( override val cordappProvider: CordappProvider, override val networkParametersService: NetworkParametersService, private val validatedTransactions: TransactionStorage -) : ServicesForResolution { +) : NodeServicesForResolution { override val networkParameters: NetworkParameters get() = networkParametersService.lookup(networkParametersService.currentHash) ?: throw IllegalArgumentException("No current parameters in network parameters storage") @Throws(TransactionResolutionException::class) override fun loadState(stateRef: StateRef): TransactionState<*> { - val stx = validatedTransactions.getTransaction(stateRef.txhash) ?: throw TransactionResolutionException(stateRef.txhash) - return stx.resolveBaseTransaction(this).outputs[stateRef.index] + return toBaseTransaction(stateRef.txhash).outputs[stateRef.index] } - @Throws(TransactionResolutionException::class) - override fun loadStates(stateRefs: Set): Set> { - return stateRefs.groupBy { it.txhash }.flatMap { - val stx = validatedTransactions.getTransaction(it.key) ?: throw TransactionResolutionException(it.key) - val baseTx = stx.resolveBaseTransaction(this) - it.value.map { ref -> StateAndRef(baseTx.outputs[ref.index], ref) } - }.toSet() + override fun >> loadStates(input: Iterable, output: C): C { + val baseTxs = HashMap() + return input.mapTo(output) { stateRef -> + val baseTx = baseTxs.computeIfAbsent(stateRef.txhash, ::toBaseTransaction) + StateAndRef(uncheckedCast(baseTx.outputs[stateRef.index]), stateRef) + } } @Throws(TransactionResolutionException::class, AttachmentResolutionException::class) override fun loadContractAttachment(stateRef: StateRef): Attachment { // We may need to recursively chase transactions if there are notary changes. fun inner(stateRef: StateRef, forContractClassName: String?): Attachment { - val ctx = validatedTransactions.getTransaction(stateRef.txhash)?.coreTransaction - ?: throw TransactionResolutionException(stateRef.txhash) + val ctx = getSignedTransaction(stateRef.txhash).coreTransaction when (ctx) { is WireTransaction -> { val transactionState = ctx.outRef(stateRef.index).state @@ -69,4 +76,10 @@ data class ServicesForResolutionImpl( } return inner(stateRef, null) } + + private fun toBaseTransaction(txhash: SecureHash): BaseTransaction = getSignedTransaction(txhash).resolveBaseTransaction(this) + + private fun getSignedTransaction(txhash: SecureHash): SignedTransaction { + return validatedTransactions.getTransaction(txhash) ?: throw TransactionResolutionException(txhash) + } } diff --git a/node/src/main/kotlin/net/corda/node/internal/artemis/BrokerJaasLoginModule.kt b/node/src/main/kotlin/net/corda/node/internal/artemis/BrokerJaasLoginModule.kt index c146629364..9658fe2e53 100644 --- a/node/src/main/kotlin/net/corda/node/internal/artemis/BrokerJaasLoginModule.kt +++ b/node/src/main/kotlin/net/corda/node/internal/artemis/BrokerJaasLoginModule.kt @@ -135,12 +135,12 @@ class BrokerJaasLoginModule : BaseBrokerJaasLoginModule() { Pair(ArtemisMessagingComponent.NODE_RPC_USER, listOf(RolePrincipal(NODE_RPC_ROLE))) } ArtemisMessagingComponent.PEER_USER -> { - requireNotNull(p2pJaasConfig) { "Attempted to connect as a peer to the rpc broker." } + val p2pJaasConfig = requireNotNull(p2pJaasConfig) { "Attempted to connect as a peer to the rpc broker." } requireTls(certificates) // This check is redundant as it was performed already during the SSL handshake - CertificateChainCheckPolicy.RootMustMatch.createCheck(p2pJaasConfig!!.keyStore, p2pJaasConfig!!.trustStore).checkCertificateChain(certificates!!) - CertificateChainCheckPolicy.RevocationCheck(p2pJaasConfig!!.revocationMode) - .createCheck(p2pJaasConfig!!.keyStore, p2pJaasConfig!!.trustStore).checkCertificateChain(certificates) + CertificateChainCheckPolicy.RootMustMatch + .createCheck(p2pJaasConfig.keyStore, p2pJaasConfig.trustStore) + .checkCertificateChain(certificates!!) Pair(certificates.first().subjectDN.name, listOf(RolePrincipal(PEER_ROLE))) } else -> { diff --git a/node/src/main/kotlin/net/corda/node/internal/artemis/CertificateChainCheckPolicy.kt b/node/src/main/kotlin/net/corda/node/internal/artemis/CertificateChainCheckPolicy.kt index 90a44f9c55..de1ac38bc8 100644 --- a/node/src/main/kotlin/net/corda/node/internal/artemis/CertificateChainCheckPolicy.kt +++ b/node/src/main/kotlin/net/corda/node/internal/artemis/CertificateChainCheckPolicy.kt @@ -2,18 +2,9 @@ package net.corda.node.internal.artemis import net.corda.core.identity.CordaX500Name import net.corda.core.utilities.contextLogger -import net.corda.nodeapi.internal.crypto.X509CertificateFactory import net.corda.nodeapi.internal.crypto.X509Utilities -import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfig -import net.corda.nodeapi.internal.protonwrapper.netty.certPathToString import java.security.KeyStore -import java.security.cert.CertPathValidator -import java.security.cert.CertPathValidatorException import java.security.cert.CertificateException -import java.security.cert.PKIXBuilderParameters -import java.security.cert.PKIXRevocationChecker -import java.security.cert.X509CertSelector -import java.util.EnumSet sealed class CertificateChainCheckPolicy { companion object { @@ -22,7 +13,6 @@ sealed class CertificateChainCheckPolicy { @FunctionalInterface interface Check { - @Suppress("DEPRECATION") // should use java.security.cert.X509Certificate fun checkCertificateChain(theirChain: Array) } @@ -31,7 +21,6 @@ sealed class CertificateChainCheckPolicy { object Any : CertificateChainCheckPolicy() { override fun createCheck(keyStore: KeyStore, trustStore: KeyStore): Check { return object : Check { - @Suppress("DEPRECATION") // should use java.security.cert.X509Certificate override fun checkCertificateChain(theirChain: Array) { // nothing to do here } @@ -44,7 +33,6 @@ sealed class CertificateChainCheckPolicy { val rootAliases = trustStore.aliases().asSequence().filter { it.startsWith(X509Utilities.CORDA_ROOT_CA) } val rootPublicKeys = rootAliases.map { trustStore.getCertificate(it).publicKey }.toSet() return object : Check { - @Suppress("DEPRECATION") // should use java.security.cert.X509Certificate override fun checkCertificateChain(theirChain: Array) { val theirRoot = theirChain.last().publicKey if (theirRoot !in rootPublicKeys) { @@ -59,7 +47,6 @@ sealed class CertificateChainCheckPolicy { override fun createCheck(keyStore: KeyStore, trustStore: KeyStore): Check { val ourPublicKey = keyStore.getCertificate(X509Utilities.CORDA_CLIENT_TLS).publicKey return object : Check { - @Suppress("DEPRECATION") // should use java.security.cert.X509Certificate override fun checkCertificateChain(theirChain: Array) { val theirLeaf = theirChain.first().publicKey if (ourPublicKey != theirLeaf) { @@ -74,7 +61,6 @@ sealed class CertificateChainCheckPolicy { override fun createCheck(keyStore: KeyStore, trustStore: KeyStore): Check { val trustedPublicKeys = trustedAliases.map { trustStore.getCertificate(it).publicKey }.toSet() return object : Check { - @Suppress("DEPRECATION") // should use java.security.cert.X509Certificate override fun checkCertificateChain(theirChain: Array) { if (!theirChain.any { it.publicKey in trustedPublicKeys }) { throw CertificateException("Their certificate chain contained none of the trusted ones") @@ -92,52 +78,10 @@ sealed class CertificateChainCheckPolicy { class UsernameMustMatchCommonNameCheck : Check { lateinit var username: String - @Suppress("DEPRECATION") // should use java.security.cert.X509Certificate override fun checkCertificateChain(theirChain: Array) { if (!theirChain.any { certificate -> CordaX500Name.parse(certificate.subjectDN.name).commonName == username }) { throw CertificateException("Client certificate does not match login username.") } } } - - class RevocationCheck(val revocationMode: RevocationConfig.Mode) : CertificateChainCheckPolicy() { - override fun createCheck(keyStore: KeyStore, trustStore: KeyStore): Check { - return object : Check { - @Suppress("DEPRECATION") // should use java.security.cert.X509Certificate - override fun checkCertificateChain(theirChain: Array) { - if (revocationMode == RevocationConfig.Mode.OFF) { - return - } - // Convert javax.security.cert.X509Certificate to java.security.cert.X509Certificate. - val chain = theirChain.map { X509CertificateFactory().generateCertificate(it.encoded.inputStream()) } - log.info("Check Client Certpath:\r\n${certPathToString(chain.toTypedArray())}") - - // Drop the last certificate which must be a trusted root (validated by RootMustMatch). - // Assume that there is no more trusted roots (or corresponding public keys) in the remaining chain. - // See PKIXValidator.engineValidate() for reference implementation. - val certPath = X509Utilities.buildCertPath(chain.dropLast(1)) - val certPathValidator = CertPathValidator.getInstance("PKIX") - val pkixRevocationChecker = certPathValidator.revocationChecker as PKIXRevocationChecker - pkixRevocationChecker.options = EnumSet.of( - // Prefer CRL over OCSP - PKIXRevocationChecker.Option.PREFER_CRLS, - // Don't fall back to OCSP checking - PKIXRevocationChecker.Option.NO_FALLBACK) - if (revocationMode == RevocationConfig.Mode.SOFT_FAIL) { - // Allow revocation check to succeed if the revocation status cannot be determined for one of - // the following reasons: The CRL or OCSP response cannot be obtained because of a network error. - pkixRevocationChecker.options = pkixRevocationChecker.options + PKIXRevocationChecker.Option.SOFT_FAIL - } - val params = PKIXBuilderParameters(trustStore, X509CertSelector()) - params.addCertPathChecker(pkixRevocationChecker) - try { - certPathValidator.validate(certPath, params) - } catch (ex: CertPathValidatorException) { - log.error("Bad certificate path", ex) - throw ex - } - } - } - } - } } \ No newline at end of file diff --git a/node/src/main/kotlin/net/corda/node/internal/classloading/Utils.kt b/node/src/main/kotlin/net/corda/node/internal/classloading/Utils.kt index d89dc3a455..bb49eeb179 100644 --- a/node/src/main/kotlin/net/corda/node/internal/classloading/Utils.kt +++ b/node/src/main/kotlin/net/corda/node/internal/classloading/Utils.kt @@ -2,6 +2,31 @@ package net.corda.node.internal.classloading +import net.corda.core.serialization.CustomSerializationScheme +import net.corda.node.internal.ConfigurationException +import net.corda.nodeapi.internal.serialization.CustomSerializationSchemeAdapter +import net.corda.serialization.internal.SerializationScheme +import java.lang.reflect.Constructor + inline fun Class<*>.requireAnnotation(): A { return requireNotNull(getDeclaredAnnotation(A::class.java)) { "$name needs to be annotated with ${A::class.java.name}" } +} + +fun scanForCustomSerializationScheme(className: String, classLoader: ClassLoader) : SerializationScheme { + val schemaClass = try { + Class.forName(className, false, classLoader) + } catch (exception: ClassNotFoundException) { + throw ConfigurationException("$className was declared as a custom serialization scheme but could not be found.") + } + val constructor = validateScheme(schemaClass, className) + return CustomSerializationSchemeAdapter(constructor.newInstance() as CustomSerializationScheme) +} + +private fun validateScheme(clazz: Class<*>, className: String): Constructor<*> { + if (!clazz.interfaces.contains(CustomSerializationScheme::class.java)) { + throw ConfigurationException("$className was declared as a custom serialization scheme but does not implement" + + " ${CustomSerializationScheme::class.java.canonicalName}") + } + return clazz.constructors.singleOrNull { it.parameters.isEmpty() } ?: throw ConfigurationException("$className was declared as a " + + "custom serialization scheme but does not have a no argument constructor.") } \ No newline at end of file diff --git a/node/src/main/kotlin/net/corda/node/internal/djvm/AttachmentFactory.kt b/node/src/main/kotlin/net/corda/node/internal/djvm/AttachmentFactory.kt index 9a616aa813..d272a7428e 100644 --- a/node/src/main/kotlin/net/corda/node/internal/djvm/AttachmentFactory.kt +++ b/node/src/main/kotlin/net/corda/node/internal/djvm/AttachmentFactory.kt @@ -1,6 +1,7 @@ package net.corda.node.internal.djvm import net.corda.core.contracts.Attachment +import net.corda.core.contracts.ContractAttachment import net.corda.core.serialization.serialize import net.corda.djvm.rewiring.SandboxClassLoader import net.corda.node.djvm.AttachmentBuilder @@ -19,14 +20,30 @@ class AttachmentFactory( fun toSandbox(attachments: List): Any? { val builder = taskFactory.apply(AttachmentBuilder::class.java) for (attachment in attachments) { - builder.apply(arrayOf( - serializer.deserialize(attachment.signerKeys.serialize()), - sandboxBasicInput.apply(attachment.size), - serializer.deserialize(attachment.id.serialize()), - attachment, - sandboxOpenAttachment - )) + builder.apply(generateArgsFor(attachment)) } return builder.apply(null) } + + private fun generateArgsFor(attachment: Attachment): Array { + val signerKeys = serializer.deserialize(attachment.signerKeys.serialize()) + val id = serializer.deserialize(attachment.id.serialize()) + val size = sandboxBasicInput.apply(attachment.size) + return if (attachment is ContractAttachment) { + val underlyingAttachment = attachment.attachment + arrayOf( + serializer.deserialize(underlyingAttachment.signerKeys.serialize()), + size, id, + underlyingAttachment, + sandboxOpenAttachment, + sandboxBasicInput.apply(attachment.contract), + sandboxBasicInput.apply(attachment.additionalContracts.toTypedArray()), + sandboxBasicInput.apply(attachment.uploader), + signerKeys, + sandboxBasicInput.apply(attachment.version) + ) + } else { + arrayOf(signerKeys, size, id, attachment, sandboxOpenAttachment) + } + } } diff --git a/node/src/main/kotlin/net/corda/node/internal/djvm/DeterministicVerifier.kt b/node/src/main/kotlin/net/corda/node/internal/djvm/DeterministicVerifier.kt index 654b218aee..3263868aa8 100644 --- a/node/src/main/kotlin/net/corda/node/internal/djvm/DeterministicVerifier.kt +++ b/node/src/main/kotlin/net/corda/node/internal/djvm/DeterministicVerifier.kt @@ -7,13 +7,14 @@ import net.corda.core.contracts.ComponentGroupEnum.SIGNERS_GROUP import net.corda.core.contracts.TransactionState import net.corda.core.contracts.TransactionVerificationException import net.corda.core.crypto.SecureHash -import net.corda.core.internal.ContractVerifier +import net.corda.core.internal.TransactionVerifier import net.corda.core.internal.Verifier import net.corda.core.internal.getNamesOfClassesImplementing import net.corda.core.serialization.SerializationCustomSerializer import net.corda.core.serialization.SerializationWhitelist import net.corda.core.serialization.serialize import net.corda.core.transactions.LedgerTransaction +import net.corda.core.utilities.contextLogger import net.corda.djvm.SandboxConfiguration import net.corda.djvm.execution.ExecutionSummary import net.corda.djvm.execution.IsolatedTask @@ -21,15 +22,19 @@ import net.corda.djvm.execution.SandboxException import net.corda.djvm.messages.Message import net.corda.djvm.rewiring.SandboxClassLoader import net.corda.djvm.source.ClassSource -import net.corda.node.djvm.LtxFactory +import net.corda.node.djvm.LtxSupplierFactory import java.util.function.Function import kotlin.collections.LinkedHashSet class DeterministicVerifier( - ltx: LedgerTransaction, - transactionClassLoader: ClassLoader, + private val ltx: LedgerTransaction, + private val transactionClassLoader: ClassLoader, private val sandboxConfiguration: SandboxConfiguration -) : Verifier(ltx, transactionClassLoader) { +) : Verifier { + private companion object { + private val logger = contextLogger() + } + /** * Read the whitelisted classes without using the [java.util.ServiceLoader] mechanism * because the whitelists themselves are untrusted. @@ -47,7 +52,7 @@ class DeterministicVerifier( } } - override fun verifyContracts() { + override fun verify() { val customSerializerNames = getNamesOfClassesImplementing(transactionClassLoader, SerializationCustomSerializer::class.java) val serializationWhitelistNames = getSerializationWhitelistNames(transactionClassLoader) val result = IsolatedTask(ltx.id.toString(), sandboxConfiguration).run(Function { classLoader -> @@ -93,14 +98,14 @@ class DeterministicVerifier( val networkingParametersData = ltx.networkParameters?.serialize() val digestServiceData = ltx.digestService.serialize() - val createSandboxTx = taskFactory.apply(LtxFactory::class.java) + val createSandboxTx = taskFactory.apply(LtxSupplierFactory::class.java) createSandboxTx.apply(arrayOf( - serializer.deserialize(serializedInputs), + classLoader.createForImport(Function { serializer.deserialize(serializedInputs) }), componentFactory.toSandbox(OUTPUTS_GROUP, TransactionState::class.java), CommandFactory(taskFactory).toSandbox( componentFactory.toSandbox(SIGNERS_GROUP, List::class.java), componentFactory.toSandbox(COMMANDS_GROUP, CommandData::class.java), - componentFactory.calculateLeafIndicesFor(COMMANDS_GROUP, digestService = ltx.digestService) + componentFactory.calculateLeafIndicesFor(COMMANDS_GROUP, ltx.digestService) ), attachmentFactory.toSandbox(ltx.attachments), serializer.deserialize(idData), @@ -108,12 +113,12 @@ class DeterministicVerifier( serializer.deserialize(timeWindowData), serializer.deserialize(privacySaltData), serializer.deserialize(networkingParametersData), - serializer.deserialize(serializedReferences), + classLoader.createForImport(Function { serializer.deserialize(serializedReferences) }), serializer.deserialize(digestServiceData) )) } - val verifier = taskFactory.apply(ContractVerifier::class.java) + val verifier = taskFactory.apply(TransactionVerifier::class.java) // Now execute the contract verifier task within the sandbox... verifier.apply(sandboxTx) @@ -128,7 +133,7 @@ class DeterministicVerifier( val sandboxEx = SandboxException( Message.getMessageFromException(this), result.identifier, - ClassSource.fromClassName(ContractVerifier::class.java.name), + ClassSource.fromClassName(TransactionVerifier::class.java.name), ExecutionSummary(result.costs), this ) diff --git a/node/src/main/kotlin/net/corda/node/internal/djvm/Serializer.kt b/node/src/main/kotlin/net/corda/node/internal/djvm/Serializer.kt index 40a5522a28..32fedf52fa 100644 --- a/node/src/main/kotlin/net/corda/node/internal/djvm/Serializer.kt +++ b/node/src/main/kotlin/net/corda/node/internal/djvm/Serializer.kt @@ -1,6 +1,9 @@ package net.corda.node.internal.djvm import net.corda.core.internal.SerializedStateAndRef +import net.corda.core.serialization.AMQP_ENVELOPE_CACHE_INITIAL_CAPACITY +import net.corda.core.serialization.AMQP_ENVELOPE_CACHE_PROPERTY +import net.corda.core.serialization.DESERIALIZATION_CACHE_PROPERTY import net.corda.core.serialization.SerializationContext import net.corda.core.serialization.SerializationFactory import net.corda.core.serialization.SerializedBytes @@ -22,14 +25,20 @@ class Serializer( init { val env = createSandboxSerializationEnv(classLoader, customSerializerNames, serializationWhitelists) factory = env.serializationFactory - context = env.p2pContext + context = env.p2pContext.withProperties(mapOf( + // Duplicate the P2P SerializationContext and give it + // these extra properties, just for this transaction. + AMQP_ENVELOPE_CACHE_PROPERTY to HashMap(AMQP_ENVELOPE_CACHE_INITIAL_CAPACITY), + DESERIALIZATION_CACHE_PROPERTY to HashMap() + )) } /** * Convert a list of [SerializedStateAndRef] objects into arrays * of deserialized sandbox objects. We will pass this array into - * [net.corda.node.djvm.LtxFactory] to be transformed finally to - * a list of [net.corda.core.contracts.StateAndRef] objects, + * [LtxSupplierFactory][net.corda.node.djvm.LtxSupplierFactory] + * to be transformed finally to a list of + * [StateAndRef][net.corda.core.contracts.StateAndRef] objects, */ fun deserialize(stateRefs: List): Array> { return stateRefs.map { diff --git a/node/src/main/kotlin/net/corda/node/migration/VaultStateMigration.kt b/node/src/main/kotlin/net/corda/node/migration/VaultStateMigration.kt index f33418c28e..b765685910 100644 --- a/node/src/main/kotlin/net/corda/node/migration/VaultStateMigration.kt +++ b/node/src/main/kotlin/net/corda/node/migration/VaultStateMigration.kt @@ -2,7 +2,6 @@ package net.corda.node.migration import liquibase.database.Database import net.corda.core.contracts.* -import net.corda.core.crypto.SecureHash import net.corda.core.identity.CordaX500Name import net.corda.core.node.services.Vault import net.corda.core.schemas.MappedSchema @@ -18,6 +17,7 @@ import net.corda.node.services.persistence.DBTransactionStorage import net.corda.node.services.persistence.NodeAttachmentService import net.corda.node.services.vault.NodeVaultService import net.corda.node.services.vault.VaultSchemaV1 +import net.corda.node.services.vault.toStateRef import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.nodeapi.internal.persistence.DatabaseTransaction import net.corda.nodeapi.internal.persistence.SchemaMigration @@ -61,8 +61,7 @@ class VaultStateMigration : CordaMigration() { private fun getStateAndRef(persistentState: VaultSchemaV1.VaultStates): StateAndRef { val persistentStateRef = persistentState.stateRef ?: throw VaultStateMigrationException("Persistent state ref missing from state") - val txHash = SecureHash.create(persistentStateRef.txId) - val stateRef = StateRef(txHash, persistentStateRef.index) + val stateRef = persistentStateRef.toStateRef() val state = try { servicesForResolution.loadState(stateRef) } catch (e: Exception) { diff --git a/node/src/main/kotlin/net/corda/node/services/config/NodeConfiguration.kt b/node/src/main/kotlin/net/corda/node/services/config/NodeConfiguration.kt index 1111e9faf9..bdda67f19a 100644 --- a/node/src/main/kotlin/net/corda/node/services/config/NodeConfiguration.kt +++ b/node/src/main/kotlin/net/corda/node/services/config/NodeConfiguration.kt @@ -60,7 +60,6 @@ interface NodeConfiguration : ConfigurationWithOptionsContainer { val noLocalShell: Boolean get() = false val transactionCacheSizeBytes: Long get() = defaultTransactionCacheSize val attachmentContentCacheSizeBytes: Long get() = defaultAttachmentContentCacheSize - val attachmentCacheBound: Long get() = defaultAttachmentCacheBound // do not change this value without syncing it with ScheduledFlowsDrainingModeTest val drainingModePollPeriod: Duration get() = Duration.ofSeconds(5) val extraNetworkMapKeys: List @@ -110,7 +109,6 @@ interface NodeConfiguration : ConfigurationWithOptionsContainer { } internal val defaultAttachmentContentCacheSize: Long = 10.MB - internal const val defaultAttachmentCacheBound = 1024L const val cordappDirectoriesKey = "cordappDirectories" @@ -168,7 +166,8 @@ data class NotaryConfig( /** Notary implementation-specific configuration parameters. */ val extraConfig: Config? = null, val raft: RaftConfig? = null, - val bftSMaRt: BFTSmartConfig? = null + val bftSMaRt: BFTSmartConfig? = null, + val enableOverridableFlows: Boolean? = null ) /** diff --git a/node/src/main/kotlin/net/corda/node/services/config/NodeConfigurationImpl.kt b/node/src/main/kotlin/net/corda/node/services/config/NodeConfigurationImpl.kt index 44b77a8264..49390958bf 100644 --- a/node/src/main/kotlin/net/corda/node/services/config/NodeConfigurationImpl.kt +++ b/node/src/main/kotlin/net/corda/node/services/config/NodeConfigurationImpl.kt @@ -65,7 +65,6 @@ data class NodeConfigurationImpl( override val database: DatabaseConfig = Defaults.database(devMode), private val transactionCacheSizeMegaBytes: Int? = Defaults.transactionCacheSizeMegaBytes, private val attachmentContentCacheSizeMegaBytes: Int? = Defaults.attachmentContentCacheSizeMegaBytes, - override val attachmentCacheBound: Long = Defaults.attachmentCacheBound, override val extraNetworkMapKeys: List = Defaults.extraNetworkMapKeys, // do not use or remove (breaks DemoBench together with rejection of unknown configuration keys during parsing) private val h2port: Int? = Defaults.h2port, @@ -112,7 +111,6 @@ data class NodeConfigurationImpl( const val localShellUnsafe: Boolean = false val transactionCacheSizeMegaBytes: Int? = null val attachmentContentCacheSizeMegaBytes: Int? = null - const val attachmentCacheBound: Long = NodeConfiguration.defaultAttachmentCacheBound val extraNetworkMapKeys: List = emptyList() val h2port: Int? = null val h2Settings: NodeH2Settings? = null diff --git a/node/src/main/kotlin/net/corda/node/services/config/schema/v1/ConfigSections.kt b/node/src/main/kotlin/net/corda/node/services/config/schema/v1/ConfigSections.kt index 6b33dcf882..eb06334901 100644 --- a/node/src/main/kotlin/net/corda/node/services/config/schema/v1/ConfigSections.kt +++ b/node/src/main/kotlin/net/corda/node/services/config/schema/v1/ConfigSections.kt @@ -205,13 +205,25 @@ internal object NotaryConfigSpec : Configuration.Specification("No private val serviceLegalName by string().mapValid(::toCordaX500Name).optional() private val className by string().optional() private val etaMessageThresholdSeconds by int().optional().withDefaultValue(NotaryServiceFlow.defaultEstimatedWaitTime.seconds.toInt()) - private val extraConfig by nestedObject().map(ConfigObject::toConfig).optional() + private val extraConfig by nestedObject(sensitive = true).map(ConfigObject::toConfig).optional() private val raft by nested(RaftConfigSpec).optional() private val bftSMaRt by nested(BFTSmartConfigSpec).optional() + private val enableOverridableFlows by boolean().optional() override fun parseValid(configuration: Config, options: Configuration.Options): Valid { val config = configuration.withOptions(options) - return valid(NotaryConfig(config[validating], config[serviceLegalName], config[className], config[etaMessageThresholdSeconds], config[extraConfig], config[raft], config[bftSMaRt])) + return valid( + NotaryConfig( + config[validating], + config[serviceLegalName], + config[className], + config[etaMessageThresholdSeconds], + config[extraConfig], + config[raft], + config[bftSMaRt], + config[enableOverridableFlows] + ) + ) } } diff --git a/node/src/main/kotlin/net/corda/node/services/config/schema/v1/V1NodeConfigurationSpec.kt b/node/src/main/kotlin/net/corda/node/services/config/schema/v1/V1NodeConfigurationSpec.kt index ab1f36f417..3808597620 100644 --- a/node/src/main/kotlin/net/corda/node/services/config/schema/v1/V1NodeConfigurationSpec.kt +++ b/node/src/main/kotlin/net/corda/node/services/config/schema/v1/V1NodeConfigurationSpec.kt @@ -40,7 +40,6 @@ internal object V1NodeConfigurationSpec : Configuration.Specification() + private val startingStateRefs: MutableSet = ConcurrentHashMap.newKeySet() private val mutex = ThreadBox(InnerState()) - private val schedulerTimerExecutor = Executors.newSingleThreadExecutor() + private val schedulerTimerExecutor = Executors.newSingleThreadExecutor(DefaultThreadFactory("SchedulerService")) // if there's nothing to do, check every minute if something fell through the cracks. // any new state should trigger a reschedule immediately if nothing is scheduled, so I would not expect diff --git a/node/src/main/kotlin/net/corda/node/services/events/PersistentScheduledFlowRepository.kt b/node/src/main/kotlin/net/corda/node/services/events/PersistentScheduledFlowRepository.kt index 2208eef88f..f62db2eee4 100644 --- a/node/src/main/kotlin/net/corda/node/services/events/PersistentScheduledFlowRepository.kt +++ b/node/src/main/kotlin/net/corda/node/services/events/PersistentScheduledFlowRepository.kt @@ -2,8 +2,8 @@ package net.corda.node.services.events import net.corda.core.contracts.ScheduledStateRef import net.corda.core.contracts.StateRef -import net.corda.core.crypto.SecureHash import net.corda.core.schemas.PersistentStateRef +import net.corda.node.services.vault.toStateRef import net.corda.nodeapi.internal.persistence.CordaPersistence interface ScheduledFlowRepository { @@ -25,9 +25,8 @@ class PersistentScheduledFlowRepository(val database: CordaPersistence) : Schedu } private fun fromPersistentEntity(scheduledStateRecord: NodeSchedulerService.PersistentScheduledState): Pair { - val txId = scheduledStateRecord.output.txId - val index = scheduledStateRecord.output.index - return Pair(StateRef(SecureHash.create(txId), index), ScheduledStateRef(StateRef(SecureHash.create(txId), index), scheduledStateRecord.scheduledAt)) + val stateRef = scheduledStateRecord.output.toStateRef() + return Pair(stateRef, ScheduledStateRef(stateRef, scheduledStateRecord.scheduledAt)) } override fun delete(key: StateRef): Boolean { diff --git a/node/src/main/kotlin/net/corda/node/services/messaging/ArtemisMessagingServer.kt b/node/src/main/kotlin/net/corda/node/services/messaging/ArtemisMessagingServer.kt index 277d51742b..ed7a5b4cb6 100644 --- a/node/src/main/kotlin/net/corda/node/services/messaging/ArtemisMessagingServer.kt +++ b/node/src/main/kotlin/net/corda/node/services/messaging/ArtemisMessagingServer.kt @@ -7,9 +7,16 @@ import net.corda.core.serialization.SingletonSerializeAsToken import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.contextLogger import net.corda.core.utilities.debug -import net.corda.node.internal.artemis.* +import net.corda.node.internal.artemis.ArtemisBroker +import net.corda.node.internal.artemis.BrokerAddresses +import net.corda.node.internal.artemis.BrokerJaasLoginModule import net.corda.node.internal.artemis.BrokerJaasLoginModule.Companion.NODE_P2P_ROLE import net.corda.node.internal.artemis.BrokerJaasLoginModule.Companion.PEER_ROLE +import net.corda.node.internal.artemis.NodeJaasConfig +import net.corda.node.internal.artemis.P2PJaasConfig +import net.corda.node.internal.artemis.SecureArtemisConfiguration +import net.corda.node.internal.artemis.UserValidationPlugin +import net.corda.node.internal.artemis.isBindingError import net.corda.node.services.config.NodeConfiguration import net.corda.nodeapi.internal.AmqpMessageSizeChecksInterceptor import net.corda.nodeapi.internal.ArtemisMessageSizeChecksInterceptor @@ -20,7 +27,10 @@ import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.P2P_PREFIX import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.SECURITY_INVALIDATION_INTERVAL import net.corda.nodeapi.internal.ArtemisTcpTransport.Companion.p2pAcceptorTcpTransport import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfig +import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfigImpl +import net.corda.nodeapi.internal.protonwrapper.netty.trustManagerFactoryWithRevocation import net.corda.nodeapi.internal.requireOnDefaultFileSystem +import net.corda.nodeapi.internal.revocation.CertDistPointCrlSource import org.apache.activemq.artemis.api.config.ActiveMQDefaultConfiguration import org.apache.activemq.artemis.api.core.SimpleString import org.apache.activemq.artemis.api.core.management.ActiveMQServerControl @@ -33,7 +43,6 @@ import org.apache.activemq.artemis.core.server.impl.ActiveMQServerImpl import org.apache.activemq.artemis.spi.core.security.ActiveMQJAASSecurityManager import java.io.IOException import java.lang.Long.max -import java.security.KeyStoreException import javax.annotation.concurrent.ThreadSafe import javax.security.auth.login.AppConfigurationEntry import javax.security.auth.login.AppConfigurationEntry.LoginModuleControlFlag.REQUIRED @@ -55,7 +64,11 @@ import javax.security.auth.login.AppConfigurationEntry.LoginModuleControlFlag.RE class ArtemisMessagingServer(private val config: NodeConfiguration, private val messagingServerAddress: NetworkHostAndPort, private val maxMessageSize: Int, - private val journalBufferTimeout : Int?) : ArtemisBroker, SingletonSerializeAsToken() { + private val journalBufferTimeout : Int? = null, + private val threadPoolName: String = "P2PServer", + private val trace: Boolean = false, + private val distPointCrlSource: CertDistPointCrlSource = CertDistPointCrlSource.SINGLETON, + private val remotingThreads: Int? = null) : ArtemisBroker, SingletonSerializeAsToken() { companion object { private val log = contextLogger() } @@ -89,9 +102,7 @@ class ArtemisMessagingServer(private val config: NodeConfiguration, override val started: Boolean get() = activeMQServer.isStarted - // TODO: Maybe wrap [IOException] on a key store load error so that it's clearly splitting key store loading from - // Artemis IO errors - @Throws(IOException::class, AddressBindingException::class, KeyStoreException::class) + @Suppress("ThrowsCount") private fun configureAndStartServer() { val artemisConfig = createArtemisConfig() val securityManager = createArtemisSecurityManager() @@ -131,7 +142,24 @@ class ArtemisMessagingServer(private val config: NodeConfiguration, // The transaction cache is configurable, and drives other cache sizes. globalMaxSize = max(config.transactionCacheSizeBytes, 10L * maxMessageSize) - acceptorConfigurations = mutableSetOf(p2pAcceptorTcpTransport(NetworkHostAndPort(messagingServerAddress.host, messagingServerAddress.port), config.p2pSslOptions)) + val revocationMode = if (config.crlCheckArtemisServer) { + if (config.crlCheckSoftFail) RevocationConfig.Mode.SOFT_FAIL else RevocationConfig.Mode.HARD_FAIL + } else { + RevocationConfig.Mode.OFF + } + val trustManagerFactory = trustManagerFactoryWithRevocation( + config.p2pSslOptions.trustStore.get(), + RevocationConfigImpl(revocationMode), + distPointCrlSource + ) + addAcceptorConfiguration(p2pAcceptorTcpTransport( + NetworkHostAndPort(messagingServerAddress.host, messagingServerAddress.port), + config.p2pSslOptions, + trustManagerFactory, + threadPoolName = threadPoolName, + trace = trace, + remotingThreads = remotingThreads + )) // Enable built in message deduplication. Note we still have to do our own as the delayed commits // and our own definition of commit mean that the built in deduplication cannot remove all duplicates. idCacheSize = 2000 // Artemis Default duplicate cache size i.e. a guess @@ -176,7 +204,6 @@ class ArtemisMessagingServer(private val config: NodeConfiguration, deleteNonDurableQueue, manage, browse, createDurableQueue || createNonDurableQueue, deleteDurableQueue || deleteNonDurableQueue) } - @Throws(IOException::class, KeyStoreException::class) private fun createArtemisSecurityManager(): ActiveMQJAASSecurityManager { val keyStore = config.p2pSslOptions.keyStore.get().value.internal val trustStore = config.p2pSslOptions.trustStore.get().value.internal diff --git a/node/src/main/kotlin/net/corda/node/services/messaging/NodeNettyAcceptorFactory.kt b/node/src/main/kotlin/net/corda/node/services/messaging/NodeNettyAcceptorFactory.kt new file mode 100644 index 0000000000..1709f896dd --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/messaging/NodeNettyAcceptorFactory.kt @@ -0,0 +1,297 @@ +package net.corda.node.services.messaging + +import io.netty.buffer.ByteBufAllocator +import io.netty.channel.ChannelHandlerContext +import io.netty.channel.group.ChannelGroup +import io.netty.handler.logging.LogLevel +import io.netty.handler.logging.LoggingHandler +import io.netty.handler.ssl.SslContext +import io.netty.handler.ssl.SslContextBuilder +import io.netty.handler.ssl.SslHandler +import io.netty.handler.ssl.SslHandshakeTimeoutException +import io.netty.handler.ssl.SslProvider +import net.corda.core.internal.declaredField +import net.corda.core.utilities.contextLogger +import net.corda.core.utilities.trace +import net.corda.nodeapi.internal.ArtemisTcpTransport +import net.corda.nodeapi.internal.config.CertificateStore +import net.corda.nodeapi.internal.protonwrapper.netty.createAndInitSslContext +import net.corda.nodeapi.internal.protonwrapper.netty.keyManagerFactory +import net.corda.nodeapi.internal.protonwrapper.netty.sslDelegatedTaskExecutor +import net.corda.nodeapi.internal.setThreadPoolName +import org.apache.activemq.artemis.api.config.ActiveMQDefaultConfiguration +import org.apache.activemq.artemis.api.core.BaseInterceptor +import org.apache.activemq.artemis.core.remoting.impl.netty.NettyAcceptor +import org.apache.activemq.artemis.core.remoting.impl.netty.TransportConstants +import org.apache.activemq.artemis.core.remoting.impl.ssl.SSLSupport +import org.apache.activemq.artemis.core.server.ActiveMQServerLogger +import org.apache.activemq.artemis.core.server.cluster.ClusterConnection +import org.apache.activemq.artemis.spi.core.protocol.ProtocolManager +import org.apache.activemq.artemis.spi.core.remoting.Acceptor +import org.apache.activemq.artemis.spi.core.remoting.AcceptorFactory +import org.apache.activemq.artemis.spi.core.remoting.BufferHandler +import org.apache.activemq.artemis.spi.core.remoting.ServerConnectionLifeCycleListener +import org.apache.activemq.artemis.utils.ConfigurationHelper +import org.apache.activemq.artemis.utils.actors.OrderedExecutor +import java.net.SocketAddress +import java.nio.channels.ClosedChannelException +import java.nio.file.Paths +import java.security.PrivilegedExceptionAction +import java.time.Duration +import java.util.concurrent.Executor +import java.util.concurrent.ScheduledExecutorService +import java.util.regex.Pattern +import javax.net.ssl.KeyManagerFactory +import javax.net.ssl.SSLContext +import javax.net.ssl.SSLEngine +import javax.net.ssl.SSLPeerUnverifiedException +import javax.net.ssl.TrustManagerFactory +import javax.security.auth.Subject + +@Suppress("unused", "TooGenericExceptionCaught", "ComplexMethod", "MagicNumber", "TooManyFunctions") +class NodeNettyAcceptorFactory : AcceptorFactory { + override fun createAcceptor(name: String?, + clusterConnection: ClusterConnection?, + configuration: Map, + handler: BufferHandler?, + listener: ServerConnectionLifeCycleListener?, + threadPool: Executor, + scheduledThreadPool: ScheduledExecutorService, + protocolMap: Map>>?): Acceptor { + val threadPoolName = ConfigurationHelper.getStringProperty(ArtemisTcpTransport.THREAD_POOL_NAME_NAME, "Acceptor", configuration) + threadPool.setThreadPoolName("$threadPoolName-artemis") + scheduledThreadPool.setThreadPoolName("$threadPoolName-artemis-scheduler") + val failureExecutor = OrderedExecutor(threadPool) + return NodeNettyAcceptor( + name, + clusterConnection, + configuration, + handler, + listener, + scheduledThreadPool, + failureExecutor, + protocolMap, + "$threadPoolName-netty" + ) + } + + + private class NodeNettyAcceptor(name: String?, + clusterConnection: ClusterConnection?, + configuration: Map, + handler: BufferHandler?, + listener: ServerConnectionLifeCycleListener?, + scheduledThreadPool: ScheduledExecutorService?, + failureExecutor: Executor, + protocolMap: Map>>?, + private val threadPoolName: String) : + NettyAcceptor(name, clusterConnection, configuration, handler, listener, scheduledThreadPool, failureExecutor, protocolMap) + { + companion object { + private val defaultThreadPoolNamePattern = Pattern.compile("""Thread-(\d+) \(activemq-netty-threads\)""") + } + + private val sslDelegatedTaskExecutor = sslDelegatedTaskExecutor(threadPoolName) + private val trace = ConfigurationHelper.getBooleanProperty(ArtemisTcpTransport.TRACE_NAME, false, configuration) + + @Synchronized + override fun start() { + super.start() + if (trace) { + // Unfortunately we have to resort to reflection to be able to get access to the server channel(s) + declaredField("serverChannelGroup").value.forEach { channel -> + channel.pipeline().addLast("logger", LoggingHandler(LogLevel.INFO)) + } + } + } + + @Synchronized + override fun stop() { + super.stop() + sslDelegatedTaskExecutor.shutdown() + } + + @Synchronized + override fun getSslHandler(alloc: ByteBufAllocator?): SslHandler { + applyThreadPoolName() + val engine = getSSLEngine(alloc) + val sslHandler = NodeAcceptorSslHandler(engine, sslDelegatedTaskExecutor, trace) + val handshakeTimeout = configuration[ArtemisTcpTransport.SSL_HANDSHAKE_TIMEOUT_NAME] as Duration? + if (handshakeTimeout != null) { + sslHandler.handshakeTimeoutMillis = handshakeTimeout.toMillis() + } + return sslHandler + } + + /** + * [NettyAcceptor.start] has hardcoded the thread pool name and does not provide a way to configure it. This is a workaround. + */ + private fun applyThreadPoolName() { + val matcher = defaultThreadPoolNamePattern.matcher(Thread.currentThread().name) + if (matcher.matches()) { + Thread.currentThread().name = "$threadPoolName-${matcher.group(1)}" // Preserve the pool thread number + } + } + + /** + * This is a copy of [NettyAcceptor.getSslHandler] so that we can provide different implementations for [loadOpenSslEngine] and + * [loadJdkSslEngine]. [NodeNettyAcceptor], instead of creating a default [TrustManagerFactory], will simply use the provided one in + * the [ArtemisTcpTransport.TRUST_MANAGER_FACTORY_NAME] configuration. + */ + private fun getSSLEngine(alloc: ByteBufAllocator?): SSLEngine { + val engine = if (sslProvider == TransportConstants.OPENSSL_PROVIDER) { + loadOpenSslEngine(alloc) + } else { + loadJdkSslEngine() + } + engine.useClientMode = false + if (needClientAuth) { + engine.needClientAuth = true + } + + // setting the enabled cipher suites resets the enabled protocols so we need + // to save the enabled protocols so that after the customer cipher suite is enabled + // we can reset the enabled protocols if a customer protocol isn't specified + val originalProtocols = engine.enabledProtocols + if (enabledCipherSuites != null) { + try { + engine.enabledCipherSuites = SSLSupport.parseCommaSeparatedListIntoArray(enabledCipherSuites) + } catch (e: IllegalArgumentException) { + ActiveMQServerLogger.LOGGER.invalidCipherSuite(SSLSupport.parseArrayIntoCommandSeparatedList(engine.supportedCipherSuites)) + throw e + } + } + if (enabledProtocols != null) { + try { + engine.enabledProtocols = SSLSupport.parseCommaSeparatedListIntoArray(enabledProtocols) + } catch (e: IllegalArgumentException) { + ActiveMQServerLogger.LOGGER.invalidProtocol(SSLSupport.parseArrayIntoCommandSeparatedList(engine.supportedProtocols)) + throw e + } + } else { + engine.enabledProtocols = originalProtocols + } + return engine + } + + /** + * Copy of [NettyAcceptor.loadOpenSslEngine] which invokes our custom [createOpenSslContext]. + */ + private fun loadOpenSslEngine(alloc: ByteBufAllocator?): SSLEngine { + val context = try { + // We copied all this code just so we could replace the SSLSupport.createNettyContext method call with our own one. + createOpenSslContext() + } catch (e: Exception) { + throw IllegalStateException("Unable to create NodeNettyAcceptor", e) + } + return Subject.doAs(null, PrivilegedExceptionAction { + context.newEngine(alloc) + }) + } + + /** + * Copy of [NettyAcceptor.loadJdkSslEngine] which invokes our custom [createJdkSSLContext]. + */ + private fun loadJdkSslEngine(): SSLEngine { + val context = try { + // We copied all this code just so we could replace the SSLHelper.createContext method call with our own one. + createJdkSSLContext() + } catch (e: Exception) { + throw IllegalStateException("Unable to create NodeNettyAcceptor", e) + } + return Subject.doAs(null, PrivilegedExceptionAction { + context.createSSLEngine() + }) + } + + /** + * Create an [SSLContext] using the [TrustManagerFactory] provided on the [ArtemisTcpTransport.TRUST_MANAGER_FACTORY_NAME] config. + */ + private fun createJdkSSLContext(): SSLContext { + return createAndInitSslContext( + createKeyManagerFactory(), + configuration[ArtemisTcpTransport.TRUST_MANAGER_FACTORY_NAME] as TrustManagerFactory? + ) + } + + /** + * Create an [SslContext] using the the [TrustManagerFactory] provided on the [ArtemisTcpTransport.TRUST_MANAGER_FACTORY_NAME] config. + */ + private fun createOpenSslContext(): SslContext { + return SslContextBuilder + .forServer(createKeyManagerFactory()) + .sslProvider(SslProvider.OPENSSL) + .trustManager(configuration[ArtemisTcpTransport.TRUST_MANAGER_FACTORY_NAME] as TrustManagerFactory?) + .build() + } + + private fun createKeyManagerFactory(): KeyManagerFactory { + return keyManagerFactory(CertificateStore.fromFile(Paths.get(keyStorePath), keyStorePassword, keyStorePassword, false)) + } + + // Replicate the fields which are private in NettyAcceptor + private val sslProvider = ConfigurationHelper.getStringProperty(TransportConstants.SSL_PROVIDER, TransportConstants.DEFAULT_SSL_PROVIDER, configuration) + private val needClientAuth = ConfigurationHelper.getBooleanProperty(TransportConstants.NEED_CLIENT_AUTH_PROP_NAME, TransportConstants.DEFAULT_NEED_CLIENT_AUTH, configuration) + private val enabledCipherSuites = ConfigurationHelper.getStringProperty(TransportConstants.ENABLED_CIPHER_SUITES_PROP_NAME, TransportConstants.DEFAULT_ENABLED_CIPHER_SUITES, configuration) + private val enabledProtocols = ConfigurationHelper.getStringProperty(TransportConstants.ENABLED_PROTOCOLS_PROP_NAME, TransportConstants.DEFAULT_ENABLED_PROTOCOLS, configuration) + private val keyStorePath = ConfigurationHelper.getStringProperty(TransportConstants.KEYSTORE_PATH_PROP_NAME, TransportConstants.DEFAULT_KEYSTORE_PATH, configuration) + private val keyStoreProvider = ConfigurationHelper.getStringProperty(TransportConstants.KEYSTORE_PROVIDER_PROP_NAME, TransportConstants.DEFAULT_KEYSTORE_PROVIDER, configuration) + private val keyStorePassword = ConfigurationHelper.getPasswordProperty(TransportConstants.KEYSTORE_PASSWORD_PROP_NAME, TransportConstants.DEFAULT_KEYSTORE_PASSWORD, configuration, ActiveMQDefaultConfiguration.getPropMaskPassword(), ActiveMQDefaultConfiguration.getPropPasswordCodec()) + } + + + private class NodeAcceptorSslHandler(engine: SSLEngine, + delegatedTaskExecutor: Executor, + private val trace: Boolean) : SslHandler(engine, delegatedTaskExecutor) { + companion object { + private val nettyLogHandshake = System.getProperty("net.corda.node.services.messaging.nettyLogHandshake")?.toBoolean() ?: false + private val logger = contextLogger() + } + + override fun handlerAdded(ctx: ChannelHandlerContext) { + logHandshake(ctx.channel().remoteAddress()) + super.handlerAdded(ctx) + // Unfortunately NettyAcceptor does not let us add extra child handlers, so we have to add our logger this way. + if (trace) { + ctx.pipeline().addLast("logger", LoggingHandler(LogLevel.INFO)) + } + } + + private fun logHandshake(remoteAddress: SocketAddress) { + val start = System.currentTimeMillis() + handshakeFuture().addListener { + val duration = System.currentTimeMillis() - start + val peer = try { + engine().session.peerPrincipal + } catch (e: SSLPeerUnverifiedException) { + remoteAddress + } + when { + it.isSuccess -> loggerInfo { "SSL handshake completed in ${duration}ms with $peer" } + it.isCancelled -> loggerWarn { "SSL handshake cancelled after ${duration}ms with $peer" } + else -> when (it.cause()) { + is ClosedChannelException -> loggerWarn { "SSL handshake closed early after ${duration}ms with $peer" } + is SslHandshakeTimeoutException -> loggerWarn { "SSL handshake timed out after ${duration}ms with $peer" } + else -> loggerWarn(it.cause()) {"SSL handshake failed after ${duration}ms with $peer" } + } + } + } + } + private fun loggerInfo(msgFn: () -> String) { + if (nettyLogHandshake && logger.isInfoEnabled) { + logger.info(msgFn()) + } + else { + logger.trace { msgFn() } + } + } + private fun loggerWarn(t: Throwable? = null, msgFn: () -> String) { + if (nettyLogHandshake && logger.isWarnEnabled) { + logger.warn(msgFn(), t) + } + else if (logger.isTraceEnabled) { + logger.trace(msgFn(), t) + } + } + } +} diff --git a/node/src/main/kotlin/net/corda/node/services/messaging/P2PMessagingClient.kt b/node/src/main/kotlin/net/corda/node/services/messaging/P2PMessagingClient.kt index 10a22f5aed..b9caae56aa 100644 --- a/node/src/main/kotlin/net/corda/node/services/messaging/P2PMessagingClient.kt +++ b/node/src/main/kotlin/net/corda/node/services/messaging/P2PMessagingClient.kt @@ -93,6 +93,7 @@ class P2PMessagingClient(val config: NodeConfiguration, cacheFactory: NamedCacheFactory, private val isDrainingModeOn: () -> Boolean, private val drainingModeWasChangedEvents: Observable>, + private val threadPoolName: String = "P2PClient", private val stateHelper: ServiceStateHelper = ServiceStateHelper(log) ) : SingletonSerializeAsToken(), MessagingService, AddressToArtemisQueueResolver, ServiceStateSupport by stateHelper { companion object { @@ -164,10 +165,8 @@ class P2PMessagingClient(val config: NodeConfiguration, started = true log.info("Connecting to message broker: $serverAddress") // TODO Add broker CN to config for host verification in case the embedded broker isn't used - val tcpTransport = p2pConnectorTcpTransport(serverAddress, config.p2pSslOptions) + val tcpTransport = p2pConnectorTcpTransport(serverAddress, config.p2pSslOptions, threadPoolName = threadPoolName) locator = ActiveMQClient.createServerLocatorWithoutHA(tcpTransport).apply { - // Never time out on our loopback Artemis connections. If we switch back to using the InVM transport this - // would be the default and the two lines below can be deleted. connectionTTL = 60000 clientFailureCheckPeriod = 30000 minLargeMessageSize = maxMessageSize + JOURNAL_HEADER_SIZE diff --git a/node/src/main/kotlin/net/corda/node/services/network/NetworkMapUpdater.kt b/node/src/main/kotlin/net/corda/node/services/network/NetworkMapUpdater.kt index fac12a9343..584b050425 100644 --- a/node/src/main/kotlin/net/corda/node/services/network/NetworkMapUpdater.kt +++ b/node/src/main/kotlin/net/corda/node/services/network/NetworkMapUpdater.kt @@ -74,7 +74,7 @@ class NetworkMapUpdater(private val networkMapCache: NetworkMapCacheInternal, } private val parametersUpdatesTrack = PublishSubject.create() - private val networkMapPoller = ScheduledThreadPoolExecutor(1, NamedThreadFactory("Network Map Updater Thread")).apply { + private val networkMapPoller = ScheduledThreadPoolExecutor(1, NamedThreadFactory("NetworkMapUpdater")).apply { executeExistingDelayedTasksAfterShutdownPolicy = false } private var newNetworkParameters: Pair? = null @@ -261,9 +261,12 @@ class NetworkMapUpdater(private val networkMapCache: NetworkMapCacheInternal, //as HTTP GET is mostly IO bound, use more threads than CPU's //maximum threads to use = 24, as if we did not limit this on large machines it could result in 100's of concurrent requests val threadsToUseForNetworkMapDownload = min(Runtime.getRuntime().availableProcessors() * 4, 24) - val executorToUseForDownloadingNodeInfos = Executors.newFixedThreadPool(threadsToUseForNetworkMapDownload, NamedThreadFactory("NetworkMapUpdaterNodeInfoDownloadThread")) + val executorToUseForDownloadingNodeInfos = Executors.newFixedThreadPool( + threadsToUseForNetworkMapDownload, + NamedThreadFactory("NetworkMapUpdaterNodeInfoDownload") + ) //DB insert is single threaded - use a single threaded executor for it. - val executorToUseForInsertionIntoDB = Executors.newSingleThreadExecutor(NamedThreadFactory("NetworkMapUpdateDBInsertThread")) + val executorToUseForInsertionIntoDB = Executors.newSingleThreadExecutor(NamedThreadFactory("NetworkMapUpdateDBInsert")) val hashesToFetch = (allHashesFromNetworkMap - allNodeHashes) val networkMapDownloadStartTime = System.currentTimeMillis() if (hashesToFetch.isNotEmpty()) { diff --git a/node/src/main/kotlin/net/corda/node/services/persistence/NodeAttachmentService.kt b/node/src/main/kotlin/net/corda/node/services/persistence/NodeAttachmentService.kt index dbf36b9bae..e6aee94e7f 100644 --- a/node/src/main/kotlin/net/corda/node/services/persistence/NodeAttachmentService.kt +++ b/node/src/main/kotlin/net/corda/node/services/persistence/NodeAttachmentService.kt @@ -16,6 +16,7 @@ import net.corda.core.internal.* import net.corda.core.internal.Version import net.corda.core.internal.cordapp.CordappImpl.Companion.CORDAPP_CONTRACT_VERSION import net.corda.core.internal.cordapp.CordappImpl.Companion.DEFAULT_CORDAPP_VERSION +import net.corda.core.internal.utilities.ZipBombDetector import net.corda.core.node.ServicesForResolution import net.corda.core.node.services.AttachmentId import net.corda.core.node.services.vault.AttachmentQueryCriteria @@ -26,7 +27,6 @@ import net.corda.core.serialization.* import net.corda.core.utilities.contextLogger import net.corda.node.services.vault.HibernateAttachmentQueryCriteriaParser import net.corda.node.utilities.InfrequentlyMutatedCache -import net.corda.node.utilities.NonInvalidatingCache import net.corda.node.utilities.NonInvalidatingWeightBasedCache import net.corda.nodeapi.exceptions.DuplicateAttachmentException import net.corda.nodeapi.internal.persistence.CordaPersistence @@ -34,6 +34,7 @@ import net.corda.nodeapi.internal.persistence.NODE_DATABASE_PREFIX import net.corda.nodeapi.internal.persistence.currentDBSession import net.corda.nodeapi.internal.withContractsInJar import org.hibernate.query.Query +import java.io.ByteArrayInputStream import java.io.FilterInputStream import java.io.IOException import java.io.InputStream @@ -259,18 +260,6 @@ class NodeAttachmentService @JvmOverloads constructor( Token(id, checkOnLoad, uploader, signerKeys) } - // slightly complex 2 level approach to attachment caching: - // On the first level we cache attachment contents loaded from the DB by their key. This is a weight based - // cache (we don't want to waste too much memory on this) and could be evicted quite aggressively. If we fail - // to load an attachment from the db, the loader will insert a non present optional - we invalidate this - // immediately as we definitely want to retry whether the attachment was just delayed. - // On the second level, we cache Attachment implementations that use the first cache to load their content - // when required. As these are fairly small, we can cache quite a lot of them, this will make checking - // repeatedly whether an attachment exists fairly cheap. Here as well, we evict non-existent entries immediately - // to force a recheck if required. - // If repeatedly looking for non-existing attachments becomes a performance issue, this is either indicating a - // a problem somewhere else or this needs to be revisited. - private val attachmentContentCache = NonInvalidatingWeightBasedCache( cacheFactory = cacheFactory, name = "NodeAttachmentService_attachmentContent", @@ -309,27 +298,13 @@ class NodeAttachmentService @JvmOverloads constructor( } } - private val attachmentCache = NonInvalidatingCache>( - cacheFactory = cacheFactory, - name = "NodeAttachmentService_attachmentPresence", - loadFunction = { key -> Optional.ofNullable(createAttachment(key)) }) - - private fun createAttachment(key: SecureHash): Attachment? { - val content = attachmentContentCache.get(key)!! + override fun openAttachment(id: SecureHash): Attachment? { + val content = attachmentContentCache.get(id)!! if (content.isPresent) { return content.get().first } // If no attachment has been found, we don't want to cache that - it might arrive later. - attachmentContentCache.invalidate(key) - return null - } - - override fun openAttachment(id: SecureHash): Attachment? { - val attachment = attachmentCache.get(id)!! - if (attachment.isPresent) { - return attachment.get() - } - attachmentCache.invalidate(id) + attachmentContentCache.invalidate(id) return null } @@ -394,6 +369,9 @@ class NodeAttachmentService @JvmOverloads constructor( // set the hash field of the new attachment record. val bytes = inputStream.readFully() + require(!ZipBombDetector.scanZip(ByteArrayInputStream(bytes), servicesForResolution.networkParameters.maxTransactionSize.toLong())) { + "The attachment is too large and exceeds both max transaction size and the maximum allowed compression ratio" + } val id = bytes.sha256() if (!hasAttachment(id)) { checkIsAValidJAR(bytes.inputStream()) @@ -426,7 +404,6 @@ class NodeAttachmentService @JvmOverloads constructor( loadAttachmentContent(id)?.let { attachmentAndContent -> // TODO: this is racey. ENT-2870 attachmentContentCache.put(id, Optional.of(attachmentAndContent)) - attachmentCache.put(id, Optional.of(attachmentAndContent.first)) } return@withContractsInJar id } diff --git a/node/src/main/kotlin/net/corda/node/services/rpc/CheckpointDumperImpl.kt b/node/src/main/kotlin/net/corda/node/services/rpc/CheckpointDumperImpl.kt index 8cfad7e660..a9aceab60b 100644 --- a/node/src/main/kotlin/net/corda/node/services/rpc/CheckpointDumperImpl.kt +++ b/node/src/main/kotlin/net/corda/node/services/rpc/CheckpointDumperImpl.kt @@ -221,7 +221,10 @@ class CheckpointDumperImpl(private val checkpointStorage: CheckpointStorage, pri instrumentCheckpointAgent(runId) val (bytes, fileName) = try { - val checkpoint = serialisedCheckpoint.deserialize(checkpointSerializationContext) + val checkpoint = serialisedCheckpoint.deserialize( + checkpointSerializationContext, + alwaysDeserializeFlowState = true + ) val json = checkpoint.toJson(runId.uuid, now) val jsonBytes = writer.writeValueAsBytes(json) jsonBytes to "${json.topLevelFlowClass.simpleName}-${runId.uuid}.json" @@ -259,7 +262,12 @@ class CheckpointDumperImpl(private val checkpointStorage: CheckpointStorage, pri //Dump checkpoints in "fibers" folder for((runId, serializedCheckpoint) in stream) { - val flowState = serializedCheckpoint.deserialize(checkpointSerializationContext).flowState + val flowState = serializedCheckpoint.deserialize( + checkpointSerializationContext, + alwaysDeserializeFlowState = true + ).flowState + // This includes paused flows because we have forced the deserialization of the checkpoint's flow state + // which will show as started. if(flowState is FlowState.Started) writeFiber2Zip(zip, checkpointSerializationContext, runId, flowState) } @@ -315,7 +323,7 @@ class CheckpointDumperImpl(private val checkpointStorage: CheckpointStorage, pri * the checkpoint agent source code */ private fun checkpointAgentRunning() = try { - javaClass.classLoader.loadClass("net.corda.tools.CheckpointAgent").kotlin.companionObject + Class.forName("net.corda.tools.CheckpointAgent", false, javaClass.classLoader).kotlin.companionObject } catch (e: ClassNotFoundException) { null }?.let { cls -> @@ -354,6 +362,7 @@ class CheckpointDumperImpl(private val checkpointStorage: CheckpointStorage, pri topLevelFlowLogic = flowLogic, flowCallStackSummary = flowCallStack.toSummary(), flowCallStack = flowCallStack, + status = status, suspendedOn = (flowState as? FlowState.Started)?.flowIORequest?.toSuspendedOn( timestamp, now @@ -436,6 +445,7 @@ class CheckpointDumperImpl(private val checkpointStorage: CheckpointStorage, pri val topLevelFlowClass: Class>, val topLevelFlowLogic: FlowLogic<*>, val flowCallStackSummary: List, + val status: Checkpoint.FlowStatus, val suspendedOn: SuspendedOn?, val flowCallStack: List, val origin: Origin, diff --git a/node/src/main/kotlin/net/corda/node/services/rpc/InternalRPCMessagingClient.kt b/node/src/main/kotlin/net/corda/node/services/rpc/InternalRPCMessagingClient.kt index 8ed025549c..e48cdf16c0 100644 --- a/node/src/main/kotlin/net/corda/node/services/rpc/InternalRPCMessagingClient.kt +++ b/node/src/main/kotlin/net/corda/node/services/rpc/InternalRPCMessagingClient.kt @@ -22,8 +22,7 @@ class InternalRPCMessagingClient(val sslConfig: MutualSslConfiguration, val serv private var rpcServer: RPCServer? = null fun init(rpcOps: List, securityManager: RPCSecurityManager, cacheFactory: NamedCacheFactory) = synchronized(this) { - - val tcpTransport = ArtemisTcpTransport.rpcInternalClientTcpTransport(serverAddress, sslConfig) + val tcpTransport = ArtemisTcpTransport.rpcInternalClientTcpTransport(serverAddress, sslConfig, threadPoolName = "RPCClient") locator = ActiveMQClient.createServerLocatorWithoutHA(tcpTransport).apply { // Never time out on our loopback Artemis connections. If we switch back to using the InVM transport this // would be the default and the two lines below can be deleted. diff --git a/node/src/main/kotlin/net/corda/node/services/rpc/RpcBrokerConfiguration.kt b/node/src/main/kotlin/net/corda/node/services/rpc/RpcBrokerConfiguration.kt index d8c320cab4..11ecd7e2c1 100644 --- a/node/src/main/kotlin/net/corda/node/services/rpc/RpcBrokerConfiguration.kt +++ b/node/src/main/kotlin/net/corda/node/services/rpc/RpcBrokerConfiguration.kt @@ -30,10 +30,10 @@ internal class RpcBrokerConfiguration(baseDirectory: Path, maxMessageSize: Int, setDirectories(baseDirectory) val acceptorConfigurationsSet = mutableSetOf( - rpcAcceptorTcpTransport(address, sslOptions, useSsl) + rpcAcceptorTcpTransport(address, sslOptions, enableSSL = useSsl, threadPoolName = "RPCServer") ) adminAddress?.let { - acceptorConfigurationsSet += rpcInternalAcceptorTcpTransport(it, nodeConfiguration) + acceptorConfigurationsSet += rpcInternalAcceptorTcpTransport(it, nodeConfiguration, threadPoolName = "RPCServerAdmin") } acceptorConfigurations = acceptorConfigurationsSet diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowCreator.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowCreator.kt index 9aab0183b7..ad57528218 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowCreator.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowCreator.kt @@ -92,6 +92,7 @@ class FlowCreator( lock: Semaphore = Semaphore(1), resultFuture: OpenFuture = openFuture(), firstRestore: Boolean = true, + isKilled: Boolean = false, progressTracker: ProgressTracker? = null ): Flow<*>? { val fiber = oldCheckpoint.getFiberFromCheckpoint(runId, firstRestore) @@ -116,7 +117,8 @@ class FlowCreator( reloadCheckpointAfterSuspendCount = reloadCheckpointAfterSuspendCount ?: if (reloadCheckpointAfterSuspend) checkpoint.checkpointState.numberOfSuspends else null, numberOfCommits = checkpoint.checkpointState.numberOfCommits, - lock = lock + lock = lock, + isKilled = isKilled ) injectOldProgressTracker(progressTracker, fiber.logic) return Flow(fiber, resultFuture) @@ -248,7 +250,8 @@ class FlowCreator( numberOfCommits: Int, lock: Semaphore, deduplicationHandler: DeduplicationHandler? = null, - senderUUID: String? = null + senderUUID: String? = null, + isKilled: Boolean = false ): StateMachineState { return StateMachineState( checkpoint = checkpoint, @@ -259,7 +262,8 @@ class FlowCreator( isAnyCheckpointPersisted = anyCheckpointPersisted, isStartIdempotent = false, isRemoved = false, - isKilled = false, + isKilled = isKilled, + isDead = false, flowLogic = fiber.logic, senderUUID = senderUUID, reloadCheckpointAfterSuspendCount = reloadCheckpointAfterSuspendCount, diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowDefaultUncaughtExceptionHandler.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowDefaultUncaughtExceptionHandler.kt index 44a3c8876b..a77967258d 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowDefaultUncaughtExceptionHandler.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowDefaultUncaughtExceptionHandler.kt @@ -10,7 +10,9 @@ import net.corda.nodeapi.internal.persistence.CordaPersistence import java.util.concurrent.ScheduledExecutorService import java.util.concurrent.TimeUnit -class FlowDefaultUncaughtExceptionHandler( +internal class FlowDefaultUncaughtExceptionHandler( + private val smm: StateMachineManagerInternal, + private val innerState: StateMachineInnerState, private val flowHospital: StaffedFlowHospital, private val checkpointStorage: CheckpointStorage, private val database: CordaPersistence, @@ -31,7 +33,19 @@ class FlowDefaultUncaughtExceptionHandler( ) } else { fiber.logger.warn("Caught exception from flow $id", throwable) - setFlowToHospitalized(fiber, throwable) + if (fiber.isKilled) { + // If the flow was already killed and it has reached this exception handler then the flow must be killed forcibly to + // ensure it terminates. This could lead to sessions related to the flow not terminating as errors might not have been + // propagated to them. + smm.killFlowForcibly(id) + } else { + innerState.withLock { + setFlowToHospitalized(fiber, throwable) + // This flow has died and cannot continue to run as normal. Mark is as dead so that it can be handled directly by + // retry, kill and shutdown operations. + fiber.transientState = fiber.transientState.copy(isDead = true) + } + } } } @@ -52,9 +66,13 @@ class FlowDefaultUncaughtExceptionHandler( @Suppress("TooGenericExceptionCaught") private fun setFlowToHospitalizedRescheduleOnFailure(id: StateMachineRunId) { try { - log.debug { "Updating the status of flow $id to hospitalized after uncaught exception" } - database.transaction { checkpointStorage.updateStatus(id, Checkpoint.FlowStatus.HOSPITALIZED) } - log.debug { "Updated the status of flow $id to hospitalized after uncaught exception" } + innerState.withLock { + if (flows[id]?.fiber?.transientState?.isDead == true) { + log.debug { "Updating the status of flow $id to hospitalized after uncaught exception" } + database.transaction { checkpointStorage.updateStatus(id, Checkpoint.FlowStatus.HOSPITALIZED) } + log.debug { "Updated the status of flow $id to hospitalized after uncaught exception" } + } + } } catch (e: Exception) { log.info("Failed to update the status of flow $id to hospitalized after uncaught exception, rescheduling", e) scheduledExecutor.schedule({ setFlowToHospitalizedRescheduleOnFailure(id) }, RESCHEDULE_DELAY, TimeUnit.SECONDS) diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowMonitor.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowMonitor.kt index c08515ab0e..734b2b8234 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowMonitor.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowMonitor.kt @@ -1,5 +1,6 @@ package net.corda.node.services.statemachine +import io.netty.util.concurrent.DefaultThreadFactory import net.corda.core.flows.FlowSession import net.corda.core.internal.FlowIORequest import net.corda.core.internal.FlowStateMachine @@ -22,10 +23,6 @@ internal class FlowMonitor( ) : LifecycleSupport { private companion object { - private fun defaultScheduler(): ScheduledExecutorService { - return Executors.newSingleThreadScheduledExecutor() - } - private val logger = loggerFor() } @@ -36,7 +33,7 @@ internal class FlowMonitor( override fun start() { synchronized(this) { if (scheduler == null) { - scheduler = defaultScheduler() + scheduler = Executors.newSingleThreadScheduledExecutor(DefaultThreadFactory("FlowMonitor")) shutdownScheduler = true } scheduler!!.scheduleAtFixedRate({ logFlowsWaitingForParty() }, 0, monitoringPeriod.toMillis(), TimeUnit.MILLISECONDS) diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt index 267ff59aa6..db697148cd 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt @@ -333,6 +333,9 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, logger.debug { "Calling flow: $logic" } val startTime = System.nanoTime() + serviceHub.monitoringService.metrics + .timer("Flows.StartupQueueTime") + .update(System.currentTimeMillis() - creationTime, TimeUnit.MILLISECONDS) var initialised = false val resultOrError = try { diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/SingleThreadedStateMachineManager.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/SingleThreadedStateMachineManager.kt index 4633636c9e..6a86ec5dbd 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/SingleThreadedStateMachineManager.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/SingleThreadedStateMachineManager.kt @@ -33,6 +33,7 @@ import net.corda.core.utilities.ProgressTracker import net.corda.core.utilities.Try import net.corda.core.utilities.contextLogger import net.corda.core.utilities.debug +import net.corda.core.utilities.minutes import net.corda.node.internal.InitiatedFlowFactory import net.corda.node.services.api.CheckpointStorage import net.corda.node.services.api.ServiceHubInternal @@ -242,6 +243,8 @@ internal class SingleThreadedStateMachineManager( private fun setFlowDefaultUncaughtExceptionHandler() { Fiber.setDefaultUncaughtExceptionHandler( FlowDefaultUncaughtExceptionHandler( + this, + innerState, flowHospital, checkpointStorage, database, @@ -272,17 +275,40 @@ internal class SingleThreadedStateMachineManager( if (stopping) throw IllegalStateException("Already stopping!") stopping = true for ((_, flow) in flows) { - flow.fiber.scheduleEvent(Event.SoftShutdown) + if (!flow.fiber.transientState.isDead) { + flow.fiber.scheduleEvent(Event.SoftShutdown) + } } } // Account for any expected Fibers in a test scenario. liveFibers.countDown(allowedUnsuspendedFiberCount) - liveFibers.await() + awaitShutdownOfFlows() flowHospital.close() scheduledFutureExecutor.shutdown() scheduler.shutdown() } + private fun awaitShutdownOfFlows() { + val shutdownLogger = StateMachineShutdownLogger(innerState) + var shutdown: Boolean + do { + // Manually shutdown dead flows as they can no longer process scheduled events. + // This needs to be repeated in this loop to prevent flows that die after shutdown is triggered from being forgotten. + // The mutex is not enough protection to stop race-conditions here, the removal of dead flows has to be repeated. + innerState.withMutex { + for ((id, flow) in flows) { + if (flow.fiber.transientState.isDead) { + removeFlow(id, FlowRemovalReason.SoftShutdown, flow.fiber.transientState) + } + } + } + shutdown = liveFibers.await(1.minutes.toMillis()) + if (!shutdown) { + shutdownLogger.log() + } + } while (!shutdown) + } + /** * Atomic get snapshot + subscribe. This is needed so we don't miss updates between subscriptions to [changes] and * calls to [allStateMachines] @@ -365,31 +391,29 @@ internal class SingleThreadedStateMachineManager( override fun killFlow(id: StateMachineRunId): Boolean { val flow = innerState.withLock { flows[id] } - val killFlowResult = flow?.let { killInMemoryFlow(it) } ?: killOutOfMemoryFlow(id) + val killFlowResult = flow?.let { + if (flow.fiber.transientState.isDead) { + // We cannot rely on fiber event processing in dead flows. + killInMemoryDeadFlow(it) + } else { + // Healthy flows need an event in case they they are suspended. + killInMemoryFlow(it) + } + } ?: killOutOfMemoryFlow(id) return killFlowResult || flowHospital.dropSessionInit(id) } private fun killInMemoryFlow(flow: Flow<*>): Boolean { val id = flow.fiber.id return flow.withFlowLock(VALID_KILL_FLOW_STATUSES) { - if (!flow.fiber.transientState.isKilled) { - flow.fiber.transientState = flow.fiber.transientState.copy(isKilled = true) + if (!transientState.isKilled) { + transientState = transientState.copy(isKilled = true) logger.info("Killing flow $id known to this node.") - // The checkpoint and soft locks are handled here as well as in a flow's transition. This means that we do not need to rely - // on the processing of the next event after setting the killed flag. This is to ensure a flow can be updated/removed from - // the database, even if it is stuck in a infinite loop. - if (flow.fiber.transientState.isAnyCheckpointPersisted) { - database.transaction { - if (flow.fiber.clientId != null) { - checkpointStorage.updateStatus(id, Checkpoint.FlowStatus.KILLED) - checkpointStorage.removeFlowException(id) - checkpointStorage.addFlowException(id, KilledFlowException(id)) - } else { - checkpointStorage.removeCheckpoint(id, mayHavePersistentResults = true) - } - serviceHub.vaultService.softLockRelease(id.uuid) - } - } + updateCheckpointWhenKillingFlow( + id = id, + clientId = transientState.checkpoint.checkpointState.invocationContext.clientId, + isAnyCheckpointPersisted = transientState.isAnyCheckpointPersisted + ) unfinishedFibers.countDown() scheduleEvent(Event.DoRemainingWork) @@ -401,6 +425,67 @@ internal class SingleThreadedStateMachineManager( } } + private fun killInMemoryDeadFlow(flow: Flow<*>): Boolean { + val id = flow.fiber.id + return flow.withFlowLock(VALID_KILL_FLOW_STATUSES) { + if (!transientState.isKilled) { + transientState = transientState.copy(isKilled = true) + logger.info("Killing dead flow $id known to this node.") + + val (flowForRetry, _) = createNewFlowForRetry(transientState) ?: return false + + updateCheckpointWhenKillingFlow( + id = id, + clientId = transientState.checkpoint.checkpointState.invocationContext.clientId, + isAnyCheckpointPersisted = transientState.isAnyCheckpointPersisted + ) + + unfinishedFibers.countDown() + + innerState.withLock { + if (stopping) { + return true + } + // Remove any sessions the old flow has. + for (sessionId in getFlowSessionIds(transientState.checkpoint)) { + sessionToFlow.remove(sessionId) + } + if (flowForRetry != null) { + addAndStartFlow(id, flowForRetry) + } + } + + true + } else { + logger.info("A repeated request to kill flow $id has been made, ignoring...") + false + } + } + } + + private fun updateCheckpointWhenKillingFlow( + id: StateMachineRunId, + clientId: String?, + isAnyCheckpointPersisted: Boolean, + exception: KilledFlowException = KilledFlowException(id) + ) { + // The checkpoint and soft locks are handled here as well as in a flow's transition. This means that we do not need to rely + // on the processing of the next event after setting the killed flag. This is to ensure a flow can be updated/removed from + // the database, even if it is stuck in a infinite loop or cannot be run (checkpoint cannot be deserialized from database). + if (isAnyCheckpointPersisted) { + database.transaction { + if (clientId != null) { + checkpointStorage.updateStatus(id, Checkpoint.FlowStatus.KILLED) + checkpointStorage.removeFlowException(id) + checkpointStorage.addFlowException(id, exception) + } else { + checkpointStorage.removeCheckpoint(id, mayHavePersistentResults = true) + } + serviceHub.vaultService.softLockRelease(id.uuid) + } + } + } + private fun killOutOfMemoryFlow(id: StateMachineRunId): Boolean { return database.transaction { val checkpoint = checkpointStorage.getCheckpoint(id) @@ -423,6 +508,25 @@ internal class SingleThreadedStateMachineManager( } } + override fun killFlowForcibly(flowId: StateMachineRunId): Boolean { + val flow = innerState.withLock { flows[flowId] } + flow?.withFlowLock(VALID_KILL_FLOW_STATUSES) { + logger.info("Forcibly killing flow $flowId, errors will not be propagated to the flow's sessions") + updateCheckpointWhenKillingFlow( + id = flowId, + clientId = transientState.checkpoint.checkpointState.invocationContext.clientId, + isAnyCheckpointPersisted = transientState.isAnyCheckpointPersisted + ) + removeFlow( + flowId, + FlowRemovalReason.ErrorFinish(listOf(FlowError(secureRandom.nextLong(), KilledFlowException(flowId)))), + transientState + ) + return true + } + return false + } + private fun markAllFlowsAsPaused() { return checkpointStorage.markAllPaused() } @@ -540,48 +644,8 @@ internal class SingleThreadedStateMachineManager( logger.error("Unable to find flow for flow $flowId. Something is very wrong. The flow will not retry.") return } - // We intentionally grab the checkpoint from storage rather than relying on the one referenced by currentState. This is so that - // we mirror exactly what happens when restarting the node. - // Ignore [isAnyCheckpointPersisted] as the checkpoint could be committed but the flag remains un-updated - val checkpointLoadingStatus = database.transaction { - val serializedCheckpoint = checkpointStorage.getCheckpoint(flowId) ?: return@transaction CheckpointLoadingStatus.NotFound - val checkpoint = serializedCheckpoint.let { - tryDeserializeCheckpoint(serializedCheckpoint, flowId)?.also { - if (it.status == Checkpoint.FlowStatus.HOSPITALIZED) { - checkpointStorage.removeFlowException(flowId) - checkpointStorage.updateStatus(flowId, Checkpoint.FlowStatus.RUNNABLE) - } - } ?: return@transaction CheckpointLoadingStatus.CouldNotDeserialize - } - - CheckpointLoadingStatus.Success(checkpoint) - } - - val (flow, numberOfCommitsFromCheckpoint) = when { - // Resurrect flow - checkpointLoadingStatus is CheckpointLoadingStatus.Success -> { - val numberOfCommitsFromCheckpoint = checkpointLoadingStatus.checkpoint.checkpointState.numberOfCommits - val flow = flowCreator.createFlowFromCheckpoint( - flowId, - checkpointLoadingStatus.checkpoint, - currentState.reloadCheckpointAfterSuspendCount, - currentState.lock, - firstRestore = false, - progressTracker = currentState.flowLogic.progressTracker - ) ?: return - flow to numberOfCommitsFromCheckpoint - } - checkpointLoadingStatus is CheckpointLoadingStatus.NotFound && currentState.isAnyCheckpointPersisted -> { - logger.error("Unable to find database checkpoint for flow $flowId. Something is very wrong. The flow will not retry.") - return - } - checkpointLoadingStatus is CheckpointLoadingStatus.CouldNotDeserialize -> return - else -> { - // Just flow initiation message - null to -1 - } - } + val (flow, numberOfCommitsFromCheckpoint) = createNewFlowForRetry(currentState) ?: return innerState.withLock { if (stopping) { @@ -599,6 +663,53 @@ internal class SingleThreadedStateMachineManager( } } + private fun createNewFlowForRetry(currentState: StateMachineState): Pair?, Int>? { + val id = currentState.flowLogic.runId + // We intentionally grab the checkpoint from storage rather than relying on the one referenced by currentState. This is so that + // we mirror exactly what happens when restarting the node. + // Ignore [isAnyCheckpointPersisted] as the checkpoint could be committed but the flag remains un-updated + val checkpointLoadingStatus = database.transaction { + val serializedCheckpoint = checkpointStorage.getCheckpoint(id) ?: return@transaction CheckpointLoadingStatus.NotFound + + val checkpoint = serializedCheckpoint.let { + tryDeserializeCheckpoint(serializedCheckpoint, id)?.also { + if (it.status == Checkpoint.FlowStatus.HOSPITALIZED) { + checkpointStorage.removeFlowException(id) + checkpointStorage.updateStatus(id, Checkpoint.FlowStatus.RUNNABLE) + } + } ?: return@transaction CheckpointLoadingStatus.CouldNotDeserialize + } + + CheckpointLoadingStatus.Success(checkpoint) + } + + return when { + // Resurrect flow + checkpointLoadingStatus is CheckpointLoadingStatus.Success -> { + val numberOfCommitsFromCheckpoint = checkpointLoadingStatus.checkpoint.checkpointState.numberOfCommits + val flow = flowCreator.createFlowFromCheckpoint( + id, + checkpointLoadingStatus.checkpoint, + currentState.reloadCheckpointAfterSuspendCount, + currentState.lock, + firstRestore = false, + isKilled = currentState.isKilled, + progressTracker = currentState.flowLogic.progressTracker + ) ?: return null + flow to numberOfCommitsFromCheckpoint + } + checkpointLoadingStatus is CheckpointLoadingStatus.NotFound && currentState.isAnyCheckpointPersisted -> { + logger.error("Unable to find database checkpoint for flow $id. Something is very wrong. The flow will not retry.") + null + } + checkpointLoadingStatus is CheckpointLoadingStatus.CouldNotDeserialize -> return null + else -> { + // Just flow initiation message + null to -1 + } + } + } + /** * Extract all the [ExternalEvent] from this flows event queue and queue them (in the correct order) in the PausedFlow. * This differs from [extractAndScheduleEventsForRetry] which also extracts (and schedules) [Event.Pause]. This means that if there are diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt index f4de47ced0..1a0fafaa73 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt @@ -142,6 +142,7 @@ internal interface StateMachineManagerInternal { fun retryFlowFromSafePoint(currentState: StateMachineState) fun scheduleFlowTimeout(flowId: StateMachineRunId) fun cancelFlowTimeout(flowId: StateMachineRunId) + fun killFlowForcibly(flowId: StateMachineRunId): Boolean } /** diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineShutdownLogger.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineShutdownLogger.kt new file mode 100644 index 0000000000..64934537a8 --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineShutdownLogger.kt @@ -0,0 +1,54 @@ +package net.corda.node.services.statemachine + +import net.corda.core.utilities.contextLogger + +internal class StateMachineShutdownLogger(private val innerState: StateMachineInnerState) { + + private companion object { + val log = contextLogger() + } + + fun log() { + innerState.withLock { + val message = StringBuilder("Shutdown of the state machine is blocked.\n") + val deadFlowMessage = StringBuilder() + if (flows.isNotEmpty()) { + message.append("The following live flows have not shutdown:\n") + for ((id, flow) in flows) { + val state = flow.fiber.transientState + val line = " - $id with properties " + + "[Status: ${state.checkpoint.status}, " + + "IO request: ${state.checkpoint.flowIoRequest ?: "Unstarted"}, " + + "Suspended: ${!state.isFlowResumed}, " + + "Last checkpoint timestamp: ${state.checkpoint.timestamp}, " + + "Killed: ${state.isKilled}]\n" + if (!state.isDead) { + message.append(line) + } else { + deadFlowMessage.append(line) + } + } + } + if (pausedFlows.isNotEmpty()) { + message.append("The following paused flows have not shutdown:\n") + for ((id, flow) in pausedFlows) { + message.append( + " - $id with properties " + + "[Status: ${flow.checkpoint.status}, " + + "IO request: ${flow.checkpoint.flowIoRequest ?: "Unstarted"}, " + + "Last checkpoint timestamp: ${flow.checkpoint.timestamp}, " + + "Resumable: ${flow.resumable}, " + + "Hospitalized: ${flow.hospitalized}]\n" + ) + } + } + if (deadFlowMessage.isNotEmpty()) { + deadFlowMessage.insert(0, "The following dead (crashed) flows have not shutdown:\n") + message.append(deadFlowMessage) + } + message.append("Manual intervention maybe be required for state machine shutdown due to these flows.\n") + message.append("Continuing state machine shutdown loop...") + log.info(message.toString()) + } + } +} \ No newline at end of file diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineState.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineState.kt index df17288549..a2bc214675 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineState.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineState.kt @@ -47,6 +47,8 @@ import java.util.concurrent.Semaphore * work. * @param isKilled true if the flow has been marked as killed. This is used to cause a flow to move to a killed flow transition no matter * what event it is set to process next. + * @param isDead true if the flow has been marked as dead. This happens when a flow experiences an unexpected error and escapes its event loop + * which prevents it from processing events. * @param senderUUID the identifier of the sending state machine or null if this flow is resumed from a checkpoint so that it does not participate in de-duplication high-water-marking. * @param reloadCheckpointAfterSuspendCount The number of times a flow has been reloaded (not retried). This is [null] when * [NodeConfiguration.reloadCheckpointAfterSuspendCount] is not enabled. @@ -68,6 +70,7 @@ data class StateMachineState( val isStartIdempotent: Boolean, val isRemoved: Boolean, val isKilled: Boolean, + val isDead: Boolean, val senderUUID: String?, val reloadCheckpointAfterSuspendCount: Int?, var numberOfCommits: Int, @@ -212,17 +215,28 @@ data class Checkpoint( /** * Deserializes the serialized fields contained in [Checkpoint.Serialized]. * - * @return A [Checkpoint] with all its fields filled in from [Checkpoint.Serialized] + * Depending on the [FlowStatus] of the [Checkpoint.Serialized], the deserialized [Checkpoint] may or may not have its [flowState] + * properly deserialized. This is to optimise the process's memory footprint by not holding the checkpoints of flows that are not + * running in-memory. + * + * The [flowState] will not be deserialized when the [FlowStatus] is: + * + * - [FlowStatus.PAUSED] + * - [FlowStatus.COMPLETED] + * - [FlowStatus.FAILED] + * + * Any other status returns a [FlowState.Unstarted] or [FlowState.Started] depending on the content of [serializedFlowState]. + * + * @param checkpointSerializationContext The [CheckpointSerializationContext] to deserialize the checkpoint's serialized content with. + * @param alwaysDeserializeFlowState A flag to specify if [flowState] should be deserialized, disregarding the [FlowStatus] of the + * checkpoint and ignoring the memory optimisation. + * + * @return A [Checkpoint] with all its fields filled in from [Checkpoint.Serialized]. */ - fun deserialize(checkpointSerializationContext: CheckpointSerializationContext): Checkpoint { - val flowState = when(status) { - FlowStatus.PAUSED -> FlowState.Paused - FlowStatus.COMPLETED, FlowStatus.FAILED -> FlowState.Finished - else -> serializedFlowState!!.checkpointDeserialize(checkpointSerializationContext) - } + fun deserialize(checkpointSerializationContext: CheckpointSerializationContext, alwaysDeserializeFlowState: Boolean = false): Checkpoint { return Checkpoint( checkpointState = serializedCheckpointState.checkpointDeserialize(checkpointSerializationContext), - flowState = flowState, + flowState = getFlowState(checkpointSerializationContext, alwaysDeserializeFlowState), errorState = errorState, result = result?.deserialize(context = SerializationDefaults.STORAGE_CONTEXT), status = status, @@ -231,6 +245,23 @@ data class Checkpoint( compatible = compatible ) } + + private fun getFlowState( + checkpointSerializationContext: CheckpointSerializationContext, + alwaysDeserializeFlowState: Boolean + ): FlowState { + return when { + alwaysDeserializeFlowState -> deserializeFlowState(checkpointSerializationContext) + status == FlowStatus.PAUSED -> FlowState.Paused + status == FlowStatus.COMPLETED -> FlowState.Finished + status == FlowStatus.FAILED -> FlowState.Finished + else -> deserializeFlowState(checkpointSerializationContext) + } + } + + private fun deserializeFlowState(checkpointSerializationContext: CheckpointSerializationContext): FlowState { + return serializedFlowState!!.checkpointDeserialize(checkpointSerializationContext) + } } } @@ -278,7 +309,7 @@ sealed class SessionState { * @property rejectionError if non-null the initiation failed. */ data class Initiating( - val bufferedMessages: List>, + val bufferedMessages: ArrayList>, val rejectionError: FlowError?, override val deduplicationSeed: String ) : SessionState() @@ -295,7 +326,7 @@ sealed class SessionState { data class Initiated( val peerParty: Party, val peerFlowInfo: FlowInfo, - val receivedMessages: List, + val receivedMessages: ArrayList, val otherSideErrored: Boolean, val peerSinkSessionId: SessionId, override val deduplicationSeed: String diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/DeliverSessionMessageTransition.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/DeliverSessionMessageTransition.kt index 5719139095..7afa6e09bb 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/DeliverSessionMessageTransition.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/DeliverSessionMessageTransition.kt @@ -87,7 +87,7 @@ class DeliverSessionMessageTransition( val initiatedSession = SessionState.Initiated( peerParty = event.sender, peerFlowInfo = message.initiatedFlowInfo, - receivedMessages = emptyList(), + receivedMessages = arrayListOf(), peerSinkSessionId = message.initiatedSessionId, deduplicationSeed = sessionState.deduplicationSeed, otherSideErrored = false diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/ErrorFlowTransition.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/ErrorFlowTransition.kt index 39044de821..a53b277757 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/ErrorFlowTransition.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/ErrorFlowTransition.kt @@ -63,7 +63,12 @@ class ErrorFlowTransition( status = Checkpoint.FlowStatus.FAILED, flowState = FlowState.Finished, checkpointState = startingState.checkpoint.checkpointState.copy( - numberOfCommits = startingState.checkpoint.checkpointState.numberOfCommits + 1 + numberOfCommits = startingState.checkpoint.checkpointState.numberOfCommits + 1, + invocationContext = if (startingState.checkpoint.checkpointState.invocationContext.arguments!!.isNotEmpty()) { + startingState.checkpoint.checkpointState.invocationContext.copy(arguments = emptyList()) + } else { + startingState.checkpoint.checkpointState.invocationContext + } ) ) currentState = currentState.copy( @@ -121,9 +126,9 @@ class ErrorFlowTransition( if (sessionState is SessionState.Initiating && sessionState.rejectionError == null) { // *prepend* the error messages in order to error the other sessions ASAP. The other messages will // be delivered all the same, they just won't trigger flow resumption because of dirtiness. - val errorMessagesWithDeduplication = errorMessages.map { + val errorMessagesWithDeduplication: ArrayList> = errorMessages.map { DeduplicationId.createForError(it.errorId, sourceSessionId) to it - } + }.toArrayList() sessionState.copy(bufferedMessages = errorMessagesWithDeduplication + sessionState.bufferedMessages) } else { sessionState diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/KilledFlowTransition.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/KilledFlowTransition.kt index 41e6c2f0a6..fb1204a6b2 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/KilledFlowTransition.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/KilledFlowTransition.kt @@ -7,12 +7,14 @@ import net.corda.node.services.statemachine.Checkpoint import net.corda.node.services.statemachine.DeduplicationId import net.corda.node.services.statemachine.ErrorSessionMessage import net.corda.node.services.statemachine.Event +import net.corda.node.services.statemachine.ExistingSessionMessagePayload import net.corda.node.services.statemachine.FlowError import net.corda.node.services.statemachine.FlowRemovalReason import net.corda.node.services.statemachine.FlowState import net.corda.node.services.statemachine.SessionId import net.corda.node.services.statemachine.SessionState import net.corda.node.services.statemachine.StateMachineState +import java.util.ArrayList class KilledFlowTransition( override val context: TransitionContext, @@ -101,9 +103,9 @@ class KilledFlowTransition( if (sessionState is SessionState.Initiating && sessionState.rejectionError == null) { // *prepend* the error messages in order to error the other sessions ASAP. The other messages will // be delivered all the same, they just won't trigger flow resumption because of dirtiness. - val errorMessagesWithDeduplication = errorMessages.map { + val errorMessagesWithDeduplication: ArrayList> = errorMessages.map { DeduplicationId.createForError(it.errorId, sourceSessionId) to it - } + }.toArrayList() sessionState.copy(bufferedMessages = errorMessagesWithDeduplication + sessionState.bufferedMessages) } else { sessionState diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/StartedFlowTransition.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/StartedFlowTransition.kt index a6a3495466..2f2b8a167b 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/StartedFlowTransition.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/StartedFlowTransition.kt @@ -250,7 +250,7 @@ class StartedFlowTransition( if (messages.isEmpty()) { someNotFound = true } else { - newSessionMessages[sessionId] = sessionState.copy(receivedMessages = messages.subList(1, messages.size).toList()) + newSessionMessages[sessionId] = sessionState.copy(receivedMessages = messages.subList(1, messages.size).toArrayList()) // at this point, we've already checked for errors and session ends, so it's guaranteed that the first message will be a data message. resultMessages[sessionId] = if (messages[0] is EndSessionMessage) { throw UnexpectedFlowEndException("Received session end message instead of a data session message. Mismatched send and receive?") @@ -285,7 +285,7 @@ class StartedFlowTransition( } val initialMessage = createInitialSessionMessage(sessionState.initiatingSubFlow, sourceSessionId, sessionState.additionalEntropy, null) val newSessionState = SessionState.Initiating( - bufferedMessages = emptyList(), + bufferedMessages = arrayListOf(), rejectionError = null, deduplicationSeed = sessionState.deduplicationSeed ) @@ -324,7 +324,7 @@ class StartedFlowTransition( val deduplicationId = DeduplicationId.createForNormal(checkpoint, index++, sessionState) val initialMessage = createInitialSessionMessage(uninitiatedSessionState.initiatingSubFlow, sourceSessionId, uninitiatedSessionState.additionalEntropy, message) newSessions[sourceSessionId] = SessionState.Initiating( - bufferedMessages = emptyList(), + bufferedMessages = arrayListOf(), rejectionError = null, deduplicationSeed = uninitiatedSessionState.deduplicationSeed ) @@ -375,7 +375,10 @@ class StartedFlowTransition( if (sessionState.receivedMessages.isNotEmpty() && sessionState.receivedMessages.first() is ErrorSessionMessage) { val errorMessage = sessionState.receivedMessages.first() as ErrorSessionMessage val exception = convertErrorMessageToException(errorMessage, sessionState.peerParty) - val newSessionState = sessionState.copy(receivedMessages = sessionState.receivedMessages.subList(1, sessionState.receivedMessages.size), otherSideErrored = true) + val newSessionState = sessionState.copy( + receivedMessages = sessionState.receivedMessages.subList(1, sessionState.receivedMessages.size).toArrayList(), + otherSideErrored = true + ) val newCheckpoint = startingState.checkpoint.addSession(sessionId to newSessionState) newState = startingState.copy(checkpoint = newCheckpoint) listOf(exception) diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/Transition.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/Transition.kt index 5a1cbb6797..f9e29ca5dc 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/Transition.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/Transition.kt @@ -24,6 +24,37 @@ interface Transition { val continuation = build(builder) return TransitionResult(builder.currentState, builder.actions, continuation) } + + /** + * Add [element] to the [ArrayList] and return the list. + * + * Copy of [List.plus] that returns an [ArrayList] instead. + */ + operator fun ArrayList.plus(element: T) : ArrayList { + val result = ArrayList(size + 1) + result.addAll(this) + result.add(element) + return result + } + + /** + * Add [elements] to the [ArrayList] and return the list. + * + * Copy of [List.plus] that returns an [ArrayList] instead. + */ + operator fun ArrayList.plus(elements: Collection) : ArrayList { + val result = ArrayList(this.size + elements.size) + result.addAll(this) + result.addAll(elements) + return result + } + + /** + * Convert the [List] into an [ArrayList]. + */ + fun List.toArrayList() : ArrayList { + return ArrayList(this) + } } class TransitionContext( diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/UnstartedFlowTransition.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/UnstartedFlowTransition.kt index b250b42232..70ea271505 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/UnstartedFlowTransition.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/UnstartedFlowTransition.kt @@ -6,6 +6,7 @@ import net.corda.node.services.statemachine.ConfirmSessionMessage import net.corda.node.services.statemachine.DataSessionMessage import net.corda.node.services.statemachine.DeduplicationId import net.corda.node.services.statemachine.ExistingSessionMessage +import net.corda.node.services.statemachine.ExistingSessionMessagePayload import net.corda.node.services.statemachine.FlowStart import net.corda.node.services.statemachine.FlowState import net.corda.node.services.statemachine.SenderDeduplicationId @@ -50,9 +51,9 @@ class UnstartedFlowTransition( appName = initiatingMessage.appName ), receivedMessages = if (initiatingMessage.firstPayload == null) { - emptyList() + arrayListOf() } else { - listOf(DataSessionMessage(initiatingMessage.firstPayload)) + arrayListOf(DataSessionMessage(initiatingMessage.firstPayload)) }, deduplicationSeed = "D-${initiatingMessage.initiatorSessionId.toLong}-${initiatingMessage.initiationEntropy}", otherSideErrored = false diff --git a/node/src/main/kotlin/net/corda/node/services/transactions/DeterministicVerifierFactoryService.kt b/node/src/main/kotlin/net/corda/node/services/transactions/DeterministicVerifierFactoryService.kt index d514335c92..5b016d5734 100644 --- a/node/src/main/kotlin/net/corda/node/services/transactions/DeterministicVerifierFactoryService.kt +++ b/node/src/main/kotlin/net/corda/node/services/transactions/DeterministicVerifierFactoryService.kt @@ -1,6 +1,5 @@ package net.corda.node.services.transactions -import net.corda.core.internal.BasicVerifier import net.corda.core.internal.Verifier import net.corda.core.serialization.ConstructorForDeserialization import net.corda.core.serialization.CordaSerializable @@ -9,6 +8,7 @@ import net.corda.core.serialization.CordaSerializationTransformEnumDefaults import net.corda.core.serialization.CordaSerializationTransformRename import net.corda.core.serialization.CordaSerializationTransformRenames import net.corda.core.serialization.DeprecatedConstructorForDeserialization +import net.corda.core.serialization.SerializationContext import net.corda.core.serialization.SingletonSerializeAsToken import net.corda.core.transactions.LedgerTransaction import net.corda.djvm.SandboxConfiguration @@ -80,13 +80,13 @@ class DeterministicVerifierFactoryService( override fun apply(ledgerTransaction: LedgerTransaction): LedgerTransaction { // Specialise the LedgerTransaction here so that // contracts are verified inside the DJVM! - return ledgerTransaction.specialise(::specialise) + return ledgerTransaction.specialise(::createDeterministicVerifier) } - private fun specialise(ltx: LedgerTransaction, classLoader: ClassLoader): Verifier { - return (classLoader as? URLClassLoader)?.run { + private fun createDeterministicVerifier(ltx: LedgerTransaction, serializationContext: SerializationContext): Verifier { + return (serializationContext.deserializationClassLoader as? URLClassLoader)?.let { classLoader -> DeterministicVerifier(ltx, classLoader, createSandbox(classLoader.urLs)) - } ?: BasicVerifier(ltx, classLoader) + } ?: throw IllegalStateException("Unsupported deserialization classloader type") } private fun createSandbox(userSource: Array): SandboxConfiguration { diff --git a/node/src/main/kotlin/net/corda/node/services/transactions/PersistentUniquenessProvider.kt b/node/src/main/kotlin/net/corda/node/services/transactions/PersistentUniquenessProvider.kt index 66ec2007fa..aa69d50db3 100644 --- a/node/src/main/kotlin/net/corda/node/services/transactions/PersistentUniquenessProvider.kt +++ b/node/src/main/kotlin/net/corda/node/services/transactions/PersistentUniquenessProvider.kt @@ -25,6 +25,7 @@ import net.corda.core.serialization.SingletonSerializeAsToken import net.corda.core.serialization.serialize import net.corda.core.utilities.contextLogger import net.corda.core.utilities.debug +import net.corda.node.services.vault.toStateRef import net.corda.node.utilities.AppendOnlyPersistentMap import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.nodeapi.internal.persistence.NODE_DATABASE_PREFIX @@ -157,13 +158,7 @@ class PersistentUniquenessProvider(val clock: Clock, val database: CordaPersiste toPersistentEntityKey = { PersistentStateRef(it.txhash.toString(), it.index) }, fromPersistentEntity = { //TODO null check will become obsolete after making DB/JPA columns not nullable - val txId = it.id.txId - val index = it.id.index - Pair( - StateRef(txhash = SecureHash.create(txId), index = index), - SecureHash.create(it.consumingTxHash) - ) - + Pair(it.id.toStateRef(), SecureHash.create(it.consumingTxHash)) }, toPersistentEntity = { (txHash, index): StateRef, id: SecureHash -> CommittedState( diff --git a/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt b/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt index e7846b2821..9c90c36dd3 100644 --- a/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt +++ b/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt @@ -3,28 +3,65 @@ package net.corda.node.services.vault import co.paralleluniverse.fibers.Suspendable import co.paralleluniverse.strands.Strand import net.corda.core.CordaRuntimeException -import net.corda.core.contracts.* +import net.corda.core.contracts.Amount +import net.corda.core.contracts.ContractState +import net.corda.core.contracts.FungibleAsset +import net.corda.core.contracts.FungibleState +import net.corda.core.contracts.Issued +import net.corda.core.contracts.OwnableState +import net.corda.core.contracts.StateAndRef +import net.corda.core.contracts.StateRef +import net.corda.core.contracts.TransactionState import net.corda.core.crypto.SecureHash import net.corda.core.crypto.containsAny import net.corda.core.flows.HospitalizeFlowException -import net.corda.core.internal.* +import net.corda.core.internal.ThreadBox +import net.corda.core.internal.TransactionDeserialisationException +import net.corda.core.internal.VisibleForTesting +import net.corda.core.internal.bufferUntilSubscribed +import net.corda.core.internal.tee +import net.corda.core.internal.uncheckedCast import net.corda.core.messaging.DataFeed -import net.corda.core.node.ServicesForResolution import net.corda.core.node.StatesToRecord -import net.corda.core.node.services.* -import net.corda.core.node.services.Vault.ConstraintInfo.Companion.constraintInfo -import net.corda.core.node.services.vault.* +import net.corda.core.node.services.KeyManagementService +import net.corda.core.node.services.StatesNotAvailableException +import net.corda.core.node.services.Vault +import net.corda.core.node.services.VaultQueryException +import net.corda.core.node.services.VaultService +import net.corda.core.node.services.queryBy +import net.corda.core.node.services.vault.DEFAULT_PAGE_NUM +import net.corda.core.node.services.vault.DEFAULT_PAGE_SIZE +import net.corda.core.node.services.vault.PageSpecification +import net.corda.core.node.services.vault.QueryCriteria +import net.corda.core.node.services.vault.Sort +import net.corda.core.node.services.vault.SortAttribute +import net.corda.core.node.services.vault.builder import net.corda.core.observable.internal.OnResilientSubscribe import net.corda.core.schemas.PersistentStateRef import net.corda.core.serialization.SingletonSerializeAsToken -import net.corda.core.transactions.* -import net.corda.core.utilities.* +import net.corda.core.transactions.ContractUpgradeWireTransaction +import net.corda.core.transactions.CoreTransaction +import net.corda.core.transactions.FullTransaction +import net.corda.core.transactions.LedgerTransaction +import net.corda.core.transactions.NotaryChangeWireTransaction +import net.corda.core.transactions.WireTransaction +import net.corda.core.utilities.NonEmptySet +import net.corda.core.utilities.contextLogger +import net.corda.core.utilities.debug +import net.corda.core.utilities.toNonEmptySet +import net.corda.core.utilities.trace +import net.corda.node.internal.NodeServicesForResolution import net.corda.node.services.api.SchemaService import net.corda.node.services.api.VaultServiceInternal import net.corda.node.services.schema.PersistentStateService import net.corda.node.services.statemachine.FlowStateMachineImpl -import net.corda.nodeapi.internal.persistence.* +import net.corda.nodeapi.internal.persistence.CordaPersistence +import net.corda.nodeapi.internal.persistence.bufferUntilDatabaseCommit +import net.corda.nodeapi.internal.persistence.contextTransactionOrNull +import net.corda.nodeapi.internal.persistence.currentDBSession +import net.corda.nodeapi.internal.persistence.wrapWithDatabaseTransaction import org.hibernate.Session +import org.hibernate.query.Query import rx.Observable import rx.exceptions.OnErrorNotImplementedException import rx.subjects.PublishSubject @@ -32,9 +69,11 @@ import java.security.PublicKey import java.sql.SQLException import java.time.Clock import java.time.Instant -import java.util.* +import java.util.Arrays +import java.util.UUID import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.CopyOnWriteArraySet +import java.util.stream.Stream import javax.persistence.PersistenceException import javax.persistence.Tuple import javax.persistence.criteria.CriteriaBuilder @@ -54,9 +93,9 @@ import javax.persistence.criteria.Root class NodeVaultService( private val clock: Clock, private val keyManagementService: KeyManagementService, - private val servicesForResolution: ServicesForResolution, + private val servicesForResolution: NodeServicesForResolution, private val database: CordaPersistence, - private val schemaService: SchemaService, + schemaService: SchemaService, private val appClassloader: ClassLoader ) : SingletonSerializeAsToken(), VaultServiceInternal { companion object { @@ -196,7 +235,7 @@ class NodeVaultService( if (lockId != null) { lockId = null lockUpdateTime = clock.instant() - log.trace("Releasing soft lock on consumed state: $stateRef") + log.trace { "Releasing soft lock on consumed state: $stateRef" } } session.save(state) } @@ -227,7 +266,7 @@ class NodeVaultService( } // we are not inside a flow, we are most likely inside a CordaService; // we will expose, by default, subscribing of -non unsubscribing- rx.Observers to rawUpdates. - return _rawUpdatesPublisher.resilientOnError() + _rawUpdatesPublisher.resilientOnError() } override val updates: Observable> @@ -271,7 +310,7 @@ class NodeVaultService( // This will cause a failure as we can't deserialize such states in the context of the `appClassloader`. // For now we ignore these states. // In the future we will use the AttachmentsClassloader to correctly deserialize and asses the relevancy. - log.debug { "Could not deserialize state $idx from transaction $txId. Cause: $e" } + log.warn("Could not deserialize state $idx from transaction $txId. Cause: $e") null } }.toMap() @@ -639,7 +678,23 @@ class NodeVaultService( @Throws(VaultQueryException::class) override fun _queryBy(criteria: QueryCriteria, paging: PageSpecification, sorting: Sort, contractStateType: Class): Vault.Page { try { - return _queryBy(criteria, paging, sorting, contractStateType, false) + // We decrement by one if the client requests MAX_VALUE, assuming they can not notice this because they don't have enough memory + // to request MAX_VALUE states at once. + val validPaging = if (paging.pageSize == Integer.MAX_VALUE) { + paging.copy(pageSize = Integer.MAX_VALUE - 1) + } else { + checkVaultQuery(paging.pageSize >= 1) { "Page specification: invalid page size ${paging.pageSize} [minimum is 1]" } + paging + } + if (!validPaging.isDefault) { + checkVaultQuery(validPaging.pageNumber >= DEFAULT_PAGE_NUM) { + "Page specification: invalid page number ${validPaging.pageNumber} [page numbers start from $DEFAULT_PAGE_NUM]" + } + } + log.debug { "Vault Query for contract type: $contractStateType, criteria: $criteria, pagination: $validPaging, sorting: $sorting" } + return database.transaction { + queryBy(criteria, validPaging, sorting, contractStateType) + } } catch (e: VaultQueryException) { throw e } catch (e: Exception) { @@ -647,100 +702,94 @@ class NodeVaultService( } } - @Throws(VaultQueryException::class) - private fun _queryBy(criteria: QueryCriteria, paging_: PageSpecification, sorting: Sort, contractStateType: Class, skipPagingChecks: Boolean): Vault.Page { - // We decrement by one if the client requests MAX_PAGE_SIZE, assuming they can not notice this because they don't have enough memory - // to request `MAX_PAGE_SIZE` states at once. - val paging = if (paging_.pageSize == Integer.MAX_VALUE) { - paging_.copy(pageSize = Integer.MAX_VALUE - 1) - } else { - paging_ + private fun queryBy(criteria: QueryCriteria, + paging: PageSpecification, + sorting: Sort, + contractStateType: Class): Vault.Page { + val (query, stateTypes) = createQuery(criteria, contractStateType, sorting) + query.setResultWindow(paging) + + val statesMetadata: MutableList = mutableListOf() + val otherResults: MutableList = mutableListOf() + + query.resultStream(paging).use { results -> + results.forEach { result -> + val result0 = result[0] + if (result0 is VaultSchemaV1.VaultStates) { + statesMetadata.add(result0.toStateMetadata()) + } else { + log.debug { "OtherResults: ${Arrays.toString(result.toArray())}" } + otherResults.addAll(result.toArray().asList()) + } + } } - log.debug { "Vault Query for contract type: $contractStateType, criteria: $criteria, pagination: $paging, sorting: $sorting" } - return database.transaction { - // calculate total results where a page specification has been defined - var totalStates = -1L - if (!skipPagingChecks && !paging.isDefault) { - val count = builder { VaultSchemaV1.VaultStates::recordedTime.count() } - val countCriteria = QueryCriteria.VaultCustomQueryCriteria(count, Vault.StateStatus.ALL) - val results = _queryBy(criteria.and(countCriteria), PageSpecification(), Sort(emptyList()), contractStateType, true) // only skip pagination checks for total results count query - totalStates = results.otherResults.last() as Long - } - val session = getSession() + val states: List> = servicesForResolution.loadStates( + statesMetadata.mapTo(LinkedHashSet()) { it.ref }, + ArrayList() + ) - val criteriaQuery = criteriaBuilder.createQuery(Tuple::class.java) - val queryRootVaultStates = criteriaQuery.from(VaultSchemaV1.VaultStates::class.java) + val totalStatesAvailable = when { + paging.isDefault -> -1L + // If the first page isn't full then we know that's all the states that are available + paging.pageNumber == DEFAULT_PAGE_NUM && states.size < paging.pageSize -> states.size.toLong() + else -> queryTotalStateCount(criteria, contractStateType) + } - // TODO: revisit (use single instance of parser for all queries) - val criteriaParser = HibernateQueryCriteriaParser(contractStateType, contractStateTypeMappings, criteriaBuilder, criteriaQuery, queryRootVaultStates) - - // parse criteria and build where predicates - criteriaParser.parse(criteria, sorting) - - // prepare query for execution - val query = session.createQuery(criteriaQuery) - - // pagination checks - if (!skipPagingChecks && !paging.isDefault) { - // pagination - if (paging.pageNumber < DEFAULT_PAGE_NUM) throw VaultQueryException("Page specification: invalid page number ${paging.pageNumber} [page numbers start from $DEFAULT_PAGE_NUM]") - if (paging.pageSize < 1) throw VaultQueryException("Page specification: invalid page size ${paging.pageSize} [minimum is 1]") - if (paging.pageSize > MAX_PAGE_SIZE) throw VaultQueryException("Page specification: invalid page size ${paging.pageSize} [maximum is $MAX_PAGE_SIZE]") - } - - // For both SQLServer and PostgresSQL, firstResult must be >= 0. So we set a floor at 0. - // TODO: This is a catch-all solution. But why is the default pageNumber set to be -1 in the first place? - // Even if we set the default pageNumber to be 1 instead, that may not cover the non-default cases. - // So the floor may be necessary anyway. - query.firstResult = maxOf(0, (paging.pageNumber - 1) * paging.pageSize) - val pageSize = paging.pageSize + 1 - query.maxResults = if (pageSize > 0) pageSize else Integer.MAX_VALUE // detection too many results, protected against overflow - - // execution - val results = query.resultList + return Vault.Page(states, statesMetadata, totalStatesAvailable, stateTypes, otherResults) + } + private fun Query.resultStream(paging: PageSpecification): Stream { + return if (paging.isDefault) { + val allResults = resultList // final pagination check (fail-fast on too many results when no pagination specified) - if (!skipPagingChecks && paging.isDefault && results.size > DEFAULT_PAGE_SIZE) { - throw VaultQueryException("There are ${results.size} results, which exceeds the limit of $DEFAULT_PAGE_SIZE for queries that do not specify paging. In order to retrieve these results, provide a `PageSpecification(pageNumber, pageSize)` to the method invoked.") + checkVaultQuery(allResults.size != paging.pageSize + 1) { + "There are more results than the limit of $DEFAULT_PAGE_SIZE for queries that do not specify paging. " + + "In order to retrieve these results, provide a PageSpecification to the method invoked." } - val statesAndRefs: MutableList> = mutableListOf() - val statesMeta: MutableList = mutableListOf() - val otherResults: MutableList = mutableListOf() - val stateRefs = mutableSetOf() - - results.asSequence() - .forEachIndexed { index, result -> - if (result[0] is VaultSchemaV1.VaultStates) { - if (!paging.isDefault && index == paging.pageSize) // skip last result if paged - return@forEachIndexed - val vaultState = result[0] as VaultSchemaV1.VaultStates - val stateRef = StateRef(SecureHash.create(vaultState.stateRef!!.txId), vaultState.stateRef!!.index) - stateRefs.add(stateRef) - statesMeta.add(Vault.StateMetadata(stateRef, - vaultState.contractStateClassName, - vaultState.recordedTime, - vaultState.consumedTime, - vaultState.stateStatus, - vaultState.notary, - vaultState.lockId, - vaultState.lockUpdateTime, - vaultState.relevancyStatus, - constraintInfo(vaultState.constraintType, vaultState.constraintData) - )) - } else { - // TODO: improve typing of returned other results - log.debug { "OtherResults: ${Arrays.toString(result.toArray())}" } - otherResults.addAll(result.toArray().asList()) - } - } - if (stateRefs.isNotEmpty()) - statesAndRefs.addAll(uncheckedCast(servicesForResolution.loadStates(stateRefs))) - - Vault.Page(states = statesAndRefs, statesMetadata = statesMeta, stateTypes = criteriaParser.stateTypes, totalStatesAvailable = totalStates, otherResults = otherResults) + allResults.stream() + } else { + stream() } } + private fun Query<*>.setResultWindow(paging: PageSpecification) { + if (paging.isDefault) { + // For both SQLServer and PostgresSQL, firstResult must be >= 0. + firstResult = 0 + // Peek ahead and see if there are more results in case pagination should be done + maxResults = paging.pageSize + 1 + } else { + firstResult = (paging.pageNumber - 1) * paging.pageSize + maxResults = paging.pageSize + } + } + + private fun queryTotalStateCount(baseCriteria: QueryCriteria, contractStateType: Class): Long { + val count = builder { VaultSchemaV1.VaultStates::recordedTime.count() } + val countCriteria = QueryCriteria.VaultCustomQueryCriteria(count, Vault.StateStatus.ALL) + val criteria = baseCriteria.and(countCriteria) + val (query) = createQuery(criteria, contractStateType, null) + val results = query.resultList + return results.last().toArray().last() as Long + } + + private fun createQuery(criteria: QueryCriteria, + contractStateType: Class, + sorting: Sort?): Pair, Vault.StateStatus> { + val criteriaQuery = criteriaBuilder.createQuery(Tuple::class.java) + val criteriaParser = HibernateQueryCriteriaParser( + contractStateType, + contractStateTypeMappings, + criteriaBuilder, + criteriaQuery, + criteriaQuery.from(VaultSchemaV1.VaultStates::class.java) + ) + criteriaParser.parse(criteria, sorting) + val query = getSession().createQuery(criteriaQuery) + return Pair(query, criteriaParser.stateTypes) + } + /** * Returns a [DataFeed] containing the results of the provided query, along with the associated observable, containing any subsequent updates. * @@ -775,6 +824,12 @@ class NodeVaultService( } } + private inline fun checkVaultQuery(value: Boolean, lazyMessage: () -> Any) { + if (!value) { + throw VaultQueryException(lazyMessage().toString()) + } + } + private fun filterContractStates(update: Vault.Update, contractStateType: Class) = update.copy(consumed = filterByContractState(contractStateType, update.consumed), produced = filterByContractState(contractStateType, update.produced)) @@ -802,6 +857,7 @@ class NodeVaultService( } private fun getSession() = database.currentOrNew().session + /** * Derive list from existing vault states and then incrementally update using vault observables */ diff --git a/node/src/main/kotlin/net/corda/node/services/vault/VaultSchema.kt b/node/src/main/kotlin/net/corda/node/services/vault/VaultSchema.kt index 06844d40d0..09c71fe1f7 100644 --- a/node/src/main/kotlin/net/corda/node/services/vault/VaultSchema.kt +++ b/node/src/main/kotlin/net/corda/node/services/vault/VaultSchema.kt @@ -2,7 +2,9 @@ package net.corda.node.services.vault import net.corda.core.contracts.ContractState import net.corda.core.contracts.MAX_ISSUER_REF_SIZE +import net.corda.core.contracts.StateRef import net.corda.core.contracts.UniqueIdentifier +import net.corda.core.crypto.SecureHash import net.corda.core.crypto.toStringShort import net.corda.core.identity.AbstractParty import net.corda.core.identity.Party @@ -192,3 +194,19 @@ object VaultSchemaV1 : MappedSchema( ) : IndirectStatePersistable } +fun PersistentStateRef.toStateRef(): StateRef = StateRef(SecureHash.create(txId), index) + +fun VaultSchemaV1.VaultStates.toStateMetadata(): Vault.StateMetadata { + return Vault.StateMetadata( + stateRef!!.toStateRef(), + contractStateClassName, + recordedTime, + consumedTime, + stateStatus, + notary, + lockId, + lockUpdateTime, + relevancyStatus, + Vault.ConstraintInfo.constraintInfo(constraintType, constraintData) + ) +} diff --git a/node/src/main/kotlin/net/corda/node/utilities/NodeNamedCache.kt b/node/src/main/kotlin/net/corda/node/utilities/NodeNamedCache.kt index 3bf2bb08cf..9c2e2eaf93 100644 --- a/node/src/main/kotlin/net/corda/node/utilities/NodeNamedCache.kt +++ b/node/src/main/kotlin/net/corda/node/utilities/NodeNamedCache.kt @@ -33,7 +33,8 @@ open class DefaultNamedCacheFactory protected constructor(private val metricRegi override fun bindWithMetrics(metricRegistry: MetricRegistry): BindableNamedCacheFactory = DefaultNamedCacheFactory(metricRegistry, this.nodeConfiguration) override fun bindWithConfig(nodeConfiguration: NodeConfiguration): BindableNamedCacheFactory = DefaultNamedCacheFactory(this.metricRegistry, nodeConfiguration) - open protected fun configuredForNamed(caffeine: Caffeine, name: String): Caffeine { + @Suppress("ComplexMethod") + protected open fun configuredForNamed(caffeine: Caffeine, name: String): Caffeine { return with(nodeConfiguration!!) { when { name.startsWith("RPCSecurityManagerShiroCache_") -> with(security?.authService?.options?.cache!!) { caffeine.maximumSize(maxEntries).expireAfterWrite(expireAfterSecs, TimeUnit.SECONDS) } @@ -43,7 +44,6 @@ open class DefaultNamedCacheFactory protected constructor(private val metricRegi name == "HibernateConfiguration_sessionFactories" -> caffeine.maximumSize(database.mappedSchemaCacheSize) name == "DBTransactionStorage_transactions" -> caffeine.maximumWeight(transactionCacheSizeBytes) name == "NodeAttachmentService_attachmentContent" -> caffeine.maximumWeight(attachmentContentCacheSizeBytes) - name == "NodeAttachmentService_attachmentPresence" -> caffeine.maximumSize(attachmentCacheBound) name == "NodeAttachmentService_contractAttachmentVersions" -> caffeine.maximumSize(defaultCacheSize) name == "PersistentIdentityService_keyToPartyAndCert" -> caffeine.maximumSize(defaultCacheSize) name == "PersistentIdentityService_nameToParty" -> caffeine.maximumSize(defaultCacheSize) @@ -85,7 +85,7 @@ open class DefaultNamedCacheFactory protected constructor(private val metricRegi return configuredForNamed(caffeine, name).build(loader) } - open protected val defaultCacheSize = 1024L + protected open val defaultCacheSize = 1024L private val defaultAttachmentsClassLoaderCacheSize = defaultCacheSize / CACHE_SIZE_DENOMINATOR } -private const val CACHE_SIZE_DENOMINATOR = 4L \ No newline at end of file +private const val CACHE_SIZE_DENOMINATOR = 4L diff --git a/node/src/main/kotlin/net/corda/notary/experimental/bftsmart/BFTSmartNotaryService.kt b/node/src/main/kotlin/net/corda/notary/experimental/bftsmart/BFTSmartNotaryService.kt index a570ccd7b5..76094c2a1d 100644 --- a/node/src/main/kotlin/net/corda/notary/experimental/bftsmart/BFTSmartNotaryService.kt +++ b/node/src/main/kotlin/net/corda/notary/experimental/bftsmart/BFTSmartNotaryService.kt @@ -21,6 +21,7 @@ import net.corda.core.utilities.getOrThrow import net.corda.core.utilities.unwrap import net.corda.node.services.api.ServiceHubInternal import net.corda.node.services.transactions.PersistentUniquenessProvider +import net.corda.node.services.vault.toStateRef import net.corda.node.utilities.AppendOnlyPersistentMap import net.corda.nodeapi.internal.persistence.NODE_DATABASE_PREFIX import java.security.PublicKey @@ -41,6 +42,8 @@ class BFTSmartNotaryService( ) : NotaryService() { companion object { private val log = contextLogger() + + @Suppress("unused") // Used by NotaryLoader via reflection @JvmStatic val serializationFilter get() = { clazz: Class<*> -> @@ -147,12 +150,7 @@ class BFTSmartNotaryService( toPersistentEntityKey = { PersistentStateRef(it.txhash.toString(), it.index) }, fromPersistentEntity = { //TODO null check will become obsolete after making DB/JPA columns not nullable - val txId = it.id.txId - val index = it.id.index - Pair( - StateRef(txhash = SecureHash.create(txId), index = index), - SecureHash.create(it.consumingTxHash) - ) + Pair(it.id.toStateRef(), SecureHash.create(it.consumingTxHash)) }, toPersistentEntity = { (txHash, index): StateRef, id: SecureHash -> CommittedState( diff --git a/node/src/main/kotlin/net/corda/notary/jpa/JPAUniquenessProvider.kt b/node/src/main/kotlin/net/corda/notary/jpa/JPAUniquenessProvider.kt index d38a3f35b7..b678478da6 100644 --- a/node/src/main/kotlin/net/corda/notary/jpa/JPAUniquenessProvider.kt +++ b/node/src/main/kotlin/net/corda/notary/jpa/JPAUniquenessProvider.kt @@ -24,6 +24,7 @@ import net.corda.core.serialization.SingletonSerializeAsToken import net.corda.core.serialization.serialize import net.corda.core.utilities.contextLogger import net.corda.core.utilities.debug +import net.corda.node.services.vault.toStateRef import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.nodeapi.internal.persistence.NODE_DATABASE_PREFIX import net.corda.notary.common.InternalResult @@ -142,10 +143,6 @@ class JPAUniquenessProvider( fun encodeStateRef(s: StateRef): PersistentStateRef { return PersistentStateRef(s.txhash.toString(), s.index) } - - fun decodeStateRef(s: PersistentStateRef): StateRef { - return StateRef(txhash = SecureHash.create(s.txId), index = s.index) - } } /** @@ -215,15 +212,15 @@ class JPAUniquenessProvider( committedStates.addAll(existing) } - return committedStates.map { - val stateRef = StateRef(txhash = SecureHash.create(it.id.txId), index = it.id.index) + return committedStates.associate { + val stateRef = it.id.toStateRef() val consumingTxId = SecureHash.create(it.consumingTxHash) if (stateRef in references) { stateRef to StateConsumptionDetails(consumingTxId.reHash(), type = StateConsumptionDetails.ConsumedStateType.REFERENCE_INPUT_STATE) } else { stateRef to StateConsumptionDetails(consumingTxId.reHash()) } - }.toMap() + } } private fun withRetry(block: () -> T): T { diff --git a/node/src/test/kotlin/net/corda/node/internal/CordaServiceTest.kt b/node/src/test/kotlin/net/corda/node/internal/CordaServiceTest.kt index 1d06c470c8..ee1d4bf0a6 100644 --- a/node/src/test/kotlin/net/corda/node/internal/CordaServiceTest.kt +++ b/node/src/test/kotlin/net/corda/node/internal/CordaServiceTest.kt @@ -123,7 +123,10 @@ class CordaServiceTest { val identityService = makeTestIdentityService(dummyNotary.identity) Assertions.assertThatThrownBy { MockServices(cordappPackages, dummyNotary, identityService, dummyCashIssuer.keyPair, bankOfCorda.keyPair) } - .isInstanceOf(ClassNotFoundException::class.java).hasMessage("Could not create jar file as the given package is not found on the classpath: com.r3.corda.sdk.tokens.money") + .isInstanceOf(ClassNotFoundException::class.java) + .hasMessageStartingWith("Could not create jar file as ") + .hasMessageContaining("com.r3.corda.sdk.tokens.money") + .hasMessageEndingWith(" not found on the classpath") } @StartableByService diff --git a/node/src/test/kotlin/net/corda/node/internal/CustomSerializationSchemeScanningTest.kt b/node/src/test/kotlin/net/corda/node/internal/CustomSerializationSchemeScanningTest.kt new file mode 100644 index 0000000000..1837a8fb49 --- /dev/null +++ b/node/src/test/kotlin/net/corda/node/internal/CustomSerializationSchemeScanningTest.kt @@ -0,0 +1,69 @@ +package net.corda.node.internal + +import net.corda.core.serialization.CustomSerializationScheme +import net.corda.core.serialization.SerializationContext +import net.corda.core.serialization.SerializationSchemeContext +import net.corda.core.utilities.ByteSequence +import net.corda.node.internal.classloading.scanForCustomSerializationScheme +import org.junit.Test +import org.mockito.Mockito +import kotlin.test.assertFailsWith + +class CustomSerializationSchemeScanningTest { + + class NonSerializationScheme + + open class DummySerializationScheme : CustomSerializationScheme { + override fun getSchemeId(): Int { + return 7 + } + + override fun deserialize(bytes: ByteSequence, clazz: Class, context: SerializationSchemeContext): T { + throw DummySerializationSchemeException("We should never get here.") + } + + override fun serialize(obj: T, context: SerializationSchemeContext): ByteSequence { + throw DummySerializationSchemeException("Tried to serialize with DummySerializationScheme") + } + } + + class DummySerializationSchemeException(override val message: String) : RuntimeException(message) + + class DummySerializationSchemeWithoutNoArgConstructor(val myArgument: String) : DummySerializationScheme() + + @Test(timeout = 300_000) + fun `Can scan for custom serialization scheme and build a serialization scheme`() { + val scheme = scanForCustomSerializationScheme(DummySerializationScheme::class.java.name, this::class.java.classLoader) + val mockContext = Mockito.mock(SerializationContext::class.java) + assertFailsWith("Tried to serialize with DummySerializationScheme") { + scheme.serialize(Any::class.java, mockContext) + } + } + + @Test(timeout = 300_000) + fun `verification fails with a helpful error if the class is not found in the classloader`() { + val missingClassName = "org.testing.DoesNotExist" + assertFailsWith("$missingClassName was declared as a custom serialization scheme but could not " + + "be found.") { + scanForCustomSerializationScheme(missingClassName, this::class.java.classLoader) + } + } + + @Test(timeout = 300_000) + fun `verification fails with a helpful error if the class is not a custom serialization scheme`() { + val schemeName = NonSerializationScheme::class.java.name + assertFailsWith("$schemeName was declared as a custom serialization scheme but does not " + + "implement CustomSerializationScheme.") { + scanForCustomSerializationScheme(schemeName, this::class.java.classLoader) + } + } + + @Test(timeout = 300_000) + fun `verification fails with a helpful error if the class does not have a no arg constructor`() { + val schemeName = DummySerializationSchemeWithoutNoArgConstructor::class.java.name + assertFailsWith("$schemeName was declared as a custom serialization scheme but does not " + + "have a no argument constructor.") { + scanForCustomSerializationScheme(schemeName, this::class.java.classLoader) + } + } +} diff --git a/node/src/test/kotlin/net/corda/node/messaging/TwoPartyTradeFlowTests.kt b/node/src/test/kotlin/net/corda/node/messaging/TwoPartyTradeFlowTests.kt index 0c49ee44ac..28a5f3b973 100644 --- a/node/src/test/kotlin/net/corda/node/messaging/TwoPartyTradeFlowTests.kt +++ b/node/src/test/kotlin/net/corda/node/messaging/TwoPartyTradeFlowTests.kt @@ -124,7 +124,7 @@ class TwoPartyTradeFlowTests(private val anonymous: Boolean) { bobNode.internals.disableDBCloseOnStop() bobNode.database.transaction { - VaultFiller(bobNode.services, dummyNotary, notary, ::Random).fillWithSomeTestCash(2000.DOLLARS, bankNode.services, 3, 10, cashIssuer) + VaultFiller(bobNode.services, dummyNotary, notary, ::Random).fillWithSomeTestCash(2000.DOLLARS, bankNode.services, 3, cashIssuer, atMostThisManyStates = 10) } val alicesFakePaper = aliceNode.database.transaction { @@ -233,7 +233,7 @@ class TwoPartyTradeFlowTests(private val anonymous: Boolean) { val issuer = bank.ref(1, 2, 3) bobNode.database.transaction { - VaultFiller(bobNode.services, dummyNotary, notary, ::Random).fillWithSomeTestCash(2000.DOLLARS, bankNode.services, 3, 10, issuer) + VaultFiller(bobNode.services, dummyNotary, notary, ::Random).fillWithSomeTestCash(2000.DOLLARS, bankNode.services, 3, issuer, atMostThisManyStates = 10) } val alicesFakePaper = aliceNode.database.transaction { fillUpForSeller(false, issuer, alice, diff --git a/node/src/test/kotlin/net/corda/node/services/persistence/HibernateConfigurationTest.kt b/node/src/test/kotlin/net/corda/node/services/persistence/HibernateConfigurationTest.kt index 1efb349ca0..30cdbe7f59 100644 --- a/node/src/test/kotlin/net/corda/node/services/persistence/HibernateConfigurationTest.kt +++ b/node/src/test/kotlin/net/corda/node/services/persistence/HibernateConfigurationTest.kt @@ -28,12 +28,14 @@ import net.corda.finance.schemas.CashSchemaV1 import net.corda.finance.test.SampleCashSchemaV1 import net.corda.finance.test.SampleCashSchemaV2 import net.corda.finance.test.SampleCashSchemaV3 +import net.corda.node.internal.NodeServicesForResolution import net.corda.node.services.api.WritableTransactionStorage import net.corda.node.services.schema.ContractStateAndRef import net.corda.node.services.schema.NodeSchemaService import net.corda.node.services.schema.PersistentStateService import net.corda.node.services.vault.NodeVaultService import net.corda.node.services.vault.VaultSchemaV1 +import net.corda.node.services.vault.toStateRef import net.corda.node.testing.DummyFungibleContract import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.nodeapi.internal.persistence.DatabaseConfig @@ -48,7 +50,6 @@ import net.corda.testing.internal.vault.VaultFiller import net.corda.testing.node.MockServices import net.corda.testing.node.MockServices.Companion.makeTestDataSourceProperties import org.assertj.core.api.Assertions -import org.assertj.core.api.Assertions.`in` import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThatThrownBy import org.hibernate.SessionFactory @@ -122,7 +123,14 @@ class HibernateConfigurationTest { services = object : MockServices(cordappPackages, BOB_NAME, mock().also { doReturn(null).whenever(it).verifyAndRegisterIdentity(argThat { name == BOB_NAME }) }, generateKeyPair(), dummyNotary.keyPair) { - override val vaultService = NodeVaultService(Clock.systemUTC(), keyManagementService, servicesForResolution, database, schemaService, cordappClassloader).apply { start() } + override val vaultService = NodeVaultService( + Clock.systemUTC(), + keyManagementService, + servicesForResolution as NodeServicesForResolution, + database, + schemaService, + cordappClassloader + ).apply { start() } override fun recordTransactions(statesToRecord: StatesToRecord, txs: Iterable) { for (stx in txs) { (validatedTransactions as WritableTransactionStorage).addTransaction(stx) @@ -183,7 +191,7 @@ class HibernateConfigurationTest { // execute query val queryResults = entityManager.createQuery(criteriaQuery).resultList val coins = queryResults.map { - services.loadState(toStateRef(it.stateRef!!)).data + services.loadState(it.stateRef!!.toStateRef()).data }.sumCash() assertThat(coins.toDecimal() >= BigDecimal("50.00")) } @@ -739,7 +747,7 @@ class HibernateConfigurationTest { val queryResults = entityManager.createQuery(criteriaQuery).resultList queryResults.forEach { - val cashState = services.loadState(toStateRef(it.stateRef!!)).data as Cash.State + val cashState = services.loadState(it.stateRef!!.toStateRef()).data as Cash.State println("${it.stateRef} with owner: ${cashState.owner.owningKey.toBase58String()}") } @@ -823,7 +831,7 @@ class HibernateConfigurationTest { // execute query val queryResults = entityManager.createQuery(criteriaQuery).resultList queryResults.forEach { - val cashState = services.loadState(toStateRef(it.stateRef!!)).data as Cash.State + val cashState = services.loadState(it.stateRef!!.toStateRef()).data as Cash.State println("${it.stateRef} with owner ${cashState.owner.owningKey.toBase58String()} and participants ${cashState.participants.map { it.owningKey.toBase58String() }}") } @@ -961,10 +969,6 @@ class HibernateConfigurationTest { } } - private fun toStateRef(pStateRef: PersistentStateRef): StateRef { - return StateRef(SecureHash.create(pStateRef.txId), pStateRef.index) - } - @Test(timeout=300_000) fun `schema change`() { fun createNewDB(schemas: Set, initialiseSchema: Boolean = true): CordaPersistence { diff --git a/node/src/test/kotlin/net/corda/node/services/statemachine/FlowSoftLocksTests.kt b/node/src/test/kotlin/net/corda/node/services/statemachine/FlowSoftLocksTests.kt index 6f0fa3278c..1930e7ffd8 100644 --- a/node/src/test/kotlin/net/corda/node/services/statemachine/FlowSoftLocksTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/statemachine/FlowSoftLocksTests.kt @@ -244,7 +244,6 @@ class FlowSoftLocksTests { 100.DOLLARS, bankNode.services, thisManyStates, - thisManyStates, cashIssuer ) } diff --git a/node/src/test/kotlin/net/corda/node/services/vault/VaultQueryTests.kt b/node/src/test/kotlin/net/corda/node/services/vault/VaultQueryTests.kt index 09964b6602..b06518667c 100644 --- a/node/src/test/kotlin/net/corda/node/services/vault/VaultQueryTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/vault/VaultQueryTests.kt @@ -20,14 +20,13 @@ import net.corda.finance.* import net.corda.finance.contracts.CommercialPaper import net.corda.finance.contracts.Commodity import net.corda.finance.contracts.DealState -import net.corda.finance.workflows.asset.selection.AbstractCashSelection import net.corda.finance.contracts.asset.Cash import net.corda.finance.schemas.CashSchemaV1 -import net.corda.finance.schemas.CashSchemaV1.PersistentCashState import net.corda.finance.schemas.CommercialPaperSchemaV1 import net.corda.finance.test.SampleCashSchemaV2 import net.corda.finance.test.SampleCashSchemaV3 import net.corda.finance.workflows.CommercialPaperUtils +import net.corda.finance.workflows.asset.selection.AbstractCashSelection import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.nodeapi.internal.persistence.DatabaseConfig import net.corda.nodeapi.internal.persistence.DatabaseTransaction @@ -197,8 +196,9 @@ abstract class VaultQueryTestsBase : VaultQueryParties { } protected fun consumeCash(amount: Amount) = vaultFiller.consumeCash(amount, CHARLIE) - private fun setUpDb(_database: CordaPersistence, delay: Long = 0) { - _database.transaction { + + private fun setUpDb(database: CordaPersistence, delay: Long = 0) { + database.transaction { // create new states vaultFiller.fillWithSomeTestCash(100.DOLLARS, notaryServices, 10, DUMMY_CASH_ISSUER) val linearStatesXYZ = vaultFiller.fillWithSomeTestLinearStates(1, "XYZ") @@ -444,7 +444,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties { Sort.Direction.DESC -> assertThat(allStates.sortedByDescending { it.state.data.linearNumber }.sortedBy { it.ref.txhash }.sortedBy { it.ref.index }).isEqualTo(allStates) } - (1..3).forEach { + repeat(3) { val newAllStates = vaultService.queryBy(sorting = sorting, criteria = criteria).states assertThat(newAllStates.groupBy(StateAndRef<*>::ref)).hasSameSizeAs(allStates) assertThat(newAllStates).containsExactlyElementsOf(allStates) @@ -485,7 +485,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties { Sort.Direction.DESC -> assertThat(allStates.sortedByDescending { it.ref.txhash }.sortedByDescending { it.ref.index }).isEqualTo(allStates) } - (1..3).forEach { + repeat(3) { val newAllStates = vaultService.queryBy(sorting = sorting, criteria = criteria).states assertThat(newAllStates.groupBy(StateAndRef<*>::ref)).hasSameSizeAs(allStates) assertThat(newAllStates).containsExactlyElementsOf(allStates) @@ -638,7 +638,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties { } val sorted = results.states.sortedBy { it.ref.toString() } assertThat(results.states).isEqualTo(sorted) - assertThat(results.states).allSatisfy { !consumed.contains(it.ref.txhash) } + assertThat(results.states).allSatisfy { assertThat(consumed).doesNotContain(it.ref.txhash) } } } @@ -1537,7 +1537,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties { vaultFiller.fillWithSomeTestDeals(listOf("123", "456", "789")) // count fungible assets val count = builder { VaultSchemaV1.VaultStates::recordedTime.count() } - val countCriteria = QueryCriteria.VaultCustomQueryCriteria(count) + val countCriteria = VaultCustomQueryCriteria(count) val fungibleStateCount = vaultService.queryBy>(countCriteria).otherResults.single() as Long assertThat(fungibleStateCount).isEqualTo(10L) @@ -1563,7 +1563,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties { val count = builder { VaultSchemaV1.VaultStates::recordedTime.count() } // count fungible assets - val countCriteria = QueryCriteria.VaultCustomQueryCriteria(count, Vault.StateStatus.ALL) + val countCriteria = VaultCustomQueryCriteria(count, Vault.StateStatus.ALL) val fungibleStateCount = vaultService.queryBy>(countCriteria).otherResults.single() as Long assertThat(fungibleStateCount).isEqualTo(10L) @@ -1583,7 +1583,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties { // UNCONSUMED states (default) // count fungible assets - val countCriteriaUnconsumed = QueryCriteria.VaultCustomQueryCriteria(count, Vault.StateStatus.UNCONSUMED) + val countCriteriaUnconsumed = VaultCustomQueryCriteria(count, Vault.StateStatus.UNCONSUMED) val fungibleStateCountUnconsumed = vaultService.queryBy>(countCriteriaUnconsumed).otherResults.single() as Long assertThat(fungibleStateCountUnconsumed.toInt()).isEqualTo(10 - cashUpdates.consumed.size + cashUpdates.produced.size) @@ -1598,7 +1598,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties { // CONSUMED states // count fungible assets - val countCriteriaConsumed = QueryCriteria.VaultCustomQueryCriteria(count, Vault.StateStatus.CONSUMED) + val countCriteriaConsumed = VaultCustomQueryCriteria(count, Vault.StateStatus.CONSUMED) val fungibleStateCountConsumed = vaultService.queryBy>(countCriteriaConsumed).otherResults.single() as Long assertThat(fungibleStateCountConsumed.toInt()).isEqualTo(cashUpdates.consumed.size) @@ -1622,7 +1622,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties { val start = TODAY val end = TODAY.plus(30, ChronoUnit.DAYS) val recordedBetweenExpression = TimeCondition( - QueryCriteria.TimeInstantType.RECORDED, + TimeInstantType.RECORDED, ColumnPredicate.Between(start, end)) val criteria = VaultQueryCriteria(timeCondition = recordedBetweenExpression) val results = vaultService.queryBy(criteria) @@ -1632,7 +1632,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties { // Future val startFuture = TODAY.plus(1, ChronoUnit.DAYS) val recordedBetweenExpressionFuture = TimeCondition( - QueryCriteria.TimeInstantType.RECORDED, ColumnPredicate.Between(startFuture, end)) + TimeInstantType.RECORDED, ColumnPredicate.Between(startFuture, end)) val criteriaFuture = VaultQueryCriteria(timeCondition = recordedBetweenExpressionFuture) assertThat(vaultService.queryBy(criteriaFuture).states).isEmpty() } @@ -1648,7 +1648,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties { consumeCash(100.DOLLARS) val asOfDateTime = TODAY val consumedAfterExpression = TimeCondition( - QueryCriteria.TimeInstantType.CONSUMED, ColumnPredicate.BinaryComparison(BinaryComparisonOperator.GREATER_THAN_OR_EQUAL, asOfDateTime)) + TimeInstantType.CONSUMED, ColumnPredicate.BinaryComparison(BinaryComparisonOperator.GREATER_THAN_OR_EQUAL, asOfDateTime)) val criteria = VaultQueryCriteria(status = Vault.StateStatus.CONSUMED, timeCondition = consumedAfterExpression) val results = vaultService.queryBy(criteria) @@ -1674,7 +1674,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties { // pagination: last page @Test(timeout=300_000) - fun `all states with paging specification - last`() { + fun `all states with paging specification - last`() { database.transaction { vaultFiller.fillWithSomeTestCash(95.DOLLARS, notaryServices, 95, DUMMY_CASH_ISSUER) // Last page implies we need to perform a row count for the Query first, @@ -1705,6 +1705,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties { } // pagination: invalid page size + @Suppress("INTEGER_OVERFLOW") @Test(timeout=300_000) fun `invalid page size`() { expectedEx.expect(VaultQueryException::class.java) @@ -1712,8 +1713,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties { database.transaction { vaultFiller.fillWithSomeTestCash(100.DOLLARS, notaryServices, 100, DUMMY_CASH_ISSUER) - @Suppress("EXPECTED_CONDITION") - val pagingSpec = PageSpecification(DEFAULT_PAGE_NUM, @Suppress("INTEGER_OVERFLOW") Integer.MAX_VALUE + 1) // overflow = -2147483648 + val pagingSpec = PageSpecification(DEFAULT_PAGE_NUM, Integer.MAX_VALUE + 1) // overflow = -2147483648 val criteria = VaultQueryCriteria(status = Vault.StateStatus.ALL) vaultService.queryBy(criteria, paging = pagingSpec) } @@ -1723,7 +1723,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties { @Test(timeout=300_000) fun `pagination not specified but more than default results available`() { expectedEx.expect(VaultQueryException::class.java) - expectedEx.expectMessage("provide a `PageSpecification(pageNumber, pageSize)`") + expectedEx.expectMessage("provide a PageSpecification") database.transaction { vaultFiller.fillWithSomeTestCash(201.DOLLARS, notaryServices, 201, DUMMY_CASH_ISSUER) @@ -1781,9 +1781,9 @@ abstract class VaultQueryTestsBase : VaultQueryParties { println("$index : $any") } assertThat(results.otherResults.size).isEqualTo(402) - val instants = results.otherResults.filter { it is Instant }.map { it as Instant } + val instants = results.otherResults.filterIsInstance() assertThat(instants).isSorted - val longs = results.otherResults.filter { it is Long }.map { it as Long } + val longs = results.otherResults.filterIsInstance() assertThat(longs.size).isEqualTo(201) assertThat(instants.size).isEqualTo(201) assertThat(longs.sum()).isEqualTo(20100L) @@ -1911,8 +1911,8 @@ abstract class VaultQueryTestsBase : VaultQueryParties { fun `LinearStateQueryCriteria returns empty resultset without errors if there is an empty list after the 'in' clause`() { database.transaction { val uid = UniqueIdentifier("999") - vaultFiller.fillWithSomeTestLinearStates(numberToCreate = 1, uniqueIdentifier = uid) - vaultFiller.fillWithSomeTestLinearStates(numberToCreate = 1, externalId = "1234") + vaultFiller.fillWithSomeTestLinearStates(txCount = 1, uniqueIdentifier = uid) + vaultFiller.fillWithSomeTestLinearStates(txCount = 1, externalId = "1234") val uuidCriteria = LinearStateQueryCriteria(uuid = listOf(uid.id)) val externalIdCriteria = LinearStateQueryCriteria(externalId = listOf("1234")) @@ -2061,6 +2061,26 @@ abstract class VaultQueryTestsBase : VaultQueryParties { } } + @Test(timeout = 300_000) + fun `unconsumed states which are globally unordered across multiple transactions sorted by custom attribute`() { + val linearNumbers = Array(2) { LongArray(2) } + // Make sure states from the same transaction are not given consecutive linear numbers. + linearNumbers[0][0] = 1L + linearNumbers[0][1] = 3L + linearNumbers[1][0] = 2L + linearNumbers[1][1] = 4L + + val results = database.transaction { + vaultFiller.fillWithTestStates(txCount = 2, statesPerTx = 2) { participantsToUse, txIndex, stateIndex -> + DummyLinearContract.State(participants = participantsToUse, linearNumber = linearNumbers[txIndex][stateIndex]) + } + + val sortColumn = Sort.SortColumn(SortAttribute.Custom(DummyLinearStateSchemaV1.PersistentDummyLinearState::class.java, "linearNumber")) + vaultService.queryBy(VaultQueryCriteria(), sorting = Sort(setOf(sortColumn))) + } + assertThat(results.states.map { it.state.data.linearNumber }).isEqualTo(listOf(1L, 2L, 3L, 4L)) + } + @Test(timeout=300_000) fun `return consumed linear states for a given linear id`() { database.transaction { @@ -2390,7 +2410,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties { services.recordTransactions(commercialPaper2) val ccyIndex = builder { CommercialPaperSchemaV1.PersistentCommercialPaperState::currency.equal(USD.currencyCode) } - val criteria1 = QueryCriteria.VaultCustomQueryCriteria(ccyIndex) + val criteria1 = VaultCustomQueryCriteria(ccyIndex) val result = vaultService.queryBy(criteria1) @@ -2433,9 +2453,9 @@ abstract class VaultQueryTestsBase : VaultQueryParties { val maturityIndex = CommercialPaperSchemaV1.PersistentCommercialPaperState::maturity.greaterThanOrEqual(TEST_TX_TIME + 30.days) val faceValueIndex = CommercialPaperSchemaV1.PersistentCommercialPaperState::faceValue.greaterThanOrEqual(10000L) - val criteria1 = QueryCriteria.VaultCustomQueryCriteria(ccyIndex) - val criteria2 = QueryCriteria.VaultCustomQueryCriteria(maturityIndex) - val criteria3 = QueryCriteria.VaultCustomQueryCriteria(faceValueIndex) + val criteria1 = VaultCustomQueryCriteria(ccyIndex) + val criteria2 = VaultCustomQueryCriteria(maturityIndex) + val criteria3 = VaultCustomQueryCriteria(faceValueIndex) vaultService.queryBy(criteria1.and(criteria3).and(criteria2)) } @@ -2458,8 +2478,8 @@ abstract class VaultQueryTestsBase : VaultQueryParties { val generalCriteria = VaultQueryCriteria(Vault.StateStatus.ALL) val results = builder { - val currencyIndex = PersistentCashState::currency.equal(USD.currencyCode) - val quantityIndex = PersistentCashState::pennies.greaterThanOrEqual(10L) + val currencyIndex = CashSchemaV1.PersistentCashState::currency.equal(USD.currencyCode) + val quantityIndex = CashSchemaV1.PersistentCashState::pennies.greaterThanOrEqual(10L) val customCriteria1 = VaultCustomQueryCriteria(currencyIndex) val customCriteria2 = VaultCustomQueryCriteria(quantityIndex) @@ -2710,7 +2730,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties { // Enrich and override QueryCriteria with additional default attributes (such as soft locks) val enrichedCriteria = VaultQueryCriteria(contractStateTypes = setOf(DealState::class.java), // enrich - softLockingCondition = QueryCriteria.SoftLockingCondition(QueryCriteria.SoftLockingType.UNLOCKED_AND_SPECIFIED, listOf(UUID.randomUUID())), + softLockingCondition = SoftLockingCondition(SoftLockingType.UNLOCKED_AND_SPECIFIED, listOf(UUID.randomUUID())), status = Vault.StateStatus.UNCONSUMED) // override // Sorting val sortAttribute = SortAttribute.Standard(Sort.CommonStateAttribute.STATE_REF) @@ -3056,7 +3076,7 @@ class VaultQueryTests : VaultQueryTestsBase(), VaultQueryParties by delegate { assertThat(snapshot.states).hasSize(0) val states = vaultFiller.fillWithSomeTestLinearAndDealStates(10).states this.session.flush() - vaultFiller.consumeLinearStates(states.toList()) + vaultFiller.consumeStates(states) updates } @@ -3079,7 +3099,7 @@ class VaultQueryTests : VaultQueryTestsBase(), VaultQueryParties by delegate { assertThat(snapshot.states).hasSize(0) val states = vaultFiller.fillWithSomeTestLinearAndDealStates(10).states this.session.flush() - vaultFiller.consumeLinearStates(states.toList()) + vaultFiller.consumeStates(states) updates } @@ -3102,7 +3122,7 @@ class VaultQueryTests : VaultQueryTestsBase(), VaultQueryParties by delegate { assertThat(snapshot.states).hasSize(0) val states = vaultFiller.fillWithSomeTestLinearAndDealStates(10).states this.session.flush() - vaultFiller.consumeLinearStates(states.toList()) + vaultFiller.consumeStates(states) updates } diff --git a/node/src/test/kotlin/net/corda/node/services/vault/VaultSoftLockManagerTest.kt b/node/src/test/kotlin/net/corda/node/services/vault/VaultSoftLockManagerTest.kt index 7e771e9904..ac621c9bff 100644 --- a/node/src/test/kotlin/net/corda/node/services/vault/VaultSoftLockManagerTest.kt +++ b/node/src/test/kotlin/net/corda/node/services/vault/VaultSoftLockManagerTest.kt @@ -10,7 +10,6 @@ import net.corda.core.flows.InitiatingFlow import net.corda.core.identity.AbstractParty import net.corda.core.internal.FlowStateMachine import net.corda.core.internal.uncheckedCast -import net.corda.core.node.ServicesForResolution import net.corda.core.node.services.KeyManagementService import net.corda.core.node.services.queryBy import net.corda.core.node.services.vault.QueryCriteria.SoftLockingCondition @@ -29,6 +28,7 @@ import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.testing.core.singleIdentity import net.corda.testing.flows.registerCoreFlowFactory import net.corda.coretesting.internal.rigorousMock +import net.corda.node.internal.NodeServicesForResolution import net.corda.testing.node.internal.InternalMockNetwork import net.corda.testing.node.internal.enclosedCordapp import net.corda.testing.node.internal.startFlow @@ -86,7 +86,7 @@ class VaultSoftLockManagerTest { private val mockNet = InternalMockNetwork(cordappsForAllNodes = listOf(enclosedCordapp()), defaultFactory = { args -> object : InternalMockNetwork.MockNode(args) { override fun makeVaultService(keyManagementService: KeyManagementService, - services: ServicesForResolution, + services: NodeServicesForResolution, database: CordaPersistence, cordappLoader: CordappLoader): VaultServiceInternal { val node = this diff --git a/samples/trader-demo/build.gradle b/samples/trader-demo/build.gradle index dba40ec2c6..68cffdff69 100644 --- a/samples/trader-demo/build.gradle +++ b/samples/trader-demo/build.gradle @@ -20,6 +20,9 @@ sourceSets { runtimeClasspath += main.output + test.output srcDir file('src/integration-test/kotlin') } + resources { + srcDir file('src/integration-test/resources') + } } } @@ -50,6 +53,7 @@ dependencies { // Corda integration dependencies cordaRuntime project(path: ":node:capsule", configuration: 'runtimeArtifacts') + testCompile "org.slf4j:slf4j-simple:$slf4j_version" testCompile(project(':node-driver')) { // We already have a SLF4J implementation on our runtime classpath, // and we don't need another one. diff --git a/samples/trader-demo/src/integration-test/resources/simplelogger.properties b/samples/trader-demo/src/integration-test/resources/simplelogger.properties new file mode 100644 index 0000000000..49f32f979f --- /dev/null +++ b/samples/trader-demo/src/integration-test/resources/simplelogger.properties @@ -0,0 +1,4 @@ +org.slf4j.simpleLogger.defaultLogLevel=info +org.slf4j.simpleLogger.showDateTime=true +org.slf4j.simpleLogger.dateTimeFormat=yyyy-MM-dd HH:mm:ss:SSS Z +org.slf4j.simpleLogger.logFile=System.out \ No newline at end of file diff --git a/serialization-deterministic/build.gradle b/serialization-deterministic/build.gradle index 6ad42b0208..7822eb3b23 100644 --- a/serialization-deterministic/build.gradle +++ b/serialization-deterministic/build.gradle @@ -23,7 +23,10 @@ def javaHome = System.getProperty('java.home') def jarBaseName = "corda-${project.name}".toString() configurations { - deterministicLibraries.extendsFrom implementation + deterministicLibraries { + canBeConsumed = false + extendsFrom implementation + } deterministicArtifacts.extendsFrom deterministicLibraries } @@ -55,7 +58,7 @@ def originalJar = serializationJarTask.map { it.outputs.files.singleFile } def patchSerialization = tasks.register('patchSerialization', Zip) { dependsOn serializationJarTask - destinationDirectory = file("$buildDir/source-libs") + destinationDirectory = layout.buildDirectory.dir('source-libs') metadataCharset 'UTF-8' archiveClassifier = 'transient' archiveExtension = 'jar' @@ -157,7 +160,7 @@ def determinise = tasks.register('determinise', ProGuardTask) { def checkDeterminism = tasks.register('checkDeterminism', ProGuardTask) def metafix = tasks.register('metafix', MetaFixerTask) { - outputDir file("$buildDir/libs") + outputDir = layout.buildDirectory.dir('libs') jars determinise suffix "" diff --git a/serialization-djvm/build.gradle b/serialization-djvm/build.gradle index f51557e2a3..8e8870398e 100644 --- a/serialization-djvm/build.gradle +++ b/serialization-djvm/build.gradle @@ -1,6 +1,3 @@ -import org.jetbrains.kotlin.gradle.tasks.KotlinCompile -import static org.gradle.api.JavaVersion.VERSION_1_8 - plugins { id 'org.jetbrains.kotlin.jvm' id 'net.corda.plugins.publish-utils' @@ -17,8 +14,12 @@ apply from: "${rootProject.projectDir}/java8.gradle" description 'Serialization support for the DJVM' configurations { - sandboxTesting - jdkRt + sandboxTesting { + canBeConsumed = false + } + jdkRt { + canBeConsumed = false + } } dependencies { @@ -56,6 +57,11 @@ jar { } } +tasks.withType(Javadoc).configureEach { + // We have no public or protected Java classes to document. + enabled = false +} + tasks.withType(Test).configureEach { useJUnitPlatform() systemProperty 'deterministic-rt.path', configurations.jdkRt.asPath @@ -66,7 +72,7 @@ tasks.withType(Test).configureEach { } publish { - name jar.archiveBaseName.get() + name jar.archiveBaseName } idea { diff --git a/serialization-djvm/src/main/java/net/corda/serialization/djvm/serializers/CacheKey.java b/serialization-djvm/src/main/java/net/corda/serialization/djvm/serializers/CacheKey.java new file mode 100644 index 0000000000..5ef3728e91 --- /dev/null +++ b/serialization-djvm/src/main/java/net/corda/serialization/djvm/serializers/CacheKey.java @@ -0,0 +1,35 @@ +package net.corda.serialization.djvm.serializers; + +import org.jetbrains.annotations.NotNull; + +import java.util.Arrays; + +/** + * This class is deliberately written in Java so + * that it can be package private. + */ +final class CacheKey { + private final byte[] bytes; + private final int hashValue; + + CacheKey(@NotNull byte[] bytes) { + this.bytes = bytes; + this.hashValue = Arrays.hashCode(bytes); + } + + @NotNull + byte[] getBytes() { + return bytes; + } + + @Override + public boolean equals(Object other) { + return (this == other) + || (other instanceof CacheKey && Arrays.equals(bytes, ((CacheKey) other).bytes)); + } + + @Override + public int hashCode() { + return hashValue; + } +} diff --git a/serialization-djvm/src/main/kotlin/net/corda/serialization/djvm/serializers/SandboxCertPathSerializer.kt b/serialization-djvm/src/main/kotlin/net/corda/serialization/djvm/serializers/SandboxCertPathSerializer.kt index 0d6cd7aff5..25710d654e 100644 --- a/serialization-djvm/src/main/kotlin/net/corda/serialization/djvm/serializers/SandboxCertPathSerializer.kt +++ b/serialization-djvm/src/main/kotlin/net/corda/serialization/djvm/serializers/SandboxCertPathSerializer.kt @@ -1,5 +1,7 @@ package net.corda.serialization.djvm.serializers +import net.corda.core.serialization.DESERIALIZATION_CACHE_PROPERTY +import net.corda.core.serialization.SerializationContext import net.corda.djvm.rewiring.SandboxClassLoader import net.corda.serialization.djvm.deserializers.CertPathDeserializer import net.corda.serialization.djvm.toSandboxAnyClass @@ -27,4 +29,13 @@ class SandboxCertPathSerializer( override fun fromProxy(proxy: Any): Any { return task.apply(proxy)!! } + + override fun fromProxy(proxy: Any, context: SerializationContext): Any { + // This requires [CertPathProxy] to have correct + // implementations for [equals] and [hashCode]. + @Suppress("unchecked_cast") + return (context.properties[DESERIALIZATION_CACHE_PROPERTY] as? MutableMap) + ?.computeIfAbsent(proxy, ::fromProxy) + ?: fromProxy(proxy) + } } diff --git a/serialization-djvm/src/main/kotlin/net/corda/serialization/djvm/serializers/SandboxPublicKeySerializer.kt b/serialization-djvm/src/main/kotlin/net/corda/serialization/djvm/serializers/SandboxPublicKeySerializer.kt index 6a22e05da6..f826672647 100644 --- a/serialization-djvm/src/main/kotlin/net/corda/serialization/djvm/serializers/SandboxPublicKeySerializer.kt +++ b/serialization-djvm/src/main/kotlin/net/corda/serialization/djvm/serializers/SandboxPublicKeySerializer.kt @@ -1,5 +1,6 @@ package net.corda.serialization.djvm.serializers +import net.corda.core.serialization.DESERIALIZATION_CACHE_PROPERTY import net.corda.core.serialization.SerializationContext import net.corda.djvm.rewiring.SandboxClassLoader import net.corda.serialization.djvm.deserializers.PublicKeyDecoder @@ -27,7 +28,11 @@ class SandboxPublicKeySerializer( override fun readObject(obj: Any, schemas: SerializationSchemas, input: DeserializationInput, context: SerializationContext): Any { val bits = input.readObject(obj, schemas, ByteArray::class.java, context) as ByteArray - return decoder.apply(bits)!! + @Suppress("unchecked_cast") + return (context.properties[DESERIALIZATION_CACHE_PROPERTY] as? MutableMap) + ?.computeIfAbsent(CacheKey(bits)) { key -> + decoder.apply(key.bytes) + } ?: decoder.apply(bits)!! } override fun writeDescribedObject(obj: Any, data: Data, type: Type, output: SerializationOutput, context: SerializationContext) { diff --git a/serialization-djvm/src/main/kotlin/net/corda/serialization/djvm/serializers/SandboxX509CRLSerializer.kt b/serialization-djvm/src/main/kotlin/net/corda/serialization/djvm/serializers/SandboxX509CRLSerializer.kt index aa52234a97..0c19470e25 100644 --- a/serialization-djvm/src/main/kotlin/net/corda/serialization/djvm/serializers/SandboxX509CRLSerializer.kt +++ b/serialization-djvm/src/main/kotlin/net/corda/serialization/djvm/serializers/SandboxX509CRLSerializer.kt @@ -1,5 +1,6 @@ package net.corda.serialization.djvm.serializers +import net.corda.core.serialization.DESERIALIZATION_CACHE_PROPERTY import net.corda.core.serialization.SerializationContext import net.corda.djvm.rewiring.SandboxClassLoader import net.corda.serialization.djvm.deserializers.X509CRLDeserializer @@ -28,7 +29,11 @@ class SandboxX509CRLSerializer( override fun readObject(obj: Any, schemas: SerializationSchemas, input: DeserializationInput, context: SerializationContext): Any { val bits = input.readObject(obj, schemas, ByteArray::class.java, context) as ByteArray - return generator.apply(bits)!! + @Suppress("unchecked_cast") + return (context.properties[DESERIALIZATION_CACHE_PROPERTY] as? MutableMap) + ?.computeIfAbsent(CacheKey(bits)) { key -> + generator.apply(key.bytes) + } ?: generator.apply(bits)!! } override fun writeDescribedObject(obj: Any, data: Data, type: Type, output: SerializationOutput, context: SerializationContext) { diff --git a/serialization-djvm/src/main/kotlin/net/corda/serialization/djvm/serializers/SandboxX509CertificateSerializer.kt b/serialization-djvm/src/main/kotlin/net/corda/serialization/djvm/serializers/SandboxX509CertificateSerializer.kt index cab56d34c6..cf6a78da7e 100644 --- a/serialization-djvm/src/main/kotlin/net/corda/serialization/djvm/serializers/SandboxX509CertificateSerializer.kt +++ b/serialization-djvm/src/main/kotlin/net/corda/serialization/djvm/serializers/SandboxX509CertificateSerializer.kt @@ -1,5 +1,6 @@ package net.corda.serialization.djvm.serializers +import net.corda.core.serialization.DESERIALIZATION_CACHE_PROPERTY import net.corda.core.serialization.SerializationContext import net.corda.djvm.rewiring.SandboxClassLoader import net.corda.serialization.djvm.deserializers.X509CertificateDeserializer @@ -28,7 +29,11 @@ class SandboxX509CertificateSerializer( override fun readObject(obj: Any, schemas: SerializationSchemas, input: DeserializationInput, context: SerializationContext): Any { val bits = input.readObject(obj, schemas, ByteArray::class.java, context) as ByteArray - return generator.apply(bits)!! + @Suppress("unchecked_cast") + return (context.properties[DESERIALIZATION_CACHE_PROPERTY] as? MutableMap) + ?.computeIfAbsent(CacheKey(bits)) { key -> + generator.apply(key.bytes) + } ?: generator.apply(bits)!! } override fun writeDescribedObject(obj: Any, data: Data, type: Type, output: SerializationOutput, context: SerializationContext) { diff --git a/serialization/build.gradle b/serialization/build.gradle index 0ba49a0804..224bd642b4 100644 --- a/serialization/build.gradle +++ b/serialization/build.gradle @@ -52,8 +52,13 @@ configurations { testArtifacts.extendsFrom testRuntimeClasspath } +tasks.withType(Javadoc).configureEach { + // We have no public or protected Java classes to document. + enabled = false +} + task testJar(type: Jar) { - classifier "tests" + archiveClassifier = 'tests' from sourceSets.test.output } @@ -68,5 +73,5 @@ jar { } publish { - name jar.baseName + name jar.archiveBaseName } diff --git a/serialization/src/main/java/net/corda/serialization/internal/amqp/custom/CacheKey.java b/serialization/src/main/java/net/corda/serialization/internal/amqp/custom/CacheKey.java new file mode 100644 index 0000000000..2a341d5130 --- /dev/null +++ b/serialization/src/main/java/net/corda/serialization/internal/amqp/custom/CacheKey.java @@ -0,0 +1,37 @@ +package net.corda.serialization.internal.amqp.custom; + +import net.corda.core.KeepForDJVM; +import org.jetbrains.annotations.NotNull; + +import java.util.Arrays; + +/** + * This class is deliberately written in Java so + * that it can be package private. + */ +@KeepForDJVM +final class CacheKey { + private final byte[] bytes; + private final int hashValue; + + CacheKey(@NotNull byte[] bytes) { + this.bytes = bytes; + this.hashValue = Arrays.hashCode(bytes); + } + + @NotNull + byte[] getBytes() { + return bytes; + } + + @Override + public boolean equals(Object other) { + return (this == other) + || (other instanceof CacheKey && Arrays.equals(bytes, ((CacheKey) other).bytes)); + } + + @Override + public int hashCode() { + return hashValue; + } +} diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/SerializationScheme.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/SerializationScheme.kt index 2447ed9642..6b63a46655 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/SerializationScheme.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/SerializationScheme.kt @@ -6,6 +6,7 @@ import net.corda.core.crypto.SecureHash import net.corda.core.internal.VisibleForTesting import net.corda.core.internal.copyBytes import net.corda.core.serialization.* +import net.corda.core.serialization.internal.CustomSerializationSchemeUtils.Companion.getSchemeIdIfCustomSerializationMagic import net.corda.core.utilities.ByteSequence import net.corda.serialization.internal.amqp.amqpMagic import org.slf4j.LoggerFactory @@ -39,6 +40,7 @@ data class SerializationContextImpl @JvmOverloads constructor(override val prefe /** * {@inheritDoc} */ + @Suppress("OverridingDeprecatedMember") override fun withAttachmentsClassLoader(attachmentHashes: List): SerializationContext { return this } @@ -47,6 +49,10 @@ data class SerializationContextImpl @JvmOverloads constructor(override val prefe return copy(properties = properties + (property to value)) } + override fun withProperties(extraProperties: Map): SerializationContext { + return copy(properties = properties + extraProperties) + } + override fun withoutReferences(): SerializationContext { return copy(objectReferencesEnabled = false) } @@ -103,10 +109,13 @@ open class SerializationFactoryImpl( val lookupKey = magic to target // ConcurrentHashMap.get() is lock free, but computeIfAbsent is not, even if the key is in the map already. return (schemes[lookupKey] ?: schemes.computeIfAbsent(lookupKey) { - registeredSchemes.filter { it.canDeserializeVersion(magic, target) }.forEach { return@computeIfAbsent it } // XXX: Not single? - logger.warn("Cannot find serialization scheme for: [$lookupKey, " + - "${if (magic == amqpMagic) "AMQP" else "UNKNOWN MAGIC"}] registeredSchemes are: $registeredSchemes") - throw UnsupportedOperationException("Serialization scheme $lookupKey not supported.") + registeredSchemes.firstOrNull { it.canDeserializeVersion(magic, target) } ?: run { + logger.warn("Cannot find serialization scheme for: [$lookupKey, " + + "${if (magic == amqpMagic) "AMQP" else "UNKNOWN MAGIC"}] registeredSchemes are: $registeredSchemes") + val schemeId = getSchemeIdIfCustomSerializationMagic(magic) ?: throw UnsupportedOperationException("Serialization scheme" + + " $lookupKey not supported.") + throw UnsupportedOperationException("Could not find custom serialization scheme with SchemeId = $schemeId.") + } }) to magic } diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/CorDappCustomSerializer.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/CorDappCustomSerializer.kt index e9e5eda38a..a55d334b40 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/CorDappCustomSerializer.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/CorDappCustomSerializer.kt @@ -88,11 +88,11 @@ class CorDappCustomSerializer( override fun readObject(obj: Any, schemas: SerializationSchemas, input: DeserializationInput, context: SerializationContext ) = uncheckedCast, SerializationCustomSerializer>( - serializer).fromProxy(uncheckedCast(proxySerializer.readObject(obj, schemas, input, context)))!! + serializer).fromProxy(proxySerializer.readObject(obj, schemas, input, context))!! /** * For 3rd party plugin serializers we are going to exist on exact type matching. i.e. we will - * not support base class serializers for derivedtypes + * not support base class serializers for derived types */ override fun isSerializerFor(clazz: Class<*>) = TypeToken.of(type.asClass()) == TypeToken.of(clazz) diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/CustomSerializer.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/CustomSerializer.kt index 53d521a80a..ee28ca00de 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/CustomSerializer.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/CustomSerializer.kt @@ -1,7 +1,6 @@ package net.corda.serialization.internal.amqp import net.corda.core.KeepForDJVM -import net.corda.core.internal.uncheckedCast import net.corda.core.serialization.SerializationContext import net.corda.serialization.internal.model.FingerprintWriter import net.corda.serialization.internal.model.TypeIdentifier @@ -52,7 +51,8 @@ abstract class CustomSerializer : AMQPSerializer, SerializerFor { context: SerializationContext, debugIndent: Int ) { data.withDescribed(descriptor) { - writeDescribedObject(uncheckedCast(obj), data, type, output, context) + @Suppress("unchecked_cast") + writeDescribedObject(obj as T, data, type, output, context) } } @@ -178,10 +178,13 @@ abstract class CustomSerializer : AMQPSerializer, SerializerFor { protected abstract fun fromProxy(proxy: P): T + protected open fun toProxy(obj: T, context: SerializationContext): P = toProxy(obj) + protected open fun fromProxy(proxy: P, context: SerializationContext): T = fromProxy(proxy) + override fun writeDescribedObject(obj: T, data: Data, type: Type, output: SerializationOutput, context: SerializationContext ) { - val proxy = toProxy(obj) + val proxy = toProxy(obj, context) data.withList { proxySerializer.propertySerializers.forEach { (_, serializer) -> serializer.writeProperty(proxy, this, output, context, 0) @@ -192,8 +195,9 @@ abstract class CustomSerializer : AMQPSerializer, SerializerFor { override fun readObject(obj: Any, schemas: SerializationSchemas, input: DeserializationInput, context: SerializationContext ): T { - val proxy: P = uncheckedCast(proxySerializer.readObject(obj, schemas, input, context)) - return fromProxy(proxy) + @Suppress("unchecked_cast") + val proxy = proxySerializer.readObject(obj, schemas, input, context) as P + return fromProxy(proxy, context) } } diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/DeserializationInput.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/DeserializationInput.kt index b8f8b55dfd..7be1425d32 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/DeserializationInput.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/DeserializationInput.kt @@ -3,12 +3,17 @@ package net.corda.serialization.internal.amqp import net.corda.core.KeepForDJVM import net.corda.core.internal.VisibleForTesting import net.corda.core.serialization.EncodingWhitelist +import net.corda.core.serialization.AMQP_ENVELOPE_CACHE_PROPERTY import net.corda.core.serialization.SerializationContext import net.corda.core.serialization.SerializedBytes import net.corda.core.utilities.ByteSequence import net.corda.core.utilities.loggerFor import net.corda.core.utilities.trace -import net.corda.serialization.internal.* +import net.corda.serialization.internal.ByteBufferInputStream +import net.corda.serialization.internal.CordaSerializationEncoding +import net.corda.serialization.internal.NullEncodingWhitelist +import net.corda.serialization.internal.SectionId +import net.corda.serialization.internal.encodingNotPermittedFormat import net.corda.serialization.internal.model.TypeIdentifier import org.apache.qpid.proton.amqp.Binary import org.apache.qpid.proton.amqp.DescribedType @@ -118,7 +123,19 @@ class DeserializationInput constructor( @Throws(NotSerializableException::class) fun deserialize(bytes: ByteSequence, clazz: Class, context: SerializationContext): T = des { - val envelope = getEnvelope(bytes, context.encodingWhitelist) + /** + * The cache uses object identity rather than [ByteSequence.equals] and + * [ByteSequence.hashCode]. This is for speed: each [ByteSequence] object + * can potentially be large, and we are optimizing for the case when we + * know we will be deserializing the exact same objects multiple times. + * This also means that the cache MUST be short-lived, as otherwise it + * becomes a memory leak. + */ + @Suppress("unchecked_cast") + val envelope = (context.properties[AMQP_ENVELOPE_CACHE_PROPERTY] as? MutableMap) + ?.computeIfAbsent(IdentityKey(bytes)) { key -> + getEnvelope(key.bytes, context.encodingWhitelist) + } ?: getEnvelope(bytes, context.encodingWhitelist) logger.trace { "deserialize blob scheme=\"${envelope.schema}\"" } @@ -219,3 +236,16 @@ class DeserializationInput constructor( else -> false } } + +/** + * We cannot use [ByteSequence.equals] and [ByteSequence.hashCode] because + * these consider the contents of the underlying [ByteArray] object. We + * only need the [ByteSequence]'s object identity for our use-case. + */ +private class IdentityKey(val bytes: ByteSequence) { + override fun hashCode() = System.identityHashCode(bytes) + + override fun equals(other: Any?): Boolean { + return (this === other) || (other is IdentityKey && bytes === other.bytes) + } +} diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/custom/CertPathSerializer.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/custom/CertPathSerializer.kt index 5921781ae8..6d7fc6d668 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/custom/CertPathSerializer.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/custom/CertPathSerializer.kt @@ -1,6 +1,8 @@ package net.corda.serialization.internal.amqp.custom import net.corda.core.KeepForDJVM +import net.corda.core.serialization.DESERIALIZATION_CACHE_PROPERTY +import net.corda.core.serialization.SerializationContext import net.corda.serialization.internal.amqp.CustomSerializer import net.corda.serialization.internal.amqp.SerializerFactory import java.io.NotSerializableException @@ -28,7 +30,21 @@ class CertPathSerializer( } } + override fun fromProxy(proxy: CertPathProxy, context: SerializationContext): CertPath { + // This requires [CertPathProxy] to have correct + // implementations for [equals] and [hashCode]. + @Suppress("unchecked_cast") + return (context.properties[DESERIALIZATION_CACHE_PROPERTY] as? MutableMap) + ?.computeIfAbsent(proxy, ::fromProxy) + ?: fromProxy(proxy) + } + @KeepForDJVM - @Suppress("ArrayInDataClass") - data class CertPathProxy(val type: String, val encoded: ByteArray) -} \ No newline at end of file + data class CertPathProxy(val type: String, val encoded: ByteArray) { + override fun hashCode() = (type.hashCode() * 31) + encoded.contentHashCode() + override fun equals(other: Any?): Boolean { + return (this === other) + || (other is CertPathProxy && (type == other.type && encoded.contentEquals(other.encoded))) + } + } +} diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/custom/PublicKeySerializer.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/custom/PublicKeySerializer.kt index 9663576780..ee4bceb09b 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/custom/PublicKeySerializer.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/custom/PublicKeySerializer.kt @@ -1,6 +1,7 @@ package net.corda.serialization.internal.amqp.custom import net.corda.core.crypto.Crypto +import net.corda.core.serialization.DESERIALIZATION_CACHE_PROPERTY import net.corda.core.serialization.SerializationContext import net.corda.serialization.internal.amqp.* import org.apache.qpid.proton.codec.Data @@ -34,6 +35,10 @@ object PublicKeySerializer context: SerializationContext ): PublicKey { val bits = input.readObject(obj, schemas, ByteArray::class.java, context) as ByteArray - return Crypto.decodePublicKey(bits) + @Suppress("unchecked_cast") + return (context.properties[DESERIALIZATION_CACHE_PROPERTY] as? MutableMap) + ?.computeIfAbsent(CacheKey(bits)) { key -> + Crypto.decodePublicKey(key.bytes) + } ?: Crypto.decodePublicKey(bits) } -} \ No newline at end of file +} diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/custom/X509CRLSerializer.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/custom/X509CRLSerializer.kt index 965b8ed40f..0680031096 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/custom/X509CRLSerializer.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/custom/X509CRLSerializer.kt @@ -1,5 +1,6 @@ package net.corda.serialization.internal.amqp.custom +import net.corda.core.serialization.DESERIALIZATION_CACHE_PROPERTY import net.corda.core.serialization.SerializationContext import net.corda.serialization.internal.amqp.* import org.apache.qpid.proton.codec.Data @@ -28,6 +29,14 @@ object X509CRLSerializer override fun readObject(obj: Any, schemas: SerializationSchemas, input: DeserializationInput, context: SerializationContext): X509CRL { val bytes = input.readObject(obj, schemas, ByteArray::class.java, context) as ByteArray + @Suppress("unchecked_cast") + return (context.properties[DESERIALIZATION_CACHE_PROPERTY] as? MutableMap) + ?.computeIfAbsent(CacheKey(bytes)) { key -> + generateCRL(key.bytes) + } ?: generateCRL(bytes) + } + + private fun generateCRL(bytes: ByteArray): X509CRL { return CertificateFactory.getInstance("X.509").generateCRL(bytes.inputStream()) as X509CRL } } diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/custom/X509CertificateSerializer.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/custom/X509CertificateSerializer.kt index 9e7a2854b4..f3dbd9438d 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/custom/X509CertificateSerializer.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/custom/X509CertificateSerializer.kt @@ -1,5 +1,6 @@ package net.corda.serialization.internal.amqp.custom +import net.corda.core.serialization.DESERIALIZATION_CACHE_PROPERTY import net.corda.core.serialization.SerializationContext import net.corda.serialization.internal.amqp.* import org.apache.qpid.proton.codec.Data @@ -28,6 +29,14 @@ object X509CertificateSerializer override fun readObject(obj: Any, schemas: SerializationSchemas, input: DeserializationInput, context: SerializationContext): X509Certificate { val bits = input.readObject(obj, schemas, ByteArray::class.java, context) as ByteArray + @Suppress("unchecked_cast") + return (context.properties[DESERIALIZATION_CACHE_PROPERTY] as? MutableMap) + ?.computeIfAbsent(CacheKey(bits)) { key -> + generateCertificate(key.bytes) + } ?: generateCertificate(bits) + } + + private fun generateCertificate(bits: ByteArray): X509Certificate { return CertificateFactory.getInstance("X.509").generateCertificate(bits.inputStream()) as X509Certificate } } diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/carpenter/ClassCarpenter.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/carpenter/ClassCarpenter.kt index 49ef897639..940b242064 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/carpenter/ClassCarpenter.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/carpenter/ClassCarpenter.kt @@ -15,7 +15,6 @@ import org.objectweb.asm.Type import java.lang.Character.isJavaIdentifierPart import java.lang.Character.isJavaIdentifierStart import java.lang.reflect.Method -import java.util.* /** * Any object that implements this interface is expected to expose its own fields via the [get] method, exactly @@ -28,8 +27,23 @@ interface SimpleFieldAccess { } @DeleteForDJVM -class CarpenterClassLoader(parentClassLoader: ClassLoader = Thread.currentThread().contextClassLoader) : +class CarpenterClassLoader(private val parentClassLoader: ClassLoader = Thread.currentThread().contextClassLoader) : ClassLoader(parentClassLoader) { + @Throws(ClassNotFoundException::class) + override fun loadClass(name: String?, resolve: Boolean): Class<*>? { + return synchronized(getClassLoadingLock(name)) { + /** + * Search parent classloaders using lock-less [Class.forName], + * bypassing [parent] to avoid its [SecurityManager] overhead. + */ + (findLoadedClass(name) ?: Class.forName(name, false, parentClassLoader)).also { clazz -> + if (resolve) { + resolveClass(clazz) + } + } + } + } + fun load(name: String, bytes: ByteArray): Class<*> { return defineClass(name, bytes, 0, bytes.size) } diff --git a/testing/core-test-utils/src/main/kotlin/net/corda/coretesting/internal/InternalSerializationTestHelpers.kt b/testing/core-test-utils/src/main/kotlin/net/corda/coretesting/internal/InternalSerializationTestHelpers.kt index 116016b991..6345dd7549 100644 --- a/testing/core-test-utils/src/main/kotlin/net/corda/coretesting/internal/InternalSerializationTestHelpers.kt +++ b/testing/core-test-utils/src/main/kotlin/net/corda/coretesting/internal/InternalSerializationTestHelpers.kt @@ -3,9 +3,11 @@ package net.corda.coretesting.internal import net.corda.nodeapi.internal.rpc.client.AMQPClientSerializationScheme import net.corda.core.internal.createInstancesOfClassesImplementing import net.corda.core.serialization.CheckpointCustomSerializer +import net.corda.core.serialization.CustomSerializationScheme import net.corda.core.serialization.SerializationCustomSerializer import net.corda.core.serialization.SerializationWhitelist import net.corda.core.serialization.internal.SerializationEnvironment +import net.corda.nodeapi.internal.serialization.CustomSerializationSchemeAdapter import net.corda.nodeapi.internal.serialization.amqp.AMQPServerSerializationScheme import net.corda.nodeapi.internal.serialization.kryo.KRYO_CHECKPOINT_CONTEXT import net.corda.nodeapi.internal.serialization.kryo.KryoCheckpointSerializer @@ -14,6 +16,7 @@ import net.corda.serialization.internal.AMQP_RPC_CLIENT_CONTEXT import net.corda.serialization.internal.AMQP_RPC_SERVER_CONTEXT import net.corda.serialization.internal.AMQP_STORAGE_CONTEXT import net.corda.serialization.internal.SerializationFactoryImpl +import net.corda.serialization.internal.SerializationScheme import net.corda.testing.common.internal.asContextEnv import java.util.ServiceLoader import java.util.concurrent.ConcurrentHashMap @@ -27,20 +30,29 @@ fun createTestSerializationEnv(): SerializationEnvironment { fun createTestSerializationEnv(classLoader: ClassLoader?): SerializationEnvironment { var customCheckpointSerializers: Set> = emptySet() - val (clientSerializationScheme, serverSerializationScheme) = if (classLoader != null) { + val serializationSchemes: MutableList = mutableListOf() + if (classLoader != null) { val customSerializers = createInstancesOfClassesImplementing(classLoader, SerializationCustomSerializer::class.java) customCheckpointSerializers = createInstancesOfClassesImplementing(classLoader, CheckpointCustomSerializer::class.java) val serializationWhitelists = ServiceLoader.load(SerializationWhitelist::class.java, classLoader).toSet() - Pair(AMQPClientSerializationScheme(customSerializers, serializationWhitelists), - AMQPServerSerializationScheme(customSerializers, serializationWhitelists)) + serializationSchemes.add(AMQPClientSerializationScheme(customSerializers, serializationWhitelists)) + serializationSchemes.add(AMQPServerSerializationScheme(customSerializers, serializationWhitelists)) + + val customSchemes = createInstancesOfClassesImplementing(classLoader, CustomSerializationScheme::class.java) + for (customScheme in customSchemes) { + serializationSchemes.add(CustomSerializationSchemeAdapter(customScheme)) + } } else { - Pair(AMQPClientSerializationScheme(emptyList()), AMQPServerSerializationScheme(emptyList())) + serializationSchemes.add(AMQPClientSerializationScheme(emptyList())) + serializationSchemes.add(AMQPServerSerializationScheme(emptyList())) } + val factory = SerializationFactoryImpl().apply { - registerScheme(clientSerializationScheme) - registerScheme(serverSerializationScheme) + for (serializationScheme in serializationSchemes) { + registerScheme(serializationScheme) + } } return SerializationEnvironment.with( factory, diff --git a/testing/core-test-utils/src/main/kotlin/net/corda/coretesting/internal/NettyTestClient.kt b/testing/core-test-utils/src/main/kotlin/net/corda/coretesting/internal/NettyTestClient.kt index 185a289472..581172a788 100644 --- a/testing/core-test-utils/src/main/kotlin/net/corda/coretesting/internal/NettyTestClient.kt +++ b/testing/core-test-utils/src/main/kotlin/net/corda/coretesting/internal/NettyTestClient.kt @@ -3,13 +3,14 @@ package net.corda.coretesting.internal import io.netty.bootstrap.Bootstrap import io.netty.channel.ChannelFuture import io.netty.channel.ChannelInboundHandlerAdapter -import io.netty.handler.ssl.SslContext import io.netty.channel.ChannelInitializer import io.netty.channel.ChannelOption import io.netty.channel.nio.NioEventLoopGroup import io.netty.channel.socket.SocketChannel import io.netty.channel.socket.nio.NioSocketChannel +import io.netty.handler.ssl.SslContext import io.netty.handler.ssl.SslHandler +import io.netty.util.concurrent.DefaultThreadFactory import java.io.Closeable import java.util.concurrent.TimeUnit import java.util.concurrent.TimeoutException @@ -17,7 +18,6 @@ import java.util.concurrent.locks.ReentrantLock import javax.net.ssl.SSLEngine import kotlin.concurrent.thread - class NettyTestClient( val sslContext: SslContext?, val targetHost: String, @@ -49,7 +49,7 @@ class NettyTestClient( private fun run() { // Configure the client. - val group = NioEventLoopGroup() + val group = NioEventLoopGroup(DefaultThreadFactory("NettyTestClient")) try { val b = Bootstrap() b.group(group) diff --git a/testing/core-test-utils/src/main/kotlin/net/corda/coretesting/internal/NettyTestServer.kt b/testing/core-test-utils/src/main/kotlin/net/corda/coretesting/internal/NettyTestServer.kt index 8fa9d23057..1abc3f5c7b 100644 --- a/testing/core-test-utils/src/main/kotlin/net/corda/coretesting/internal/NettyTestServer.kt +++ b/testing/core-test-utils/src/main/kotlin/net/corda/coretesting/internal/NettyTestServer.kt @@ -11,6 +11,7 @@ import io.netty.channel.socket.nio.NioServerSocketChannel import io.netty.handler.logging.LogLevel import io.netty.handler.logging.LoggingHandler import io.netty.handler.ssl.SslContext +import io.netty.util.concurrent.DefaultThreadFactory import java.io.Closeable import java.util.concurrent.TimeUnit import java.util.concurrent.TimeoutException @@ -45,8 +46,8 @@ class NettyTestServer( fun run() { // Configure the server. - val bossGroup = NioEventLoopGroup(1) - val workerGroup = NioEventLoopGroup() + val bossGroup = NioEventLoopGroup(1, DefaultThreadFactory("NettyTestServer-boss")) + val workerGroup = NioEventLoopGroup(DefaultThreadFactory("NettyTestServer-worker")) try { val b = ServerBootstrap() b.group(bossGroup, workerGroup) diff --git a/testing/core-test-utils/src/main/kotlin/net/corda/coretesting/internal/stubs/CertificateStoreStubs.kt b/testing/core-test-utils/src/main/kotlin/net/corda/coretesting/internal/stubs/CertificateStoreStubs.kt index c23d458a80..2f90555c44 100644 --- a/testing/core-test-utils/src/main/kotlin/net/corda/coretesting/internal/stubs/CertificateStoreStubs.kt +++ b/testing/core-test-utils/src/main/kotlin/net/corda/coretesting/internal/stubs/CertificateStoreStubs.kt @@ -8,6 +8,7 @@ import net.corda.nodeapi.internal.config.FileBasedCertificateStoreSupplier import net.corda.nodeapi.internal.config.MutualSslConfiguration import net.corda.nodeapi.internal.config.SslConfiguration import java.nio.file.Path +import java.time.Duration class CertificateStoreStubs { @@ -49,11 +50,11 @@ class CertificateStoreStubs { keyStorePassword: String = KeyStore.DEFAULT_STORE_PASSWORD, keyPassword: String = keyStorePassword, trustStoreFileName: String = TrustStore.DEFAULT_STORE_FILE_NAME, trustStorePassword: String = TrustStore.DEFAULT_STORE_PASSWORD, - trustStoreKeyPassword: String = TrustStore.DEFAULT_KEY_PASSWORD): MutualSslConfiguration { - + trustStoreKeyPassword: String = TrustStore.DEFAULT_KEY_PASSWORD, + sslHandshakeTimeout: Duration? = null): MutualSslConfiguration { val keyStore = FileBasedCertificateStoreSupplier(certificatesDirectory / keyStoreFileName, keyStorePassword, keyPassword) val trustStore = FileBasedCertificateStoreSupplier(certificatesDirectory / trustStoreFileName, trustStorePassword, trustStoreKeyPassword) - return SslConfiguration.mutual(keyStore, trustStore) + return SslConfiguration.mutual(keyStore, trustStore, sslHandshakeTimeout) } @JvmStatic diff --git a/testing/core-test-utils/src/main/kotlin/net/corda/testing/core/TestUtils.kt b/testing/core-test-utils/src/main/kotlin/net/corda/testing/core/TestUtils.kt index 1bcf1ac389..175f52f5ae 100644 --- a/testing/core-test-utils/src/main/kotlin/net/corda/testing/core/TestUtils.kt +++ b/testing/core-test-utils/src/main/kotlin/net/corda/testing/core/TestUtils.kt @@ -1,30 +1,50 @@ -@file:Suppress("UNUSED_PARAMETER") @file:JvmName("TestUtils") +@file:Suppress("TooGenericExceptionCaught", "MagicNumber", "ComplexMethod", "LongParameterList") package net.corda.testing.core import net.corda.core.contracts.PartyAndReference import net.corda.core.contracts.StateRef -import net.corda.core.crypto.* +import net.corda.core.crypto.Crypto +import net.corda.core.crypto.SecureHash +import net.corda.core.crypto.SignatureScheme +import net.corda.core.crypto.toStringShort import net.corda.core.identity.CordaX500Name import net.corda.core.identity.Party import net.corda.core.identity.PartyAndCertificate +import net.corda.core.internal.toX500Name import net.corda.core.internal.unspecifiedCountry import net.corda.core.node.NodeInfo import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.millis +import net.corda.core.utilities.minutes +import net.corda.coretesting.internal.DEV_INTERMEDIATE_CA +import net.corda.coretesting.internal.DEV_ROOT_CA import net.corda.nodeapi.internal.createDevNodeCa import net.corda.nodeapi.internal.crypto.CertificateAndKeyPair import net.corda.nodeapi.internal.crypto.CertificateType import net.corda.nodeapi.internal.crypto.X509Utilities -import net.corda.coretesting.internal.DEV_INTERMEDIATE_CA -import net.corda.coretesting.internal.DEV_ROOT_CA +import net.corda.nodeapi.internal.crypto.X509Utilities.toGeneralNames +import org.bouncycastle.asn1.x509.CRLReason +import org.bouncycastle.asn1.x509.DistributionPointName +import org.bouncycastle.asn1.x509.Extension +import org.bouncycastle.asn1.x509.ExtensionsGenerator +import org.bouncycastle.asn1.x509.GeneralName +import org.bouncycastle.asn1.x509.GeneralNames +import org.bouncycastle.asn1.x509.IssuingDistributionPoint +import org.bouncycastle.cert.jcajce.JcaX509CRLConverter +import org.bouncycastle.cert.jcajce.JcaX509ExtensionUtils +import org.bouncycastle.cert.jcajce.JcaX509v2CRLBuilder +import org.bouncycastle.operator.jcajce.JcaContentSignerBuilder import java.math.BigInteger +import java.net.URI import java.security.KeyPair import java.security.PublicKey +import java.security.cert.X509CRL import java.security.cert.X509Certificate import java.time.Duration import java.time.Instant +import java.util.* import java.util.concurrent.atomic.AtomicInteger import kotlin.test.fail @@ -109,6 +129,44 @@ fun getTestPartyAndCertificate(name: CordaX500Name, publicKey: PublicKey): Party return getTestPartyAndCertificate(Party(name, publicKey)) } +fun createCRL(issuer: CertificateAndKeyPair, + revokedCerts: List, + issuingDistPoint: URI? = null, + thisUpdate: Instant = Instant.now(), + nextUpdate: Instant = thisUpdate + 5.minutes, + indirect: Boolean = false, + revocationDate: Instant = thisUpdate, + crlReason: Int = CRLReason.keyCompromise, + signatureAlgorithm: String = "SHA256withECDSA"): X509CRL { + val builder = JcaX509v2CRLBuilder(issuer.certificate.subjectX500Principal, Date.from(thisUpdate)) + val extensionUtils = JcaX509ExtensionUtils() + builder.addExtension(Extension.authorityKeyIdentifier, false, extensionUtils.createAuthorityKeyIdentifier(issuer.certificate)) + // This is required and needs to match the certificate settings with respect to being indirect + builder.addExtension( + Extension.issuingDistributionPoint, + true, + IssuingDistributionPoint( + issuingDistPoint?.let { DistributionPointName(toGeneralNames(it.toString(), GeneralName.uniformResourceIdentifier)) }, + indirect, + false + ) + ) + builder.setNextUpdate(Date.from(nextUpdate)) + for (revokedCert in revokedCerts) { + val extensionsGenerator = ExtensionsGenerator() + extensionsGenerator.addExtension(Extension.reasonCode, false, CRLReason.lookup(crlReason)) + // Certificate issuer is required for indirect CRL + extensionsGenerator.addExtension( + Extension.certificateIssuer, + true, + GeneralNames(GeneralName(revokedCert.issuerX500Principal.toX500Name())) + ) + builder.addCRLEntry(revokedCert.serialNumber, Date.from(revocationDate), extensionsGenerator.generate()) + } + val bcProvider = Crypto.findProvider("BC") + val signer = JcaContentSignerBuilder(signatureAlgorithm).setProvider(bcProvider).build(issuer.keyPair.private) + return JcaX509CRLConverter().setProvider(bcProvider).getCRL(builder.build(signer)) +} private val count = AtomicInteger(0) /** @@ -188,7 +246,6 @@ fun NodeInfo.singleIdentity(): Party = singleIdentityAndCert().party * The above will test our expectation that the getWaitingFlows action was executed successfully considering * that it may take a few hundreds of milliseconds for the flow state machine states to settle. */ -@Suppress("TooGenericExceptionCaught", "MagicNumber", "ComplexMethod") fun executeTest( timeout: Duration, cleanup: (() -> Unit)? = null, diff --git a/testing/node-driver/src/main/kotlin/net/corda/testing/node/MockServices.kt b/testing/node-driver/src/main/kotlin/net/corda/testing/node/MockServices.kt index 899d36cdc3..6dcb0db299 100644 --- a/testing/node-driver/src/main/kotlin/net/corda/testing/node/MockServices.kt +++ b/testing/node-driver/src/main/kotlin/net/corda/testing/node/MockServices.kt @@ -26,6 +26,7 @@ import net.corda.core.transactions.SignedTransaction import net.corda.core.utilities.NetworkHostAndPort import net.corda.node.VersionInfo import net.corda.node.internal.ServicesForResolutionImpl +import net.corda.node.internal.NodeServicesForResolution import net.corda.node.internal.cordapp.JarScanningCordappLoader import net.corda.node.services.api.* import net.corda.node.services.diagnostics.NodeDiagnosticsService @@ -251,11 +252,15 @@ open class MockServices private constructor( override fun jdbcSession(): Connection = persistence.createSession() override fun withEntityManager(block: EntityManager.() -> T): T { - return block(contextTransaction.entityManager) + return contextTransaction.entityManager.run { + block(this).also { flush () } + } } override fun withEntityManager(block: Consumer) { - return block.accept(contextTransaction.entityManager) + return contextTransaction.entityManager.run { + block.accept(this).also { flush () } + } } } } @@ -456,7 +461,14 @@ open class MockServices private constructor( get() = ServicesForResolutionImpl(identityService, attachments, cordappProvider, networkParametersService, validatedTransactions) internal fun makeVaultService(schemaService: SchemaService, database: CordaPersistence, cordappLoader: CordappLoader): VaultServiceInternal { - return NodeVaultService(clock, keyManagementService, servicesForResolution, database, schemaService, cordappLoader.appClassLoader).apply { start() } + return NodeVaultService( + clock, + keyManagementService, + servicesForResolution as NodeServicesForResolution, + database, + schemaService, + cordappLoader.appClassLoader + ).apply { start() } } // This needs to be internal as MutableClassToInstanceMap is a guava type and shouldn't be part of our public API diff --git a/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/CustomCordapp.kt b/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/CustomCordapp.kt index d461becbfc..d3cdc7043e 100644 --- a/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/CustomCordapp.kt +++ b/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/CustomCordapp.kt @@ -59,6 +59,12 @@ data class CustomCordapp( } classGraph.enableClassInfo().pooledScan().use { scanResult -> + if (scanResult.allResources.isEmpty()) { + throw ClassNotFoundException( + "Could not create jar file as the given classes(${classes.joinToString()}) / packages(${packages.joinToString()}) were not found on the classpath" + ) + } + val whitelistService = SerializationWhitelist::class.java.name val whitelists = scanResult.getClassesImplementing(whitelistService) @@ -73,13 +79,9 @@ data class CustomCordapp( } } - if (scanResult.allResources.isEmpty()){ - throw ClassNotFoundException("Could not create jar file as the given package is not found on the classpath: ${packages.toList()[0]}") - } - // The same resource may be found in different locations (this will happen when running from gradle) so just // pick the first one found. - scanResult.allResources.asMap().forEach { path, resourceList -> + scanResult.allResourcesAsMap.forEach { (path, resourceList) -> jos.addEntry(testEntry(path), resourceList[0].open()) } } diff --git a/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/MockNetworkParametersService.kt b/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/MockNetworkParametersService.kt index 6938c9d256..390b0b6369 100644 --- a/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/MockNetworkParametersService.kt +++ b/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/MockNetworkParametersService.kt @@ -20,17 +20,28 @@ import java.time.Instant class MockNetworkParametersStorage(private var currentParameters: NetworkParameters = testNetworkParameters(modifiedTime = Instant.MIN)) : NetworkParametersStorage { private val hashToParametersMap: HashMap = HashMap() private val hashToSignedParametersMap: HashMap = HashMap() + override var currentHash = currentParameters.computeHash() + override val defaultHash: SecureHash get() = currentHash init { storeCurrentParameters() } + private fun NetworkParameters.computeHash(): SecureHash { + return withTestSerializationEnvIfNotSet { + this.serialize().hash + } + } + fun setCurrentParametersUnverified(networkParameters: NetworkParameters) { currentParameters = networkParameters + currentHash = currentParameters.computeHash() storeCurrentParameters() } override fun setCurrentParameters(currentSignedParameters: SignedDataWithCert, trustRoots: Set) { - setCurrentParametersUnverified(currentSignedParameters.verifiedNetworkParametersCert(trustRoots)) + currentParameters = currentSignedParameters.verifiedNetworkParametersCert(trustRoots) + currentHash = currentSignedParameters.raw.hash + storeCurrentParameters() } override fun lookupSigned(hash: SecureHash): SignedDataWithCert? { @@ -39,13 +50,6 @@ class MockNetworkParametersStorage(private var currentParameters: NetworkParamet override fun hasParameters(hash: SecureHash): Boolean = hash in hashToParametersMap - override val currentHash: SecureHash - get() { - return withTestSerializationEnvIfNotSet { - currentParameters.serialize().hash - } - } - override val defaultHash: SecureHash get() = currentHash override fun lookup(hash: SecureHash): NetworkParameters? = hashToParametersMap[hash] override fun getEpochFromHash(hash: SecureHash): Int? = lookup(hash)?.epoch override fun saveParameters(signedNetworkParameters: SignedDataWithCert) { diff --git a/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/network/CrlServer.kt b/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/network/CrlServer.kt new file mode 100644 index 0000000000..b6dee805fa --- /dev/null +++ b/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/network/CrlServer.kt @@ -0,0 +1,194 @@ +@file:Suppress("MagicNumber") + +package net.corda.testing.node.internal.network + +import net.corda.core.crypto.Crypto +import net.corda.core.internal.CertRole +import net.corda.core.internal.toX500Name +import net.corda.core.utilities.NetworkHostAndPort +import net.corda.core.utilities.contextLogger +import net.corda.core.utilities.days +import net.corda.core.utilities.minutes +import net.corda.coretesting.internal.DEV_INTERMEDIATE_CA +import net.corda.coretesting.internal.DEV_ROOT_CA +import net.corda.nodeapi.internal.crypto.CertificateAndKeyPair +import net.corda.nodeapi.internal.crypto.ContentSignerBuilder +import net.corda.nodeapi.internal.crypto.X509Utilities +import net.corda.nodeapi.internal.crypto.X509Utilities.toGeneralNames +import net.corda.nodeapi.internal.crypto.certificateType +import net.corda.nodeapi.internal.crypto.toJca +import net.corda.testing.core.createCRL +import org.bouncycastle.asn1.x509.CRLDistPoint +import org.bouncycastle.asn1.x509.DistributionPoint +import org.bouncycastle.asn1.x509.DistributionPointName +import org.bouncycastle.asn1.x509.Extension +import org.bouncycastle.asn1.x509.GeneralName +import org.bouncycastle.asn1.x509.GeneralNames +import org.eclipse.jetty.server.Server +import org.eclipse.jetty.server.ServerConnector +import org.eclipse.jetty.server.handler.HandlerCollection +import org.eclipse.jetty.servlet.ServletContextHandler +import org.eclipse.jetty.servlet.ServletHolder +import org.glassfish.jersey.server.ResourceConfig +import org.glassfish.jersey.servlet.ServletContainer +import java.io.Closeable +import java.net.InetSocketAddress +import java.net.URI +import java.security.KeyPair +import java.security.cert.X509CRL +import java.security.cert.X509Certificate +import java.time.Duration +import java.util.* +import javax.security.auth.x500.X500Principal +import javax.ws.rs.GET +import javax.ws.rs.Path +import javax.ws.rs.Produces +import javax.ws.rs.core.Response +import kotlin.collections.ArrayList + +class CrlServer(hostAndPort: NetworkHostAndPort) : Closeable { + companion object { + private val logger = contextLogger() + + const val NODE_CRL = "node.crl" + const val FORBIDDEN_CRL = "forbidden.crl" + const val INTERMEDIATE_CRL = "intermediate.crl" + const val EMPTY_CRL = "empty.crl" + + fun X509Certificate.withCrlDistPoint(issuerKeyPair: KeyPair, crlDistPoint: String?, crlIssuer: X500Principal? = null): X509Certificate { + val signatureScheme = Crypto.findSignatureScheme(issuerKeyPair.private) + val provider = Crypto.findProvider(signatureScheme.providerName) + val issuerSigner = ContentSignerBuilder.build(signatureScheme, issuerKeyPair.private, provider) + val builder = X509Utilities.createPartialCertificate( + CertRole.extract(this)!!.certificateType, + issuerX500Principal, + issuerKeyPair.public, + subjectX500Principal, + publicKey, + Pair(Date(System.currentTimeMillis() - 5.minutes.toMillis()), Date(System.currentTimeMillis() + 10.days.toMillis())), + null + ) + if (crlDistPoint != null) { + val distPointName = DistributionPointName(toGeneralNames(crlDistPoint, GeneralName.uniformResourceIdentifier)) + val crlIssuerGeneralNames = crlIssuer?.let { GeneralNames(GeneralName(it.toX500Name())) } + val distPoint = DistributionPoint(distPointName, null, crlIssuerGeneralNames) + builder.addExtension(Extension.cRLDistributionPoints, false, CRLDistPoint(arrayOf(distPoint))) + } + return builder.build(issuerSigner).toJca() + } + } + + private val server: Server = Server(InetSocketAddress(hostAndPort.host, hostAndPort.port)).apply { + handler = HandlerCollection().apply { + addHandler(buildServletContextHandler()) + } + } + + val revokedNodeCerts: MutableList = ArrayList() + val revokedIntermediateCerts: MutableList = ArrayList() + + val rootCa: CertificateAndKeyPair = DEV_ROOT_CA + + private lateinit var _intermediateCa: CertificateAndKeyPair + val intermediateCa: CertificateAndKeyPair get() = _intermediateCa + + @Volatile + var delay: Duration? = null + + val hostAndPort: NetworkHostAndPort + get() = server.connectors.mapNotNull { it as? ServerConnector } + .map { NetworkHostAndPort(it.host, it.localPort) } + .first() + + fun start() { + server.start() + _intermediateCa = CertificateAndKeyPair( + DEV_INTERMEDIATE_CA.certificate.withCrlDistPoint(rootCa.keyPair, "http://$hostAndPort/crl/$INTERMEDIATE_CRL"), + DEV_INTERMEDIATE_CA.keyPair + ) + logger.info("Network management web services started on $hostAndPort") + } + + fun replaceNodeCertDistPoint(nodeCaCert: X509Certificate, + nodeCaCrlDistPoint: String? = "http://$hostAndPort/crl/$NODE_CRL", + crlIssuer: X500Principal? = null): X509Certificate { + return nodeCaCert.withCrlDistPoint(intermediateCa.keyPair, nodeCaCrlDistPoint, crlIssuer) + } + + private fun createServerCRL(issuer: CertificateAndKeyPair, + endpoint: String, + indirect: Boolean, + revokedCerts: List): X509CRL { + logger.info("Generating CRL for /$endpoint: ${revokedCerts.map { it.serialNumber }}") + return createCRL( + issuer, + revokedCerts, + issuingDistPoint = URI("http://$hostAndPort/crl/$endpoint"), + indirect = indirect + ) + } + + override fun close() { + server.stop() + server.join() + } + + private fun buildServletContextHandler(): ServletContextHandler { + return ServletContextHandler().apply { + contextPath = "/" + val resourceConfig = ResourceConfig().apply { + register(CrlServlet(this@CrlServer)) + } + val jerseyServlet = ServletHolder(ServletContainer(resourceConfig)).apply { initOrder = 0 } + addServlet(jerseyServlet, "/*") + } + } + + @Path("crl") + class CrlServlet(private val crlServer: CrlServer) { + @GET + @Path(NODE_CRL) + @Produces("application/pkcs7-crl") + fun getNodeCRL(): Response { + crlServer.delay?.toMillis()?.let(Thread::sleep) + return Response.ok(crlServer.createServerCRL( + crlServer.intermediateCa, + NODE_CRL, + false, + crlServer.revokedNodeCerts + ).encoded).build() + } + + @GET + @Path(FORBIDDEN_CRL) + @Produces("application/pkcs7-crl") + fun getNodeSlowCRL(): Response { + return Response.status(Response.Status.FORBIDDEN).build() + } + + @GET + @Path(INTERMEDIATE_CRL) + @Produces("application/pkcs7-crl") + fun getIntermediateCRL(): Response { + crlServer.delay?.toMillis()?.let(Thread::sleep) + return Response.ok(crlServer.createServerCRL( + crlServer.rootCa, + INTERMEDIATE_CRL, + false, + crlServer.revokedIntermediateCerts + ).encoded).build() + } + + @GET + @Path(EMPTY_CRL) + @Produces("application/pkcs7-crl") + fun getEmptyCRL(): Response { + return Response.ok(crlServer.createServerCRL( + crlServer.rootCa, + EMPTY_CRL, + true, + emptyList() + ).encoded).build() + } + } +} diff --git a/testing/test-utils/src/main/kotlin/net/corda/testing/internal/InternalTestUtils.kt b/testing/test-utils/src/main/kotlin/net/corda/testing/internal/InternalTestUtils.kt index 1dbf7249cb..ab080c3d7e 100644 --- a/testing/test-utils/src/main/kotlin/net/corda/testing/internal/InternalTestUtils.kt +++ b/testing/test-utils/src/main/kotlin/net/corda/testing/internal/InternalTestUtils.kt @@ -42,6 +42,7 @@ import net.corda.nodeapi.internal.crypto.X509Utilities import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.nodeapi.internal.persistence.DatabaseConfig import net.corda.nodeapi.internal.persistence.SchemaMigration +import net.corda.nodeapi.internal.protonwrapper.netty.CrlSource import net.corda.nodeapi.internal.registerDevP2pCertificates import net.corda.serialization.internal.amqp.AMQP_ENABLED import net.corda.testing.core.ALICE_NAME @@ -52,6 +53,8 @@ import java.io.IOException import java.net.ServerSocket import java.nio.file.Path import java.security.KeyPair +import java.security.cert.X509CRL +import java.security.cert.X509Certificate import java.util.* import java.util.jar.JarOutputStream import java.util.jar.Manifest @@ -147,6 +150,12 @@ fun p2pSslOptions(path: Path, name: CordaX500Name = CordaX500Name("MegaCorp", "L return sslConfig } +fun fixedCrlSource(crls: Set): CrlSource { + return object : CrlSource { + override fun fetch(certificate: X509Certificate): Set = crls + } +} + /** This is the same as the deprecated [WireTransaction] c'tor but avoids the deprecation warning. */ @SuppressWarnings("LongParameterList") fun createWireTransaction(inputs: List, diff --git a/testing/test-utils/src/main/kotlin/net/corda/testing/internal/vault/VaultFiller.kt b/testing/test-utils/src/main/kotlin/net/corda/testing/internal/vault/VaultFiller.kt index 467b54ea22..f2775e1878 100644 --- a/testing/test-utils/src/main/kotlin/net/corda/testing/internal/vault/VaultFiller.kt +++ b/testing/test-utils/src/main/kotlin/net/corda/testing/internal/vault/VaultFiller.kt @@ -1,6 +1,20 @@ +@file:Suppress("LongParameterList") + package net.corda.testing.internal.vault -import net.corda.core.contracts.* +import net.corda.core.contracts.Amount +import net.corda.core.contracts.AttachmentConstraint +import net.corda.core.contracts.AutomaticPlaceholderConstraint +import net.corda.core.contracts.BelongsToContract +import net.corda.core.contracts.CommandAndState +import net.corda.core.contracts.ContractState +import net.corda.core.contracts.FungibleAsset +import net.corda.core.contracts.Issued +import net.corda.core.contracts.LinearState +import net.corda.core.contracts.PartyAndReference +import net.corda.core.contracts.StateAndRef +import net.corda.core.contracts.TransactionState +import net.corda.core.contracts.UniqueIdentifier import net.corda.core.crypto.Crypto import net.corda.core.crypto.SignatureMetadata import net.corda.core.identity.AbstractParty @@ -19,9 +33,7 @@ import net.corda.finance.contracts.asset.Cash import net.corda.finance.contracts.asset.Obligation import net.corda.finance.contracts.asset.OnLedgerAsset import net.corda.finance.workflows.asset.CashUtils -import net.corda.testing.contracts.DummyContract import net.corda.testing.contracts.DummyState -import net.corda.testing.core.DummyCommandData import net.corda.testing.core.TestIdentity import net.corda.testing.core.dummyCommand import net.corda.testing.core.singleIdentity @@ -32,6 +44,7 @@ import java.time.Duration import java.time.Instant import java.time.Instant.now import java.util.* +import kotlin.math.floor /** * The service hub should provide at least a key management service and a storage service. @@ -46,7 +59,7 @@ class VaultFiller @JvmOverloads constructor( private val rngFactory: () -> Random = { Random(0L) }) { companion object { fun calculateRandomlySizedAmounts(howMuch: Amount, min: Int, max: Int, rng: Random): LongArray { - val numSlots = min + Math.floor(rng.nextDouble() * (max - min)).toInt() + val numSlots = min + floor(rng.nextDouble() * (max - min)).toInt() val baseSize = howMuch.quantity / numSlots check(baseSize > 0) { baseSize } @@ -79,31 +92,18 @@ class VaultFiller @JvmOverloads constructor( issuerServices: ServiceHub = services, participants: List = emptyList(), includeMe: Boolean = true): Vault { - val myKey: PublicKey = services.myInfo.chooseIdentity().owningKey - val me = AnonymousParty(myKey) - val participantsToUse = if (includeMe) participants.plus(me) else participants - - val transactions: List = dealIds.map { - // Issue a deal state - val dummyIssue = TransactionBuilder(notary = defaultNotary.party).apply { - addOutputState(DummyDealContract.State(ref = it, participants = participantsToUse), DUMMY_DEAL_PROGRAM_ID) - addCommand(dummyCommand()) - } - val stx = issuerServices.signInitialTransaction(dummyIssue) - return@map services.addSignature(stx, defaultNotary.publicKey) + return fillWithTestStates( + txCount = dealIds.size, + participants = participants, + includeMe = includeMe, + services = issuerServices + ) { participantsToUse, txIndex, _ -> + DummyDealContract.State(ref = dealIds[txIndex], participants = participantsToUse) } - val statesToRecord = if (includeMe) StatesToRecord.ONLY_RELEVANT else StatesToRecord.ALL_VISIBLE - services.recordTransactions(statesToRecord, transactions) - // Get all the StateAndRefs of all the generated transactions. - val states = transactions.flatMap { stx -> - stx.tx.outputs.indices.map { i -> stx.tx.outRef(i) } - } - - return Vault(states) } @JvmOverloads - fun fillWithSomeTestLinearStates(numberToCreate: Int, + fun fillWithSomeTestLinearStates(txCount: Int, externalId: String? = null, participants: List = emptyList(), uniqueIdentifier: UniqueIdentifier? = null, @@ -113,81 +113,41 @@ class VaultFiller @JvmOverloads constructor( linearTimestamp: Instant = now(), constraint: AttachmentConstraint = AutomaticPlaceholderConstraint, includeMe: Boolean = true): Vault { - val myKey: PublicKey = services.myInfo.chooseIdentity().owningKey - val me = AnonymousParty(myKey) - val issuerKey = defaultNotary.keyPair - val signatureMetadata = SignatureMetadata(services.myInfo.platformVersion, Crypto.findSignatureScheme(issuerKey.public).schemeNumberID) - val participantsToUse = if (includeMe) participants.plus(me) else participants - val transactions: List = (1..numberToCreate).map { - // Issue a Linear state - val dummyIssue = TransactionBuilder(notary = defaultNotary.party).apply { - addOutputState(DummyLinearContract.State( - linearId = uniqueIdentifier ?: UniqueIdentifier(externalId), - participants = participantsToUse, - linearString = linearString, - linearNumber = linearNumber, - linearBoolean = linearBoolean, - linearTimestamp = linearTimestamp), DUMMY_LINEAR_CONTRACT_PROGRAM_ID, - constraint = constraint) - addCommand(dummyCommand()) - } - return@map services.signInitialTransaction(dummyIssue).withAdditionalSignature(issuerKey, signatureMetadata) + return fillWithTestStates(txCount, 1, participants, constraint, includeMe) { participantsToUse, _, _ -> + DummyLinearContract.State( + linearId = uniqueIdentifier ?: UniqueIdentifier(externalId), + participants = participantsToUse, + linearString = linearString, + linearNumber = linearNumber, + linearBoolean = linearBoolean, + linearTimestamp = linearTimestamp + ) } - val statesToRecord = if (includeMe) StatesToRecord.ONLY_RELEVANT else StatesToRecord.ALL_VISIBLE - services.recordTransactions(statesToRecord, transactions) - // Get all the StateAndRefs of all the generated transactions. - val states = transactions.flatMap { stx -> - stx.tx.outputs.indices.map { i -> stx.tx.outRef(i) } - } - - return Vault(states) } @JvmOverloads - fun fillWithSomeTestLinearAndDealStates(numberToCreate: Int, + fun fillWithSomeTestLinearAndDealStates(txCount: Int, externalId: String? = null, participants: List = emptyList(), linearString: String = "", linearNumber: Long = 0L, linearBoolean: Boolean = false, - linearTimestamp: Instant = now()): Vault { - val myKey: PublicKey = services.myInfo.chooseIdentity().owningKey - val me = AnonymousParty(myKey) - val issuerKey = defaultNotary.keyPair - val signatureMetadata = SignatureMetadata(services.myInfo.platformVersion, Crypto.findSignatureScheme(issuerKey.public).schemeNumberID) - val transactions: List = (1..numberToCreate).map { - val dummyIssue = TransactionBuilder(notary = defaultNotary.party).apply { - // Issue a Linear state - addOutputState(DummyLinearContract.State( + linearTimestamp: Instant = now()): Vault { + return fillWithTestStates(txCount, 2, participants) { participantsToUse, _, stateIndex -> + when (stateIndex) { + 0 -> DummyLinearContract.State( linearId = UniqueIdentifier(externalId), - participants = participants.plus(me), + participants = participantsToUse, linearString = linearString, linearNumber = linearNumber, linearBoolean = linearBoolean, - linearTimestamp = linearTimestamp), DUMMY_LINEAR_CONTRACT_PROGRAM_ID) - // Issue a Deal state - addOutputState(DummyDealContract.State(ref = "test ref", participants = participants.plus(me)), DUMMY_DEAL_PROGRAM_ID) - addCommand(dummyCommand()) + linearTimestamp = linearTimestamp + ) + else -> DummyDealContract.State(ref = "test ref", participants = participantsToUse) } - return@map services.signInitialTransaction(dummyIssue).withAdditionalSignature(issuerKey, signatureMetadata) } - services.recordTransactions(transactions) - // Get all the StateAndRefs of all the generated transactions. - val states = transactions.flatMap { stx -> - stx.tx.outputs.indices.map { i -> stx.tx.outRef(i) } - } - return Vault(states) } - @JvmOverloads - fun fillWithSomeTestCash(howMuch: Amount, - issuerServices: ServiceHub, - thisManyStates: Int, - issuedBy: PartyAndReference, - owner: AbstractParty? = null, - rng: Random? = null, - statesToRecord: StatesToRecord = StatesToRecord.ONLY_RELEVANT) = fillWithSomeTestCash(howMuch, issuerServices, thisManyStates, thisManyStates, issuedBy, owner, rng, statesToRecord) - /** * Creates a random set of between (by default) 3 and 10 cash states that add up to the given amount and adds them * to the vault. This is intended for unit tests. By default the cash is owned by the legal @@ -196,14 +156,15 @@ class VaultFiller @JvmOverloads constructor( * @param issuerServices service hub of the issuer node, which will be used to sign the transaction. * @return a vault object that represents the generated states (it will NOT be the full vault from the service hub!). */ + @JvmOverloads fun fillWithSomeTestCash(howMuch: Amount, issuerServices: ServiceHub, atLeastThisManyStates: Int, - atMostThisManyStates: Int, issuedBy: PartyAndReference, owner: AbstractParty? = null, rng: Random? = null, - statesToRecord: StatesToRecord = StatesToRecord.ONLY_RELEVANT): Vault { + statesToRecord: StatesToRecord = StatesToRecord.ONLY_RELEVANT, + atMostThisManyStates: Int = atLeastThisManyStates): Vault { val amounts = calculateRandomlySizedAmounts(howMuch, atLeastThisManyStates, atMostThisManyStates, rng ?: rngFactory()) // We will allocate one state to one transaction, for simplicities sake. val cash = Cash() @@ -212,39 +173,46 @@ class VaultFiller @JvmOverloads constructor( cash.generateIssue(issuance, Amount(pennies, Issued(issuedBy, howMuch.token)), owner ?: services.myInfo.singleIdentity(), altNotary) return@map issuerServices.signInitialTransaction(issuance, issuedBy.party.owningKey) } - services.recordTransactions(statesToRecord, transactions) - // Get all the StateRefs of all the generated transactions. - val states = transactions.flatMap { stx -> - stx.tx.outputs.indices.map { i -> stx.tx.outRef(i) } - } - - return Vault(states) + return recordTransactions(transactions, statesToRecord) } /** * Records a dummy state in the Vault (useful for creating random states when testing vault queries) */ - fun fillWithDummyState(participants: List = listOf(services.myInfo.singleIdentity())) : Vault { - val outputState = TransactionState( - data = DummyState(Random().nextInt(), participants = participants), - contract = DummyContract.PROGRAM_ID, - notary = defaultNotary.party - ) - val participantKeys : List = participants.map { it.owningKey } - val builder = TransactionBuilder() - .addOutputState(outputState) - .addCommand(DummyCommandData, participantKeys) - val stxn = services.signInitialTransaction(builder) - services.recordTransactions(stxn) - return Vault(setOf(stxn.tx.outRef(0))) + fun fillWithDummyState(participants: List = listOf(services.myInfo.singleIdentity())): Vault { + return fillWithTestStates(participants = participants) { participantsToUse, _, _ -> + DummyState(Random().nextInt(), participants = participantsToUse) + } } - /** - * Puts together an issuance transaction for the specified amount that starts out being owned by the given pubkey. - */ - fun generateCommoditiesIssue(tx: TransactionBuilder, amount: Amount>, owner: AbstractParty, notary: Party) - = OnLedgerAsset.generateIssue(tx, TransactionState(CommodityState(amount, owner), Obligation.PROGRAM_ID, notary), Obligation.Commands.Issue()) - + fun fillWithTestStates(txCount: Int = 1, + statesPerTx: Int = 1, + participants: List = emptyList(), + constraint: AttachmentConstraint = AutomaticPlaceholderConstraint, + includeMe: Boolean = true, + services: ServiceHub = this.services, + genOutputState: (participantsToUse: List, txIndex: Int, stateIndex: Int) -> T): Vault { + val issuerKey = defaultNotary.keyPair + val signatureMetadata = SignatureMetadata( + services.myInfo.platformVersion, + Crypto.findSignatureScheme(issuerKey.public).schemeNumberID + ) + val participantsToUse = if (includeMe) { + participants + AnonymousParty(this.services.myInfo.chooseIdentity().owningKey) + } else { + participants + } + val transactions = Array(txCount) { txIndex -> + val builder = TransactionBuilder(notary = defaultNotary.party) + repeat(statesPerTx) { stateIndex -> + builder.addOutputState(genOutputState(participantsToUse, txIndex, stateIndex), constraint) + } + builder.addCommand(dummyCommand()) + services.signInitialTransaction(builder).withAdditionalSignature(issuerKey, signatureMetadata) + } + val statesToRecord = if (includeMe) StatesToRecord.ONLY_RELEVANT else StatesToRecord.ALL_VISIBLE + return recordTransactions(transactions.asList(), statesToRecord) + } /** * @@ -257,13 +225,16 @@ class VaultFiller @JvmOverloads constructor( val me = AnonymousParty(myKey) val issuance = TransactionBuilder(null as Party?) - generateCommoditiesIssue(issuance, Amount(amount.quantity, Issued(issuedBy, amount.token)), me, altNotary) + OnLedgerAsset.generateIssue( + issuance, + TransactionState(CommodityState(Amount(amount.quantity, Issued(issuedBy, amount.token)), me), Obligation.PROGRAM_ID, altNotary), + Obligation.Commands.Issue() + ) val transaction = issuerServices.signInitialTransaction(issuance, issuedBy.party.owningKey) - services.recordTransactions(transaction) - return Vault(setOf(transaction.tx.outRef(0))) + return recordTransactions(listOf(transaction)) } - private fun consume(states: List>) { + fun consumeStates(states: Iterable>) { // Create a txn consuming different contract types states.forEach { val builder = TransactionBuilder(notary = altNotary).apply { @@ -300,10 +271,11 @@ class VaultFiller @JvmOverloads constructor( } } - fun consumeDeals(dealStates: List>) = consume(dealStates) - fun consumeLinearStates(linearStates: List>) = consume(linearStates) + fun consumeDeals(dealStates: List>) = consumeStates(dealStates) + fun consumeLinearStates(linearStates: List>) = consumeStates(linearStates) fun evolveLinearStates(linearStates: List>) = consumeAndProduce(linearStates) fun evolveLinearState(linearState: StateAndRef): StateAndRef = consumeAndProduce(linearState) + /** * Consume cash, sending any change to the default identity for this node. Only suitable for use in test scenarios, * where nodes have a default identity. @@ -319,6 +291,16 @@ class VaultFiller @JvmOverloads constructor( services.recordTransactions(spendTx) return update.getOrThrow(Duration.ofSeconds(3)) } + + private fun recordTransactions(transactions: Iterable, + statesToRecord: StatesToRecord = StatesToRecord.ONLY_RELEVANT): Vault { + services.recordTransactions(statesToRecord, transactions) + // Get all the StateAndRefs of all the generated transactions. + val states = transactions.flatMap { stx -> + stx.tx.outputs.indices.map { i -> stx.tx.outRef(i) } + } + return Vault(states) + } } @@ -344,4 +326,3 @@ data class CommodityState( override fun withNewOwner(newOwner: AbstractParty) = CommandAndState(Obligation.Commands.Move(), copy(owner = newOwner)) } -