core: Add LastLineShouldBeVerifiesOrFails, fix attachment primitive, Java interop

This commit is contained in:
Andras Slemmer 2016-07-05 12:05:19 +01:00
parent bdda3d239a
commit a27f195b4f
4 changed files with 134 additions and 73 deletions

View File

@ -2,19 +2,24 @@ package com.r3corda.core.testing
import com.r3corda.core.contracts.* import com.r3corda.core.contracts.*
import com.r3corda.core.crypto.SecureHash import com.r3corda.core.crypto.SecureHash
import java.io.InputStream
interface OutputStateLookup { interface OutputStateLookup {
fun <State: ContractState> retrieveOutputStateAndRef(clazz: Class<State>, label: String): StateAndRef<State> fun <State: ContractState> retrieveOutputStateAndRef(clazz: Class<State>, label: String): StateAndRef<State>
} }
interface LedgerDslInterpreter<Return, out TransactionInterpreter: TransactionDslInterpreter<Return>> :
interface LedgerDslInterpreter<out TransactionInterpreter: TransactionDslInterpreter> :
OutputStateLookup { OutputStateLookup {
fun transaction(transactionLabel: String?, dsl: TransactionDsl<TransactionInterpreter>.() -> Unit): WireTransaction fun transaction(
fun nonVerifiedTransaction(transactionLabel: String?, dsl: TransactionDsl<TransactionInterpreter>.() -> Unit): WireTransaction transactionLabel: String?,
fun tweak(dsl: LedgerDsl<TransactionInterpreter, LedgerDslInterpreter<TransactionInterpreter>>.() -> Unit) dsl: TransactionDsl<Return, TransactionInterpreter>.() -> Return
fun attachment(attachment: Attachment): SecureHash ): WireTransaction
fun nonVerifiedTransaction(
transactionLabel: String?,
dsl: TransactionDsl<Return, TransactionInterpreter>.() -> Unit
): WireTransaction
fun tweak(dsl: LedgerDsl<Return, TransactionInterpreter, LedgerDslInterpreter<Return, TransactionInterpreter>>.() -> Unit)
fun attachment(attachment: InputStream): SecureHash
fun verifies() fun verifies()
} }
@ -24,13 +29,15 @@ interface LedgerDslInterpreter<out TransactionInterpreter: TransactionDslInterpr
* covariance of the TransactionInterpreter parameter * covariance of the TransactionInterpreter parameter
*/ */
class LedgerDsl< class LedgerDsl<
out TransactionInterpreter: TransactionDslInterpreter, Return,
out LedgerInterpreter: LedgerDslInterpreter<TransactionInterpreter> out TransactionInterpreter: TransactionDslInterpreter<Return>,
> (val interpreter: LedgerInterpreter) out LedgerInterpreter: LedgerDslInterpreter<Return, TransactionInterpreter>
: LedgerDslInterpreter<TransactionDslInterpreter> by interpreter { > (val interpreter: LedgerInterpreter
) : LedgerDslInterpreter<Return, TransactionDslInterpreter<Return>> by interpreter {
fun transaction(dsl: TransactionDsl<TransactionDslInterpreter>.() -> Unit) = transaction(null, dsl) fun transaction(dsl: TransactionDsl<Return, TransactionDslInterpreter<Return>>.() -> Return) =
fun nonVerifiedTransaction(dsl: TransactionDsl<TransactionDslInterpreter>.() -> Unit) = transaction(null, dsl)
fun nonVerifiedTransaction(dsl: TransactionDsl<Return, TransactionDslInterpreter<Return>>.() -> Unit) =
nonVerifiedTransaction(null, dsl) nonVerifiedTransaction(null, dsl)
inline fun <reified State: ContractState> String.outputStateAndRef(): StateAndRef<State> = inline fun <reified State: ContractState> String.outputStateAndRef(): StateAndRef<State> =
@ -38,12 +45,4 @@ class LedgerDsl<
inline fun <reified State: ContractState> String.output(): TransactionState<State> = inline fun <reified State: ContractState> String.output(): TransactionState<State> =
outputStateAndRef<State>().state outputStateAndRef<State>().state
fun String.outputRef(): StateRef = outputStateAndRef<ContractState>().ref fun String.outputRef(): StateRef = outputStateAndRef<ContractState>().ref
fun TransactionDslInterpreter.input(state: ContractState) {
val transaction = nonVerifiedTransaction {
output { state }
}
input(transaction.outRef<ContractState>(0).ref)
}
fun TransactionDslInterpreter.input(stateClosure: () -> ContractState) = input(stateClosure())
} }

View File

@ -9,26 +9,23 @@ import com.r3corda.core.node.services.IdentityService
import com.r3corda.core.node.services.StorageService import com.r3corda.core.node.services.StorageService
import com.r3corda.core.node.services.testing.MockStorageService import com.r3corda.core.node.services.testing.MockStorageService
import com.r3corda.core.serialization.serialize import com.r3corda.core.serialization.serialize
import java.io.InputStream
import java.security.KeyPair import java.security.KeyPair
import java.security.PublicKey import java.security.PublicKey
import java.util.* import java.util.*
inline fun ledger( fun ledger(
identityService: IdentityService = MOCK_IDENTITY_SERVICE, identityService: IdentityService = MOCK_IDENTITY_SERVICE,
storageService: StorageService = MockStorageService(), storageService: StorageService = MockStorageService(),
dsl: LedgerDsl<TestTransactionDslInterpreter, TestLedgerDslInterpreter>.() -> Unit dsl: LedgerDsl<LastLineShouldTestForVerifiesOrFails, TestTransactionDslInterpreter, TestLedgerDslInterpreter>.() -> Unit
): LedgerDsl<TestTransactionDslInterpreter, TestLedgerDslInterpreter> { ) = JavaTestHelpers.ledger(identityService, storageService, dsl)
val ledgerDsl = LedgerDsl(TestLedgerDslInterpreter(identityService, storageService))
dsl(ledgerDsl)
return ledgerDsl
}
@Deprecated( @Deprecated(
message = "ledger doesn't nest, use tweak", message = "ledger doesn't nest, use tweak",
replaceWith = ReplaceWith("tweak"), replaceWith = ReplaceWith("tweak"),
level = DeprecationLevel.ERROR) level = DeprecationLevel.ERROR)
fun TransactionDslInterpreter.ledger( fun TransactionDslInterpreter<LastLineShouldTestForVerifiesOrFails>.ledger(
dsl: LedgerDsl<TestTransactionDslInterpreter, TestLedgerDslInterpreter>.() -> Unit) { dsl: LedgerDsl<LastLineShouldTestForVerifiesOrFails, TestTransactionDslInterpreter, TestLedgerDslInterpreter>.() -> Unit) {
this.toString() this.toString()
dsl.toString() dsl.toString()
} }
@ -37,25 +34,34 @@ fun TransactionDslInterpreter.ledger(
message = "ledger doesn't nest, use tweak", message = "ledger doesn't nest, use tweak",
replaceWith = ReplaceWith("tweak"), replaceWith = ReplaceWith("tweak"),
level = DeprecationLevel.ERROR) level = DeprecationLevel.ERROR)
fun LedgerDslInterpreter<TransactionDslInterpreter>.ledger( fun LedgerDslInterpreter<LastLineShouldTestForVerifiesOrFails, TransactionDslInterpreter<LastLineShouldTestForVerifiesOrFails>>.ledger(
dsl: LedgerDsl<TestTransactionDslInterpreter, TestLedgerDslInterpreter>.() -> Unit) { dsl: LedgerDsl<LastLineShouldTestForVerifiesOrFails, TestTransactionDslInterpreter, TestLedgerDslInterpreter>.() -> Unit) {
this.toString() this.toString()
dsl.toString() dsl.toString()
} }
/** If you jumped here from a compiler error make sure the last line of your test tests for a transaction verify or fail
* This is a dummy type that can only be instantiated by functions in this module. This way we can ensure that all tests
* will have as the last line either an accept or a failure test. The name is deliberately long to help make sense of
* the triggered diagnostic
*/
sealed class LastLineShouldTestForVerifiesOrFails {
internal object Token: LastLineShouldTestForVerifiesOrFails()
}
/** /**
* This interpreter builds a transaction, and [TransactionDsl.verifies] that the resolved transaction is correct. Note * This interpreter builds a transaction, and [TransactionDsl.verifies] that the resolved transaction is correct. Note
* that transactions corresponding to input states are not verified. Use [LedgerDsl.verifies] for that. * that transactions corresponding to input states are not verified. Use [LedgerDsl.verifies] for that.
*/ */
data class TestTransactionDslInterpreter( data class TestTransactionDslInterpreter(
private val ledgerInterpreter: TestLedgerDslInterpreter, override val ledgerInterpreter: TestLedgerDslInterpreter,
private val inputStateRefs: ArrayList<StateRef> = arrayListOf(), private val inputStateRefs: ArrayList<StateRef> = arrayListOf(),
internal val outputStates: ArrayList<LabeledOutput> = arrayListOf(), internal val outputStates: ArrayList<LabeledOutput> = arrayListOf(),
private val attachments: ArrayList<SecureHash> = arrayListOf(), private val attachments: ArrayList<SecureHash> = arrayListOf(),
private val commands: ArrayList<Command> = arrayListOf(), private val commands: ArrayList<Command> = arrayListOf(),
private val signers: LinkedHashSet<PublicKey> = LinkedHashSet(), private val signers: LinkedHashSet<PublicKey> = LinkedHashSet(),
private val transactionType: TransactionType = TransactionType.General() private val transactionType: TransactionType = TransactionType.General()
) : TransactionDslInterpreter, OutputStateLookup { ) : TransactionDslInterpreter<LastLineShouldTestForVerifiesOrFails>, OutputStateLookup by ledgerInterpreter {
private fun copy(): TestTransactionDslInterpreter = private fun copy(): TestTransactionDslInterpreter =
TestTransactionDslInterpreter( TestTransactionDslInterpreter(
ledgerInterpreter = ledgerInterpreter, ledgerInterpreter = ledgerInterpreter,
@ -83,7 +89,7 @@ data class TestTransactionDslInterpreter(
inputStateRefs.add(stateRef) inputStateRefs.add(stateRef)
} }
override fun output(label: String?, notary: Party, contractState: ContractState) { override fun _output(label: String?, notary: Party, contractState: ContractState) {
outputStates.add(LabeledOutput(label, TransactionState(contractState, notary))) outputStates.add(LabeledOutput(label, TransactionState(contractState, notary)))
} }
@ -96,12 +102,13 @@ data class TestTransactionDslInterpreter(
commands.add(Command(commandData, signers)) commands.add(Command(commandData, signers))
} }
override fun verifies() { override fun verifies(): LastLineShouldTestForVerifiesOrFails {
val resolvedTransaction = ledgerInterpreter.resolveWireTransaction(toWireTransaction()) val resolvedTransaction = ledgerInterpreter.resolveWireTransaction(toWireTransaction())
resolvedTransaction.verify() resolvedTransaction.verify()
return LastLineShouldTestForVerifiesOrFails.Token
} }
override fun failsWith(expectedMessage: String?) { override fun failsWith(expectedMessage: String?): LastLineShouldTestForVerifiesOrFails {
val exceptionThrown = try { val exceptionThrown = try {
this.verifies() this.verifies()
false false
@ -124,12 +131,16 @@ data class TestTransactionDslInterpreter(
if (!exceptionThrown) { if (!exceptionThrown) {
throw AssertionError("Expected exception but didn't get one") throw AssertionError("Expected exception but didn't get one")
} }
return LastLineShouldTestForVerifiesOrFails.Token
} }
override fun tweak(dsl: TransactionDsl<TransactionDslInterpreter>.() -> Unit) = override fun tweak(
dsl(TransactionDsl(copy())) dsl: TransactionDsl<
LastLineShouldTestForVerifiesOrFails,
override fun <State: ContractState> retrieveOutputStateAndRef(clazz: Class<State>, label: String) = ledgerInterpreter.retrieveOutputStateAndRef(clazz, label) TransactionDslInterpreter<LastLineShouldTestForVerifiesOrFails>
>.() -> LastLineShouldTestForVerifiesOrFails
) = dsl(TransactionDsl(copy()))
} }
class AttachmentResolutionException(attachmentId: SecureHash) : class AttachmentResolutionException(attachmentId: SecureHash) :
@ -141,7 +152,7 @@ data class TestLedgerDslInterpreter private constructor (
internal val labelToOutputStateAndRefs: HashMap<String, StateAndRef<ContractState>> = HashMap(), internal val labelToOutputStateAndRefs: HashMap<String, StateAndRef<ContractState>> = HashMap(),
private val transactionWithLocations: HashMap<SecureHash, WireTransactionWithLocation> = HashMap(), private val transactionWithLocations: HashMap<SecureHash, WireTransactionWithLocation> = HashMap(),
private val nonVerifiedTransactionWithLocations: HashMap<SecureHash, WireTransactionWithLocation> = HashMap() private val nonVerifiedTransactionWithLocations: HashMap<SecureHash, WireTransactionWithLocation> = HashMap()
) : LedgerDslInterpreter<TestTransactionDslInterpreter> { ) : LedgerDslInterpreter<LastLineShouldTestForVerifiesOrFails, TestTransactionDslInterpreter> {
val wireTransactions: List<WireTransaction> get() = transactionWithLocations.values.map { it.transaction } val wireTransactions: List<WireTransaction> get() = transactionWithLocations.values.map { it.transaction }
@ -212,8 +223,9 @@ data class TestLedgerDslInterpreter private constructor (
internal fun resolveAttachment(attachmentId: SecureHash): Attachment = internal fun resolveAttachment(attachmentId: SecureHash): Attachment =
storageService.attachments.openAttachment(attachmentId) ?: throw AttachmentResolutionException(attachmentId) storageService.attachments.openAttachment(attachmentId) ?: throw AttachmentResolutionException(attachmentId)
private fun interpretTransactionDsl(dsl: TransactionDsl<TestTransactionDslInterpreter>.() -> Unit): private fun <Return> interpretTransactionDsl(
TestTransactionDslInterpreter { dsl: TransactionDsl<LastLineShouldTestForVerifiesOrFails, TestTransactionDslInterpreter>.() -> Return
): TestTransactionDslInterpreter {
val transactionInterpreter = TestTransactionDslInterpreter(this) val transactionInterpreter = TestTransactionDslInterpreter(this)
dsl(TransactionDsl(transactionInterpreter)) dsl(TransactionDsl(transactionInterpreter))
return transactionInterpreter return transactionInterpreter
@ -241,9 +253,9 @@ data class TestLedgerDslInterpreter private constructor (
fun outputToLabel(state: ContractState): String? = fun outputToLabel(state: ContractState): String? =
labelToOutputStateAndRefs.filter { it.value.state.data == state }.keys.firstOrNull() labelToOutputStateAndRefs.filter { it.value.state.data == state }.keys.firstOrNull()
private fun recordTransactionWithTransactionMap( private fun <Return> recordTransactionWithTransactionMap(
transactionLabel: String?, transactionLabel: String?,
dsl: TransactionDsl<TestTransactionDslInterpreter>.() -> Unit, dsl: TransactionDsl<LastLineShouldTestForVerifiesOrFails, TestTransactionDslInterpreter>.() -> Return,
transactionMap: HashMap<SecureHash, WireTransactionWithLocation> = HashMap() transactionMap: HashMap<SecureHash, WireTransactionWithLocation> = HashMap()
): WireTransaction { ): WireTransaction {
val transactionLocation = getCallerLocation(3) val transactionLocation = getCallerLocation(3)
@ -263,20 +275,24 @@ data class TestLedgerDslInterpreter private constructor (
return wireTransaction return wireTransaction
} }
override fun transaction(transactionLabel: String?, dsl: TransactionDsl<TestTransactionDslInterpreter>.() -> Unit) = override fun transaction(
recordTransactionWithTransactionMap(transactionLabel, dsl, transactionWithLocations) transactionLabel: String?,
dsl: TransactionDsl<LastLineShouldTestForVerifiesOrFails, TestTransactionDslInterpreter>.() -> LastLineShouldTestForVerifiesOrFails
) = recordTransactionWithTransactionMap(transactionLabel, dsl, transactionWithLocations)
override fun nonVerifiedTransaction(transactionLabel: String?, dsl: TransactionDsl<TestTransactionDslInterpreter>.() -> Unit) = override fun nonVerifiedTransaction(
transactionLabel: String?,
dsl: TransactionDsl<LastLineShouldTestForVerifiesOrFails, TestTransactionDslInterpreter>.() -> Unit
) =
recordTransactionWithTransactionMap(transactionLabel, dsl, nonVerifiedTransactionWithLocations) recordTransactionWithTransactionMap(transactionLabel, dsl, nonVerifiedTransactionWithLocations)
override fun tweak( override fun tweak(
dsl: LedgerDsl<TestTransactionDslInterpreter, dsl: LedgerDsl<LastLineShouldTestForVerifiesOrFails, TestTransactionDslInterpreter,
LedgerDslInterpreter<TestTransactionDslInterpreter>>.() -> Unit) = LedgerDslInterpreter<LastLineShouldTestForVerifiesOrFails, TestTransactionDslInterpreter>>.() -> Unit) =
dsl(LedgerDsl(copy())) dsl(LedgerDsl(copy()))
override fun attachment(attachment: Attachment): SecureHash { override fun attachment(attachment: InputStream): SecureHash {
storageService.attachments.importAttachment(attachment.open()) return storageService.attachments.importAttachment(attachment)
return attachment.id
} }
override fun verifies() { override fun verifies() {

View File

@ -6,6 +6,8 @@ import com.google.common.base.Throwables
import com.google.common.net.HostAndPort import com.google.common.net.HostAndPort
import com.r3corda.core.contracts.* import com.r3corda.core.contracts.*
import com.r3corda.core.crypto.* import com.r3corda.core.crypto.*
import com.r3corda.core.node.services.IdentityService
import com.r3corda.core.node.services.StorageService
import com.r3corda.core.node.services.testing.MockIdentityService import com.r3corda.core.node.services.testing.MockIdentityService
import com.r3corda.core.node.services.testing.MockStorageService import com.r3corda.core.node.services.testing.MockStorageService
import com.r3corda.core.seconds import com.r3corda.core.seconds
@ -94,9 +96,16 @@ object JavaTestHelpers {
@JvmStatic fun generateStateRef() = StateRef(SecureHash.randomSHA256(), 0) @JvmStatic fun generateStateRef() = StateRef(SecureHash.randomSHA256(), 0)
@JvmStatic fun transaction(body: TransactionForTest.() -> LastLineShouldTestForAcceptOrFailure): LastLineShouldTestForAcceptOrFailure { @JvmStatic @JvmOverloads fun ledger(
return body(TransactionForTest()) identityService: IdentityService = MOCK_IDENTITY_SERVICE,
storageService: StorageService = MockStorageService(),
dsl: LedgerDsl<LastLineShouldTestForVerifiesOrFails, TestTransactionDslInterpreter, TestLedgerDslInterpreter>.() -> Unit
): LedgerDsl<LastLineShouldTestForVerifiesOrFails, TestTransactionDslInterpreter, TestLedgerDslInterpreter> {
val ledgerDsl = LedgerDsl(TestLedgerDslInterpreter(identityService, storageService))
dsl(ledgerDsl)
return ledgerDsl
} }
} }
val TEST_TX_TIME = JavaTestHelpers.TEST_TX_TIME val TEST_TX_TIME = JavaTestHelpers.TEST_TX_TIME

View File

@ -3,41 +3,78 @@ package com.r3corda.core.testing
import com.r3corda.core.contracts.* import com.r3corda.core.contracts.*
import com.r3corda.core.crypto.Party import com.r3corda.core.crypto.Party
import com.r3corda.core.crypto.SecureHash import com.r3corda.core.crypto.SecureHash
import com.r3corda.core.node.services.IdentityService
import com.r3corda.core.seconds import com.r3corda.core.seconds
import java.security.PublicKey import java.security.PublicKey
import java.time.Instant import java.time.Instant
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
//
// Defines a simple DSL for building pseudo-transactions (not the same as the wire protocol) for testing purposes.
//
// Define a transaction like this:
//
// ledger {
// transaction {
// input { someExpression }
// output { someExpression }
// command { someExpression }
//
// tweak {
// ... same thing but works with a copy of the parent, can add inputs/outputs/commands just within this scope.
// }
//
// contract.verifies() -> verify() should pass
// contract `fails with` "some substring of the error message"
// }
// }
//
/** /**
* [State] is bound at the top level. This allows the definition of e.g. [String.output], however it also means that we * The [TransactionDslInterpreter] defines the interface DSL interpreters should satisfy. No
* cannot mix different types of states in the same transaction. * overloading/default valuing should be done here, only the basic functions that are required to implement everything.
* TODO: Move the [State] binding to the primitives' level to allow different State types, use reflection to check types * Same goes for functions requiring reflection e.g. [OutputStateLookup.retrieveOutputStateAndRef]
* dynamically, come up with a substitute for primitives relying on early bind * Put convenience functions in [TransactionDsl] instead. There are some cases where the overloads would clash with the
* Interpreter interface, in these cases define a "backing" function in the interface instead (e.g. [_command]).
*
* This way the responsibility of providing a nice frontend DSL and the implementation(s) are separated
*/ */
interface TransactionDslInterpreter : OutputStateLookup { interface TransactionDslInterpreter<Return> : OutputStateLookup {
val ledgerInterpreter: LedgerDslInterpreter<Return, TransactionDslInterpreter<Return>>
fun input(stateRef: StateRef) fun input(stateRef: StateRef)
fun output(label: String?, notary: Party, contractState: ContractState) fun _output(label: String?, notary: Party, contractState: ContractState)
fun attachment(attachmentId: SecureHash) fun attachment(attachmentId: SecureHash)
fun _command(signers: List<PublicKey>, commandData: CommandData) fun _command(signers: List<PublicKey>, commandData: CommandData)
fun verifies() fun verifies(): Return
fun failsWith(expectedMessage: String?) fun failsWith(expectedMessage: String?): Return
fun tweak(dsl: TransactionDsl<TransactionDslInterpreter>.() -> Unit) fun tweak(
dsl: TransactionDsl<Return, TransactionDslInterpreter<Return>>.() -> Return
): Return
} }
class TransactionDsl< class TransactionDsl<
out TransactionInterpreter: TransactionDslInterpreter Return,
out TransactionInterpreter: TransactionDslInterpreter<Return>
> (val interpreter: TransactionInterpreter) > (val interpreter: TransactionInterpreter)
: TransactionDslInterpreter by interpreter { : TransactionDslInterpreter<Return> by interpreter {
fun input(stateLabel: String) = input(retrieveOutputStateAndRef(ContractState::class.java, stateLabel).ref) fun input(stateLabel: String) = input(retrieveOutputStateAndRef(ContractState::class.java, stateLabel).ref)
/**
* Adds the passed in state as a non-verified transaction output to the ledger and adds that as an input
*/
fun input(state: ContractState) {
val transaction = ledgerInterpreter.nonVerifiedTransaction(null) {
output { state }
}
input(transaction.outRef<ContractState>(0).ref)
}
fun input(stateClosure: () -> ContractState) = input(stateClosure())
// Convenience functions
fun output(label: String? = null, notary: Party = DUMMY_NOTARY, contractStateClosure: () -> ContractState) =
output(label, notary, contractStateClosure())
@JvmOverloads @JvmOverloads
fun output(label: String? = null, contractState: ContractState) = output(label, DUMMY_NOTARY, contractState) fun output(label: String? = null, notary: Party = DUMMY_NOTARY, contractStateClosure: () -> ContractState) =
_output(label, notary, contractStateClosure())
@JvmOverloads
fun output(label: String? = null, contractState: ContractState) =
_output(label, DUMMY_NOTARY, contractState)
fun command(vararg signers: PublicKey, commandDataClosure: () -> CommandData) = fun command(vararg signers: PublicKey, commandDataClosure: () -> CommandData) =
_command(listOf(*signers), commandDataClosure()) _command(listOf(*signers), commandDataClosure())