Introduce ContractFactory to replace Map<SecureHash, Contract>. It allows for lazy loading of contracts.

This commit is contained in:
Mike Hearn 2016-02-09 19:36:52 +01:00
parent 2ccbd5db3e
commit 31964f8695
7 changed files with 41 additions and 18 deletions

View File

@ -10,14 +10,14 @@ package contracts
import core.Contract import core.Contract
import core.ContractState import core.ContractState
import core.crypto.SecureHash
import core.TransactionForVerification import core.TransactionForVerification
import core.crypto.SecureHash
// The dummy contract doesn't do anything useful. It exists for testing purposes. // The dummy contract doesn't do anything useful. It exists for testing purposes.
val DUMMY_PROGRAM_ID = SecureHash.sha256("dummy") val DUMMY_PROGRAM_ID = SecureHash.sha256("dummy")
object DummyContract : Contract { class DummyContract : Contract {
class State : ContractState { class State : ContractState {
override val programRef: SecureHash = DUMMY_PROGRAM_ID override val programRef: SecureHash = DUMMY_PROGRAM_ID
} }

View File

@ -134,3 +134,16 @@ interface Contract {
*/ */
val legalContractReference: SecureHash val legalContractReference: SecureHash
} }
/** A contract factory knows how to lazily load and instantiate contract objects. */
interface ContractFactory {
/**
* Loads, instantiates and returns a contract object from its class bytecodes, given the hash of that bytecode.
*
* @throws UnknownContractException if the hash doesn't map to any known contract.
* @throws ClassCastException if the hash mapped to a contract, but it was not of type T
*/
operator fun <T : Contract> get(hash: SecureHash): T
}
class UnknownContractException : Exception()

View File

@ -27,7 +27,7 @@ class TransactionGroup(val transactions: Set<LedgerTransaction>, val nonVerified
/** /**
* Verifies the group and returns the set of resolved transactions. * Verifies the group and returns the set of resolved transactions.
*/ */
fun verify(programMap: Map<SecureHash, Contract>): Set<TransactionForVerification> { fun verify(programMap: ContractFactory): Set<TransactionForVerification> {
// Check that every input can be resolved to an output. // Check that every input can be resolved to an output.
// Check that no output is referenced by more than one input. // Check that no output is referenced by more than one input.
// Cycles should be impossible due to the use of hashes as pointers. // Cycles should be impossible due to the use of hashes as pointers.
@ -73,12 +73,12 @@ data class TransactionForVerification(val inStates: List<ContractState>,
* @throws IllegalStateException if a state refers to an unknown contract. * @throws IllegalStateException if a state refers to an unknown contract.
*/ */
@Throws(TransactionVerificationException::class, IllegalStateException::class) @Throws(TransactionVerificationException::class, IllegalStateException::class)
fun verify(programMap: Map<SecureHash, Contract>) { fun verify(programMap: ContractFactory) {
// For each input and output state, locate the program to run. Then execute the verification function. If any // For each input and output state, locate the program to run. Then execute the verification function. If any
// throws an exception, the entire transaction is invalid. // throws an exception, the entire transaction is invalid.
val programHashes = (inStates.map { it.programRef } + outStates.map { it.programRef }).toSet() val programHashes = (inStates.map { it.programRef } + outStates.map { it.programRef }).toSet()
for (hash in programHashes) { for (hash in programHashes) {
val program = programMap[hash] ?: throw IllegalStateException("Unknown program hash $hash") val program: Contract = programMap[hash]
try { try {
program.verify(this) program.verify(this)
} catch(e: Throwable) { } catch(e: Throwable) {

View File

@ -216,11 +216,11 @@ class CommercialPaperTestsGeneric {
val validRedemption = makeRedeemTX(TEST_TX_TIME + 31.days) val validRedemption = makeRedeemTX(TEST_TX_TIME + 31.days)
val e = assertFailsWith(TransactionVerificationException::class) { val e = assertFailsWith(TransactionVerificationException::class) {
TransactionGroup(setOf(issueTX, moveTX, tooEarlyRedemption), setOf(corpWalletTX, alicesWalletTX)).verify(TEST_PROGRAM_MAP) TransactionGroup(setOf(issueTX, moveTX, tooEarlyRedemption), setOf(corpWalletTX, alicesWalletTX)).verify(MockContractFactory)
} }
assertTrue(e.cause!!.message!!.contains("paper must have matured")) assertTrue(e.cause!!.message!!.contains("paper must have matured"))
TransactionGroup(setOf(issueTX, moveTX, validRedemption), setOf(corpWalletTX, alicesWalletTX)).verify(TEST_PROGRAM_MAP) TransactionGroup(setOf(issueTX, moveTX, validRedemption), setOf(corpWalletTX, alicesWalletTX)).verify(MockContractFactory)
} }
// Generate a trade lifecycle with various parameters. // Generate a trade lifecycle with various parameters.

View File

@ -157,11 +157,11 @@ class CrowdFundTests {
val validClose = makeFundedTX(TEST_TX_TIME + 8.days) val validClose = makeFundedTX(TEST_TX_TIME + 8.days)
val e = assertFailsWith(TransactionVerificationException::class) { val e = assertFailsWith(TransactionVerificationException::class) {
TransactionGroup(setOf(registerTX, pledgeTX, tooEarlyClose), setOf(miniCorpWalletTx, aliceWalletTX)).verify(TEST_PROGRAM_MAP) TransactionGroup(setOf(registerTX, pledgeTX, tooEarlyClose), setOf(miniCorpWalletTx, aliceWalletTX)).verify(MockContractFactory)
} }
assertTrue(e.cause!!.message!!.contains("the closing date has past")) assertTrue(e.cause!!.message!!.contains("the closing date has past"))
// This verification passes // This verification passes
TransactionGroup(setOf(registerTX, pledgeTX, validClose), setOf(aliceWalletTX)).verify(TEST_PROGRAM_MAP) TransactionGroup(setOf(registerTX, pledgeTX, validClose), setOf(aliceWalletTX)).verify(MockContractFactory)
} }
} }

View File

@ -9,6 +9,7 @@
package core package core
import core.crypto.DigitalSignature import core.crypto.DigitalSignature
import core.crypto.SecureHash
import core.crypto.generateKeyPair import core.crypto.generateKeyPair
import core.crypto.signWithECDSA import core.crypto.signWithECDSA
import core.messaging.MessagingService import core.messaging.MessagingService
@ -18,6 +19,7 @@ import core.node.TimestampingError
import core.serialization.SerializedBytes import core.serialization.SerializedBytes
import core.serialization.deserialize import core.serialization.deserialize
import core.testutils.TEST_KEYS_TO_CORP_MAP import core.testutils.TEST_KEYS_TO_CORP_MAP
import core.testutils.TEST_PROGRAM_MAP
import core.testutils.TEST_TX_TIME import core.testutils.TEST_TX_TIME
import java.security.KeyPair import java.security.KeyPair
import java.security.PrivateKey import java.security.PrivateKey
@ -77,6 +79,14 @@ class MockStorageService : StorageService {
} }
} }
object MockContractFactory : ContractFactory {
override operator fun <T : Contract> get(hash: SecureHash): T {
val clazz = TEST_PROGRAM_MAP[hash] ?: throw UnknownContractException()
@Suppress("UNCHECKED_CAST")
return clazz.newInstance() as T
}
}
class MockServices( class MockServices(
val wallet: WalletService? = null, val wallet: WalletService? = null,
val keyManagement: KeyManagementService? = null, val keyManagement: KeyManagementService? = null,

View File

@ -52,13 +52,13 @@ val TEST_KEYS_TO_CORP_MAP: Map<PublicKey, Party> = mapOf(
val TEST_TX_TIME = Instant.parse("2015-04-17T12:00:00.00Z") val TEST_TX_TIME = Instant.parse("2015-04-17T12:00:00.00Z")
// In a real system this would be a persistent map of hash to bytecode and we'd instantiate the object as needed inside // In a real system this would be a persistent map of hash to bytecode and we'd instantiate the object as needed inside
// a sandbox. For now we just instantiate right at the start of the program. // a sandbox. For unit tests we just have a hard-coded list.
val TEST_PROGRAM_MAP: Map<SecureHash, Contract> = mapOf( val TEST_PROGRAM_MAP: Map<SecureHash, Class<out Contract>> = mapOf(
CASH_PROGRAM_ID to Cash(), CASH_PROGRAM_ID to Cash::class.java,
CP_PROGRAM_ID to CommercialPaper(), CP_PROGRAM_ID to CommercialPaper::class.java,
JavaCommercialPaper.JCP_PROGRAM_ID to JavaCommercialPaper(), JavaCommercialPaper.JCP_PROGRAM_ID to JavaCommercialPaper::class.java,
CROWDFUND_PROGRAM_ID to CrowdFund(), CROWDFUND_PROGRAM_ID to CrowdFund::class.java,
DUMMY_PROGRAM_ID to DummyContract DUMMY_PROGRAM_ID to DummyContract::class.java
) )
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
@ -135,7 +135,7 @@ open class TransactionForTest : AbstractTransactionForTest() {
protected fun run(time: Instant) { protected fun run(time: Instant) {
val cmds = commandsToAuthenticatedObjects() val cmds = commandsToAuthenticatedObjects()
val tx = TransactionForVerification(inStates, outStates.map { it.state }, cmds, SecureHash.randomSHA256()) val tx = TransactionForVerification(inStates, outStates.map { it.state }, cmds, SecureHash.randomSHA256())
tx.verify(TEST_PROGRAM_MAP) tx.verify(MockContractFactory)
} }
fun accepts(time: Instant = TEST_TX_TIME) = run(time) fun accepts(time: Instant = TEST_TX_TIME) = run(time)
@ -297,7 +297,7 @@ class TransactionGroupDSL<T : ContractState>(private val stateType: Class<T>) {
fun verify() { fun verify() {
val group = toTransactionGroup() val group = toTransactionGroup()
try { try {
group.verify(TEST_PROGRAM_MAP) group.verify(MockContractFactory)
} catch (e: TransactionVerificationException) { } catch (e: TransactionVerificationException) {
// Let the developer know the index of the transaction that failed. // Let the developer know the index of the transaction that failed.
val ltx: LedgerTransaction = txns.find { it.hash == e.tx.origHash }!! val ltx: LedgerTransaction = txns.find { it.hash == e.tx.origHash }!!