diff --git a/core-deterministic/build.gradle b/core-deterministic/build.gradle index 71ba8d4efa..48dac3afd0 100644 --- a/core-deterministic/build.gradle +++ b/core-deterministic/build.gradle @@ -23,7 +23,10 @@ def javaHome = System.getProperty('java.home') def jarBaseName = "corda-${project.name}".toString() configurations { - deterministicLibraries.extendsFrom api + deterministicLibraries { + canBeConsumed = false + extendsFrom api + } deterministicArtifacts.extendsFrom deterministicLibraries } @@ -59,7 +62,7 @@ def originalJar = coreJarTask.map { it.outputs.files.singleFile } def patchCore = tasks.register('patchCore', Zip) { dependsOn coreJarTask - destinationDirectory = file("$buildDir/source-libs") + destinationDirectory = layout.buildDirectory.dir('source-libs') metadataCharset 'UTF-8' archiveClassifier = 'transient' archiveExtension = 'jar' @@ -169,7 +172,7 @@ def determinise = tasks.register('determinise', ProGuardTask) { def checkDeterminism = tasks.register('checkDeterminism', ProGuardTask) def metafix = tasks.register('metafix', MetaFixerTask) { - outputDir file("$buildDir/libs") + outputDir = layout.buildDirectory.dir('libs') jars determinise suffix "" diff --git a/core-deterministic/src/main/kotlin/net/corda/core/serialization/SerializationFactory.kt b/core-deterministic/src/main/kotlin/net/corda/core/serialization/SerializationFactory.kt index 69ddb8887b..2d6d8d3b09 100644 --- a/core-deterministic/src/main/kotlin/net/corda/core/serialization/SerializationFactory.kt +++ b/core-deterministic/src/main/kotlin/net/corda/core/serialization/SerializationFactory.kt @@ -55,12 +55,16 @@ abstract class SerializationFactory { * Change the current context inside the block to that supplied. */ fun withCurrentContext(context: SerializationContext?, block: () -> T): T { - val priorContext = _currentContext - if (context != null) _currentContext = context - try { - return block() - } finally { - if (context != null) _currentContext = priorContext + return if (context == null) { + block() + } else { + val priorContext = _currentContext + _currentContext = context + try { + block() + } finally { + _currentContext = priorContext + } } } diff --git a/core-deterministic/testing/data/build.gradle b/core-deterministic/testing/data/build.gradle index ab3acd9249..0141dc3c61 100644 --- a/core-deterministic/testing/data/build.gradle +++ b/core-deterministic/testing/data/build.gradle @@ -3,7 +3,9 @@ plugins { } configurations { - testData + testData { + canBeResolved = false + } } dependencies { diff --git a/core-deterministic/testing/verifier/build.gradle b/core-deterministic/testing/verifier/build.gradle index 774592c6a4..334191cb9f 100644 --- a/core-deterministic/testing/verifier/build.gradle +++ b/core-deterministic/testing/verifier/build.gradle @@ -9,7 +9,12 @@ apply from: "${rootProject.projectDir}/deterministic.gradle" description 'Test utilities for deterministic contract verification' configurations { - deterministicArtifacts + deterministicArtifacts { + canBeResolved = false + } + + // Compile against the deterministic artifacts to ensure that we use only the deterministic API subset. + compileOnly.extendsFrom deterministicArtifacts runtimeArtifacts.extendsFrom api } @@ -20,8 +25,6 @@ dependencies { runtimeArtifacts project(':serialization') runtimeArtifacts project(':core') - // Compile against the deterministic artifacts to ensure that we use only the deterministic API subset. - compileOnly configurations.deterministicArtifacts api "junit:junit:$junit_version" runtimeOnly "org.junit.vintage:junit-vintage-engine:$junit_vintage_version" } diff --git a/core-deterministic/testing/verifier/src/main/kotlin/net/corda/deterministic/verifier/TransactionVerificationRequest.kt b/core-deterministic/testing/verifier/src/main/kotlin/net/corda/deterministic/verifier/TransactionVerificationRequest.kt index c259f2791c..3c9fde9c06 100644 --- a/core-deterministic/testing/verifier/src/main/kotlin/net/corda/deterministic/verifier/TransactionVerificationRequest.kt +++ b/core-deterministic/testing/verifier/src/main/kotlin/net/corda/deterministic/verifier/TransactionVerificationRequest.kt @@ -13,7 +13,7 @@ import net.corda.core.transactions.WireTransaction @Suppress("MemberVisibilityCanBePrivate") //TODO the use of deprecated toLedgerTransaction need to be revisited as resolveContractAttachment requires attachments of the transactions which created input states... -//TODO ...to check contract version non downgrade rule, curretly dummy Attachment if not fund is used which sets contract version to '1' +//TODO ...to check contract version non downgrade rule, currently dummy Attachment if not fund is used which sets contract version to '1' @CordaSerializable class TransactionVerificationRequest(val wtxToVerify: SerializedBytes, val dependencies: Array>, diff --git a/core-tests/src/test/kotlin/net/corda/coretests/transactions/AttachmentsClassLoaderSerializationTests.kt b/core-tests/src/test/kotlin/net/corda/coretests/transactions/AttachmentsClassLoaderSerializationTests.kt index 4ca58d6b46..63f5461e46 100644 --- a/core-tests/src/test/kotlin/net/corda/coretests/transactions/AttachmentsClassLoaderSerializationTests.kt +++ b/core-tests/src/test/kotlin/net/corda/coretests/transactions/AttachmentsClassLoaderSerializationTests.kt @@ -55,7 +55,8 @@ class AttachmentsClassLoaderSerializationTests { arrayOf(isolatedId, att1, att2).map { storage.openAttachment(it)!! }, testNetworkParameters(), SecureHash.zeroHash, - { attachmentTrustCalculator.calculate(it) }, attachmentsClassLoaderCache = null) { classLoader -> + { attachmentTrustCalculator.calculate(it) }, attachmentsClassLoaderCache = null) { serializationContext -> + val classLoader = serializationContext.deserializationClassLoader val contractClass = Class.forName(ISOLATED_CONTRACT_CLASS_NAME, true, classLoader) val contract = contractClass.getDeclaredConstructor().newInstance() as Contract assertEquals("helloworld", contract.declaredField("magicString").value) diff --git a/core/src/main/kotlin/net/corda/core/crypto/internal/DigestAlgorithmFactory.kt b/core/src/main/kotlin/net/corda/core/crypto/internal/DigestAlgorithmFactory.kt index 532b95f4c1..892506aa76 100644 --- a/core/src/main/kotlin/net/corda/core/crypto/internal/DigestAlgorithmFactory.kt +++ b/core/src/main/kotlin/net/corda/core/crypto/internal/DigestAlgorithmFactory.kt @@ -28,9 +28,7 @@ sealed class DigestAlgorithmFactory { } private class CustomAlgorithmFactory(className: String) : DigestAlgorithmFactory() { - val constructor: Constructor = javaClass - .classLoader - .loadClass(className) + val constructor: Constructor = Class.forName(className, false, javaClass.classLoader) .asSubclass(DigestAlgorithm::class.java) .getConstructor() override val algorithm: String = constructor.newInstance().algorithm diff --git a/core/src/main/kotlin/net/corda/core/internal/ClassLoadingUtils.kt b/core/src/main/kotlin/net/corda/core/internal/ClassLoadingUtils.kt index 32ae2608d8..5ead87ca59 100644 --- a/core/src/main/kotlin/net/corda/core/internal/ClassLoadingUtils.kt +++ b/core/src/main/kotlin/net/corda/core/internal/ClassLoadingUtils.kt @@ -23,7 +23,7 @@ import net.corda.core.serialization.internal.AttachmentURLStreamHandlerFactory.a fun createInstancesOfClassesImplementing(classloader: ClassLoader, clazz: Class, classVersionRange: IntRange? = null): Set { return getNamesOfClassesImplementing(classloader, clazz, classVersionRange) - .map { classloader.loadClass(it).asSubclass(clazz) } + .map { Class.forName(it, false, classloader).asSubclass(clazz) } .mapTo(LinkedHashSet()) { it.kotlin.objectOrNewInstance() } } diff --git a/core/src/main/kotlin/net/corda/core/internal/TransactionUtils.kt b/core/src/main/kotlin/net/corda/core/internal/TransactionUtils.kt index 4aa68e650b..9dd32fb54b 100644 --- a/core/src/main/kotlin/net/corda/core/internal/TransactionUtils.kt +++ b/core/src/main/kotlin/net/corda/core/internal/TransactionUtils.kt @@ -54,7 +54,7 @@ fun combinedHash(components: Iterable, digestService: DigestService) components.forEach { stream.write(it.bytes) } - return digestService.hash(stream.toByteArray()); + return digestService.hash(stream.toByteArray()) } /** diff --git a/core/src/main/kotlin/net/corda/core/internal/TransactionVerifierServiceInternal.kt b/core/src/main/kotlin/net/corda/core/internal/TransactionVerifierServiceInternal.kt index e7ca576618..2a8c13036e 100644 --- a/core/src/main/kotlin/net/corda/core/internal/TransactionVerifierServiceInternal.kt +++ b/core/src/main/kotlin/net/corda/core/internal/TransactionVerifierServiceInternal.kt @@ -3,14 +3,39 @@ package net.corda.core.internal import net.corda.core.DeleteForDJVM import net.corda.core.KeepForDJVM import net.corda.core.concurrent.CordaFuture -import net.corda.core.contracts.* +import net.corda.core.contracts.Attachment +import net.corda.core.contracts.Contract +import net.corda.core.contracts.ContractAttachment +import net.corda.core.contracts.ContractClassName +import net.corda.core.contracts.ContractState +import net.corda.core.contracts.HashAttachmentConstraint +import net.corda.core.contracts.SignatureAttachmentConstraint +import net.corda.core.contracts.StateAndRef +import net.corda.core.contracts.StateRef +import net.corda.core.contracts.TransactionState +import net.corda.core.contracts.TransactionVerificationException.ConflictingAttachmentsRejection +import net.corda.core.contracts.TransactionVerificationException.ConstraintPropagationRejection +import net.corda.core.contracts.TransactionVerificationException.ContractCreationError +import net.corda.core.contracts.TransactionVerificationException.ContractRejection +import net.corda.core.contracts.TransactionVerificationException.ContractConstraintRejection +import net.corda.core.contracts.TransactionVerificationException.Direction +import net.corda.core.contracts.TransactionVerificationException.DuplicateAttachmentsRejection +import net.corda.core.contracts.TransactionVerificationException.InvalidConstraintRejection +import net.corda.core.contracts.TransactionVerificationException.MissingAttachmentRejection +import net.corda.core.contracts.TransactionVerificationException.NotaryChangeInWrongTransactionType import net.corda.core.contracts.TransactionVerificationException.TransactionContractConflictException +import net.corda.core.contracts.TransactionVerificationException.TransactionDuplicateEncumbranceException +import net.corda.core.contracts.TransactionVerificationException.TransactionMissingEncumbranceException +import net.corda.core.contracts.TransactionVerificationException.TransactionNonMatchingEncumbranceException +import net.corda.core.contracts.TransactionVerificationException.TransactionNotaryMismatchEncumbranceException +import net.corda.core.contracts.TransactionVerificationException.TransactionRequiredContractUnspecifiedException import net.corda.core.crypto.CompositeKey import net.corda.core.crypto.SecureHash import net.corda.core.internal.rules.StateContractValidationEnforcementRule import net.corda.core.transactions.LedgerTransaction import net.corda.core.utilities.contextLogger import java.util.function.Function +import java.util.function.Supplier @DeleteForDJVM interface TransactionVerifierServiceInternal { @@ -27,8 +52,8 @@ fun LedgerTransaction.prepareVerify(attachments: List) = internalPre * wrong object instance. This class helps avoid that. */ abstract class Verifier(val ltx: LedgerTransaction, protected val transactionClassLoader: ClassLoader) { - private val inputStates: List> = ltx.inputs.map { it.state } - private val allStates: List> = inputStates + ltx.references.map { it.state } + ltx.outputs + private val inputStates: List> = ltx.inputs.map(StateAndRef::state) + private val allStates: List> = inputStates + ltx.references.map(StateAndRef::state) + ltx.outputs companion object { val logger = contextLogger() @@ -39,7 +64,7 @@ abstract class Verifier(val ltx: LedgerTransaction, protected val transactionCla * * It is a critical piece of the security of the platform. * - * @throws TransactionVerificationException + * @throws net.corda.core.contracts.TransactionVerificationException */ fun verify() { // checkNoNotaryChange and checkEncumbrancesValid are called here, and not in the c'tor, as they need access to the "outputs" @@ -82,10 +107,10 @@ abstract class Verifier(val ltx: LedgerTransaction, protected val transactionCla * This is an important piece of the security of transactions. */ private fun getUniqueContractAttachmentsByContract(): Map { - val contractClasses = allStates.map { it.contract }.toSet() + val contractClasses = allStates.mapTo(LinkedHashSet(), TransactionState<*>::contract) // Check that there are no duplicate attachments added. - if (ltx.attachments.size != ltx.attachments.toSet().size) throw TransactionVerificationException.DuplicateAttachmentsRejection(ltx.id, ltx.attachments.groupBy { it }.filterValues { it.size > 1 }.keys.first()) + if (ltx.attachments.size != ltx.attachments.toSet().size) throw DuplicateAttachmentsRejection(ltx.id, ltx.attachments.groupBy { it }.filterValues { it.size > 1 }.keys.first()) // For each attachment this finds all the relevant state contracts that it provides. // And then maps them to the attachment. @@ -103,12 +128,12 @@ abstract class Verifier(val ltx: LedgerTransaction, protected val transactionCla .groupBy { it.first } // Group by contract. .filter { (_, attachments) -> attachments.size > 1 } // And only keep contracts that are in multiple attachments. It's guaranteed that attachments were unique by a previous check. .keys.firstOrNull() // keep the first one - if any - to throw a meaningful exception. - if (contractWithMultipleAttachments != null) throw TransactionVerificationException.ConflictingAttachmentsRejection(ltx.id, contractWithMultipleAttachments) + if (contractWithMultipleAttachments != null) throw ConflictingAttachmentsRejection(ltx.id, contractWithMultipleAttachments) val result = contractAttachmentsPerContract.toMap() // Check that there is an attachment for each contract. - if (result.keys != contractClasses) throw TransactionVerificationException.MissingAttachmentRejection(ltx.id, contractClasses.minus(result.keys).first()) + if (result.keys != contractClasses) throw MissingAttachmentRejection(ltx.id, contractClasses.minus(result.keys).first()) return result } @@ -124,7 +149,7 @@ abstract class Verifier(val ltx: LedgerTransaction, protected val transactionCla if (ltx.notary != null && (ltx.inputs.isNotEmpty() || ltx.references.isNotEmpty())) { ltx.outputs.forEach { if (it.notary != ltx.notary) { - throw TransactionVerificationException.NotaryChangeInWrongTransactionType(ltx.id, ltx.notary, it.notary) + throw NotaryChangeInWrongTransactionType(ltx.id, ltx.notary, it.notary) } } } @@ -156,10 +181,10 @@ abstract class Verifier(val ltx: LedgerTransaction, protected val transactionCla it.ref.txhash == ref.txhash && it.ref.index == state.encumbrance } if (!encumbranceStateExists) { - throw TransactionVerificationException.TransactionMissingEncumbranceException( + throw TransactionMissingEncumbranceException( ltx.id, state.encumbrance!!, - TransactionVerificationException.Direction.INPUT + Direction.INPUT ) } } @@ -194,15 +219,15 @@ abstract class Verifier(val ltx: LedgerTransaction, protected val transactionCla statesAndEncumbrance.forEach { (statePosition, encumbrance) -> // Check it does not refer to itself. if (statePosition == encumbrance || encumbrance >= ltx.outputs.size) { - throw TransactionVerificationException.TransactionMissingEncumbranceException( + throw TransactionMissingEncumbranceException( ltx.id, encumbrance, - TransactionVerificationException.Direction.OUTPUT + Direction.OUTPUT ) } else { encumberedSet.add(statePosition) // Guaranteed to have unique elements. if (!encumbranceSet.add(encumbrance)) { - throw TransactionVerificationException.TransactionDuplicateEncumbranceException(ltx.id, encumbrance) + throw TransactionDuplicateEncumbranceException(ltx.id, encumbrance) } } } @@ -211,7 +236,7 @@ abstract class Verifier(val ltx: LedgerTransaction, protected val transactionCla val symmetricDifference = (encumberedSet union encumbranceSet).subtract(encumberedSet intersect encumbranceSet) if (symmetricDifference.isNotEmpty()) { // At least one encumbered state is not in the [encumbranceSet] and vice versa. - throw TransactionVerificationException.TransactionNonMatchingEncumbranceException(ltx.id, symmetricDifference) + throw TransactionNonMatchingEncumbranceException(ltx.id, symmetricDifference) } } @@ -235,7 +260,7 @@ abstract class Verifier(val ltx: LedgerTransaction, protected val transactionCla if (indicesAlreadyChecked.add(index)) { val encumbranceIndex = ltx.outputs[index].encumbrance!! if (ltx.outputs[index].notary != ltx.outputs[encumbranceIndex].notary) { - throw TransactionVerificationException.TransactionNotaryMismatchEncumbranceException( + throw TransactionNotaryMismatchEncumbranceException( ltx.id, index, encumbranceIndex, @@ -263,7 +288,7 @@ abstract class Verifier(val ltx: LedgerTransaction, protected val transactionCla val shouldEnforce = StateContractValidationEnforcementRule.shouldEnforce(state.data) val requiredContractClassName = state.data.requiredContractClassName - ?: if (shouldEnforce) throw TransactionVerificationException.TransactionRequiredContractUnspecifiedException(ltx.id, state) else return + ?: if (shouldEnforce) throw TransactionRequiredContractUnspecifiedException(ltx.id, state) else return if (state.contract != requiredContractClassName) if (shouldEnforce) { @@ -310,7 +335,7 @@ abstract class Verifier(val ltx: LedgerTransaction, protected val transactionCla outputConstraints.forEach { outputConstraint -> inputConstraints.forEach { inputConstraint -> if (!(outputConstraint.canBeTransitionedFrom(inputConstraint, contractAttachment))) { - throw TransactionVerificationException.ConstraintPropagationRejection( + throw ConstraintPropagationRejection( ltx.id, contractClassName, inputConstraint, @@ -331,7 +356,7 @@ abstract class Verifier(val ltx: LedgerTransaction, protected val transactionCla @Suppress("NestedBlockDepth", "MagicNumber") private fun verifyConstraints(contractAttachmentsByContract: Map) { // For each contract/constraint pair check that the relevant attachment is valid. - allStates.map { it.contract to it.constraint }.toSet().forEach { (contract, constraint) -> + allStates.mapTo(LinkedHashSet()) { it.contract to it.constraint }.forEach { (contract, constraint) -> if (constraint is SignatureAttachmentConstraint) { /** * Support for signature constraints has been added on @@ -346,9 +371,9 @@ abstract class Verifier(val ltx: LedgerTransaction, protected val transactionCla "Signature constraints" ) val constraintKey = constraint.key - if (ltx.networkParameters?.minimumPlatformVersion ?: 1 >= PlatformVersionSwitches.LIMIT_KEYS_IN_SIGNATURE_CONSTRAINTS) { + if ((ltx.networkParameters?.minimumPlatformVersion ?: 1) >= PlatformVersionSwitches.LIMIT_KEYS_IN_SIGNATURE_CONSTRAINTS) { if (constraintKey is CompositeKey && constraintKey.leafKeys.size > MAX_NUMBER_OF_KEYS_IN_SIGNATURE_CONSTRAINT) { - throw TransactionVerificationException.InvalidConstraintRejection(ltx.id, contract, + throw InvalidConstraintRejection(ltx.id, contract, "Signature constraint contains composite key with ${constraintKey.leafKeys.size} leaf keys, " + "which is more than the maximum allowed number of keys " + "($MAX_NUMBER_OF_KEYS_IN_SIGNATURE_CONSTRAINT).") @@ -364,7 +389,7 @@ abstract class Verifier(val ltx: LedgerTransaction, protected val transactionCla if (HashAttachmentConstraint.disableHashConstraints && constraint is HashAttachmentConstraint) logger.warnOnce("Skipping hash constraints verification.") else if (!constraint.isSatisfiedBy(constraintAttachment)) - throw TransactionVerificationException.ContractConstraintRejection(ltx.id, contract) + throw ContractConstraintRejection(ltx.id, contract) } } @@ -374,29 +399,12 @@ abstract class Verifier(val ltx: LedgerTransaction, protected val transactionCla abstract fun verifyContracts() } -class BasicVerifier(ltx: LedgerTransaction, transactionClassLoader: ClassLoader) : Verifier(ltx, transactionClassLoader) { - /** - * Check the transaction is contract-valid by running the verify() for each input and output state contract. - * If any contract fails to verify, the whole transaction is considered to be invalid. - * - * Note: Reference states are not verified. - */ - override fun verifyContracts() { - try { - ContractVerifier(transactionClassLoader).apply(ltx) - } catch (e: TransactionVerificationException.ContractRejection) { - logger.error("Error validating transaction ${ltx.id}.", e.cause) - throw e - } - } -} - /** * Verify all of the contracts on the given [LedgerTransaction]. */ @Suppress("TooGenericExceptionCaught") @KeepForDJVM -class ContractVerifier(private val transactionClassLoader: ClassLoader) : Function { +class ContractVerifier(private val transactionClassLoader: ClassLoader) : Function, Unit> { // This constructor is used inside the DJVM's sandbox. @Suppress("unused") constructor() : this(ClassLoader.getSystemClassLoader()) @@ -406,34 +414,45 @@ class ContractVerifier(private val transactionClassLoader: ClassLoader) : Functi return try { Class.forName(contractClassName, false, transactionClassLoader).asSubclass(Contract::class.java) } catch (e: Exception) { - throw TransactionVerificationException.ContractCreationError(id, contractClassName, e) + throw ContractCreationError(id, contractClassName, e) } } - override fun apply(ltx: LedgerTransaction) { - val contractClassNames = (ltx.inputs.map(StateAndRef::state) + ltx.outputs) + private fun generateContracts(ltx: LedgerTransaction): List { + return (ltx.inputs.map(StateAndRef::state) + ltx.outputs) .mapTo(LinkedHashSet(), TransactionState<*>::contract) - - contractClassNames.associateBy( - { it }, { createContractClass(ltx.id, it) } - ).map { (contractClassName, contractClass) -> - try { - /** - * This function must execute with the DJVM's sandbox, which does not - * permit user code to invoke [java.lang.Class.getDeclaredConstructor]. - * - * [Class.newInstance] is deprecated as of Java 9. - */ - @Suppress("deprecation") - contractClass.newInstance() - } catch (e: Exception) { - throw TransactionVerificationException.ContractCreationError(ltx.id, contractClassName, e) + .map { contractClassName -> + createContractClass(ltx.id, contractClassName) + }.map { contractClass -> + try { + /** + * This function must execute within the DJVM's sandbox, which does not + * permit user code to invoke [java.lang.reflect.Constructor.newInstance]. + * (This would be fixable now, provided the constructor is public.) + * + * [Class.newInstance] is deprecated as of Java 9. + */ + @Suppress("deprecation") + contractClass.newInstance() + } catch (e: Exception) { + throw ContractCreationError(ltx.id, contractClass.name, e) + } } + } + + override fun apply(transactionFactory: Supplier) { + var firstLtx: LedgerTransaction? = null + + transactionFactory.get().let { ltx -> + firstLtx = ltx + generateContracts(ltx) }.forEach { contract -> + val ltx = firstLtx ?: transactionFactory.get() + firstLtx = null try { contract.verify(ltx) } catch (e: Exception) { - throw TransactionVerificationException.ContractRejection(ltx.id, contract, e) + throw ContractRejection(ltx.id, contract, e) } } } diff --git a/core/src/main/kotlin/net/corda/core/node/NetworkParameters.kt b/core/src/main/kotlin/net/corda/core/node/NetworkParameters.kt index 4c85a4dc07..50fe5eb258 100644 --- a/core/src/main/kotlin/net/corda/core/node/NetworkParameters.kt +++ b/core/src/main/kotlin/net/corda/core/node/NetworkParameters.kt @@ -13,6 +13,8 @@ import net.corda.core.utilities.days import java.security.PublicKey import java.time.Duration import java.time.Instant +import java.util.Collections.unmodifiableList +import java.util.Collections.unmodifiableMap // DOCSTART 1 /** @@ -166,6 +168,38 @@ data class NetworkParameters( epoch=$epoch }""" } + + fun toImmutable(): NetworkParameters { + return NetworkParameters( + minimumPlatformVersion = minimumPlatformVersion, + notaries = unmodifiable(notaries), + maxMessageSize = maxMessageSize, + maxTransactionSize = maxTransactionSize, + modifiedTime = modifiedTime, + epoch = epoch, + whitelistedContractImplementations = unmodifiable(whitelistedContractImplementations) { entry -> + unmodifiableList(entry.value) + }, + eventHorizon = eventHorizon, + packageOwnership = unmodifiable(packageOwnership) + ) + } +} + +private fun unmodifiable(list: List): List { + return if (list.isEmpty()) { + emptyList() + } else { + unmodifiableList(list) + } +} + +private inline fun unmodifiable(map: Map, transform: (Map.Entry) -> V = Map.Entry::value): Map { + return if (map.isEmpty()) { + emptyMap() + } else { + unmodifiableMap(map.mapValues(transform)) + } } /** diff --git a/core/src/main/kotlin/net/corda/core/serialization/SerializationAPI.kt b/core/src/main/kotlin/net/corda/core/serialization/SerializationAPI.kt index bcae581e66..77289d8c8e 100644 --- a/core/src/main/kotlin/net/corda/core/serialization/SerializationAPI.kt +++ b/core/src/main/kotlin/net/corda/core/serialization/SerializationAPI.kt @@ -13,6 +13,10 @@ import net.corda.core.utilities.sequence import java.io.NotSerializableException import java.sql.Blob +const val DESERIALIZATION_CACHE_PROPERTY = "DESERIALIZATION_CACHE" +const val AMQP_ENVELOPE_CACHE_PROPERTY = "AMQP_ENVELOPE_CACHE" +const val AMQP_ENVELOPE_CACHE_INITIAL_CAPACITY = 256 + data class ObjectWithCompatibleContext(val obj: T, val context: SerializationContext) /** @@ -65,12 +69,16 @@ abstract class SerializationFactory { * Change the current context inside the block to that supplied. */ fun withCurrentContext(context: SerializationContext?, block: () -> T): T { - val priorContext = _currentContext.get() - if (context != null) _currentContext.set(context) - try { - return block() - } finally { - if (context != null) _currentContext.set(priorContext) + return if (context == null) { + block() + } else { + val priorContext = _currentContext.get() + _currentContext.set(context) + try { + block() + } finally { + _currentContext.set(priorContext) + } } } diff --git a/core/src/main/kotlin/net/corda/core/serialization/internal/AttachmentsClassLoader.kt b/core/src/main/kotlin/net/corda/core/serialization/internal/AttachmentsClassLoader.kt index 49d7a48508..0a287a7f7d 100644 --- a/core/src/main/kotlin/net/corda/core/serialization/internal/AttachmentsClassLoader.kt +++ b/core/src/main/kotlin/net/corda/core/serialization/internal/AttachmentsClassLoader.kt @@ -20,6 +20,9 @@ import net.corda.core.internal.createInstancesOfClassesImplementing import net.corda.core.internal.createSimpleCache import net.corda.core.internal.toSynchronised import net.corda.core.node.NetworkParameters +import net.corda.core.serialization.AMQP_ENVELOPE_CACHE_INITIAL_CAPACITY +import net.corda.core.serialization.AMQP_ENVELOPE_CACHE_PROPERTY +import net.corda.core.serialization.DESERIALIZATION_CACHE_PROPERTY import net.corda.core.serialization.SerializationContext import net.corda.core.serialization.SerializationCustomSerializer import net.corda.core.serialization.SerializationFactory @@ -39,7 +42,9 @@ import java.net.URLStreamHandler import java.net.URLStreamHandlerFactory import java.security.MessageDigest import java.security.Permission -import java.util.* +import java.util.Locale +import java.util.ServiceLoader +import java.util.WeakHashMap import java.util.function.Function /** @@ -67,12 +72,15 @@ class AttachmentsClassLoader(attachments: List, init { // Apply our own URLStreamHandlerFactory to resolve attachments setOrDecorateURLStreamHandlerFactory() + + // Allow AttachmentsClassLoader to be used concurrently. + registerAsParallelCapable() } // Jolokia and Json-simple are dependencies that were bundled by mistake within contract jars. // In the AttachmentsClassLoader we just block any class in those 2 packages. private val ignoreDirectories = listOf("org/jolokia/", "org/json/simple/") - private val ignorePackages = ignoreDirectories.map { it.replace("/", ".") } + private val ignorePackages = ignoreDirectories.map { it.replace('/', '.') } /** * Apply our custom factory either directly, if `URL.setURLStreamHandlerFactory` has not been called yet, @@ -176,10 +184,10 @@ class AttachmentsClassLoader(attachments: List, // TODO - investigate potential exploits. private fun shouldCheckForNoOverlap(path: String, targetPlatformVersion: Int): Boolean { require(path.toLowerCase() == path) - require(!path.contains("\\")) + require(!path.contains('\\')) return when { - path.endsWith("/") -> false // Directories (packages) can overlap. + path.endsWith('/') -> false // Directories (packages) can overlap. targetPlatformVersion < PlatformVersionSwitches.IGNORE_JOLOKIA_JSON_SIMPLE_IN_CORDAPPS && ignoreDirectories.any { path.startsWith(it) } -> false // Ignore jolokia and json-simple for old cordapps. path.endsWith(".class") -> true // All class files need to be unique. @@ -219,7 +227,7 @@ class AttachmentsClassLoader(attachments: List, // attacks on externally connected systems that only consider type names, we allow people to formally // claim their parts of the Java package namespace via registration with the zone operator. - val classLoaderEntries = mutableMapOf() + val classLoaderEntries = mutableMapOf() val ctx = AttachmentHashContext(sampleTxId) for (attachment in attachments) { // We may have been given an attachment loaded from the database in which case, important info like @@ -238,7 +246,7 @@ class AttachmentsClassLoader(attachments: List, // signed by the owners of the packages, even if it's not. We'd eventually discover that fact // when trying to read the class file to use it, but if we'd made any decisions based on // perceived correctness of the signatures or package ownership already, that would be too late. - attachment.openAsJAR().use { JarSignatureCollector.collectSigners(it) } + attachment.openAsJAR().use(JarSignatureCollector::collectSigners) } // Now open it again to compute the overlap and package ownership data. @@ -309,11 +317,11 @@ class AttachmentsClassLoader(attachments: List, * Required to prevent classes that were excluded from the no-overlap check from being loaded by contract code. * As it can lead to non-determinism. */ - override fun loadClass(name: String?): Class<*> { - if (ignorePackages.any { name!!.startsWith(it) }) { + override fun loadClass(name: String, resolve: Boolean): Class<*>? { + if (ignorePackages.any { name.startsWith(it) }) { throw ClassNotFoundException(name) } - return super.loadClass(name) + return super.loadClass(name, resolve) } } @@ -323,7 +331,7 @@ class AttachmentsClassLoader(attachments: List, */ @VisibleForTesting object AttachmentsClassLoaderBuilder { - const val CACHE_SIZE = 16 + private const val CACHE_SIZE = 16 private val fallBackCache: AttachmentsClassLoaderCache = AttachmentsClassLoaderSimpleCacheImpl(CACHE_SIZE) @@ -339,13 +347,13 @@ object AttachmentsClassLoaderBuilder { isAttachmentTrusted: (Attachment) -> Boolean, parent: ClassLoader = ClassLoader.getSystemClassLoader(), attachmentsClassLoaderCache: AttachmentsClassLoaderCache?, - block: (ClassLoader) -> T): T { - val attachmentIds = attachments.map(Attachment::id).toSet() + block: (SerializationContext) -> T): T { + val attachmentIds = attachments.mapTo(LinkedHashSet(), Attachment::id) val cache = attachmentsClassLoaderCache ?: fallBackCache - val serializationContext = cache.computeIfAbsent(AttachmentsClassLoaderKey(attachmentIds, params), Function { + val serializationContext = cache.computeIfAbsent(AttachmentsClassLoaderKey(attachmentIds, params), Function { key -> // Create classloader and load serializers, whitelisted classes - val transactionClassLoader = AttachmentsClassLoader(attachments, params, txId, isAttachmentTrusted, parent) + val transactionClassLoader = AttachmentsClassLoader(attachments, key.params, txId, isAttachmentTrusted, parent) val serializers = try { createInstancesOfClassesImplementing(transactionClassLoader, SerializationCustomSerializer::class.java, JDK1_2_CLASS_FILE_FORMAT_MAJOR_VERSION..JDK8_CLASS_FILE_FORMAT_MAJOR_VERSION) @@ -366,11 +374,16 @@ object AttachmentsClassLoaderBuilder { .withWhitelist(whitelistedClasses) .withCustomSerializers(serializers) .withoutCarpenter() - }) + }).withProperties(mapOf( + // Duplicate the SerializationContext from the cache and give + // it these extra properties, just for this transaction. + AMQP_ENVELOPE_CACHE_PROPERTY to HashMap(AMQP_ENVELOPE_CACHE_INITIAL_CAPACITY), + DESERIALIZATION_CACHE_PROPERTY to HashMap() + )) // Deserialize all relevant classes in the transaction classloader. return SerializationFactory.defaultFactory.withCurrentContext(serializationContext) { - block(serializationContext.deserializationClassLoader) + block(serializationContext) } } } @@ -495,4 +508,4 @@ private class AttachmentURLConnection(url: URL, private val attachment: Attachme override fun connect() { connected = true } -} \ No newline at end of file +} diff --git a/core/src/main/kotlin/net/corda/core/transactions/ContractUpgradeTransactions.kt b/core/src/main/kotlin/net/corda/core/transactions/ContractUpgradeTransactions.kt index f1f53f90b8..911c67d0da 100644 --- a/core/src/main/kotlin/net/corda/core/transactions/ContractUpgradeTransactions.kt +++ b/core/src/main/kotlin/net/corda/core/transactions/ContractUpgradeTransactions.kt @@ -145,7 +145,7 @@ data class ContractUpgradeWireTransaction( private fun upgradedContract(className: ContractClassName, classLoader: ClassLoader): UpgradedContract = try { @Suppress("UNCHECKED_CAST") - classLoader.loadClass(className).asSubclass(UpgradedContract::class.java).getDeclaredConstructor().newInstance() as UpgradedContract + Class.forName(className, false, classLoader).asSubclass(UpgradedContract::class.java).getDeclaredConstructor().newInstance() as UpgradedContract } catch (e: Exception) { throw TransactionVerificationException.ContractCreationError(id, className, e) } @@ -166,9 +166,9 @@ data class ContractUpgradeWireTransaction( params, id, { (services as ServiceHubCoreInternal).attachmentTrustCalculator.calculate(it) }, - attachmentsClassLoaderCache = (services as ServiceHubCoreInternal).attachmentsClassLoaderCache) { transactionClassLoader -> + attachmentsClassLoaderCache = (services as ServiceHubCoreInternal).attachmentsClassLoaderCache) { serializationContext -> val resolvedInput = binaryInput.deserialize() - val upgradedContract = upgradedContract(upgradedContractClassName, transactionClassLoader) + val upgradedContract = upgradedContract(upgradedContractClassName, serializationContext.deserializationClassLoader) val outputState = calculateUpgradedState(resolvedInput, upgradedContract, upgradedAttachment) outputState.serialize() } @@ -311,8 +311,7 @@ private constructor( @CordaInternal internal fun loadUpgradedContract(upgradedContractClassName: ContractClassName, classLoader: ClassLoader): UpgradedContract { @Suppress("UNCHECKED_CAST") - return classLoader - .loadClass(upgradedContractClassName) + return Class.forName(upgradedContractClassName, false, classLoader) .asSubclass(Contract::class.java) .getConstructor() .newInstance() as UpgradedContract diff --git a/core/src/main/kotlin/net/corda/core/transactions/LedgerTransaction.kt b/core/src/main/kotlin/net/corda/core/transactions/LedgerTransaction.kt index 717b9b5937..c7b00f60bd 100644 --- a/core/src/main/kotlin/net/corda/core/transactions/LedgerTransaction.kt +++ b/core/src/main/kotlin/net/corda/core/transactions/LedgerTransaction.kt @@ -18,21 +18,25 @@ import net.corda.core.crypto.DigestService import net.corda.core.crypto.SecureHash import net.corda.core.flows.FlowLogic import net.corda.core.identity.Party -import net.corda.core.internal.BasicVerifier +import net.corda.core.internal.ContractVerifier import net.corda.core.internal.SerializedStateAndRef import net.corda.core.internal.Verifier import net.corda.core.internal.castIfPossible import net.corda.core.internal.deserialiseCommands import net.corda.core.internal.deserialiseComponentGroup +import net.corda.core.internal.eagerDeserialise import net.corda.core.internal.isUploaderTrusted import net.corda.core.internal.uncheckedCast import net.corda.core.node.NetworkParameters import net.corda.core.serialization.DeprecatedConstructorForDeserialization +import net.corda.core.serialization.SerializationContext +import net.corda.core.serialization.SerializationFactory import net.corda.core.serialization.internal.AttachmentsClassLoaderCache import net.corda.core.serialization.internal.AttachmentsClassLoaderBuilder import net.corda.core.utilities.contextLogger import java.util.Collections.unmodifiableList import java.util.function.Predicate +import java.util.function.Supplier /** * A LedgerTransaction is derived from a [WireTransaction]. It is the result of doing the following operations: @@ -90,7 +94,7 @@ private constructor( private val serializedInputs: List?, private val serializedReferences: List?, private val isAttachmentTrusted: (Attachment) -> Boolean, - private val verifierFactory: (LedgerTransaction, ClassLoader) -> Verifier, + private val verifierFactory: (LedgerTransaction, SerializationContext) -> Verifier, private val attachmentsClassLoaderCache: AttachmentsClassLoaderCache?, val digestService: DigestService ) : FullTransaction() { @@ -100,22 +104,23 @@ private constructor( */ @DeprecatedConstructorForDeserialization(1) private constructor( - inputs: List>, - outputs: List>, - commands: List>, - attachments: List, - id: SecureHash, - notary: Party?, - timeWindow: TimeWindow?, - privacySalt: PrivacySalt, - networkParameters: NetworkParameters?, - references: List>, - componentGroups: List?, - serializedInputs: List?, - serializedReferences: List?, - isAttachmentTrusted: (Attachment) -> Boolean, - verifierFactory: (LedgerTransaction, ClassLoader) -> Verifier, - attachmentsClassLoaderCache: AttachmentsClassLoaderCache?) : this( + inputs: List>, + outputs: List>, + commands: List>, + attachments: List, + id: SecureHash, + notary: Party?, + timeWindow: TimeWindow?, + privacySalt: PrivacySalt, + networkParameters: NetworkParameters?, + references: List>, + componentGroups: List?, + serializedInputs: List?, + serializedReferences: List?, + isAttachmentTrusted: (Attachment) -> Boolean, + verifierFactory: (LedgerTransaction, SerializationContext) -> Verifier, + attachmentsClassLoaderCache: AttachmentsClassLoaderCache? + ) : this( inputs, outputs, commands, attachments, id, notary, timeWindow, privacySalt, networkParameters, references, componentGroups, serializedInputs, serializedReferences, isAttachmentTrusted, verifierFactory, attachmentsClassLoaderCache, DigestService.sha2_256) @@ -124,8 +129,8 @@ private constructor( companion object { private val logger = contextLogger() - private fun protect(list: List?): List? { - return list?.run { + private fun protect(list: List): List { + return list.run { if (isEmpty()) { emptyList() } else { @@ -134,6 +139,8 @@ private constructor( } } + private fun protectOrNull(list: List?): List? = list?.let(::protect) + @CordaInternal internal fun create( inputs: List>, @@ -164,9 +171,9 @@ private constructor( privacySalt = privacySalt, networkParameters = networkParameters, references = references, - componentGroups = protect(componentGroups), - serializedInputs = protect(serializedInputs), - serializedReferences = protect(serializedReferences), + componentGroups = protectOrNull(componentGroups), + serializedInputs = protectOrNull(serializedInputs), + serializedReferences = protectOrNull(serializedReferences), isAttachmentTrusted = isAttachmentTrusted, verifierFactory = ::BasicVerifier, attachmentsClassLoaderCache = attachmentsClassLoaderCache, @@ -176,10 +183,11 @@ private constructor( /** * This factory function will create an instance of [LedgerTransaction] - * that will be used inside the DJVM sandbox. + * that will be used for contract verification. See [BasicVerifier] and + * [DeterministicVerifier][net.corda.node.internal.djvm.DeterministicVerifier]. */ @CordaInternal - fun createForSandbox( + fun createForContractVerify( inputs: List>, outputs: List>, commands: List>, @@ -188,28 +196,31 @@ private constructor( notary: Party?, timeWindow: TimeWindow?, privacySalt: PrivacySalt, - networkParameters: NetworkParameters, + networkParameters: NetworkParameters?, references: List>, digestService: DigestService): LedgerTransaction { return LedgerTransaction( - inputs = inputs, - outputs = outputs, - commands = commands, - attachments = attachments, + inputs = protect(inputs), + outputs = protect(outputs), + commands = protect(commands), + attachments = protect(attachments), id = id, notary = notary, timeWindow = timeWindow, privacySalt = privacySalt, networkParameters = networkParameters, - references = references, + references = protect(references), componentGroups = null, serializedInputs = null, serializedReferences = null, isAttachmentTrusted = { true }, - verifierFactory = ::BasicVerifier, + verifierFactory = ::NoOpVerifier, attachmentsClassLoaderCache = null, digestService = digestService - ) + // This check accesses input states and must run on the LedgerTransaction + // instance that is verified, not on the outer LedgerTransaction shell. + // All states must also deserialize using the correct SerializationContext. + ).also(LedgerTransaction::checkBaseInvariants) } } @@ -251,11 +262,17 @@ private constructor( getParamsWithGoo(), id, isAttachmentTrusted = isAttachmentTrusted, - attachmentsClassLoaderCache = attachmentsClassLoaderCache) { transactionClassLoader -> - // Create a copy of the outer LedgerTransaction which deserializes all fields inside the [transactionClassLoader]. + attachmentsClassLoaderCache = attachmentsClassLoaderCache) { serializationContext -> + + // Legacy check - warns if the LedgerTransaction was created incorrectly. + checkLtxForVerification() + + // Create a copy of the outer LedgerTransaction which deserializes all fields using + // the serialization context (or its deserializationClassloader). // Only the copy will be used for verification, and the outer shell will be discarded. // This artifice is required to preserve backwards compatibility. - verifierFactory(createLtxForVerification(), transactionClassLoader) + // NOTE: The Verifier creates the copies of the LedgerTransaction object now. + verifierFactory(this, serializationContext) } } @@ -272,7 +289,7 @@ private constructor( * Node without changing either the wire format or any public APIs. */ @CordaInternal - fun specialise(alternateVerifier: (LedgerTransaction, ClassLoader) -> Verifier): LedgerTransaction = LedgerTransaction( + fun specialise(alternateVerifier: (LedgerTransaction, SerializationContext) -> Verifier): LedgerTransaction = LedgerTransaction( inputs = inputs, outputs = outputs, commands = commands, @@ -287,7 +304,11 @@ private constructor( serializedInputs = serializedInputs, serializedReferences = serializedReferences, isAttachmentTrusted = isAttachmentTrusted, - verifierFactory = alternateVerifier, + verifierFactory = if (verifierFactory == ::NoOpVerifier) { + throw IllegalStateException("Cannot specialise transaction while verifying contracts") + } else { + alternateVerifier + }, attachmentsClassLoaderCache = attachmentsClassLoaderCache, digestService = digestService ) @@ -319,58 +340,12 @@ private constructor( } /** - * Create the [LedgerTransaction] instance that will be used by contract verification. - * - * This method needs to run in the special transaction attachments classloader context. */ - private fun createLtxForVerification(): LedgerTransaction { - val serializedInputs = this.serializedInputs - val serializedReferences = this.serializedReferences - val componentGroups = this.componentGroups - - val transaction= if (serializedInputs != null && serializedReferences != null && componentGroups != null) { - // Deserialize all relevant classes in the transaction classloader. - val deserializedInputs = serializedInputs.map { it.toStateAndRef() } - val deserializedReferences = serializedReferences.map { it.toStateAndRef() } - val deserializedOutputs = deserialiseComponentGroup(componentGroups, TransactionState::class, ComponentGroupEnum.OUTPUTS_GROUP, forceDeserialize = true) - val deserializedCommands = deserialiseCommands(componentGroups, forceDeserialize = true, digestService = digestService) - val authenticatedDeserializedCommands = deserializedCommands.map { cmd -> - @Suppress("DEPRECATION") // Deprecated feature. - val parties = commands.find { it.value.javaClass.name == cmd.value.javaClass.name }!!.signingParties - CommandWithParties(cmd.signers, parties, cmd.value) - } - - LedgerTransaction( - inputs = deserializedInputs, - outputs = deserializedOutputs, - commands = authenticatedDeserializedCommands, - attachments = this.attachments, - id = this.id, - notary = this.notary, - timeWindow = this.timeWindow, - privacySalt = this.privacySalt, - networkParameters = this.networkParameters, - references = deserializedReferences, - componentGroups = componentGroups, - serializedInputs = serializedInputs, - serializedReferences = serializedReferences, - isAttachmentTrusted = isAttachmentTrusted, - verifierFactory = verifierFactory, - attachmentsClassLoaderCache = attachmentsClassLoaderCache, - digestService = digestService - ) - } else { - // This branch is only present for backwards compatibility. + private fun checkLtxForVerification() { + if (serializedInputs == null || serializedReferences == null || componentGroups == null) { logger.warn("The LedgerTransaction should not be instantiated directly from client code. Please use WireTransaction.toLedgerTransaction." + "The result of the verify method might not be accurate.") - this } - - // This check accesses input states and must be run in this context. - // It must run on the instance that is verified, not on the outer LedgerTransaction shell. - transaction.checkBaseInvariants() - - return transaction } /** @@ -740,7 +715,7 @@ private constructor( componentGroups = null, serializedInputs = null, serializedReferences = null, - isAttachmentTrusted = { it.isUploaderTrusted() }, + isAttachmentTrusted = Attachment::isUploaderTrusted, verifierFactory = ::BasicVerifier, attachmentsClassLoaderCache = null ) @@ -770,7 +745,7 @@ private constructor( componentGroups = null, serializedInputs = null, serializedReferences = null, - isAttachmentTrusted = { it.isUploaderTrusted() }, + isAttachmentTrusted = Attachment::isUploaderTrusted, verifierFactory = ::BasicVerifier, attachmentsClassLoaderCache = null ) @@ -838,3 +813,89 @@ private constructor( ) } } + +/** + * This is the default [Verifier] that configures Corda + * to execute [Contract.verify(LedgerTransaction)]. + * + * THIS CLASS IS NOT PUBLIC API, AND IS DELIBERATELY PRIVATE! + */ +@CordaInternal +private class BasicVerifier( + ltx: LedgerTransaction, + private val serializationContext: SerializationContext +) : Verifier(ltx, serializationContext.deserializationClassLoader) { + + init { + // This is a sanity check: We should only instantiate this + // class from [LedgerTransaction.internalPrepareVerify]. + require(serializationContext === SerializationFactory.defaultFactory.currentContext) { + "BasicVerifier for TX ${ltx.id} created outside its SerializationContext" + } + + // Fetch these commands' signing parties from the database. + // Corda forbids database access during contract verification, + // and so we must load the commands here eagerly instead. + ltx.commands.eagerDeserialise() + } + + private fun createTransaction(): LedgerTransaction { + // Deserialize all relevant classes using the serializationContext. + return SerializationFactory.defaultFactory.withCurrentContext(serializationContext) { + ltx.transform { componentGroups, serializedInputs, serializedReferences -> + val deserializedInputs = serializedInputs.map(SerializedStateAndRef::toStateAndRef) + val deserializedReferences = serializedReferences.map(SerializedStateAndRef::toStateAndRef) + val deserializedOutputs = deserialiseComponentGroup(componentGroups, TransactionState::class, ComponentGroupEnum.OUTPUTS_GROUP, forceDeserialize = true) + val deserializedCommands = deserialiseCommands(componentGroups, forceDeserialize = true, digestService = ltx.digestService) + val authenticatedDeserializedCommands = deserializedCommands.mapIndexed { idx, cmd -> + // Requires ltx.commands to have been deserialized already. + @Suppress("DEPRECATION") // Deprecated feature. + val parties = ltx.commands[idx].signingParties + CommandWithParties(cmd.signers, parties, cmd.value) + } + + LedgerTransaction.createForContractVerify( + inputs = deserializedInputs, + outputs = deserializedOutputs, + commands = authenticatedDeserializedCommands, + attachments = ltx.attachments, + id = ltx.id, + notary = ltx.notary, + timeWindow = ltx.timeWindow, + privacySalt = ltx.privacySalt, + networkParameters = ltx.networkParameters, + references = deserializedReferences, + digestService = ltx.digestService + ) + } + } + } + + /** + * Check the transaction is contract-valid by running verify() for each input and output state contract. + * If any contract fails to verify, the whole transaction is considered to be invalid. + * + * Note: Reference states are not verified. + */ + override fun verifyContracts() { + try { + ContractVerifier(transactionClassLoader).apply(Supplier(::createTransaction)) + } catch (e: TransactionVerificationException) { + logger.error("Error validating transaction ${ltx.id}.", e.cause) + throw e + } + } +} + +/** + * A "do nothing" [Verifier] installed for contract verification. + * + * THIS CLASS IS NOT PUBLIC API, AND IS DELIBERATELY PRIVATE! + */ +@CordaInternal +private class NoOpVerifier(ltx: LedgerTransaction, serializationContext: SerializationContext) + : Verifier(ltx, serializationContext.deserializationClassLoader) { + // Invoking LedgerTransaction.verify() from Contract.verify(LedgerTransaction) + // will execute this function. But why would anyone do that?! + override fun verifyContracts() {} +} diff --git a/core/src/main/kotlin/net/corda/core/transactions/WireTransaction.kt b/core/src/main/kotlin/net/corda/core/transactions/WireTransaction.kt index 22bfb19be2..194a7020e7 100644 --- a/core/src/main/kotlin/net/corda/core/transactions/WireTransaction.kt +++ b/core/src/main/kotlin/net/corda/core/transactions/WireTransaction.kt @@ -154,7 +154,7 @@ class WireTransaction(componentGroups: List, val privacySalt: Pr resolveAttachment, { stateRef -> resolveStateRef(stateRef)?.serialize() }, { null }, - { it.isUploaderTrusted() }, + Attachment::isUploaderTrusted, null ) } @@ -214,7 +214,7 @@ class WireTransaction(componentGroups: List, val privacySalt: Pr notary, timeWindow, privacySalt, - resolvedNetworkParameters, + resolvedNetworkParameters.toImmutable(), resolvedReferences, componentGroups, serializedResolvedInputs, diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/kryo/KryoCheckpointSerializer.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/kryo/KryoCheckpointSerializer.kt index 06698d99ad..178682e088 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/kryo/KryoCheckpointSerializer.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/kryo/KryoCheckpointSerializer.kt @@ -90,7 +90,7 @@ object KryoCheckpointSerializer : CheckpointSerializer { */ private fun getInputClassForCustomSerializer(classLoader: ClassLoader, customSerializer: CustomSerializerCheckpointAdaptor<*, *>): Class<*> { val typeNameWithoutGenerics = customSerializer.cordappType.typeName.substringBefore('<') - return classLoader.loadClass(typeNameWithoutGenerics) + return Class.forName(typeNameWithoutGenerics, false, classLoader) } /** diff --git a/node/djvm/src/main/kotlin/net/corda/node/djvm/CommandBuilder.kt b/node/djvm/src/main/kotlin/net/corda/node/djvm/CommandBuilder.kt index 311f8e69ff..247fef3ec6 100644 --- a/node/djvm/src/main/kotlin/net/corda/node/djvm/CommandBuilder.kt +++ b/node/djvm/src/main/kotlin/net/corda/node/djvm/CommandBuilder.kt @@ -5,35 +5,43 @@ import net.corda.core.contracts.CommandWithParties import net.corda.core.internal.lazyMapped import java.security.PublicKey import java.util.function.Function +import java.util.function.Supplier -class CommandBuilder : Function, List>> { +class CommandBuilder : Function, Supplier>>> { @Suppress("unchecked_cast") - override fun apply(inputs: Array): List> { - val signers = inputs[0] as? List> ?: emptyList() - val commandsData = inputs[1] as? List ?: emptyList() + override fun apply(inputs: Array): Supplier>> { + val signersProvider = inputs[0] as? Supplier>> ?: Supplier(::emptyList) + val commandsDataProvider = inputs[1] as? Supplier> ?: Supplier(::emptyList) val partialMerkleLeafIndices = inputs[2] as? IntArray /** * This logic has been lovingly reproduced from [net.corda.core.internal.deserialiseCommands]. */ - return if (partialMerkleLeafIndices != null) { - check(commandsData.size <= signers.size) { - "Invalid Transaction. Fewer Signers (${signers.size}) than CommandData (${commandsData.size}) objects" - } - if (partialMerkleLeafIndices.isNotEmpty()) { - check(partialMerkleLeafIndices.max()!! < signers.size) { - "Invalid Transaction. A command with no corresponding signer detected" + return Supplier { + val signers = signersProvider.get() + val commandsData = commandsDataProvider.get() + + if (partialMerkleLeafIndices != null) { + check(commandsData.size <= signers.size) { + "Invalid Transaction. Fewer Signers (${signers.size}) than CommandData (${commandsData.size}) objects" + } + if (partialMerkleLeafIndices.isNotEmpty()) { + check(partialMerkleLeafIndices.max()!! < signers.size) { + "Invalid Transaction. A command with no corresponding signer detected" + } + } + commandsData.lazyMapped { commandData, index -> + // Deprecated signingParties property not supported. + CommandWithParties(signers[partialMerkleLeafIndices[index]], emptyList(), commandData) + } + } else { + check(commandsData.size == signers.size) { + "Invalid Transaction. Sizes of CommandData (${commandsData.size}) and Signers (${signers.size}) do not match" + } + commandsData.lazyMapped { commandData, index -> + // Deprecated signingParties property not supported. + CommandWithParties(signers[index], emptyList(), commandData) } - } - commandsData.lazyMapped { commandData, index -> - CommandWithParties(signers[partialMerkleLeafIndices[index]], emptyList(), commandData) - } - } else { - check(commandsData.size == signers.size) { - "Invalid Transaction. Sizes of CommandData (${commandsData.size}) and Signers (${signers.size}) do not match" - } - commandsData.lazyMapped { commandData, index -> - CommandWithParties(signers[index], emptyList(), commandData) } } } diff --git a/node/djvm/src/main/kotlin/net/corda/node/djvm/ComponentBuilder.kt b/node/djvm/src/main/kotlin/net/corda/node/djvm/ComponentBuilder.kt index 78c3efc737..f0e2e476aa 100644 --- a/node/djvm/src/main/kotlin/net/corda/node/djvm/ComponentBuilder.kt +++ b/node/djvm/src/main/kotlin/net/corda/node/djvm/ComponentBuilder.kt @@ -5,19 +5,22 @@ import net.corda.core.internal.TransactionDeserialisationException import net.corda.core.internal.lazyMapped import net.corda.core.utilities.OpaqueBytes import java.util.function.Function +import java.util.function.Supplier -class ComponentBuilder : Function, List<*>> { +class ComponentBuilder : Function, Supplier>> { @Suppress("unchecked_cast", "TooGenericExceptionCaught") - override fun apply(inputs: Array): List<*> { + override fun apply(inputs: Array): Supplier> { val deserializer = inputs[0] as Function val groupType = inputs[1] as ComponentGroupEnum val components = (inputs[2] as Array).map(::OpaqueBytes) - return components.lazyMapped { component, index -> - try { - deserializer.apply(component.bytes) - } catch (e: Exception) { - throw TransactionDeserialisationException(groupType, index, e) + return Supplier { + components.lazyMapped { component, index -> + try { + deserializer.apply(component.bytes) + } catch (e: Exception) { + throw TransactionDeserialisationException(groupType, index, e) + } } } } diff --git a/node/djvm/src/main/kotlin/net/corda/node/djvm/LtxFactory.kt b/node/djvm/src/main/kotlin/net/corda/node/djvm/LtxFactory.kt deleted file mode 100644 index f12f0cb108..0000000000 --- a/node/djvm/src/main/kotlin/net/corda/node/djvm/LtxFactory.kt +++ /dev/null @@ -1,54 +0,0 @@ -@file:JvmName("LtxConstants") -package net.corda.node.djvm - -import net.corda.core.contracts.Attachment -import net.corda.core.contracts.CommandData -import net.corda.core.contracts.CommandWithParties -import net.corda.core.contracts.ContractState -import net.corda.core.contracts.PrivacySalt -import net.corda.core.contracts.StateAndRef -import net.corda.core.contracts.StateRef -import net.corda.core.contracts.TimeWindow -import net.corda.core.contracts.TransactionState -import net.corda.core.crypto.DigestService -import net.corda.core.crypto.SecureHash -import net.corda.core.identity.Party -import net.corda.core.node.NetworkParameters -import net.corda.core.transactions.LedgerTransaction -import java.util.function.Function - -private const val TX_INPUTS = 0 -private const val TX_OUTPUTS = 1 -private const val TX_COMMANDS = 2 -private const val TX_ATTACHMENTS = 3 -private const val TX_ID = 4 -private const val TX_NOTARY = 5 -private const val TX_TIME_WINDOW = 6 -private const val TX_PRIVACY_SALT = 7 -private const val TX_NETWORK_PARAMETERS = 8 -private const val TX_REFERENCES = 9 -private const val TX_DIGEST_SERVICE = 10 - -class LtxFactory : Function, LedgerTransaction> { - - @Suppress("unchecked_cast") - override fun apply(txArgs: Array): LedgerTransaction { - return LedgerTransaction.createForSandbox( - inputs = (txArgs[TX_INPUTS] as Array>).map { it.toStateAndRef() }, - outputs = (txArgs[TX_OUTPUTS] as? List>) ?: emptyList(), - commands = (txArgs[TX_COMMANDS] as? List>) ?: emptyList(), - attachments = (txArgs[TX_ATTACHMENTS] as? List) ?: emptyList(), - id = txArgs[TX_ID] as SecureHash, - notary = txArgs[TX_NOTARY] as? Party, - timeWindow = txArgs[TX_TIME_WINDOW] as? TimeWindow, - privacySalt = txArgs[TX_PRIVACY_SALT] as PrivacySalt, - networkParameters = txArgs[TX_NETWORK_PARAMETERS] as NetworkParameters, - references = (txArgs[TX_REFERENCES] as Array>).map { it.toStateAndRef() }, - digestService = if (txArgs.size > TX_DIGEST_SERVICE) (txArgs[TX_DIGEST_SERVICE] as DigestService) else DigestService.sha2_256 - ) - } - - private fun Array<*>.toStateAndRef(): StateAndRef { - return StateAndRef(this[0] as TransactionState<*>, this[1] as StateRef) - } -} diff --git a/node/djvm/src/main/kotlin/net/corda/node/djvm/LtxSupplierFactory.kt b/node/djvm/src/main/kotlin/net/corda/node/djvm/LtxSupplierFactory.kt new file mode 100644 index 0000000000..fd982d610a --- /dev/null +++ b/node/djvm/src/main/kotlin/net/corda/node/djvm/LtxSupplierFactory.kt @@ -0,0 +1,73 @@ +@file:JvmName("LtxTools") +package net.corda.node.djvm + +import net.corda.core.contracts.Attachment +import net.corda.core.contracts.CommandData +import net.corda.core.contracts.CommandWithParties +import net.corda.core.contracts.ContractState +import net.corda.core.contracts.PrivacySalt +import net.corda.core.contracts.StateAndRef +import net.corda.core.contracts.StateRef +import net.corda.core.contracts.TimeWindow +import net.corda.core.contracts.TransactionState +import net.corda.core.crypto.DigestService +import net.corda.core.crypto.SecureHash +import net.corda.core.identity.Party +import net.corda.core.node.NetworkParameters +import net.corda.core.transactions.LedgerTransaction +import java.util.function.Function +import java.util.function.Supplier + +private const val TX_INPUTS = 0 +private const val TX_OUTPUTS = 1 +private const val TX_COMMANDS = 2 +private const val TX_ATTACHMENTS = 3 +private const val TX_ID = 4 +private const val TX_NOTARY = 5 +private const val TX_TIME_WINDOW = 6 +private const val TX_PRIVACY_SALT = 7 +private const val TX_NETWORK_PARAMETERS = 8 +private const val TX_REFERENCES = 9 +private const val TX_DIGEST_SERVICE = 10 + +class LtxSupplierFactory : Function, Supplier> { + @Suppress("unchecked_cast") + override fun apply(txArgs: Array): Supplier { + val inputProvider = (txArgs[TX_INPUTS] as Function>>) + .andThen(Function(Array>::toContractStatesAndRef)) + .toSupplier() + val outputProvider = txArgs[TX_OUTPUTS] as? Supplier>> ?: Supplier(::emptyList) + val commandsProvider = txArgs[TX_COMMANDS] as Supplier>> + val referencesProvider = (txArgs[TX_REFERENCES] as Function>>) + .andThen(Function(Array>::toContractStatesAndRef)) + .toSupplier() + val networkParameters = (txArgs[TX_NETWORK_PARAMETERS] as? NetworkParameters)?.toImmutable() + return Supplier { + LedgerTransaction.createForContractVerify( + inputs = inputProvider.get(), + outputs = outputProvider.get(), + commands = commandsProvider.get(), + attachments = txArgs[TX_ATTACHMENTS] as? List ?: emptyList(), + id = txArgs[TX_ID] as SecureHash, + notary = txArgs[TX_NOTARY] as? Party, + timeWindow = txArgs[TX_TIME_WINDOW] as? TimeWindow, + privacySalt = txArgs[TX_PRIVACY_SALT] as PrivacySalt, + networkParameters = networkParameters, + references = referencesProvider.get(), + digestService = txArgs[TX_DIGEST_SERVICE] as DigestService + ) + } + } +} + +private fun Function.toSupplier(): Supplier { + return Supplier { apply(null) } +} + +private fun Array>.toContractStatesAndRef(): List> { + return map(Array::toStateAndRef) +} + +private fun Array<*>.toStateAndRef(): StateAndRef { + return StateAndRef(this[0] as TransactionState<*>, this[1] as StateRef) +} diff --git a/node/src/integration-test/kotlin/net/corda/contracts/multiple/evil/EvilContract.kt b/node/src/integration-test/kotlin/net/corda/contracts/multiple/evil/EvilContract.kt new file mode 100644 index 0000000000..a5c547bd2d --- /dev/null +++ b/node/src/integration-test/kotlin/net/corda/contracts/multiple/evil/EvilContract.kt @@ -0,0 +1,46 @@ +package net.corda.contracts.multiple.evil + +import net.corda.contracts.multiple.vulnerable.MutableDataObject +import net.corda.contracts.multiple.vulnerable.VulnerablePaymentContract.VulnerablePurchase +import net.corda.contracts.multiple.vulnerable.VulnerablePaymentContract.VulnerableState +import net.corda.core.contracts.CommandData +import net.corda.core.contracts.Contract +import net.corda.core.contracts.ContractState +import net.corda.core.identity.AbstractParty +import net.corda.core.transactions.LedgerTransaction + +@Suppress("unused") +class EvilContract : Contract { + override fun verify(tx: LedgerTransaction) { + val vulnerableStates = tx.outputsOfType(VulnerableState::class.java) + val vulnerablePurchases = tx.commandsOfType(VulnerablePurchase::class.java) + + val addExtras = tx.commandsOfType(AddExtra::class.java) + addExtras.forEach { extra -> + val extraValue = extra.value.payment.value + + // And our extra value to every vulnerable output state. + vulnerableStates.forEach { state -> + state.data?.also { data -> + data.value += extraValue + } + } + + // Add our extra value to every vulnerable command too. + vulnerablePurchases.forEach { purchase -> + purchase.value.payment.value += extraValue + } + } + } + + class EvilState(val owner: AbstractParty) : ContractState { + override val participants: List = listOf(owner) + + @Override + override fun toString(): String { + return "Money For Nothing!" + } + } + + class AddExtra(val payment: MutableDataObject) : CommandData +} diff --git a/node/src/integration-test/kotlin/net/corda/contracts/multiple/vulnerable/MutableDataObject.kt b/node/src/integration-test/kotlin/net/corda/contracts/multiple/vulnerable/MutableDataObject.kt new file mode 100644 index 0000000000..50ffebfa17 --- /dev/null +++ b/node/src/integration-test/kotlin/net/corda/contracts/multiple/vulnerable/MutableDataObject.kt @@ -0,0 +1,14 @@ +package net.corda.contracts.multiple.vulnerable + +import net.corda.core.serialization.CordaSerializable + +@CordaSerializable +data class MutableDataObject(var value: Long) : Comparable { + override fun toString(): String { + return "$value data points" + } + + override fun compareTo(other: MutableDataObject): Int { + return value.compareTo(other.value) + } +} diff --git a/node/src/integration-test/kotlin/net/corda/contracts/multiple/vulnerable/VulnerablePaymentContract.kt b/node/src/integration-test/kotlin/net/corda/contracts/multiple/vulnerable/VulnerablePaymentContract.kt new file mode 100644 index 0000000000..6e90bb5725 --- /dev/null +++ b/node/src/integration-test/kotlin/net/corda/contracts/multiple/vulnerable/VulnerablePaymentContract.kt @@ -0,0 +1,43 @@ +package net.corda.contracts.multiple.vulnerable + +import net.corda.core.contracts.CommandData +import net.corda.core.contracts.Contract +import net.corda.core.contracts.ContractState +import net.corda.core.contracts.requireThat +import net.corda.core.identity.AbstractParty +import net.corda.core.transactions.LedgerTransaction + +@Suppress("unused") +class VulnerablePaymentContract : Contract { + companion object { + const val BASE_PAYMENT = 2000L + } + + override fun verify(tx: LedgerTransaction) { + val states = tx.outputsOfType() + requireThat { + "Requires at least one data state" using states.isNotEmpty() + } + val purchases = tx.commandsOfType() + requireThat { + "Requires at least one purchase" using purchases.isNotEmpty() + } + for (purchase in purchases) { + val payment = purchase.value.payment + requireThat { + "Purchase payment of $payment should be at least $BASE_PAYMENT" using (payment.value >= BASE_PAYMENT) + } + } + } + + class VulnerableState(val owner: AbstractParty, val data: MutableDataObject?) : ContractState { + override val participants: List = listOf(owner) + + @Override + override fun toString(): String { + return data.toString() + } + } + + class VulnerablePurchase(val payment: MutableDataObject) : CommandData +} diff --git a/node/src/integration-test/kotlin/net/corda/contracts/mutator/MutatorContract.kt b/node/src/integration-test/kotlin/net/corda/contracts/mutator/MutatorContract.kt new file mode 100644 index 0000000000..cffcf18b3d --- /dev/null +++ b/node/src/integration-test/kotlin/net/corda/contracts/mutator/MutatorContract.kt @@ -0,0 +1,113 @@ +package net.corda.contracts.mutator + +import net.corda.core.contracts.CommandData +import net.corda.core.contracts.CommandWithParties +import net.corda.core.contracts.Contract +import net.corda.core.contracts.ContractState +import net.corda.core.contracts.TransactionState +import net.corda.core.contracts.requireSingleCommand +import net.corda.core.contracts.requireThat +import net.corda.core.identity.AbstractParty +import net.corda.core.internal.Verifier +import net.corda.core.serialization.SerializationContext +import net.corda.core.transactions.LedgerTransaction + +class MutatorContract : Contract { + override fun verify(tx: LedgerTransaction) { + tx.transform { componentGroups, serializedInputs, serializedReferences -> + requireThat { + "component groups are protected" using componentGroups.isImmutableAnd(isEmpty = true) + "serialized inputs are protected" using serializedInputs.isImmutableAnd(isEmpty = true) + "serialized references are protected" using serializedReferences.isImmutableAnd(isEmpty = true) + } + } + + requireThat { + "Cannot add/remove inputs" using tx.inputs.isImmutable() + "Cannot add/remove outputs" using failToMutateOutputs(tx) + "Cannot add/remove commands" using failToMutateCommands(tx) + "Cannot add/remove references" using tx.references.isImmutable() + "Cannot add/remove attachments" using tx.attachments.isImmutableAnd(isEmpty = false) + "Cannot specialise transaction" using failToSpecialise(tx) + } + + requireNotNull(tx.networkParameters).also { networkParameters -> + requireThat { + "Cannot add/remove notaries" using networkParameters.notaries.isImmutableAnd(isEmpty = false) + "Cannot add/remove package ownerships" using networkParameters.packageOwnership.isImmutable() + "Cannot add/remove whitelisted contracts" using networkParameters.whitelistedContractImplementations.isImmutable() + } + } + } + + private fun List<*>.isImmutableAnd(isEmpty: Boolean): Boolean { + return isImmutable() && (this.isEmpty() == isEmpty) + } + + private fun List<*>.isImmutable(): Boolean { + return try { + @Suppress("platform_class_mapped_to_kotlin") + (this as java.util.List<*>).clear() + false + } catch (e: UnsupportedOperationException) { + true + } + } + + private fun failToMutateOutputs(tx: LedgerTransaction): Boolean { + val output = tx.outputsOfType().single() + val mutableOutputs = tx.outputs as MutableList> + return try { + mutableOutputs += TransactionState(MutateState(output.owner), MutatorContract::class.java.name, tx.notary!!, 0) + false + } catch (e: UnsupportedOperationException) { + true + } + } + + private fun failToMutateCommands(tx: LedgerTransaction): Boolean { + val mutate = tx.commands.requireSingleCommand() + val mutableCommands = tx.commands as MutableList> + return try { + mutableCommands += CommandWithParties(mutate.signers, emptyList(), MutateCommand()) + false + } catch (e: UnsupportedOperationException) { + true + } + } + + private fun Map<*, *>.isImmutable(): Boolean { + return try { + @Suppress("platform_class_mapped_to_kotlin") + (this as java.util.Map<*, *>).clear() + false + } catch (e: UnsupportedOperationException) { + true + } + } + + private fun failToSpecialise(ltx: LedgerTransaction): Boolean { + return try { + ltx.specialise(::ExtraSpecialise) + false + } catch (e: IllegalStateException) { + true + } + } + + private class ExtraSpecialise(ltx: LedgerTransaction, ctx: SerializationContext) + : Verifier(ltx, ctx.deserializationClassLoader) { + override fun verifyContracts() {} + } + + class MutateState(val owner: AbstractParty) : ContractState { + override val participants: List = listOf(owner) + + @Override + override fun toString(): String { + return "All change!" + } + } + + class MutateCommand : CommandData +} diff --git a/node/src/integration-test/kotlin/net/corda/flows/multiple/evil/EvilFlow.kt b/node/src/integration-test/kotlin/net/corda/flows/multiple/evil/EvilFlow.kt new file mode 100644 index 0000000000..0c627c7e46 --- /dev/null +++ b/node/src/integration-test/kotlin/net/corda/flows/multiple/evil/EvilFlow.kt @@ -0,0 +1,39 @@ +package net.corda.flows.multiple.evil + +import co.paralleluniverse.fibers.Suspendable +import net.corda.contracts.multiple.evil.EvilContract.EvilState +import net.corda.contracts.multiple.evil.EvilContract.AddExtra +import net.corda.contracts.multiple.vulnerable.MutableDataObject +import net.corda.contracts.multiple.vulnerable.VulnerablePaymentContract.VulnerablePurchase +import net.corda.contracts.multiple.vulnerable.VulnerablePaymentContract.VulnerableState +import net.corda.core.contracts.Command +import net.corda.core.crypto.SecureHash +import net.corda.core.flows.FlowLogic +import net.corda.core.flows.StartableByRPC +import net.corda.core.transactions.TransactionBuilder + +@StartableByRPC +class EvilFlow( + private val purchase: MutableDataObject +) : FlowLogic() { + private companion object { + private val NOTHING = MutableDataObject(0) + } + + @Suspendable + override fun call(): SecureHash { + val notary = serviceHub.networkMapCache.notaryIdentities[0] + val stx = serviceHub.signInitialTransaction( + TransactionBuilder(notary) + // Add Evil objects first, so that Corda will verify EvilContract first. + .addCommand(Command(AddExtra(purchase), ourIdentity.owningKey)) + .addOutputState(EvilState(ourIdentity)) + + // Now add the VulnerablePaymentContract objects with NO PAYMENT! + .addCommand(Command(VulnerablePurchase(NOTHING), ourIdentity.owningKey)) + .addOutputState(VulnerableState(ourIdentity, NOTHING)) + ) + stx.verify(serviceHub, checkSufficientSignatures = false) + return stx.id + } +} diff --git a/node/src/integration-test/kotlin/net/corda/flows/mutator/MutatorFlow.kt b/node/src/integration-test/kotlin/net/corda/flows/mutator/MutatorFlow.kt new file mode 100644 index 0000000000..c131dfa433 --- /dev/null +++ b/node/src/integration-test/kotlin/net/corda/flows/mutator/MutatorFlow.kt @@ -0,0 +1,26 @@ +package net.corda.flows.mutator + +import co.paralleluniverse.fibers.Suspendable +import net.corda.contracts.mutator.MutatorContract.MutateCommand +import net.corda.contracts.mutator.MutatorContract.MutateState +import net.corda.core.contracts.Command +import net.corda.core.crypto.SecureHash +import net.corda.core.flows.FlowLogic +import net.corda.core.flows.StartableByRPC +import net.corda.core.transactions.TransactionBuilder + +@StartableByRPC +class MutatorFlow : FlowLogic() { + @Suspendable + override fun call(): SecureHash { + val notary = serviceHub.networkMapCache.notaryIdentities[0] + val stx = serviceHub.signInitialTransaction( + TransactionBuilder(notary) + // Create some content for the LedgerTransaction. + .addOutputState(MutateState(ourIdentity)) + .addCommand(Command(MutateCommand(), ourIdentity.owningKey)) + ) + stx.verify(serviceHub, checkSufficientSignatures = false) + return stx.id + } +} diff --git a/node/src/integration-test/kotlin/net/corda/node/ContractCannotMutateTransactionTest.kt b/node/src/integration-test/kotlin/net/corda/node/ContractCannotMutateTransactionTest.kt new file mode 100644 index 0000000000..eecfe203ec --- /dev/null +++ b/node/src/integration-test/kotlin/net/corda/node/ContractCannotMutateTransactionTest.kt @@ -0,0 +1,48 @@ +package net.corda.node + +import net.corda.client.rpc.CordaRPCClient +import net.corda.core.messaging.startFlow +import net.corda.core.utilities.getOrThrow +import net.corda.core.utilities.loggerFor +import net.corda.flows.mutator.MutatorFlow +import net.corda.node.services.Permissions +import net.corda.testing.core.ALICE_NAME +import net.corda.testing.core.DUMMY_NOTARY_NAME +import net.corda.testing.driver.DriverParameters +import net.corda.testing.driver.driver +import net.corda.testing.driver.internal.incrementalPortAllocation +import net.corda.testing.node.NotarySpec +import net.corda.testing.node.User +import net.corda.testing.node.internal.cordappWithPackages +import org.junit.Test + +class ContractCannotMutateTransactionTest { + companion object { + private val logger = loggerFor() + private val user = User("u", "p", setOf(Permissions.all())) + private val mutatorFlowCorDapp = cordappWithPackages("net.corda.flows.mutator").signed() + private val mutatorContractCorDapp = cordappWithPackages("net.corda.contracts.mutator").signed() + + fun driverParameters(runInProcess: Boolean): DriverParameters { + return DriverParameters( + portAllocation = incrementalPortAllocation(), + startNodesInProcess = runInProcess, + notarySpecs = listOf(NotarySpec(DUMMY_NOTARY_NAME, startInProcess = runInProcess, validating = true)), + cordappsForAllNodes = listOf(mutatorContractCorDapp, mutatorFlowCorDapp) + ) + } + } + + @Test(timeout = 300_000) + fun testContractCannotModifyTransaction() { + driver(driverParameters(runInProcess = false)) { + val alice = startNode(providedName = ALICE_NAME, rpcUsers = listOf(user)).getOrThrow() + val txID = CordaRPCClient(hostAndPort = alice.rpcAddress) + .start(user.username, user.password) + .use { client -> + client.proxy.startFlow(::MutatorFlow).returnValue.getOrThrow() + } + logger.info("TX-ID: {}", txID) + } + } +} diff --git a/node/src/integration-test/kotlin/net/corda/node/EvilContractCannotModifyStatesTest.kt b/node/src/integration-test/kotlin/net/corda/node/EvilContractCannotModifyStatesTest.kt new file mode 100644 index 0000000000..c365552553 --- /dev/null +++ b/node/src/integration-test/kotlin/net/corda/node/EvilContractCannotModifyStatesTest.kt @@ -0,0 +1,60 @@ +package net.corda.node + +import net.corda.client.rpc.CordaRPCClient +import net.corda.contracts.multiple.vulnerable.MutableDataObject +import net.corda.core.contracts.TransactionVerificationException.ContractRejection +import net.corda.core.messaging.startFlow +import net.corda.core.utilities.getOrThrow +import net.corda.flows.multiple.evil.EvilFlow +import net.corda.node.services.Permissions +import net.corda.testing.core.ALICE_NAME +import net.corda.testing.core.DUMMY_NOTARY_NAME +import net.corda.testing.driver.DriverParameters +import net.corda.testing.driver.driver +import net.corda.testing.driver.internal.incrementalPortAllocation +import net.corda.testing.node.NotarySpec +import net.corda.testing.node.User +import net.corda.testing.node.internal.cordappWithPackages +import org.assertj.core.api.Assertions.assertThat +import org.junit.Test +import kotlin.test.assertFailsWith + +class EvilContractCannotModifyStatesTest { + companion object { + private val user = User("u", "p", setOf(Permissions.all())) + private val evilFlowCorDapp = cordappWithPackages("net.corda.flows.multiple.evil").signed() + private val evilContractCorDapp = cordappWithPackages("net.corda.contracts.multiple.evil").signed() + private val vulnerableContractCorDapp = cordappWithPackages("net.corda.contracts.multiple.vulnerable").signed() + + private val NOTHING = MutableDataObject(0) + + fun driverParameters(runInProcess: Boolean): DriverParameters { + return DriverParameters( + portAllocation = incrementalPortAllocation(), + startNodesInProcess = runInProcess, + notarySpecs = listOf(NotarySpec(DUMMY_NOTARY_NAME, startInProcess = runInProcess, validating = true)), + cordappsForAllNodes = listOf( + vulnerableContractCorDapp, + evilContractCorDapp, + evilFlowCorDapp + ) + ) + } + } + + @Test(timeout = 300_000) + fun testContractThatTriesToModifyStates() { + val evilData = MutableDataObject(5000) + driver(driverParameters(runInProcess = false)) { + val alice = startNode(providedName = ALICE_NAME, rpcUsers = listOf(user)).getOrThrow() + val ex = assertFailsWith { + CordaRPCClient(hostAndPort = alice.rpcAddress) + .start(user.username, user.password) + .use { client -> + client.proxy.startFlow(::EvilFlow, evilData).returnValue.getOrThrow() + } + } + assertThat(ex).hasMessageContaining("Purchase payment of $NOTHING should be at least ") + } + } +} diff --git a/node/src/integration-test/kotlin/net/corda/node/services/DeterministicContractCannotMutateTransactionTest.kt b/node/src/integration-test/kotlin/net/corda/node/services/DeterministicContractCannotMutateTransactionTest.kt new file mode 100644 index 0000000000..68b8c3531c --- /dev/null +++ b/node/src/integration-test/kotlin/net/corda/node/services/DeterministicContractCannotMutateTransactionTest.kt @@ -0,0 +1,55 @@ +package net.corda.node.services + +import net.corda.client.rpc.CordaRPCClient +import net.corda.core.messaging.startFlow +import net.corda.core.utilities.getOrThrow +import net.corda.core.utilities.loggerFor +import net.corda.flows.mutator.MutatorFlow +import net.corda.node.DeterministicSourcesRule +import net.corda.testing.core.ALICE_NAME +import net.corda.testing.core.DUMMY_NOTARY_NAME +import net.corda.testing.driver.DriverParameters +import net.corda.testing.driver.driver +import net.corda.testing.driver.internal.incrementalPortAllocation +import net.corda.testing.node.NotarySpec +import net.corda.testing.node.User +import net.corda.testing.node.internal.cordappWithPackages +import org.junit.ClassRule +import org.junit.Test + +class DeterministicContractCannotMutateTransactionTest { + companion object { + private val logger = loggerFor() + private val user = User("u", "p", setOf(Permissions.all())) + private val mutatorFlowCorDapp = cordappWithPackages("net.corda.flows.mutator").signed() + private val mutatorContractCorDapp = cordappWithPackages("net.corda.contracts.mutator").signed() + + @ClassRule + @JvmField + val djvmSources = DeterministicSourcesRule() + + fun driverParameters(runInProcess: Boolean): DriverParameters { + return DriverParameters( + portAllocation = incrementalPortAllocation(), + startNodesInProcess = runInProcess, + notarySpecs = listOf(NotarySpec(DUMMY_NOTARY_NAME, startInProcess = runInProcess, validating = true)), + cordappsForAllNodes = listOf(mutatorContractCorDapp, mutatorFlowCorDapp), + djvmBootstrapSource = djvmSources.bootstrap, + djvmCordaSource = djvmSources.corda + ) + } + } + + @Test(timeout = 300_000) + fun testContractCannotModifyTransaction() { + driver(driverParameters(runInProcess = false)) { + val alice = startNode(providedName = ALICE_NAME, rpcUsers = listOf(user)).getOrThrow() + val txID = CordaRPCClient(hostAndPort = alice.rpcAddress) + .start(user.username, user.password) + .use { client -> + client.proxy.startFlow(::MutatorFlow).returnValue.getOrThrow() + } + logger.info("TX-ID: {}", txID) + } + } +} diff --git a/node/src/integration-test/kotlin/net/corda/node/services/DeterministicEvilContractCannotModifyStatesTest.kt b/node/src/integration-test/kotlin/net/corda/node/services/DeterministicEvilContractCannotModifyStatesTest.kt new file mode 100644 index 0000000000..f2d455dce4 --- /dev/null +++ b/node/src/integration-test/kotlin/net/corda/node/services/DeterministicEvilContractCannotModifyStatesTest.kt @@ -0,0 +1,71 @@ +package net.corda.node.services + +import net.corda.client.rpc.CordaRPCClient +import net.corda.contracts.multiple.vulnerable.MutableDataObject +import net.corda.contracts.multiple.vulnerable.VulnerablePaymentContract +import net.corda.core.messaging.startFlow +import net.corda.core.utilities.getOrThrow +import net.corda.flows.multiple.evil.EvilFlow +import net.corda.node.DeterministicSourcesRule +import net.corda.node.internal.djvm.DeterministicVerificationException +import net.corda.testing.core.ALICE_NAME +import net.corda.testing.core.DUMMY_NOTARY_NAME +import net.corda.testing.driver.DriverParameters +import net.corda.testing.driver.driver +import net.corda.testing.driver.internal.incrementalPortAllocation +import net.corda.testing.node.NotarySpec +import net.corda.testing.node.User +import net.corda.testing.node.internal.cordappWithPackages +import org.assertj.core.api.Assertions.assertThat +import org.junit.ClassRule +import org.junit.Test +import kotlin.test.assertFailsWith + +class DeterministicEvilContractCannotModifyStatesTest { + companion object { + private val user = User("u", "p", setOf(Permissions.all())) + private val evilFlowCorDapp = cordappWithPackages("net.corda.flows.multiple.evil").signed() + private val evilContractCorDapp = cordappWithPackages("net.corda.contracts.multiple.evil").signed() + private val vulnerableContractCorDapp = cordappWithPackages("net.corda.contracts.multiple.vulnerable").signed() + + private val NOTHING = MutableDataObject(0) + + @ClassRule + @JvmField + val djvmSources = DeterministicSourcesRule() + + fun driverParameters(runInProcess: Boolean): DriverParameters { + return DriverParameters( + portAllocation = incrementalPortAllocation(), + startNodesInProcess = runInProcess, + notarySpecs = listOf(NotarySpec(DUMMY_NOTARY_NAME, startInProcess = runInProcess, validating = true)), + cordappsForAllNodes = listOf( + vulnerableContractCorDapp, + evilContractCorDapp, + evilFlowCorDapp + ), + djvmBootstrapSource = djvmSources.bootstrap, + djvmCordaSource = djvmSources.corda + ) + } + } + + @Test(timeout = 300_000) + fun testContractThatTriesToModifyStates() { + val evilData = MutableDataObject(5000) + driver(driverParameters(runInProcess = false)) { + val alice = startNode(providedName = ALICE_NAME, rpcUsers = listOf(user)).getOrThrow() + val ex = assertFailsWith { + CordaRPCClient(hostAndPort = alice.rpcAddress) + .start(user.username, user.password) + .use { client -> + client.proxy.startFlow(::EvilFlow, evilData).returnValue.getOrThrow() + } + } + assertThat(ex) + .hasMessageStartingWith("sandbox.net.corda.core.contracts.TransactionVerificationException\$ContractRejection -> ") + .hasMessageContaining(" Contract verification failed: Failed requirement: Purchase payment of $NOTHING should be at least ") + .hasMessageContaining(", contract: sandbox.${VulnerablePaymentContract::class.java.name}, ") + } + } +} diff --git a/node/src/integration-test/kotlin/net/corda/node/services/SandboxAttachmentsTest.kt b/node/src/integration-test/kotlin/net/corda/node/services/SandboxAttachmentsTest.kt index 85dac332a9..c825357581 100644 --- a/node/src/integration-test/kotlin/net/corda/node/services/SandboxAttachmentsTest.kt +++ b/node/src/integration-test/kotlin/net/corda/node/services/SandboxAttachmentsTest.kt @@ -5,7 +5,6 @@ import net.corda.contracts.djvm.attachment.SandboxAttachmentContract.ExtractFile import net.corda.core.messaging.startFlow import net.corda.core.utilities.getOrThrow import net.corda.core.utilities.loggerFor -import net.corda.djvm.code.asResourcePath import net.corda.flows.djvm.attachment.SandboxAttachmentFlow import net.corda.node.DeterministicSourcesRule import net.corda.node.internal.djvm.DeterministicVerificationException @@ -52,7 +51,7 @@ class SandboxAttachmentsTest { @Test(timeout=300_000) fun `test attachment accessible within sandbox`() { - val extractFile = ExtractFile(SandboxAttachmentContract::class.java.name.asResourcePath + ".class") + val extractFile = ExtractFile(SandboxAttachmentContract::class.java.name.replace('.', '/') + ".class") driver(parametersFor(djvmSources)) { val alice = startNode(providedName = ALICE_NAME).getOrThrow() val txId = assertDoesNotThrow { diff --git a/node/src/main/kotlin/net/corda/node/internal/classloading/Utils.kt b/node/src/main/kotlin/net/corda/node/internal/classloading/Utils.kt index 07d1631097..bb49eeb179 100644 --- a/node/src/main/kotlin/net/corda/node/internal/classloading/Utils.kt +++ b/node/src/main/kotlin/net/corda/node/internal/classloading/Utils.kt @@ -14,7 +14,7 @@ inline fun Class<*>.requireAnnotation(): A { fun scanForCustomSerializationScheme(className: String, classLoader: ClassLoader) : SerializationScheme { val schemaClass = try { - classLoader.loadClass(className) + Class.forName(className, false, classLoader) } catch (exception: ClassNotFoundException) { throw ConfigurationException("$className was declared as a custom serialization scheme but could not be found.") } diff --git a/node/src/main/kotlin/net/corda/node/internal/djvm/DeterministicVerifier.kt b/node/src/main/kotlin/net/corda/node/internal/djvm/DeterministicVerifier.kt index 654b218aee..f1880fbf54 100644 --- a/node/src/main/kotlin/net/corda/node/internal/djvm/DeterministicVerifier.kt +++ b/node/src/main/kotlin/net/corda/node/internal/djvm/DeterministicVerifier.kt @@ -21,7 +21,7 @@ import net.corda.djvm.execution.SandboxException import net.corda.djvm.messages.Message import net.corda.djvm.rewiring.SandboxClassLoader import net.corda.djvm.source.ClassSource -import net.corda.node.djvm.LtxFactory +import net.corda.node.djvm.LtxSupplierFactory import java.util.function.Function import kotlin.collections.LinkedHashSet @@ -93,14 +93,14 @@ class DeterministicVerifier( val networkingParametersData = ltx.networkParameters?.serialize() val digestServiceData = ltx.digestService.serialize() - val createSandboxTx = taskFactory.apply(LtxFactory::class.java) + val createSandboxTx = taskFactory.apply(LtxSupplierFactory::class.java) createSandboxTx.apply(arrayOf( - serializer.deserialize(serializedInputs), + classLoader.createForImport(Function { serializer.deserialize(serializedInputs) }), componentFactory.toSandbox(OUTPUTS_GROUP, TransactionState::class.java), CommandFactory(taskFactory).toSandbox( componentFactory.toSandbox(SIGNERS_GROUP, List::class.java), componentFactory.toSandbox(COMMANDS_GROUP, CommandData::class.java), - componentFactory.calculateLeafIndicesFor(COMMANDS_GROUP, digestService = ltx.digestService) + componentFactory.calculateLeafIndicesFor(COMMANDS_GROUP, ltx.digestService) ), attachmentFactory.toSandbox(ltx.attachments), serializer.deserialize(idData), @@ -108,7 +108,7 @@ class DeterministicVerifier( serializer.deserialize(timeWindowData), serializer.deserialize(privacySaltData), serializer.deserialize(networkingParametersData), - serializer.deserialize(serializedReferences), + classLoader.createForImport(Function { serializer.deserialize(serializedReferences) }), serializer.deserialize(digestServiceData) )) } diff --git a/node/src/main/kotlin/net/corda/node/internal/djvm/Serializer.kt b/node/src/main/kotlin/net/corda/node/internal/djvm/Serializer.kt index 40a5522a28..32fedf52fa 100644 --- a/node/src/main/kotlin/net/corda/node/internal/djvm/Serializer.kt +++ b/node/src/main/kotlin/net/corda/node/internal/djvm/Serializer.kt @@ -1,6 +1,9 @@ package net.corda.node.internal.djvm import net.corda.core.internal.SerializedStateAndRef +import net.corda.core.serialization.AMQP_ENVELOPE_CACHE_INITIAL_CAPACITY +import net.corda.core.serialization.AMQP_ENVELOPE_CACHE_PROPERTY +import net.corda.core.serialization.DESERIALIZATION_CACHE_PROPERTY import net.corda.core.serialization.SerializationContext import net.corda.core.serialization.SerializationFactory import net.corda.core.serialization.SerializedBytes @@ -22,14 +25,20 @@ class Serializer( init { val env = createSandboxSerializationEnv(classLoader, customSerializerNames, serializationWhitelists) factory = env.serializationFactory - context = env.p2pContext + context = env.p2pContext.withProperties(mapOf( + // Duplicate the P2P SerializationContext and give it + // these extra properties, just for this transaction. + AMQP_ENVELOPE_CACHE_PROPERTY to HashMap(AMQP_ENVELOPE_CACHE_INITIAL_CAPACITY), + DESERIALIZATION_CACHE_PROPERTY to HashMap() + )) } /** * Convert a list of [SerializedStateAndRef] objects into arrays * of deserialized sandbox objects. We will pass this array into - * [net.corda.node.djvm.LtxFactory] to be transformed finally to - * a list of [net.corda.core.contracts.StateAndRef] objects, + * [LtxSupplierFactory][net.corda.node.djvm.LtxSupplierFactory] + * to be transformed finally to a list of + * [StateAndRef][net.corda.core.contracts.StateAndRef] objects, */ fun deserialize(stateRefs: List): Array> { return stateRefs.map { diff --git a/node/src/main/kotlin/net/corda/node/services/rpc/CheckpointDumperImpl.kt b/node/src/main/kotlin/net/corda/node/services/rpc/CheckpointDumperImpl.kt index c49f0b5a63..eb9258d44e 100644 --- a/node/src/main/kotlin/net/corda/node/services/rpc/CheckpointDumperImpl.kt +++ b/node/src/main/kotlin/net/corda/node/services/rpc/CheckpointDumperImpl.kt @@ -315,7 +315,7 @@ class CheckpointDumperImpl(private val checkpointStorage: CheckpointStorage, pri * the checkpoint agent source code */ private fun checkpointAgentRunning() = try { - javaClass.classLoader.loadClass("net.corda.tools.CheckpointAgent").kotlin.companionObject + Class.forName("net.corda.tools.CheckpointAgent", false, javaClass.classLoader).kotlin.companionObject } catch (e: ClassNotFoundException) { null }?.let { cls -> diff --git a/node/src/main/kotlin/net/corda/node/services/transactions/DeterministicVerifierFactoryService.kt b/node/src/main/kotlin/net/corda/node/services/transactions/DeterministicVerifierFactoryService.kt index d514335c92..5b016d5734 100644 --- a/node/src/main/kotlin/net/corda/node/services/transactions/DeterministicVerifierFactoryService.kt +++ b/node/src/main/kotlin/net/corda/node/services/transactions/DeterministicVerifierFactoryService.kt @@ -1,6 +1,5 @@ package net.corda.node.services.transactions -import net.corda.core.internal.BasicVerifier import net.corda.core.internal.Verifier import net.corda.core.serialization.ConstructorForDeserialization import net.corda.core.serialization.CordaSerializable @@ -9,6 +8,7 @@ import net.corda.core.serialization.CordaSerializationTransformEnumDefaults import net.corda.core.serialization.CordaSerializationTransformRename import net.corda.core.serialization.CordaSerializationTransformRenames import net.corda.core.serialization.DeprecatedConstructorForDeserialization +import net.corda.core.serialization.SerializationContext import net.corda.core.serialization.SingletonSerializeAsToken import net.corda.core.transactions.LedgerTransaction import net.corda.djvm.SandboxConfiguration @@ -80,13 +80,13 @@ class DeterministicVerifierFactoryService( override fun apply(ledgerTransaction: LedgerTransaction): LedgerTransaction { // Specialise the LedgerTransaction here so that // contracts are verified inside the DJVM! - return ledgerTransaction.specialise(::specialise) + return ledgerTransaction.specialise(::createDeterministicVerifier) } - private fun specialise(ltx: LedgerTransaction, classLoader: ClassLoader): Verifier { - return (classLoader as? URLClassLoader)?.run { + private fun createDeterministicVerifier(ltx: LedgerTransaction, serializationContext: SerializationContext): Verifier { + return (serializationContext.deserializationClassLoader as? URLClassLoader)?.let { classLoader -> DeterministicVerifier(ltx, classLoader, createSandbox(classLoader.urLs)) - } ?: BasicVerifier(ltx, classLoader) + } ?: throw IllegalStateException("Unsupported deserialization classloader type") } private fun createSandbox(userSource: Array): SandboxConfiguration { diff --git a/node/src/test/kotlin/net/corda/node/internal/CustomSerializationSchemeScanningTest.kt b/node/src/test/kotlin/net/corda/node/internal/CustomSerializationSchemeScanningTest.kt index 79eb969b21..1837a8fb49 100644 --- a/node/src/test/kotlin/net/corda/node/internal/CustomSerializationSchemeScanningTest.kt +++ b/node/src/test/kotlin/net/corda/node/internal/CustomSerializationSchemeScanningTest.kt @@ -1,6 +1,5 @@ package net.corda.node.internal -import com.nhaarman.mockito_kotlin.whenever import net.corda.core.serialization.CustomSerializationScheme import net.corda.core.serialization.SerializationContext import net.corda.core.serialization.SerializationSchemeContext @@ -16,7 +15,7 @@ class CustomSerializationSchemeScanningTest { open class DummySerializationScheme : CustomSerializationScheme { override fun getSchemeId(): Int { - return 7; + return 7 } override fun deserialize(bytes: ByteSequence, clazz: Class, context: SerializationSchemeContext): T { @@ -34,9 +33,7 @@ class CustomSerializationSchemeScanningTest { @Test(timeout = 300_000) fun `Can scan for custom serialization scheme and build a serialization scheme`() { - val classLoader = Mockito.mock(ClassLoader::class.java) - whenever(classLoader.loadClass(DummySerializationScheme::class.java.canonicalName)).thenAnswer { DummySerializationScheme::class.java } - val scheme = scanForCustomSerializationScheme(DummySerializationScheme::class.java.canonicalName, classLoader) + val scheme = scanForCustomSerializationScheme(DummySerializationScheme::class.java.name, this::class.java.classLoader) val mockContext = Mockito.mock(SerializationContext::class.java) assertFailsWith("Tried to serialize with DummySerializationScheme") { scheme.serialize(Any::class.java, mockContext) @@ -45,34 +42,28 @@ class CustomSerializationSchemeScanningTest { @Test(timeout = 300_000) fun `verification fails with a helpful error if the class is not found in the classloader`() { - val classLoader = Mockito.mock(ClassLoader::class.java) - val missingClassName = DummySerializationScheme::class.java.canonicalName - whenever(classLoader.loadClass(missingClassName)).thenAnswer { throw ClassNotFoundException()} + val missingClassName = "org.testing.DoesNotExist" assertFailsWith("$missingClassName was declared as a custom serialization scheme but could not " + "be found.") { - scanForCustomSerializationScheme(missingClassName, classLoader) + scanForCustomSerializationScheme(missingClassName, this::class.java.classLoader) } } @Test(timeout = 300_000) fun `verification fails with a helpful error if the class is not a custom serialization scheme`() { - val canonicalName = NonSerializationScheme::class.java.canonicalName - val classLoader = Mockito.mock(ClassLoader::class.java) - whenever(classLoader.loadClass(canonicalName)).thenAnswer { NonSerializationScheme::class.java } - assertFailsWith("$canonicalName was declared as a custom serialization scheme but does not " + + val schemeName = NonSerializationScheme::class.java.name + assertFailsWith("$schemeName was declared as a custom serialization scheme but does not " + "implement CustomSerializationScheme.") { - scanForCustomSerializationScheme(canonicalName, classLoader) + scanForCustomSerializationScheme(schemeName, this::class.java.classLoader) } } @Test(timeout = 300_000) fun `verification fails with a helpful error if the class does not have a no arg constructor`() { - val classLoader = Mockito.mock(ClassLoader::class.java) - val canonicalName = DummySerializationSchemeWithoutNoArgConstructor::class.java.canonicalName - whenever(classLoader.loadClass(canonicalName)).thenAnswer { DummySerializationSchemeWithoutNoArgConstructor::class.java } - assertFailsWith("$canonicalName was declared as a custom serialization scheme but does not " + + val schemeName = DummySerializationSchemeWithoutNoArgConstructor::class.java.name + assertFailsWith("$schemeName was declared as a custom serialization scheme but does not " + "have a no argument constructor.") { - scanForCustomSerializationScheme(canonicalName, classLoader) + scanForCustomSerializationScheme(schemeName, this::class.java.classLoader) } } -} \ No newline at end of file +} diff --git a/serialization-deterministic/build.gradle b/serialization-deterministic/build.gradle index 6ad42b0208..7822eb3b23 100644 --- a/serialization-deterministic/build.gradle +++ b/serialization-deterministic/build.gradle @@ -23,7 +23,10 @@ def javaHome = System.getProperty('java.home') def jarBaseName = "corda-${project.name}".toString() configurations { - deterministicLibraries.extendsFrom implementation + deterministicLibraries { + canBeConsumed = false + extendsFrom implementation + } deterministicArtifacts.extendsFrom deterministicLibraries } @@ -55,7 +58,7 @@ def originalJar = serializationJarTask.map { it.outputs.files.singleFile } def patchSerialization = tasks.register('patchSerialization', Zip) { dependsOn serializationJarTask - destinationDirectory = file("$buildDir/source-libs") + destinationDirectory = layout.buildDirectory.dir('source-libs') metadataCharset 'UTF-8' archiveClassifier = 'transient' archiveExtension = 'jar' @@ -157,7 +160,7 @@ def determinise = tasks.register('determinise', ProGuardTask) { def checkDeterminism = tasks.register('checkDeterminism', ProGuardTask) def metafix = tasks.register('metafix', MetaFixerTask) { - outputDir file("$buildDir/libs") + outputDir = layout.buildDirectory.dir('libs') jars determinise suffix "" diff --git a/serialization-djvm/build.gradle b/serialization-djvm/build.gradle index f51557e2a3..8e8870398e 100644 --- a/serialization-djvm/build.gradle +++ b/serialization-djvm/build.gradle @@ -1,6 +1,3 @@ -import org.jetbrains.kotlin.gradle.tasks.KotlinCompile -import static org.gradle.api.JavaVersion.VERSION_1_8 - plugins { id 'org.jetbrains.kotlin.jvm' id 'net.corda.plugins.publish-utils' @@ -17,8 +14,12 @@ apply from: "${rootProject.projectDir}/java8.gradle" description 'Serialization support for the DJVM' configurations { - sandboxTesting - jdkRt + sandboxTesting { + canBeConsumed = false + } + jdkRt { + canBeConsumed = false + } } dependencies { @@ -56,6 +57,11 @@ jar { } } +tasks.withType(Javadoc).configureEach { + // We have no public or protected Java classes to document. + enabled = false +} + tasks.withType(Test).configureEach { useJUnitPlatform() systemProperty 'deterministic-rt.path', configurations.jdkRt.asPath @@ -66,7 +72,7 @@ tasks.withType(Test).configureEach { } publish { - name jar.archiveBaseName.get() + name jar.archiveBaseName } idea { diff --git a/serialization-djvm/src/main/java/net/corda/serialization/djvm/serializers/CacheKey.java b/serialization-djvm/src/main/java/net/corda/serialization/djvm/serializers/CacheKey.java new file mode 100644 index 0000000000..5ef3728e91 --- /dev/null +++ b/serialization-djvm/src/main/java/net/corda/serialization/djvm/serializers/CacheKey.java @@ -0,0 +1,35 @@ +package net.corda.serialization.djvm.serializers; + +import org.jetbrains.annotations.NotNull; + +import java.util.Arrays; + +/** + * This class is deliberately written in Java so + * that it can be package private. + */ +final class CacheKey { + private final byte[] bytes; + private final int hashValue; + + CacheKey(@NotNull byte[] bytes) { + this.bytes = bytes; + this.hashValue = Arrays.hashCode(bytes); + } + + @NotNull + byte[] getBytes() { + return bytes; + } + + @Override + public boolean equals(Object other) { + return (this == other) + || (other instanceof CacheKey && Arrays.equals(bytes, ((CacheKey) other).bytes)); + } + + @Override + public int hashCode() { + return hashValue; + } +} diff --git a/serialization-djvm/src/main/kotlin/net/corda/serialization/djvm/serializers/SandboxCertPathSerializer.kt b/serialization-djvm/src/main/kotlin/net/corda/serialization/djvm/serializers/SandboxCertPathSerializer.kt index 0d6cd7aff5..25710d654e 100644 --- a/serialization-djvm/src/main/kotlin/net/corda/serialization/djvm/serializers/SandboxCertPathSerializer.kt +++ b/serialization-djvm/src/main/kotlin/net/corda/serialization/djvm/serializers/SandboxCertPathSerializer.kt @@ -1,5 +1,7 @@ package net.corda.serialization.djvm.serializers +import net.corda.core.serialization.DESERIALIZATION_CACHE_PROPERTY +import net.corda.core.serialization.SerializationContext import net.corda.djvm.rewiring.SandboxClassLoader import net.corda.serialization.djvm.deserializers.CertPathDeserializer import net.corda.serialization.djvm.toSandboxAnyClass @@ -27,4 +29,13 @@ class SandboxCertPathSerializer( override fun fromProxy(proxy: Any): Any { return task.apply(proxy)!! } + + override fun fromProxy(proxy: Any, context: SerializationContext): Any { + // This requires [CertPathProxy] to have correct + // implementations for [equals] and [hashCode]. + @Suppress("unchecked_cast") + return (context.properties[DESERIALIZATION_CACHE_PROPERTY] as? MutableMap) + ?.computeIfAbsent(proxy, ::fromProxy) + ?: fromProxy(proxy) + } } diff --git a/serialization-djvm/src/main/kotlin/net/corda/serialization/djvm/serializers/SandboxPublicKeySerializer.kt b/serialization-djvm/src/main/kotlin/net/corda/serialization/djvm/serializers/SandboxPublicKeySerializer.kt index 6a22e05da6..f826672647 100644 --- a/serialization-djvm/src/main/kotlin/net/corda/serialization/djvm/serializers/SandboxPublicKeySerializer.kt +++ b/serialization-djvm/src/main/kotlin/net/corda/serialization/djvm/serializers/SandboxPublicKeySerializer.kt @@ -1,5 +1,6 @@ package net.corda.serialization.djvm.serializers +import net.corda.core.serialization.DESERIALIZATION_CACHE_PROPERTY import net.corda.core.serialization.SerializationContext import net.corda.djvm.rewiring.SandboxClassLoader import net.corda.serialization.djvm.deserializers.PublicKeyDecoder @@ -27,7 +28,11 @@ class SandboxPublicKeySerializer( override fun readObject(obj: Any, schemas: SerializationSchemas, input: DeserializationInput, context: SerializationContext): Any { val bits = input.readObject(obj, schemas, ByteArray::class.java, context) as ByteArray - return decoder.apply(bits)!! + @Suppress("unchecked_cast") + return (context.properties[DESERIALIZATION_CACHE_PROPERTY] as? MutableMap) + ?.computeIfAbsent(CacheKey(bits)) { key -> + decoder.apply(key.bytes) + } ?: decoder.apply(bits)!! } override fun writeDescribedObject(obj: Any, data: Data, type: Type, output: SerializationOutput, context: SerializationContext) { diff --git a/serialization-djvm/src/main/kotlin/net/corda/serialization/djvm/serializers/SandboxX509CRLSerializer.kt b/serialization-djvm/src/main/kotlin/net/corda/serialization/djvm/serializers/SandboxX509CRLSerializer.kt index aa52234a97..0c19470e25 100644 --- a/serialization-djvm/src/main/kotlin/net/corda/serialization/djvm/serializers/SandboxX509CRLSerializer.kt +++ b/serialization-djvm/src/main/kotlin/net/corda/serialization/djvm/serializers/SandboxX509CRLSerializer.kt @@ -1,5 +1,6 @@ package net.corda.serialization.djvm.serializers +import net.corda.core.serialization.DESERIALIZATION_CACHE_PROPERTY import net.corda.core.serialization.SerializationContext import net.corda.djvm.rewiring.SandboxClassLoader import net.corda.serialization.djvm.deserializers.X509CRLDeserializer @@ -28,7 +29,11 @@ class SandboxX509CRLSerializer( override fun readObject(obj: Any, schemas: SerializationSchemas, input: DeserializationInput, context: SerializationContext): Any { val bits = input.readObject(obj, schemas, ByteArray::class.java, context) as ByteArray - return generator.apply(bits)!! + @Suppress("unchecked_cast") + return (context.properties[DESERIALIZATION_CACHE_PROPERTY] as? MutableMap) + ?.computeIfAbsent(CacheKey(bits)) { key -> + generator.apply(key.bytes) + } ?: generator.apply(bits)!! } override fun writeDescribedObject(obj: Any, data: Data, type: Type, output: SerializationOutput, context: SerializationContext) { diff --git a/serialization-djvm/src/main/kotlin/net/corda/serialization/djvm/serializers/SandboxX509CertificateSerializer.kt b/serialization-djvm/src/main/kotlin/net/corda/serialization/djvm/serializers/SandboxX509CertificateSerializer.kt index cab56d34c6..cf6a78da7e 100644 --- a/serialization-djvm/src/main/kotlin/net/corda/serialization/djvm/serializers/SandboxX509CertificateSerializer.kt +++ b/serialization-djvm/src/main/kotlin/net/corda/serialization/djvm/serializers/SandboxX509CertificateSerializer.kt @@ -1,5 +1,6 @@ package net.corda.serialization.djvm.serializers +import net.corda.core.serialization.DESERIALIZATION_CACHE_PROPERTY import net.corda.core.serialization.SerializationContext import net.corda.djvm.rewiring.SandboxClassLoader import net.corda.serialization.djvm.deserializers.X509CertificateDeserializer @@ -28,7 +29,11 @@ class SandboxX509CertificateSerializer( override fun readObject(obj: Any, schemas: SerializationSchemas, input: DeserializationInput, context: SerializationContext): Any { val bits = input.readObject(obj, schemas, ByteArray::class.java, context) as ByteArray - return generator.apply(bits)!! + @Suppress("unchecked_cast") + return (context.properties[DESERIALIZATION_CACHE_PROPERTY] as? MutableMap) + ?.computeIfAbsent(CacheKey(bits)) { key -> + generator.apply(key.bytes) + } ?: generator.apply(bits)!! } override fun writeDescribedObject(obj: Any, data: Data, type: Type, output: SerializationOutput, context: SerializationContext) { diff --git a/serialization/build.gradle b/serialization/build.gradle index 0ba49a0804..224bd642b4 100644 --- a/serialization/build.gradle +++ b/serialization/build.gradle @@ -52,8 +52,13 @@ configurations { testArtifacts.extendsFrom testRuntimeClasspath } +tasks.withType(Javadoc).configureEach { + // We have no public or protected Java classes to document. + enabled = false +} + task testJar(type: Jar) { - classifier "tests" + archiveClassifier = 'tests' from sourceSets.test.output } @@ -68,5 +73,5 @@ jar { } publish { - name jar.baseName + name jar.archiveBaseName } diff --git a/serialization/src/main/java/net/corda/serialization/internal/amqp/custom/CacheKey.java b/serialization/src/main/java/net/corda/serialization/internal/amqp/custom/CacheKey.java new file mode 100644 index 0000000000..2a341d5130 --- /dev/null +++ b/serialization/src/main/java/net/corda/serialization/internal/amqp/custom/CacheKey.java @@ -0,0 +1,37 @@ +package net.corda.serialization.internal.amqp.custom; + +import net.corda.core.KeepForDJVM; +import org.jetbrains.annotations.NotNull; + +import java.util.Arrays; + +/** + * This class is deliberately written in Java so + * that it can be package private. + */ +@KeepForDJVM +final class CacheKey { + private final byte[] bytes; + private final int hashValue; + + CacheKey(@NotNull byte[] bytes) { + this.bytes = bytes; + this.hashValue = Arrays.hashCode(bytes); + } + + @NotNull + byte[] getBytes() { + return bytes; + } + + @Override + public boolean equals(Object other) { + return (this == other) + || (other instanceof CacheKey && Arrays.equals(bytes, ((CacheKey) other).bytes)); + } + + @Override + public int hashCode() { + return hashValue; + } +} diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/SerializationScheme.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/SerializationScheme.kt index dbadd68339..6b63a46655 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/SerializationScheme.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/SerializationScheme.kt @@ -40,6 +40,7 @@ data class SerializationContextImpl @JvmOverloads constructor(override val prefe /** * {@inheritDoc} */ + @Suppress("OverridingDeprecatedMember") override fun withAttachmentsClassLoader(attachmentHashes: List): SerializationContext { return this } @@ -108,12 +109,13 @@ open class SerializationFactoryImpl( val lookupKey = magic to target // ConcurrentHashMap.get() is lock free, but computeIfAbsent is not, even if the key is in the map already. return (schemes[lookupKey] ?: schemes.computeIfAbsent(lookupKey) { - registeredSchemes.filter { it.canDeserializeVersion(magic, target) }.forEach { return@computeIfAbsent it } // XXX: Not single? - logger.warn("Cannot find serialization scheme for: [$lookupKey, " + - "${if (magic == amqpMagic) "AMQP" else "UNKNOWN MAGIC"}] registeredSchemes are: $registeredSchemes") - val schemeId = getSchemeIdIfCustomSerializationMagic(magic) ?: throw UnsupportedOperationException("Serialization scheme" + - " $lookupKey not supported.") - throw UnsupportedOperationException("Could not find custom serialization scheme with SchemeId = $schemeId.") + registeredSchemes.firstOrNull { it.canDeserializeVersion(magic, target) } ?: run { + logger.warn("Cannot find serialization scheme for: [$lookupKey, " + + "${if (magic == amqpMagic) "AMQP" else "UNKNOWN MAGIC"}] registeredSchemes are: $registeredSchemes") + val schemeId = getSchemeIdIfCustomSerializationMagic(magic) ?: throw UnsupportedOperationException("Serialization scheme" + + " $lookupKey not supported.") + throw UnsupportedOperationException("Could not find custom serialization scheme with SchemeId = $schemeId.") + } }) to magic } diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/CorDappCustomSerializer.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/CorDappCustomSerializer.kt index e9e5eda38a..a55d334b40 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/CorDappCustomSerializer.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/CorDappCustomSerializer.kt @@ -88,11 +88,11 @@ class CorDappCustomSerializer( override fun readObject(obj: Any, schemas: SerializationSchemas, input: DeserializationInput, context: SerializationContext ) = uncheckedCast, SerializationCustomSerializer>( - serializer).fromProxy(uncheckedCast(proxySerializer.readObject(obj, schemas, input, context)))!! + serializer).fromProxy(proxySerializer.readObject(obj, schemas, input, context))!! /** * For 3rd party plugin serializers we are going to exist on exact type matching. i.e. we will - * not support base class serializers for derivedtypes + * not support base class serializers for derived types */ override fun isSerializerFor(clazz: Class<*>) = TypeToken.of(type.asClass()) == TypeToken.of(clazz) diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/CustomSerializer.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/CustomSerializer.kt index 53d521a80a..ee28ca00de 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/CustomSerializer.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/CustomSerializer.kt @@ -1,7 +1,6 @@ package net.corda.serialization.internal.amqp import net.corda.core.KeepForDJVM -import net.corda.core.internal.uncheckedCast import net.corda.core.serialization.SerializationContext import net.corda.serialization.internal.model.FingerprintWriter import net.corda.serialization.internal.model.TypeIdentifier @@ -52,7 +51,8 @@ abstract class CustomSerializer : AMQPSerializer, SerializerFor { context: SerializationContext, debugIndent: Int ) { data.withDescribed(descriptor) { - writeDescribedObject(uncheckedCast(obj), data, type, output, context) + @Suppress("unchecked_cast") + writeDescribedObject(obj as T, data, type, output, context) } } @@ -178,10 +178,13 @@ abstract class CustomSerializer : AMQPSerializer, SerializerFor { protected abstract fun fromProxy(proxy: P): T + protected open fun toProxy(obj: T, context: SerializationContext): P = toProxy(obj) + protected open fun fromProxy(proxy: P, context: SerializationContext): T = fromProxy(proxy) + override fun writeDescribedObject(obj: T, data: Data, type: Type, output: SerializationOutput, context: SerializationContext ) { - val proxy = toProxy(obj) + val proxy = toProxy(obj, context) data.withList { proxySerializer.propertySerializers.forEach { (_, serializer) -> serializer.writeProperty(proxy, this, output, context, 0) @@ -192,8 +195,9 @@ abstract class CustomSerializer : AMQPSerializer, SerializerFor { override fun readObject(obj: Any, schemas: SerializationSchemas, input: DeserializationInput, context: SerializationContext ): T { - val proxy: P = uncheckedCast(proxySerializer.readObject(obj, schemas, input, context)) - return fromProxy(proxy) + @Suppress("unchecked_cast") + val proxy = proxySerializer.readObject(obj, schemas, input, context) as P + return fromProxy(proxy, context) } } diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/DeserializationInput.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/DeserializationInput.kt index b8f8b55dfd..7be1425d32 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/DeserializationInput.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/DeserializationInput.kt @@ -3,12 +3,17 @@ package net.corda.serialization.internal.amqp import net.corda.core.KeepForDJVM import net.corda.core.internal.VisibleForTesting import net.corda.core.serialization.EncodingWhitelist +import net.corda.core.serialization.AMQP_ENVELOPE_CACHE_PROPERTY import net.corda.core.serialization.SerializationContext import net.corda.core.serialization.SerializedBytes import net.corda.core.utilities.ByteSequence import net.corda.core.utilities.loggerFor import net.corda.core.utilities.trace -import net.corda.serialization.internal.* +import net.corda.serialization.internal.ByteBufferInputStream +import net.corda.serialization.internal.CordaSerializationEncoding +import net.corda.serialization.internal.NullEncodingWhitelist +import net.corda.serialization.internal.SectionId +import net.corda.serialization.internal.encodingNotPermittedFormat import net.corda.serialization.internal.model.TypeIdentifier import org.apache.qpid.proton.amqp.Binary import org.apache.qpid.proton.amqp.DescribedType @@ -118,7 +123,19 @@ class DeserializationInput constructor( @Throws(NotSerializableException::class) fun deserialize(bytes: ByteSequence, clazz: Class, context: SerializationContext): T = des { - val envelope = getEnvelope(bytes, context.encodingWhitelist) + /** + * The cache uses object identity rather than [ByteSequence.equals] and + * [ByteSequence.hashCode]. This is for speed: each [ByteSequence] object + * can potentially be large, and we are optimizing for the case when we + * know we will be deserializing the exact same objects multiple times. + * This also means that the cache MUST be short-lived, as otherwise it + * becomes a memory leak. + */ + @Suppress("unchecked_cast") + val envelope = (context.properties[AMQP_ENVELOPE_CACHE_PROPERTY] as? MutableMap) + ?.computeIfAbsent(IdentityKey(bytes)) { key -> + getEnvelope(key.bytes, context.encodingWhitelist) + } ?: getEnvelope(bytes, context.encodingWhitelist) logger.trace { "deserialize blob scheme=\"${envelope.schema}\"" } @@ -219,3 +236,16 @@ class DeserializationInput constructor( else -> false } } + +/** + * We cannot use [ByteSequence.equals] and [ByteSequence.hashCode] because + * these consider the contents of the underlying [ByteArray] object. We + * only need the [ByteSequence]'s object identity for our use-case. + */ +private class IdentityKey(val bytes: ByteSequence) { + override fun hashCode() = System.identityHashCode(bytes) + + override fun equals(other: Any?): Boolean { + return (this === other) || (other is IdentityKey && bytes === other.bytes) + } +} diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/custom/CertPathSerializer.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/custom/CertPathSerializer.kt index 5921781ae8..6d7fc6d668 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/custom/CertPathSerializer.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/custom/CertPathSerializer.kt @@ -1,6 +1,8 @@ package net.corda.serialization.internal.amqp.custom import net.corda.core.KeepForDJVM +import net.corda.core.serialization.DESERIALIZATION_CACHE_PROPERTY +import net.corda.core.serialization.SerializationContext import net.corda.serialization.internal.amqp.CustomSerializer import net.corda.serialization.internal.amqp.SerializerFactory import java.io.NotSerializableException @@ -28,7 +30,21 @@ class CertPathSerializer( } } + override fun fromProxy(proxy: CertPathProxy, context: SerializationContext): CertPath { + // This requires [CertPathProxy] to have correct + // implementations for [equals] and [hashCode]. + @Suppress("unchecked_cast") + return (context.properties[DESERIALIZATION_CACHE_PROPERTY] as? MutableMap) + ?.computeIfAbsent(proxy, ::fromProxy) + ?: fromProxy(proxy) + } + @KeepForDJVM - @Suppress("ArrayInDataClass") - data class CertPathProxy(val type: String, val encoded: ByteArray) -} \ No newline at end of file + data class CertPathProxy(val type: String, val encoded: ByteArray) { + override fun hashCode() = (type.hashCode() * 31) + encoded.contentHashCode() + override fun equals(other: Any?): Boolean { + return (this === other) + || (other is CertPathProxy && (type == other.type && encoded.contentEquals(other.encoded))) + } + } +} diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/custom/PublicKeySerializer.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/custom/PublicKeySerializer.kt index 9663576780..ee4bceb09b 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/custom/PublicKeySerializer.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/custom/PublicKeySerializer.kt @@ -1,6 +1,7 @@ package net.corda.serialization.internal.amqp.custom import net.corda.core.crypto.Crypto +import net.corda.core.serialization.DESERIALIZATION_CACHE_PROPERTY import net.corda.core.serialization.SerializationContext import net.corda.serialization.internal.amqp.* import org.apache.qpid.proton.codec.Data @@ -34,6 +35,10 @@ object PublicKeySerializer context: SerializationContext ): PublicKey { val bits = input.readObject(obj, schemas, ByteArray::class.java, context) as ByteArray - return Crypto.decodePublicKey(bits) + @Suppress("unchecked_cast") + return (context.properties[DESERIALIZATION_CACHE_PROPERTY] as? MutableMap) + ?.computeIfAbsent(CacheKey(bits)) { key -> + Crypto.decodePublicKey(key.bytes) + } ?: Crypto.decodePublicKey(bits) } -} \ No newline at end of file +} diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/custom/X509CRLSerializer.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/custom/X509CRLSerializer.kt index 965b8ed40f..0680031096 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/custom/X509CRLSerializer.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/custom/X509CRLSerializer.kt @@ -1,5 +1,6 @@ package net.corda.serialization.internal.amqp.custom +import net.corda.core.serialization.DESERIALIZATION_CACHE_PROPERTY import net.corda.core.serialization.SerializationContext import net.corda.serialization.internal.amqp.* import org.apache.qpid.proton.codec.Data @@ -28,6 +29,14 @@ object X509CRLSerializer override fun readObject(obj: Any, schemas: SerializationSchemas, input: DeserializationInput, context: SerializationContext): X509CRL { val bytes = input.readObject(obj, schemas, ByteArray::class.java, context) as ByteArray + @Suppress("unchecked_cast") + return (context.properties[DESERIALIZATION_CACHE_PROPERTY] as? MutableMap) + ?.computeIfAbsent(CacheKey(bytes)) { key -> + generateCRL(key.bytes) + } ?: generateCRL(bytes) + } + + private fun generateCRL(bytes: ByteArray): X509CRL { return CertificateFactory.getInstance("X.509").generateCRL(bytes.inputStream()) as X509CRL } } diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/custom/X509CertificateSerializer.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/custom/X509CertificateSerializer.kt index 9e7a2854b4..f3dbd9438d 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/custom/X509CertificateSerializer.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/custom/X509CertificateSerializer.kt @@ -1,5 +1,6 @@ package net.corda.serialization.internal.amqp.custom +import net.corda.core.serialization.DESERIALIZATION_CACHE_PROPERTY import net.corda.core.serialization.SerializationContext import net.corda.serialization.internal.amqp.* import org.apache.qpid.proton.codec.Data @@ -28,6 +29,14 @@ object X509CertificateSerializer override fun readObject(obj: Any, schemas: SerializationSchemas, input: DeserializationInput, context: SerializationContext): X509Certificate { val bits = input.readObject(obj, schemas, ByteArray::class.java, context) as ByteArray + @Suppress("unchecked_cast") + return (context.properties[DESERIALIZATION_CACHE_PROPERTY] as? MutableMap) + ?.computeIfAbsent(CacheKey(bits)) { key -> + generateCertificate(key.bytes) + } ?: generateCertificate(bits) + } + + private fun generateCertificate(bits: ByteArray): X509Certificate { return CertificateFactory.getInstance("X.509").generateCertificate(bits.inputStream()) as X509Certificate } } diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/carpenter/ClassCarpenter.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/carpenter/ClassCarpenter.kt index 49ef897639..940b242064 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/carpenter/ClassCarpenter.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/carpenter/ClassCarpenter.kt @@ -15,7 +15,6 @@ import org.objectweb.asm.Type import java.lang.Character.isJavaIdentifierPart import java.lang.Character.isJavaIdentifierStart import java.lang.reflect.Method -import java.util.* /** * Any object that implements this interface is expected to expose its own fields via the [get] method, exactly @@ -28,8 +27,23 @@ interface SimpleFieldAccess { } @DeleteForDJVM -class CarpenterClassLoader(parentClassLoader: ClassLoader = Thread.currentThread().contextClassLoader) : +class CarpenterClassLoader(private val parentClassLoader: ClassLoader = Thread.currentThread().contextClassLoader) : ClassLoader(parentClassLoader) { + @Throws(ClassNotFoundException::class) + override fun loadClass(name: String?, resolve: Boolean): Class<*>? { + return synchronized(getClassLoadingLock(name)) { + /** + * Search parent classloaders using lock-less [Class.forName], + * bypassing [parent] to avoid its [SecurityManager] overhead. + */ + (findLoadedClass(name) ?: Class.forName(name, false, parentClassLoader)).also { clazz -> + if (resolve) { + resolveClass(clazz) + } + } + } + } + fun load(name: String, bytes: ByteArray): Class<*> { return defineClass(name, bytes, 0, bytes.size) }