CORDA-2871: Refactor the contract verification code into a separate class,

and allow LedgerTransaction to choose different Verifier objects.
This commit is contained in:
Chris Rankin 2019-07-18 12:44:11 +01:00
parent 444881d536
commit 840e717ccf
2 changed files with 61 additions and 22 deletions

View File

@ -1,12 +1,15 @@
package net.corda.core.internal
import net.corda.core.DeleteForDJVM
import net.corda.core.KeepForDJVM
import net.corda.core.concurrent.CordaFuture
import net.corda.core.contracts.*
import net.corda.core.contracts.TransactionVerificationException.TransactionContractConflictException
import net.corda.core.crypto.SecureHash
import net.corda.core.internal.rules.StateContractValidationEnforcementRule
import net.corda.core.transactions.LedgerTransaction
import net.corda.core.utilities.contextLogger
import java.util.function.Function
@DeleteForDJVM
interface TransactionVerifierServiceInternal {
@ -25,10 +28,8 @@ fun LedgerTransaction.prepareVerify(extraAttachments: List<Attachment>) = this.i
/**
* Because we create a separate [LedgerTransaction] onto which we need to perform verification, it becomes important we don't verify the
* wrong object instance. This class helps avoid that.
*
* @param inputVersions A map linking each contract class name to the advertised version of the JAR that defines it. Used for downgrade protection.
*/
class Verifier(val ltx: LedgerTransaction, private val transactionClassLoader: ClassLoader) {
open class Verifier(val ltx: LedgerTransaction, protected open val transactionClassLoader: ClassLoader) {
private val inputStates: List<TransactionState<*>> = ltx.inputs.map { it.state }
private val allStates: List<TransactionState<*>> = inputStates + ltx.references.map { it.state } + ltx.outputs
@ -140,7 +141,7 @@ class Verifier(val ltx: LedgerTransaction, private val transactionClassLoader: C
.withIndex()
.filter { it.value.encumbrance != null }
.map { Pair(it.index, it.value.encumbrance!!) }
if (!statesAndEncumbrance.isEmpty()) {
if (statesAndEncumbrance.isNotEmpty()) {
checkBidirectionalOutputEncumbrances(statesAndEncumbrance)
checkNotariesOutputEncumbrance(statesAndEncumbrance)
}
@ -349,22 +350,42 @@ class Verifier(val ltx: LedgerTransaction, private val transactionClassLoader: C
*
* Note: Reference states are not verified.
*/
private fun verifyContracts() {
open fun verifyContracts() {
try {
ContractVerifier(transactionClassLoader).apply(ltx)
} catch (e: TransactionVerificationException.ContractRejection) {
logger.error("Error validating transaction ${ltx.id}.", e.cause)
throw e
}
}
}
/**
* Verify all of the contracts on the given [LedgerTransaction].
*/
@KeepForDJVM
class ContractVerifier(private val transactionClassLoader: ClassLoader) : Function<LedgerTransaction, Unit> {
// This constructor is used inside the DJVM's sandbox.
@Suppress("unused")
constructor() : this(ClassLoader.getSystemClassLoader())
// Loads the contract class from the transactionClassLoader.
fun contractClassFor(className: ContractClassName) = try {
transactionClassLoader.loadClass(className).asSubclass(Contract::class.java)
private fun createContractClass(id: SecureHash, contractClassName: ContractClassName): Class<out Contract> {
return try {
Class.forName(contractClassName, false, transactionClassLoader).asSubclass(Contract::class.java)
} catch (e: Exception) {
throw TransactionVerificationException.ContractCreationError(ltx.id, className, e)
throw TransactionVerificationException.ContractCreationError(id, contractClassName, e)
}
}
val contractClasses: Map<ContractClassName, Class<out Contract>> = (inputStates + ltx.outputs)
.map { it.contract }
override fun apply(ltx: LedgerTransaction) {
val contractClassNames = (ltx.inputs.map(StateAndRef<ContractState>::state) + ltx.outputs)
.map(TransactionState<*>::contract)
.toSet()
.map { contract -> contract to contractClassFor(contract) }
.toMap()
val contractInstances: List<Contract> = contractClasses.map { (contractClassName, contractClass) ->
contractClassNames.associateBy(
{ it }, { createContractClass(ltx.id, it) }
).map { (contractClassName, contractClass) ->
try {
/**
* This function must execute with the DJVM's sandbox, which does not
@ -377,13 +398,10 @@ class Verifier(val ltx: LedgerTransaction, private val transactionClassLoader: C
} catch (e: Exception) {
throw TransactionVerificationException.ContractCreationError(ltx.id, contractClassName, e)
}
}
contractInstances.forEach { contract ->
}.forEach { contract ->
try {
contract.verify(ltx)
} catch (e: Exception) {
logger.error("Error validating transaction ${ltx.id}.", e)
throw TransactionVerificationException.ContractRejection(ltx.id, contract, e)
}
}

View File

@ -75,6 +75,7 @@ private constructor(
private var serializedInputs: List<SerializedStateAndRef>? = null
private var serializedReferences: List<SerializedStateAndRef>? = null
private var isAttachmentTrusted: (Attachment) -> Boolean = { it.isUploaderTrusted() }
private var verifierFactory: (LedgerTransaction, ClassLoader) -> Verifier = ::Verifier
init {
if (timeWindow != null) check(notary != null) { "Transactions with time-windows must be notarised" }
@ -151,10 +152,30 @@ private constructor(
// Create a copy of the outer LedgerTransaction which deserializes all fields inside the [transactionClassLoader].
// Only the copy will be used for verification, and the outer shell will be discarded.
// This artifice is required to preserve backwards compatibility.
Verifier(createLtxForVerification(), transactionClassLoader)
verifierFactory(createLtxForVerification(), transactionClassLoader)
}
}
/**
* We need a way to customise transaction verification inside the
* Node without changing either the wire format or any public APIs.
*/
@CordaInternal
fun specialise(alternateVerifier: (LedgerTransaction, ClassLoader) -> Verifier): LedgerTransaction = LedgerTransaction(
inputs = inputs,
outputs = outputs,
commands = commands,
attachments = attachments,
id = id,
notary = notary,
timeWindow = timeWindow,
privacySalt = privacySalt,
networkParameters = networkParameters,
references = references
).also { ltx ->
ltx.verifierFactory = alternateVerifier
}
// Read network parameters with backwards compatibility goo.
private fun getParamsWithGoo(): NetworkParameters {
var params = networkParameters