From 9b36df607e885dbdbc4c1b6f9fee6e77156d211d Mon Sep 17 00:00:00 2001 From: Andras Slemmer Date: Mon, 4 Jul 2016 17:13:55 +0100 Subject: [PATCH] core: Remove binding of State type in test dsl --- .../core/testing/LedgerDslInterpreter.kt | 44 ++-- .../com/r3corda/core/testing/TestDsl.kt | 188 +++++++++++------- .../com/r3corda/core/testing/TestUtils.kt | 14 +- .../core/testing/TransactionDslInterpreter.kt | 33 +-- 4 files changed, 163 insertions(+), 116 deletions(-) diff --git a/core/src/main/kotlin/com/r3corda/core/testing/LedgerDslInterpreter.kt b/core/src/main/kotlin/com/r3corda/core/testing/LedgerDslInterpreter.kt index c8b7668e18..b39b39c19c 100644 --- a/core/src/main/kotlin/com/r3corda/core/testing/LedgerDslInterpreter.kt +++ b/core/src/main/kotlin/com/r3corda/core/testing/LedgerDslInterpreter.kt @@ -1,18 +1,21 @@ package com.r3corda.core.testing -import com.r3corda.core.contracts.Attachment -import com.r3corda.core.contracts.ContractState +import com.r3corda.core.contracts.* import com.r3corda.core.crypto.SecureHash -import com.r3corda.core.node.services.IdentityService -import com.r3corda.core.node.services.StorageService -import com.r3corda.core.node.services.testing.MockStorageService -interface LedgerDslInterpreter> { - fun transaction(dsl: TransactionDsl.() -> Unit): Unit - fun nonVerifiedTransaction(dsl: TransactionDsl.() -> Unit): Unit - fun tweak(dsl: LedgerDsl>.() -> Unit) +interface OutputStateLookup { + fun retrieveOutputStateAndRef(clazz: Class, label: String): StateAndRef +} + + + +interface LedgerDslInterpreter : + OutputStateLookup { + fun transaction(transactionLabel: String?, dsl: TransactionDsl.() -> Unit): WireTransaction + fun nonVerifiedTransaction(transactionLabel: String?, dsl: TransactionDsl.() -> Unit): WireTransaction + fun tweak(dsl: LedgerDsl>.() -> Unit) fun attachment(attachment: Attachment): SecureHash - fun _verifies(identityService: IdentityService, storageService: StorageService) + fun verifies() } /** @@ -21,15 +24,18 @@ interface LedgerDslInterpreter, - out LedgerInterpreter: LedgerDslInterpreter + out TransactionInterpreter: TransactionDslInterpreter, + out LedgerInterpreter: LedgerDslInterpreter > (val interpreter: LedgerInterpreter) - : LedgerDslInterpreter> by interpreter { + : LedgerDslInterpreter by interpreter { - @JvmOverloads - fun verifies( - identityService: IdentityService = MOCK_IDENTITY_SERVICE, - storageService: StorageService = MockStorageService() - ) = _verifies(identityService, storageService) + fun transaction(dsl: TransactionDsl.() -> Unit) = transaction(null, dsl) + fun nonVerifiedTransaction(dsl: TransactionDsl.() -> Unit) = + nonVerifiedTransaction(null, dsl) + + inline fun String.outputStateAndRef(): StateAndRef = + retrieveOutputStateAndRef(State::class.java, this) + inline fun String.output(): TransactionState = + outputStateAndRef().state + fun String.outputRef(): StateRef = outputStateAndRef().ref } diff --git a/core/src/main/kotlin/com/r3corda/core/testing/TestDsl.kt b/core/src/main/kotlin/com/r3corda/core/testing/TestDsl.kt index f508ba58fb..e1b6e70f0a 100644 --- a/core/src/main/kotlin/com/r3corda/core/testing/TestDsl.kt +++ b/core/src/main/kotlin/com/r3corda/core/testing/TestDsl.kt @@ -1,23 +1,34 @@ package com.r3corda.core.testing import com.r3corda.core.contracts.* +import com.r3corda.core.crypto.DigitalSignature import com.r3corda.core.crypto.Party import com.r3corda.core.crypto.SecureHash +import com.r3corda.core.crypto.signWithECDSA import com.r3corda.core.node.services.IdentityService import com.r3corda.core.node.services.StorageService +import com.r3corda.core.node.services.testing.MockStorageService +import com.r3corda.core.serialization.serialize +import java.security.KeyPair import java.security.PublicKey import java.util.* -inline fun ledger( - dsl: LedgerDsl, TestLedgerDslInterpreter>.() -> Unit) = - dsl(LedgerDsl(TestLedgerDslInterpreter.create())) +inline fun ledger( + identityService: IdentityService = MOCK_IDENTITY_SERVICE, + storageService: StorageService = MockStorageService(), + dsl: LedgerDsl.() -> Unit +): LedgerDsl { + val ledgerDsl = LedgerDsl(TestLedgerDslInterpreter(identityService, storageService)) + dsl(ledgerDsl) + return ledgerDsl +} @Deprecated( message = "ledger doesn't nest, use tweak", replaceWith = ReplaceWith("tweak"), level = DeprecationLevel.ERROR) -fun TransactionDslInterpreter.ledger( - dsl: LedgerDsl, TestLedgerDslInterpreter>.() -> Unit) { +fun TransactionDslInterpreter.ledger( + dsl: LedgerDsl.() -> Unit) { this.toString() dsl.toString() } @@ -26,8 +37,8 @@ fun TransactionDslInterpreter.ledger( message = "ledger doesn't nest, use tweak", replaceWith = ReplaceWith("tweak"), level = DeprecationLevel.ERROR) -fun LedgerDslInterpreter>.ledger( - dsl: LedgerDsl, TestLedgerDslInterpreter>.() -> Unit) { +fun LedgerDslInterpreter.ledger( + dsl: LedgerDsl.() -> Unit) { this.toString() dsl.toString() } @@ -36,17 +47,16 @@ fun LedgerDslInterpreter( - private val ledgerInterpreter: TestLedgerDslInterpreter, +data class TestTransactionDslInterpreter( + private val ledgerInterpreter: TestLedgerDslInterpreter, private val inputStateRefs: ArrayList = arrayListOf(), internal val outputStates: ArrayList = arrayListOf(), private val attachments: ArrayList = arrayListOf(), private val commands: ArrayList = arrayListOf(), private val signers: LinkedHashSet = LinkedHashSet(), private val transactionType: TransactionType = TransactionType.General() -) : TransactionDslInterpreter { - - private fun copy(): TestTransactionDslInterpreter = +) : TransactionDslInterpreter, OutputStateLookup { + private fun copy(): TestTransactionDslInterpreter = TestTransactionDslInterpreter( ledgerInterpreter = ledgerInterpreter.copy(), inputStateRefs = ArrayList(inputStateRefs), @@ -68,18 +78,18 @@ data class TestTransactionDslInterpreter( ) override fun input(stateLabel: String) { - val notary = stateLabel.output.notary.owningKey - signers.add(notary) - inputStateRefs.add(stateLabel.outputRef) + val stateAndRef = retrieveOutputStateAndRef(ContractState::class.java, stateLabel) + signers.add(stateAndRef.state.notary.owningKey) + inputStateRefs.add(stateAndRef.ref) } override fun input(stateRef: StateRef) { - val notary = ledgerInterpreter.resolveStateRef(stateRef).notary + val notary = ledgerInterpreter.resolveStateRef(stateRef).notary signers.add(notary.owningKey) inputStateRefs.add(stateRef) } - override fun output(label: String?, notary: Party, contractState: State) { + override fun output(label: String?, notary: Party, contractState: ContractState) { outputStates.add(LabeledOutput(label, TransactionState(contractState, notary))) } @@ -92,14 +102,14 @@ data class TestTransactionDslInterpreter( commands.add(Command(commandData, signers)) } - override fun _verifies(identityService: IdentityService) { - val resolvedTransaction = ledgerInterpreter.resolveWireTransaction(toWireTransaction(), identityService) + override fun verifies() { + val resolvedTransaction = ledgerInterpreter.resolveWireTransaction(toWireTransaction()) resolvedTransaction.verify() } - override fun failsWith(expectedMessage: String?, identityService: IdentityService) { + override fun failsWith(expectedMessage: String?) { val exceptionThrown = try { - _verifies(identityService) + verifies() false } catch (exception: Exception) { if (expectedMessage != null) { @@ -110,7 +120,7 @@ data class TestTransactionDslInterpreter( ) } else if (!exceptionMessage.toLowerCase().contains(expectedMessage.toLowerCase())) { throw AssertionError( - "Expected exception containing '$expectedMessage' but raised exception was '$exceptionMessage'" + "Expected exception containing '$expectedMessage' but raised exception was '$exception'" ) } } @@ -122,58 +132,60 @@ data class TestTransactionDslInterpreter( } } - override fun tweak(dsl: TransactionDsl>.() -> Unit) = + override fun tweak(dsl: TransactionDsl.() -> Unit) = dsl(TransactionDsl(copy())) - override fun retrieveOutputStateAndRef(label: String): StateAndRef? = - ledgerInterpreter.labelToOutputStateAndRefs[label] + override fun retrieveOutputStateAndRef(clazz: Class, label: String) = ledgerInterpreter.retrieveOutputStateAndRef(clazz, label) } -class AttachmentResolutionException(val attachmentId: SecureHash) : +class AttachmentResolutionException(attachmentId: SecureHash) : Exception("Attachment with id $attachmentId not found") -data class TestLedgerDslInterpreter private constructor ( - internal val stateClazz: Class, - internal val labelToOutputStateAndRefs: HashMap> = HashMap(), +data class TestLedgerDslInterpreter private constructor ( + private val identityService: IdentityService, + private val storageService: StorageService, + internal val labelToOutputStateAndRefs: HashMap> = HashMap(), private val transactionWithLocations: HashMap = HashMap(), - private val nonVerifiedTransactionWithLocations: HashMap = HashMap(), - private val attachments: HashMap = HashMap() -) : LedgerDslInterpreter> { + private val nonVerifiedTransactionWithLocations: HashMap = HashMap() +) : LedgerDslInterpreter { // We specify [labelToOutputStateAndRefs] just so that Kotlin picks the primary constructor instead of cycling - constructor(stateClazz: Class) : this(stateClazz, labelToOutputStateAndRefs = HashMap()) + constructor(identityService: IdentityService, storageService: StorageService) : this( + identityService, storageService, labelToOutputStateAndRefs = HashMap() + ) companion object { - /** - * Convenience factory to avoid having to pass in the Class - */ - inline fun create() = TestLedgerDslInterpreter(State::class.java) - private fun getCallerLocation(offset: Int): String { val stackTraceElement = Thread.currentThread().stackTrace[3 + offset] return stackTraceElement.toString() } } - private data class WireTransactionWithLocation(val transaction: WireTransaction, val location: String) - private class VerifiesFailed(transactionLocation: String, cause: Throwable) : + internal data class WireTransactionWithLocation( + val label: String?, + val transaction: WireTransaction, + val location: String + ) + class VerifiesFailed(transactionLocation: String, cause: Throwable) : Exception("Transaction defined at ($transactionLocation) didn't verify: $cause", cause) + class TypeMismatch(requested: Class<*>, actual: Class<*>) : + Exception("Actual type $actual is not a subtype of requested type $requested") - internal fun copy(): TestLedgerDslInterpreter = + internal fun copy(): TestLedgerDslInterpreter = TestLedgerDslInterpreter( - stateClazz = stateClazz, + identityService, + storageService, labelToOutputStateAndRefs = HashMap(labelToOutputStateAndRefs), transactionWithLocations = HashMap(transactionWithLocations), - nonVerifiedTransactionWithLocations = HashMap(nonVerifiedTransactionWithLocations), - attachments = HashMap(attachments) + nonVerifiedTransactionWithLocations = HashMap(nonVerifiedTransactionWithLocations) ) - fun resolveWireTransaction(wireTransaction: WireTransaction, identityService: IdentityService): TransactionForVerification { + internal fun resolveWireTransaction(wireTransaction: WireTransaction): TransactionForVerification { return wireTransaction.run { val authenticatedCommands = commands.map { AuthenticatedObject(it.signers, it.signers.mapNotNull { identityService.partyFromKey(it) }, it.value) } - val resolvedInputStates = inputs.map { resolveStateRef(it) } + val resolvedInputStates = inputs.map { resolveStateRef(it) } val resolvedAttachments = attachments.map { resolveAttachment(it) } TransactionForVerification( inputs = resolvedInputStates, @@ -188,30 +200,30 @@ data class TestLedgerDslInterpreter private constructor ( } } - fun resolveStateRef(stateRef: StateRef): TransactionState { + internal inline fun resolveStateRef(stateRef: StateRef): TransactionState { val transactionWithLocation = transactionWithLocations[stateRef.txhash] ?: nonVerifiedTransactionWithLocations[stateRef.txhash] ?: throw TransactionResolutionException(stateRef.txhash) val output = transactionWithLocation.transaction.outputs[stateRef.index] - return if (stateClazz.isInstance(output.data)) @Suppress("UNCHECKED_CAST") { + return if (State::class.java.isAssignableFrom(output.data.javaClass)) @Suppress("UNCHECKED_CAST") { output as TransactionState } else { - throw IllegalArgumentException("Referenced state is of another type than requested") + throw TypeMismatch(requested = State::class.java, actual = output.data.javaClass) } } - fun resolveAttachment(attachmentId: SecureHash): Attachment = - attachments[attachmentId] ?: throw AttachmentResolutionException(attachmentId) + internal fun resolveAttachment(attachmentId: SecureHash): Attachment = + storageService.attachments.openAttachment(attachmentId) ?: throw AttachmentResolutionException(attachmentId) - private fun interpretTransactionDsl(dsl: TransactionDsl>.() -> Unit): - TestTransactionDslInterpreter { + private fun interpretTransactionDsl(dsl: TransactionDsl.() -> Unit): + TestTransactionDslInterpreter { val transactionInterpreter = TestTransactionDslInterpreter(this) dsl(TransactionDsl(transactionInterpreter)) return transactionInterpreter } - private fun toTransactionGroup(identityService: IdentityService, storageService: StorageService): TransactionGroup { + fun toTransactionGroup(): TransactionGroup { val ledgerTransactions = transactionWithLocations.map { it.value.transaction.toLedgerTransaction(identityService, storageService.attachments) } @@ -221,10 +233,23 @@ data class TestLedgerDslInterpreter private constructor ( return TransactionGroup(ledgerTransactions.toSet(), nonVerifiedLedgerTransactions.toSet()) } + fun transactionName(transactionHash: SecureHash): String? { + val transactionWithLocation = transactionWithLocations[transactionHash] + return if (transactionWithLocation != null) { + transactionWithLocation.label ?: "TX[${transactionWithLocation.location}]" + } else { + null + } + } + + fun outputToLabel(state: ContractState): String? = + labelToOutputStateAndRefs.filter { it.value.state.data == state }.keys.firstOrNull() + private fun recordTransactionWithTransactionMap( - dsl: TransactionDsl>.() -> Unit, + transactionLabel: String?, + dsl: TransactionDsl.() -> Unit, transactionMap: HashMap = HashMap() - ) { + ): WireTransaction { val transactionLocation = getCallerLocation(3) val transactionInterpreter = interpretTransactionDsl(dsl) // Create the WireTransaction @@ -237,38 +262,67 @@ data class TestLedgerDslInterpreter private constructor ( } transactionMap[wireTransaction.serialized.hash] = - WireTransactionWithLocation(wireTransaction, transactionLocation) + WireTransactionWithLocation(transactionLabel, wireTransaction, transactionLocation) + return wireTransaction } - override fun transaction(dsl: TransactionDsl>.() -> Unit) = - recordTransactionWithTransactionMap(dsl, transactionWithLocations) + override fun transaction(transactionLabel: String?, dsl: TransactionDsl.() -> Unit) = + recordTransactionWithTransactionMap(transactionLabel, dsl, transactionWithLocations) - override fun nonVerifiedTransaction(dsl: TransactionDsl>.() -> Unit) = - recordTransactionWithTransactionMap(dsl, nonVerifiedTransactionWithLocations) + override fun nonVerifiedTransaction(transactionLabel: String?, dsl: TransactionDsl.() -> Unit) = + recordTransactionWithTransactionMap(transactionLabel, dsl, nonVerifiedTransactionWithLocations) override fun tweak( - dsl: LedgerDsl, - LedgerDslInterpreter>>.() -> Unit) = + dsl: LedgerDsl>.() -> Unit) = dsl(LedgerDsl(copy())) override fun attachment(attachment: Attachment): SecureHash { - attachments[attachment.id] = attachment + storageService.attachments.importAttachment(attachment.open()) return attachment.id } - override fun _verifies(identityService: IdentityService, storageService: StorageService) { - val transactionGroup = toTransactionGroup(identityService, storageService) + override fun verifies() { + val transactionGroup = toTransactionGroup() try { transactionGroup.verify() } catch (exception: TransactionVerificationException) { throw VerifiesFailed(transactionWithLocations[exception.tx.origHash]?.location ?: "", exception) } } + + override fun retrieveOutputStateAndRef(clazz: Class, label: String): StateAndRef { + val stateAndRef = labelToOutputStateAndRefs[label] + if (stateAndRef == null) { + throw IllegalArgumentException("State with label '$label' was not found") + } else if (!clazz.isAssignableFrom(stateAndRef.state.data.javaClass)) { + throw TypeMismatch(requested = clazz, actual = stateAndRef.state.data.javaClass) + } else { + @Suppress("UNCHECKED_CAST") + return stateAndRef as StateAndRef + } + } +} + +fun signAll(transactionsToSign: List, vararg extraKeys: KeyPair): List { + return transactionsToSign.map { wtx -> + val allPubKeys = wtx.signers.toMutableSet() + val bits = wtx.serialize() + require(bits == wtx.serialized) + val signatures = ArrayList() + for (key in ALL_TEST_KEYS + extraKeys) { + if (allPubKeys.contains(key.public)) { + signatures += key.signWithECDSA(bits) + allPubKeys -= key.public + } + } + SignedTransaction(bits, signatures) + } } fun main(args: Array) { - ledger { + ledger { nonVerifiedTransaction { output("hello") { DummyLinearState() } } diff --git a/core/src/main/kotlin/com/r3corda/core/testing/TestUtils.kt b/core/src/main/kotlin/com/r3corda/core/testing/TestUtils.kt index d5148803e6..9266fe1069 100644 --- a/core/src/main/kotlin/com/r3corda/core/testing/TestUtils.kt +++ b/core/src/main/kotlin/com/r3corda/core/testing/TestUtils.kt @@ -6,7 +6,6 @@ import com.google.common.base.Throwables import com.google.common.net.HostAndPort import com.r3corda.core.contracts.* import com.r3corda.core.crypto.* -import com.r3corda.core.node.services.IdentityService import com.r3corda.core.node.services.testing.MockIdentityService import com.r3corda.core.node.services.testing.MockStorageService import com.r3corda.core.seconds @@ -219,22 +218,21 @@ open class TransactionForTest : AbstractTransactionForTest() { } fun input(s: ContractState) = input { s } - protected fun runCommandsAndVerify(time: Instant) { + protected fun runCommandsAndVerify() { val cmds = commandsToAuthenticatedObjects() val tx = TransactionForVerification(inStates, outStates.map { it.state }, emptyList(), cmds, SecureHash.Companion.randomSHA256(), signers.toList(), type) tx.verify() } - @JvmOverloads - fun accepts(time: Instant = TEST_TX_TIME): LastLineShouldTestForAcceptOrFailure { - runCommandsAndVerify(time) + fun accepts(): LastLineShouldTestForAcceptOrFailure { + runCommandsAndVerify() return LastLineShouldTestForAcceptOrFailure.Token } @JvmOverloads - fun rejects(withMessage: String? = null, time: Instant = TEST_TX_TIME): LastLineShouldTestForAcceptOrFailure { + fun rejects(withMessage: String? = null): LastLineShouldTestForAcceptOrFailure { val r = try { - runCommandsAndVerify(time) + runCommandsAndVerify() false } catch (e: Exception) { val m = e.message @@ -300,7 +298,7 @@ open class TransactionForTest : AbstractTransactionForTest() { } } -class TransactionGroupDSL(private val stateType: Class) { +class TransactionGroupDSL(private val stateType: Class) { open inner class WireTransactionDSL : AbstractTransactionForTest() { private val inStates = ArrayList() diff --git a/core/src/main/kotlin/com/r3corda/core/testing/TransactionDslInterpreter.kt b/core/src/main/kotlin/com/r3corda/core/testing/TransactionDslInterpreter.kt index ac8eba0adf..2822f830e1 100644 --- a/core/src/main/kotlin/com/r3corda/core/testing/TransactionDslInterpreter.kt +++ b/core/src/main/kotlin/com/r3corda/core/testing/TransactionDslInterpreter.kt @@ -15,50 +15,39 @@ import java.time.Instant * TODO: Move the [State] binding to the primitives' level to allow different State types, use reflection to check types * dynamically, come up with a substitute for primitives relying on early bind */ -interface TransactionDslInterpreter { +interface TransactionDslInterpreter : OutputStateLookup { fun input(stateLabel: String) fun input(stateRef: StateRef) - fun output(label: String?, notary: Party, contractState: State) + fun output(label: String?, notary: Party, contractState: ContractState) fun attachment(attachmentId: SecureHash) fun _command(signers: List, commandData: CommandData) - fun _verifies(identityService: IdentityService) - fun failsWith(expectedMessage: String?, identityService: IdentityService) - fun tweak(dsl: TransactionDsl>.() -> Unit) - fun retrieveOutputStateAndRef(label: String): StateAndRef? - - val String.outputStateAndRef: StateAndRef - get() = retrieveOutputStateAndRef(this) ?: throw IllegalArgumentException("State with label '$this' was not found") - val String.output: TransactionState - get() = outputStateAndRef.state - val String.outputRef: StateRef - get() = outputStateAndRef.ref + fun verifies() + fun failsWith(expectedMessage: String?) + fun tweak(dsl: TransactionDsl.() -> Unit) } class TransactionDsl< - State: ContractState, - out TransactionInterpreter: TransactionDslInterpreter + out TransactionInterpreter: TransactionDslInterpreter > (val interpreter: TransactionInterpreter) - : TransactionDslInterpreter by interpreter { + : TransactionDslInterpreter by interpreter { // Convenience functions - fun output(label: String? = null, notary: Party = DUMMY_NOTARY, contractStateClosure: () -> State) = + fun output(label: String? = null, notary: Party = DUMMY_NOTARY, contractStateClosure: () -> ContractState) = output(label, notary, contractStateClosure()) @JvmOverloads - fun output(label: String? = null, contractState: State) = output(label, DUMMY_NOTARY, contractState) + fun output(label: String? = null, contractState: ContractState) = output(label, DUMMY_NOTARY, contractState) fun command(vararg signers: PublicKey, commandDataClosure: () -> CommandData) = _command(listOf(*signers), commandDataClosure()) fun command(signer: PublicKey, commandData: CommandData) = _command(listOf(signer), commandData) - fun verifies(identityService: IdentityService = MOCK_IDENTITY_SERVICE) = _verifies(identityService) - @JvmOverloads fun timestamp(time: Instant, notary: PublicKey = DUMMY_NOTARY.owningKey) = timestamp(TimestampCommand(time, 30.seconds), notary) @JvmOverloads fun timestamp(data: TimestampCommand, notary: PublicKey = DUMMY_NOTARY.owningKey) = command(notary, data) - fun fails(identityService: IdentityService = MOCK_IDENTITY_SERVICE) = failsWith(null, identityService) - infix fun `fails with`(msg: String) = failsWith(msg, MOCK_IDENTITY_SERVICE) + fun fails() = failsWith(null) + infix fun `fails with`(msg: String) = failsWith(msg) }