diff --git a/src/main/kotlin/contracts/DummyContract.kt b/src/main/kotlin/contracts/DummyContract.kt index 45c0705f47..52956d71c5 100644 --- a/src/main/kotlin/contracts/DummyContract.kt +++ b/src/main/kotlin/contracts/DummyContract.kt @@ -10,14 +10,14 @@ package contracts import core.Contract import core.ContractState -import core.crypto.SecureHash import core.TransactionForVerification +import core.crypto.SecureHash // The dummy contract doesn't do anything useful. It exists for testing purposes. val DUMMY_PROGRAM_ID = SecureHash.sha256("dummy") -object DummyContract : Contract { +class DummyContract : Contract { class State : ContractState { override val programRef: SecureHash = DUMMY_PROGRAM_ID } diff --git a/src/main/kotlin/core/Structures.kt b/src/main/kotlin/core/Structures.kt index 2195f0564c..498e4ea4ef 100644 --- a/src/main/kotlin/core/Structures.kt +++ b/src/main/kotlin/core/Structures.kt @@ -134,3 +134,16 @@ interface Contract { */ val legalContractReference: SecureHash } + +/** A contract factory knows how to lazily load and instantiate contract objects. */ +interface ContractFactory { + /** + * Loads, instantiates and returns a contract object from its class bytecodes, given the hash of that bytecode. + * + * @throws UnknownContractException if the hash doesn't map to any known contract. + * @throws ClassCastException if the hash mapped to a contract, but it was not of type T + */ + operator fun get(hash: SecureHash): T +} + +class UnknownContractException : Exception() \ No newline at end of file diff --git a/src/main/kotlin/core/TransactionVerification.kt b/src/main/kotlin/core/TransactionVerification.kt index 4ba13566d0..69e5f4079a 100644 --- a/src/main/kotlin/core/TransactionVerification.kt +++ b/src/main/kotlin/core/TransactionVerification.kt @@ -27,7 +27,7 @@ class TransactionGroup(val transactions: Set, val nonVerified /** * Verifies the group and returns the set of resolved transactions. */ - fun verify(programMap: Map): Set { + fun verify(programMap: ContractFactory): Set { // Check that every input can be resolved to an output. // Check that no output is referenced by more than one input. // Cycles should be impossible due to the use of hashes as pointers. @@ -73,12 +73,12 @@ data class TransactionForVerification(val inStates: List, * @throws IllegalStateException if a state refers to an unknown contract. */ @Throws(TransactionVerificationException::class, IllegalStateException::class) - fun verify(programMap: Map) { + fun verify(programMap: ContractFactory) { // For each input and output state, locate the program to run. Then execute the verification function. If any // throws an exception, the entire transaction is invalid. val programHashes = (inStates.map { it.programRef } + outStates.map { it.programRef }).toSet() for (hash in programHashes) { - val program = programMap[hash] ?: throw IllegalStateException("Unknown program hash $hash") + val program: Contract = programMap[hash] try { program.verify(this) } catch(e: Throwable) { diff --git a/src/test/kotlin/contracts/CommercialPaperTests.kt b/src/test/kotlin/contracts/CommercialPaperTests.kt index e8b212553d..40b0df4fe7 100644 --- a/src/test/kotlin/contracts/CommercialPaperTests.kt +++ b/src/test/kotlin/contracts/CommercialPaperTests.kt @@ -216,11 +216,11 @@ class CommercialPaperTestsGeneric { val validRedemption = makeRedeemTX(TEST_TX_TIME + 31.days) val e = assertFailsWith(TransactionVerificationException::class) { - TransactionGroup(setOf(issueTX, moveTX, tooEarlyRedemption), setOf(corpWalletTX, alicesWalletTX)).verify(TEST_PROGRAM_MAP) + TransactionGroup(setOf(issueTX, moveTX, tooEarlyRedemption), setOf(corpWalletTX, alicesWalletTX)).verify(MockContractFactory) } assertTrue(e.cause!!.message!!.contains("paper must have matured")) - TransactionGroup(setOf(issueTX, moveTX, validRedemption), setOf(corpWalletTX, alicesWalletTX)).verify(TEST_PROGRAM_MAP) + TransactionGroup(setOf(issueTX, moveTX, validRedemption), setOf(corpWalletTX, alicesWalletTX)).verify(MockContractFactory) } // Generate a trade lifecycle with various parameters. diff --git a/src/test/kotlin/contracts/CrowdFundTests.kt b/src/test/kotlin/contracts/CrowdFundTests.kt index bf8a4528ef..d8292537b6 100644 --- a/src/test/kotlin/contracts/CrowdFundTests.kt +++ b/src/test/kotlin/contracts/CrowdFundTests.kt @@ -157,11 +157,11 @@ class CrowdFundTests { val validClose = makeFundedTX(TEST_TX_TIME + 8.days) val e = assertFailsWith(TransactionVerificationException::class) { - TransactionGroup(setOf(registerTX, pledgeTX, tooEarlyClose), setOf(miniCorpWalletTx, aliceWalletTX)).verify(TEST_PROGRAM_MAP) + TransactionGroup(setOf(registerTX, pledgeTX, tooEarlyClose), setOf(miniCorpWalletTx, aliceWalletTX)).verify(MockContractFactory) } assertTrue(e.cause!!.message!!.contains("the closing date has past")) // This verification passes - TransactionGroup(setOf(registerTX, pledgeTX, validClose), setOf(aliceWalletTX)).verify(TEST_PROGRAM_MAP) + TransactionGroup(setOf(registerTX, pledgeTX, validClose), setOf(aliceWalletTX)).verify(MockContractFactory) } } \ No newline at end of file diff --git a/src/test/kotlin/core/MockServices.kt b/src/test/kotlin/core/MockServices.kt index a65b6e104f..18854cd2dd 100644 --- a/src/test/kotlin/core/MockServices.kt +++ b/src/test/kotlin/core/MockServices.kt @@ -9,6 +9,7 @@ package core import core.crypto.DigitalSignature +import core.crypto.SecureHash import core.crypto.generateKeyPair import core.crypto.signWithECDSA import core.messaging.MessagingService @@ -18,6 +19,7 @@ import core.node.TimestampingError import core.serialization.SerializedBytes import core.serialization.deserialize import core.testutils.TEST_KEYS_TO_CORP_MAP +import core.testutils.TEST_PROGRAM_MAP import core.testutils.TEST_TX_TIME import java.security.KeyPair import java.security.PrivateKey @@ -77,6 +79,14 @@ class MockStorageService : StorageService { } } +object MockContractFactory : ContractFactory { + override operator fun get(hash: SecureHash): T { + val clazz = TEST_PROGRAM_MAP[hash] ?: throw UnknownContractException() + @Suppress("UNCHECKED_CAST") + return clazz.newInstance() as T + } +} + class MockServices( val wallet: WalletService? = null, val keyManagement: KeyManagementService? = null, diff --git a/src/test/kotlin/core/testutils/TestUtils.kt b/src/test/kotlin/core/testutils/TestUtils.kt index adc53c694c..1dccf00ef3 100644 --- a/src/test/kotlin/core/testutils/TestUtils.kt +++ b/src/test/kotlin/core/testutils/TestUtils.kt @@ -52,13 +52,13 @@ val TEST_KEYS_TO_CORP_MAP: Map = mapOf( val TEST_TX_TIME = Instant.parse("2015-04-17T12:00:00.00Z") // In a real system this would be a persistent map of hash to bytecode and we'd instantiate the object as needed inside -// a sandbox. For now we just instantiate right at the start of the program. -val TEST_PROGRAM_MAP: Map = mapOf( - CASH_PROGRAM_ID to Cash(), - CP_PROGRAM_ID to CommercialPaper(), - JavaCommercialPaper.JCP_PROGRAM_ID to JavaCommercialPaper(), - CROWDFUND_PROGRAM_ID to CrowdFund(), - DUMMY_PROGRAM_ID to DummyContract +// a sandbox. For unit tests we just have a hard-coded list. +val TEST_PROGRAM_MAP: Map> = mapOf( + CASH_PROGRAM_ID to Cash::class.java, + CP_PROGRAM_ID to CommercialPaper::class.java, + JavaCommercialPaper.JCP_PROGRAM_ID to JavaCommercialPaper::class.java, + CROWDFUND_PROGRAM_ID to CrowdFund::class.java, + DUMMY_PROGRAM_ID to DummyContract::class.java ) //////////////////////////////////////////////////////////////////////////////////////////////////////////////////// @@ -135,7 +135,7 @@ open class TransactionForTest : AbstractTransactionForTest() { protected fun run(time: Instant) { val cmds = commandsToAuthenticatedObjects() val tx = TransactionForVerification(inStates, outStates.map { it.state }, cmds, SecureHash.randomSHA256()) - tx.verify(TEST_PROGRAM_MAP) + tx.verify(MockContractFactory) } fun accepts(time: Instant = TEST_TX_TIME) = run(time) @@ -297,7 +297,7 @@ class TransactionGroupDSL(private val stateType: Class) { fun verify() { val group = toTransactionGroup() try { - group.verify(TEST_PROGRAM_MAP) + group.verify(MockContractFactory) } catch (e: TransactionVerificationException) { // Let the developer know the index of the transaction that failed. val ltx: LedgerTransaction = txns.find { it.hash == e.tx.origHash }!!