From bf4272b64ad31d651e36e0fea1a23f9bc425fdb2 Mon Sep 17 00:00:00 2001 From: Andras Slemmer Date: Fri, 1 Jul 2016 16:09:58 +0100 Subject: [PATCH] core: transaction/ledger DSL interfaces and implementation for tests --- .../core/testing/LedgerDslInterpreter.kt | 35 +++ .../com/r3corda/core/testing/TestDsl.kt | 295 ++++++++++++++++++ .../core/testing/TransactionDslInterpreter.kt | 64 ++++ 3 files changed, 394 insertions(+) create mode 100644 core/src/main/kotlin/com/r3corda/core/testing/LedgerDslInterpreter.kt create mode 100644 core/src/main/kotlin/com/r3corda/core/testing/TestDsl.kt create mode 100644 core/src/main/kotlin/com/r3corda/core/testing/TransactionDslInterpreter.kt diff --git a/core/src/main/kotlin/com/r3corda/core/testing/LedgerDslInterpreter.kt b/core/src/main/kotlin/com/r3corda/core/testing/LedgerDslInterpreter.kt new file mode 100644 index 0000000000..c8b7668e18 --- /dev/null +++ b/core/src/main/kotlin/com/r3corda/core/testing/LedgerDslInterpreter.kt @@ -0,0 +1,35 @@ +package com.r3corda.core.testing + +import com.r3corda.core.contracts.Attachment +import com.r3corda.core.contracts.ContractState +import com.r3corda.core.crypto.SecureHash +import com.r3corda.core.node.services.IdentityService +import com.r3corda.core.node.services.StorageService +import com.r3corda.core.node.services.testing.MockStorageService + +interface LedgerDslInterpreter> { + fun transaction(dsl: TransactionDsl.() -> Unit): Unit + fun nonVerifiedTransaction(dsl: TransactionDsl.() -> Unit): Unit + fun tweak(dsl: LedgerDsl>.() -> Unit) + fun attachment(attachment: Attachment): SecureHash + fun _verifies(identityService: IdentityService, storageService: StorageService) +} + +/** + * This is the class the top-level primitives deal with. It delegates all other primitives to the contained interpreter. + * This way we have a decoupling of the DSL "AST" and the interpretation(s) of it. Note how the delegation forces + * covariance of the TransactionInterpreter parameter + */ +class LedgerDsl< + State: ContractState, + out TransactionInterpreter: TransactionDslInterpreter, + out LedgerInterpreter: LedgerDslInterpreter + > (val interpreter: LedgerInterpreter) + : LedgerDslInterpreter> by interpreter { + + @JvmOverloads + fun verifies( + identityService: IdentityService = MOCK_IDENTITY_SERVICE, + storageService: StorageService = MockStorageService() + ) = _verifies(identityService, storageService) +} diff --git a/core/src/main/kotlin/com/r3corda/core/testing/TestDsl.kt b/core/src/main/kotlin/com/r3corda/core/testing/TestDsl.kt new file mode 100644 index 0000000000..f508ba58fb --- /dev/null +++ b/core/src/main/kotlin/com/r3corda/core/testing/TestDsl.kt @@ -0,0 +1,295 @@ +package com.r3corda.core.testing + +import com.r3corda.core.contracts.* +import com.r3corda.core.crypto.Party +import com.r3corda.core.crypto.SecureHash +import com.r3corda.core.node.services.IdentityService +import com.r3corda.core.node.services.StorageService +import java.security.PublicKey +import java.util.* + +inline fun ledger( + dsl: LedgerDsl, TestLedgerDslInterpreter>.() -> Unit) = + dsl(LedgerDsl(TestLedgerDslInterpreter.create())) + +@Deprecated( + message = "ledger doesn't nest, use tweak", + replaceWith = ReplaceWith("tweak"), + level = DeprecationLevel.ERROR) +fun TransactionDslInterpreter.ledger( + dsl: LedgerDsl, TestLedgerDslInterpreter>.() -> Unit) { + this.toString() + dsl.toString() +} + +@Deprecated( + message = "ledger doesn't nest, use tweak", + replaceWith = ReplaceWith("tweak"), + level = DeprecationLevel.ERROR) +fun LedgerDslInterpreter>.ledger( + dsl: LedgerDsl, TestLedgerDslInterpreter>.() -> Unit) { + this.toString() + dsl.toString() +} + +/** + * This interpreter builds a transaction, and [TransactionDsl.verifies] that the resolved transaction is correct. Note + * that transactions corresponding to input states are not verified. Use [LedgerDsl.verifies] for that. + */ +data class TestTransactionDslInterpreter( + private val ledgerInterpreter: TestLedgerDslInterpreter, + private val inputStateRefs: ArrayList = arrayListOf(), + internal val outputStates: ArrayList = arrayListOf(), + private val attachments: ArrayList = arrayListOf(), + private val commands: ArrayList = arrayListOf(), + private val signers: LinkedHashSet = LinkedHashSet(), + private val transactionType: TransactionType = TransactionType.General() +) : TransactionDslInterpreter { + + private fun copy(): TestTransactionDslInterpreter = + TestTransactionDslInterpreter( + ledgerInterpreter = ledgerInterpreter.copy(), + inputStateRefs = ArrayList(inputStateRefs), + outputStates = ArrayList(outputStates), + attachments = ArrayList(attachments), + commands = ArrayList(commands), + signers = LinkedHashSet(signers), + transactionType = transactionType + ) + + internal fun toWireTransaction(): WireTransaction = + WireTransaction( + inputs = inputStateRefs, + outputs = outputStates.map { it.state }, + attachments = attachments, + commands = commands, + signers = signers.toList(), + type = transactionType + ) + + override fun input(stateLabel: String) { + val notary = stateLabel.output.notary.owningKey + signers.add(notary) + inputStateRefs.add(stateLabel.outputRef) + } + + override fun input(stateRef: StateRef) { + val notary = ledgerInterpreter.resolveStateRef(stateRef).notary + signers.add(notary.owningKey) + inputStateRefs.add(stateRef) + } + + override fun output(label: String?, notary: Party, contractState: State) { + outputStates.add(LabeledOutput(label, TransactionState(contractState, notary))) + } + + override fun attachment(attachmentId: SecureHash) { + attachments.add(attachmentId) + } + + override fun _command(signers: List, commandData: CommandData) { + this.signers.addAll(signers) + commands.add(Command(commandData, signers)) + } + + override fun _verifies(identityService: IdentityService) { + val resolvedTransaction = ledgerInterpreter.resolveWireTransaction(toWireTransaction(), identityService) + resolvedTransaction.verify() + } + + override fun failsWith(expectedMessage: String?, identityService: IdentityService) { + val exceptionThrown = try { + _verifies(identityService) + false + } catch (exception: Exception) { + if (expectedMessage != null) { + val exceptionMessage = exception.message + if (exceptionMessage == null) { + throw AssertionError( + "Expected exception containing '$expectedMessage' but raised exception had no message" + ) + } else if (!exceptionMessage.toLowerCase().contains(expectedMessage.toLowerCase())) { + throw AssertionError( + "Expected exception containing '$expectedMessage' but raised exception was '$exceptionMessage'" + ) + } + } + true + } + + if (!exceptionThrown) { + throw AssertionError("Expected exception but didn't get one") + } + } + + override fun tweak(dsl: TransactionDsl>.() -> Unit) = + dsl(TransactionDsl(copy())) + + override fun retrieveOutputStateAndRef(label: String): StateAndRef? = + ledgerInterpreter.labelToOutputStateAndRefs[label] +} + +class AttachmentResolutionException(val attachmentId: SecureHash) : + Exception("Attachment with id $attachmentId not found") + +data class TestLedgerDslInterpreter private constructor ( + internal val stateClazz: Class, + internal val labelToOutputStateAndRefs: HashMap> = HashMap(), + private val transactionWithLocations: HashMap = HashMap(), + private val nonVerifiedTransactionWithLocations: HashMap = HashMap(), + private val attachments: HashMap = HashMap() +) : LedgerDslInterpreter> { + + // We specify [labelToOutputStateAndRefs] just so that Kotlin picks the primary constructor instead of cycling + constructor(stateClazz: Class) : this(stateClazz, labelToOutputStateAndRefs = HashMap()) + + companion object { + /** + * Convenience factory to avoid having to pass in the Class + */ + inline fun create() = TestLedgerDslInterpreter(State::class.java) + + private fun getCallerLocation(offset: Int): String { + val stackTraceElement = Thread.currentThread().stackTrace[3 + offset] + return stackTraceElement.toString() + } + } + + private data class WireTransactionWithLocation(val transaction: WireTransaction, val location: String) + private class VerifiesFailed(transactionLocation: String, cause: Throwable) : + Exception("Transaction defined at ($transactionLocation) didn't verify: $cause", cause) + + internal fun copy(): TestLedgerDslInterpreter = + TestLedgerDslInterpreter( + stateClazz = stateClazz, + labelToOutputStateAndRefs = HashMap(labelToOutputStateAndRefs), + transactionWithLocations = HashMap(transactionWithLocations), + nonVerifiedTransactionWithLocations = HashMap(nonVerifiedTransactionWithLocations), + attachments = HashMap(attachments) + ) + + fun resolveWireTransaction(wireTransaction: WireTransaction, identityService: IdentityService): TransactionForVerification { + return wireTransaction.run { + val authenticatedCommands = commands.map { + AuthenticatedObject(it.signers, it.signers.mapNotNull { identityService.partyFromKey(it) }, it.value) + } + val resolvedInputStates = inputs.map { resolveStateRef(it) } + val resolvedAttachments = attachments.map { resolveAttachment(it) } + TransactionForVerification( + inputs = resolvedInputStates, + outputs = outputs, + commands = authenticatedCommands, + origHash = wireTransaction.serialized.hash, + attachments = resolvedAttachments, + signers = signers.toList(), + type = type + ) + + } + } + + fun resolveStateRef(stateRef: StateRef): TransactionState { + val transactionWithLocation = + transactionWithLocations[stateRef.txhash] ?: + nonVerifiedTransactionWithLocations[stateRef.txhash] ?: + throw TransactionResolutionException(stateRef.txhash) + val output = transactionWithLocation.transaction.outputs[stateRef.index] + return if (stateClazz.isInstance(output.data)) @Suppress("UNCHECKED_CAST") { + output as TransactionState + } else { + throw IllegalArgumentException("Referenced state is of another type than requested") + } + } + + fun resolveAttachment(attachmentId: SecureHash): Attachment = + attachments[attachmentId] ?: throw AttachmentResolutionException(attachmentId) + + private fun interpretTransactionDsl(dsl: TransactionDsl>.() -> Unit): + TestTransactionDslInterpreter { + val transactionInterpreter = TestTransactionDslInterpreter(this) + dsl(TransactionDsl(transactionInterpreter)) + return transactionInterpreter + } + + private fun toTransactionGroup(identityService: IdentityService, storageService: StorageService): TransactionGroup { + val ledgerTransactions = transactionWithLocations.map { + it.value.transaction.toLedgerTransaction(identityService, storageService.attachments) + } + val nonVerifiedLedgerTransactions = nonVerifiedTransactionWithLocations.map { + it.value.transaction.toLedgerTransaction(identityService, storageService.attachments) + } + return TransactionGroup(ledgerTransactions.toSet(), nonVerifiedLedgerTransactions.toSet()) + } + + private fun recordTransactionWithTransactionMap( + dsl: TransactionDsl>.() -> Unit, + transactionMap: HashMap = HashMap() + ) { + val transactionLocation = getCallerLocation(3) + val transactionInterpreter = interpretTransactionDsl(dsl) + // Create the WireTransaction + val wireTransaction = transactionInterpreter.toWireTransaction() + // Record the output states + transactionInterpreter.outputStates.forEachIndexed { index, labeledOutput -> + if (labeledOutput.label != null) { + labelToOutputStateAndRefs[labeledOutput.label] = wireTransaction.outRef(index) + } + } + + transactionMap[wireTransaction.serialized.hash] = + WireTransactionWithLocation(wireTransaction, transactionLocation) + + } + + override fun transaction(dsl: TransactionDsl>.() -> Unit) = + recordTransactionWithTransactionMap(dsl, transactionWithLocations) + + override fun nonVerifiedTransaction(dsl: TransactionDsl>.() -> Unit) = + recordTransactionWithTransactionMap(dsl, nonVerifiedTransactionWithLocations) + + override fun tweak( + dsl: LedgerDsl, + LedgerDslInterpreter>>.() -> Unit) = + dsl(LedgerDsl(copy())) + + override fun attachment(attachment: Attachment): SecureHash { + attachments[attachment.id] = attachment + return attachment.id + } + + override fun _verifies(identityService: IdentityService, storageService: StorageService) { + val transactionGroup = toTransactionGroup(identityService, storageService) + try { + transactionGroup.verify() + } catch (exception: TransactionVerificationException) { + throw VerifiesFailed(transactionWithLocations[exception.tx.origHash]?.location ?: "", exception) + } + } +} + +fun main(args: Array) { + ledger { + nonVerifiedTransaction { + output("hello") { DummyLinearState() } + } + + transaction { + input("hello") + tweak { + timestamp(TEST_TX_TIME, MEGA_CORP_PUBKEY) + fails() + } + } + + tweak { + + transaction { + input("hello") + timestamp(TEST_TX_TIME, MEGA_CORP_PUBKEY) + fails() + } + } + + verifies() + } +} diff --git a/core/src/main/kotlin/com/r3corda/core/testing/TransactionDslInterpreter.kt b/core/src/main/kotlin/com/r3corda/core/testing/TransactionDslInterpreter.kt new file mode 100644 index 0000000000..ac8eba0adf --- /dev/null +++ b/core/src/main/kotlin/com/r3corda/core/testing/TransactionDslInterpreter.kt @@ -0,0 +1,64 @@ +package com.r3corda.core.testing + +import com.r3corda.core.contracts.* +import com.r3corda.core.crypto.Party +import com.r3corda.core.crypto.SecureHash +import com.r3corda.core.node.services.IdentityService +import com.r3corda.core.seconds +import java.security.PublicKey +import java.time.Instant + + +/** + * [State] is bound at the top level. This allows the definition of e.g. [String.output], however it also means that we + * cannot mix different types of states in the same transaction. + * TODO: Move the [State] binding to the primitives' level to allow different State types, use reflection to check types + * dynamically, come up with a substitute for primitives relying on early bind + */ +interface TransactionDslInterpreter { + fun input(stateLabel: String) + fun input(stateRef: StateRef) + fun output(label: String?, notary: Party, contractState: State) + fun attachment(attachmentId: SecureHash) + fun _command(signers: List, commandData: CommandData) + fun _verifies(identityService: IdentityService) + fun failsWith(expectedMessage: String?, identityService: IdentityService) + fun tweak(dsl: TransactionDsl>.() -> Unit) + fun retrieveOutputStateAndRef(label: String): StateAndRef? + + val String.outputStateAndRef: StateAndRef + get() = retrieveOutputStateAndRef(this) ?: throw IllegalArgumentException("State with label '$this' was not found") + val String.output: TransactionState + get() = outputStateAndRef.state + val String.outputRef: StateRef + get() = outputStateAndRef.ref +} + + +class TransactionDsl< + State: ContractState, + out TransactionInterpreter: TransactionDslInterpreter + > (val interpreter: TransactionInterpreter) + : TransactionDslInterpreter by interpreter { + + // Convenience functions + fun output(label: String? = null, notary: Party = DUMMY_NOTARY, contractStateClosure: () -> State) = + output(label, notary, contractStateClosure()) + @JvmOverloads + fun output(label: String? = null, contractState: State) = output(label, DUMMY_NOTARY, contractState) + + fun command(vararg signers: PublicKey, commandDataClosure: () -> CommandData) = + _command(listOf(*signers), commandDataClosure()) + fun command(signer: PublicKey, commandData: CommandData) = _command(listOf(signer), commandData) + + fun verifies(identityService: IdentityService = MOCK_IDENTITY_SERVICE) = _verifies(identityService) + + @JvmOverloads + fun timestamp(time: Instant, notary: PublicKey = DUMMY_NOTARY.owningKey) = + timestamp(TimestampCommand(time, 30.seconds), notary) + @JvmOverloads + fun timestamp(data: TimestampCommand, notary: PublicKey = DUMMY_NOTARY.owningKey) = command(notary, data) + + fun fails(identityService: IdentityService = MOCK_IDENTITY_SERVICE) = failsWith(null, identityService) + infix fun `fails with`(msg: String) = failsWith(msg, MOCK_IDENTITY_SERVICE) +}