CORDA-2871: Refactor the contract verification code into a separate class,

and allow LedgerTransaction to choose different Verifier objects.
This commit is contained in:
Chris Rankin 2019-07-18 12:44:11 +01:00
parent 444881d536
commit 840e717ccf
2 changed files with 61 additions and 22 deletions

View File

@ -1,12 +1,15 @@
package net.corda.core.internal package net.corda.core.internal
import net.corda.core.DeleteForDJVM import net.corda.core.DeleteForDJVM
import net.corda.core.KeepForDJVM
import net.corda.core.concurrent.CordaFuture import net.corda.core.concurrent.CordaFuture
import net.corda.core.contracts.* import net.corda.core.contracts.*
import net.corda.core.contracts.TransactionVerificationException.TransactionContractConflictException import net.corda.core.contracts.TransactionVerificationException.TransactionContractConflictException
import net.corda.core.crypto.SecureHash
import net.corda.core.internal.rules.StateContractValidationEnforcementRule import net.corda.core.internal.rules.StateContractValidationEnforcementRule
import net.corda.core.transactions.LedgerTransaction import net.corda.core.transactions.LedgerTransaction
import net.corda.core.utilities.contextLogger import net.corda.core.utilities.contextLogger
import java.util.function.Function
@DeleteForDJVM @DeleteForDJVM
interface TransactionVerifierServiceInternal { interface TransactionVerifierServiceInternal {
@ -25,10 +28,8 @@ fun LedgerTransaction.prepareVerify(extraAttachments: List<Attachment>) = this.i
/** /**
* Because we create a separate [LedgerTransaction] onto which we need to perform verification, it becomes important we don't verify the * Because we create a separate [LedgerTransaction] onto which we need to perform verification, it becomes important we don't verify the
* wrong object instance. This class helps avoid that. * wrong object instance. This class helps avoid that.
*
* @param inputVersions A map linking each contract class name to the advertised version of the JAR that defines it. Used for downgrade protection.
*/ */
class Verifier(val ltx: LedgerTransaction, private val transactionClassLoader: ClassLoader) { open class Verifier(val ltx: LedgerTransaction, protected open val transactionClassLoader: ClassLoader) {
private val inputStates: List<TransactionState<*>> = ltx.inputs.map { it.state } private val inputStates: List<TransactionState<*>> = ltx.inputs.map { it.state }
private val allStates: List<TransactionState<*>> = inputStates + ltx.references.map { it.state } + ltx.outputs private val allStates: List<TransactionState<*>> = inputStates + ltx.references.map { it.state } + ltx.outputs
@ -140,7 +141,7 @@ class Verifier(val ltx: LedgerTransaction, private val transactionClassLoader: C
.withIndex() .withIndex()
.filter { it.value.encumbrance != null } .filter { it.value.encumbrance != null }
.map { Pair(it.index, it.value.encumbrance!!) } .map { Pair(it.index, it.value.encumbrance!!) }
if (!statesAndEncumbrance.isEmpty()) { if (statesAndEncumbrance.isNotEmpty()) {
checkBidirectionalOutputEncumbrances(statesAndEncumbrance) checkBidirectionalOutputEncumbrances(statesAndEncumbrance)
checkNotariesOutputEncumbrance(statesAndEncumbrance) checkNotariesOutputEncumbrance(statesAndEncumbrance)
} }
@ -349,22 +350,42 @@ class Verifier(val ltx: LedgerTransaction, private val transactionClassLoader: C
* *
* Note: Reference states are not verified. * Note: Reference states are not verified.
*/ */
private fun verifyContracts() { open fun verifyContracts() {
try {
// Loads the contract class from the transactionClassLoader. ContractVerifier(transactionClassLoader).apply(ltx)
fun contractClassFor(className: ContractClassName) = try { } catch (e: TransactionVerificationException.ContractRejection) {
transactionClassLoader.loadClass(className).asSubclass(Contract::class.java) logger.error("Error validating transaction ${ltx.id}.", e.cause)
} catch (e: Exception) { throw e
throw TransactionVerificationException.ContractCreationError(ltx.id, className, e) }
}
} }
val contractClasses: Map<ContractClassName, Class<out Contract>> = (inputStates + ltx.outputs) /**
.map { it.contract } * Verify all of the contracts on the given [LedgerTransaction].
.toSet() */
.map { contract -> contract to contractClassFor(contract) } @KeepForDJVM
.toMap() class ContractVerifier(private val transactionClassLoader: ClassLoader) : Function<LedgerTransaction, Unit> {
// This constructor is used inside the DJVM's sandbox.
@Suppress("unused")
constructor() : this(ClassLoader.getSystemClassLoader())
val contractInstances: List<Contract> = contractClasses.map { (contractClassName, contractClass) -> // Loads the contract class from the transactionClassLoader.
private fun createContractClass(id: SecureHash, contractClassName: ContractClassName): Class<out Contract> {
return try {
Class.forName(contractClassName, false, transactionClassLoader).asSubclass(Contract::class.java)
} catch (e: Exception) {
throw TransactionVerificationException.ContractCreationError(id, contractClassName, e)
}
}
override fun apply(ltx: LedgerTransaction) {
val contractClassNames = (ltx.inputs.map(StateAndRef<ContractState>::state) + ltx.outputs)
.map(TransactionState<*>::contract)
.toSet()
contractClassNames.associateBy(
{ it }, { createContractClass(ltx.id, it) }
).map { (contractClassName, contractClass) ->
try { try {
/** /**
* This function must execute with the DJVM's sandbox, which does not * This function must execute with the DJVM's sandbox, which does not
@ -377,13 +398,10 @@ class Verifier(val ltx: LedgerTransaction, private val transactionClassLoader: C
} catch (e: Exception) { } catch (e: Exception) {
throw TransactionVerificationException.ContractCreationError(ltx.id, contractClassName, e) throw TransactionVerificationException.ContractCreationError(ltx.id, contractClassName, e)
} }
} }.forEach { contract ->
contractInstances.forEach { contract ->
try { try {
contract.verify(ltx) contract.verify(ltx)
} catch (e: Exception) { } catch (e: Exception) {
logger.error("Error validating transaction ${ltx.id}.", e)
throw TransactionVerificationException.ContractRejection(ltx.id, contract, e) throw TransactionVerificationException.ContractRejection(ltx.id, contract, e)
} }
} }

View File

@ -75,6 +75,7 @@ private constructor(
private var serializedInputs: List<SerializedStateAndRef>? = null private var serializedInputs: List<SerializedStateAndRef>? = null
private var serializedReferences: List<SerializedStateAndRef>? = null private var serializedReferences: List<SerializedStateAndRef>? = null
private var isAttachmentTrusted: (Attachment) -> Boolean = { it.isUploaderTrusted() } private var isAttachmentTrusted: (Attachment) -> Boolean = { it.isUploaderTrusted() }
private var verifierFactory: (LedgerTransaction, ClassLoader) -> Verifier = ::Verifier
init { init {
if (timeWindow != null) check(notary != null) { "Transactions with time-windows must be notarised" } if (timeWindow != null) check(notary != null) { "Transactions with time-windows must be notarised" }
@ -151,10 +152,30 @@ private constructor(
// Create a copy of the outer LedgerTransaction which deserializes all fields inside the [transactionClassLoader]. // Create a copy of the outer LedgerTransaction which deserializes all fields inside the [transactionClassLoader].
// Only the copy will be used for verification, and the outer shell will be discarded. // Only the copy will be used for verification, and the outer shell will be discarded.
// This artifice is required to preserve backwards compatibility. // This artifice is required to preserve backwards compatibility.
Verifier(createLtxForVerification(), transactionClassLoader) verifierFactory(createLtxForVerification(), transactionClassLoader)
} }
} }
/**
* We need a way to customise transaction verification inside the
* Node without changing either the wire format or any public APIs.
*/
@CordaInternal
fun specialise(alternateVerifier: (LedgerTransaction, ClassLoader) -> Verifier): LedgerTransaction = LedgerTransaction(
inputs = inputs,
outputs = outputs,
commands = commands,
attachments = attachments,
id = id,
notary = notary,
timeWindow = timeWindow,
privacySalt = privacySalt,
networkParameters = networkParameters,
references = references
).also { ltx ->
ltx.verifierFactory = alternateVerifier
}
// Read network parameters with backwards compatibility goo. // Read network parameters with backwards compatibility goo.
private fun getParamsWithGoo(): NetworkParameters { private fun getParamsWithGoo(): NetworkParameters {
var params = networkParameters var params = networkParameters