mirror of
https://github.com/corda/corda.git
synced 2025-04-28 15:02:59 +00:00
core: transaction/ledger DSL interfaces and implementation for tests
This commit is contained in:
parent
5c0e7fbbf2
commit
bf4272b64a
@ -0,0 +1,35 @@
|
||||
package com.r3corda.core.testing
|
||||
|
||||
import com.r3corda.core.contracts.Attachment
|
||||
import com.r3corda.core.contracts.ContractState
|
||||
import com.r3corda.core.crypto.SecureHash
|
||||
import com.r3corda.core.node.services.IdentityService
|
||||
import com.r3corda.core.node.services.StorageService
|
||||
import com.r3corda.core.node.services.testing.MockStorageService
|
||||
|
||||
interface LedgerDslInterpreter<State: ContractState, out TransactionInterpreter: TransactionDslInterpreter<State>> {
|
||||
fun transaction(dsl: TransactionDsl<State, TransactionInterpreter>.() -> Unit): Unit
|
||||
fun nonVerifiedTransaction(dsl: TransactionDsl<State, TransactionInterpreter>.() -> Unit): Unit
|
||||
fun tweak(dsl: LedgerDsl<State, TransactionInterpreter, LedgerDslInterpreter<State, TransactionInterpreter>>.() -> Unit)
|
||||
fun attachment(attachment: Attachment): SecureHash
|
||||
fun _verifies(identityService: IdentityService, storageService: StorageService)
|
||||
}
|
||||
|
||||
/**
|
||||
* This is the class the top-level primitives deal with. It delegates all other primitives to the contained interpreter.
|
||||
* This way we have a decoupling of the DSL "AST" and the interpretation(s) of it. Note how the delegation forces
|
||||
* covariance of the TransactionInterpreter parameter
|
||||
*/
|
||||
class LedgerDsl<
|
||||
State: ContractState,
|
||||
out TransactionInterpreter: TransactionDslInterpreter<State>,
|
||||
out LedgerInterpreter: LedgerDslInterpreter<State, TransactionInterpreter>
|
||||
> (val interpreter: LedgerInterpreter)
|
||||
: LedgerDslInterpreter<State, TransactionDslInterpreter<State>> by interpreter {
|
||||
|
||||
@JvmOverloads
|
||||
fun verifies(
|
||||
identityService: IdentityService = MOCK_IDENTITY_SERVICE,
|
||||
storageService: StorageService = MockStorageService()
|
||||
) = _verifies(identityService, storageService)
|
||||
}
|
295
core/src/main/kotlin/com/r3corda/core/testing/TestDsl.kt
Normal file
295
core/src/main/kotlin/com/r3corda/core/testing/TestDsl.kt
Normal file
@ -0,0 +1,295 @@
|
||||
package com.r3corda.core.testing
|
||||
|
||||
import com.r3corda.core.contracts.*
|
||||
import com.r3corda.core.crypto.Party
|
||||
import com.r3corda.core.crypto.SecureHash
|
||||
import com.r3corda.core.node.services.IdentityService
|
||||
import com.r3corda.core.node.services.StorageService
|
||||
import java.security.PublicKey
|
||||
import java.util.*
|
||||
|
||||
inline fun <reified State: ContractState> ledger(
|
||||
dsl: LedgerDsl<State, TestTransactionDslInterpreter<State>, TestLedgerDslInterpreter<State>>.() -> Unit) =
|
||||
dsl(LedgerDsl(TestLedgerDslInterpreter.create()))
|
||||
|
||||
@Deprecated(
|
||||
message = "ledger doesn't nest, use tweak",
|
||||
replaceWith = ReplaceWith("tweak"),
|
||||
level = DeprecationLevel.ERROR)
|
||||
fun <State: ContractState> TransactionDslInterpreter<State>.ledger(
|
||||
dsl: LedgerDsl<State, TestTransactionDslInterpreter<State>, TestLedgerDslInterpreter<State>>.() -> Unit) {
|
||||
this.toString()
|
||||
dsl.toString()
|
||||
}
|
||||
|
||||
@Deprecated(
|
||||
message = "ledger doesn't nest, use tweak",
|
||||
replaceWith = ReplaceWith("tweak"),
|
||||
level = DeprecationLevel.ERROR)
|
||||
fun <State: ContractState> LedgerDslInterpreter<State, TransactionDslInterpreter<State>>.ledger(
|
||||
dsl: LedgerDsl<State, TestTransactionDslInterpreter<State>, TestLedgerDslInterpreter<State>>.() -> Unit) {
|
||||
this.toString()
|
||||
dsl.toString()
|
||||
}
|
||||
|
||||
/**
|
||||
* This interpreter builds a transaction, and [TransactionDsl.verifies] that the resolved transaction is correct. Note
|
||||
* that transactions corresponding to input states are not verified. Use [LedgerDsl.verifies] for that.
|
||||
*/
|
||||
data class TestTransactionDslInterpreter<State: ContractState>(
|
||||
private val ledgerInterpreter: TestLedgerDslInterpreter<State>,
|
||||
private val inputStateRefs: ArrayList<StateRef> = arrayListOf(),
|
||||
internal val outputStates: ArrayList<LabeledOutput> = arrayListOf(),
|
||||
private val attachments: ArrayList<SecureHash> = arrayListOf(),
|
||||
private val commands: ArrayList<Command> = arrayListOf(),
|
||||
private val signers: LinkedHashSet<PublicKey> = LinkedHashSet(),
|
||||
private val transactionType: TransactionType = TransactionType.General()
|
||||
) : TransactionDslInterpreter<State> {
|
||||
|
||||
private fun copy(): TestTransactionDslInterpreter<State> =
|
||||
TestTransactionDslInterpreter(
|
||||
ledgerInterpreter = ledgerInterpreter.copy(),
|
||||
inputStateRefs = ArrayList(inputStateRefs),
|
||||
outputStates = ArrayList(outputStates),
|
||||
attachments = ArrayList(attachments),
|
||||
commands = ArrayList(commands),
|
||||
signers = LinkedHashSet(signers),
|
||||
transactionType = transactionType
|
||||
)
|
||||
|
||||
internal fun toWireTransaction(): WireTransaction =
|
||||
WireTransaction(
|
||||
inputs = inputStateRefs,
|
||||
outputs = outputStates.map { it.state },
|
||||
attachments = attachments,
|
||||
commands = commands,
|
||||
signers = signers.toList(),
|
||||
type = transactionType
|
||||
)
|
||||
|
||||
override fun input(stateLabel: String) {
|
||||
val notary = stateLabel.output.notary.owningKey
|
||||
signers.add(notary)
|
||||
inputStateRefs.add(stateLabel.outputRef)
|
||||
}
|
||||
|
||||
override fun input(stateRef: StateRef) {
|
||||
val notary = ledgerInterpreter.resolveStateRef(stateRef).notary
|
||||
signers.add(notary.owningKey)
|
||||
inputStateRefs.add(stateRef)
|
||||
}
|
||||
|
||||
override fun output(label: String?, notary: Party, contractState: State) {
|
||||
outputStates.add(LabeledOutput(label, TransactionState(contractState, notary)))
|
||||
}
|
||||
|
||||
override fun attachment(attachmentId: SecureHash) {
|
||||
attachments.add(attachmentId)
|
||||
}
|
||||
|
||||
override fun _command(signers: List<PublicKey>, commandData: CommandData) {
|
||||
this.signers.addAll(signers)
|
||||
commands.add(Command(commandData, signers))
|
||||
}
|
||||
|
||||
override fun _verifies(identityService: IdentityService) {
|
||||
val resolvedTransaction = ledgerInterpreter.resolveWireTransaction(toWireTransaction(), identityService)
|
||||
resolvedTransaction.verify()
|
||||
}
|
||||
|
||||
override fun failsWith(expectedMessage: String?, identityService: IdentityService) {
|
||||
val exceptionThrown = try {
|
||||
_verifies(identityService)
|
||||
false
|
||||
} catch (exception: Exception) {
|
||||
if (expectedMessage != null) {
|
||||
val exceptionMessage = exception.message
|
||||
if (exceptionMessage == null) {
|
||||
throw AssertionError(
|
||||
"Expected exception containing '$expectedMessage' but raised exception had no message"
|
||||
)
|
||||
} else if (!exceptionMessage.toLowerCase().contains(expectedMessage.toLowerCase())) {
|
||||
throw AssertionError(
|
||||
"Expected exception containing '$expectedMessage' but raised exception was '$exceptionMessage'"
|
||||
)
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
if (!exceptionThrown) {
|
||||
throw AssertionError("Expected exception but didn't get one")
|
||||
}
|
||||
}
|
||||
|
||||
override fun tweak(dsl: TransactionDsl<State, TransactionDslInterpreter<State>>.() -> Unit) =
|
||||
dsl(TransactionDsl(copy()))
|
||||
|
||||
override fun retrieveOutputStateAndRef(label: String): StateAndRef<State>? =
|
||||
ledgerInterpreter.labelToOutputStateAndRefs[label]
|
||||
}
|
||||
|
||||
class AttachmentResolutionException(val attachmentId: SecureHash) :
|
||||
Exception("Attachment with id $attachmentId not found")
|
||||
|
||||
data class TestLedgerDslInterpreter<State: ContractState> private constructor (
|
||||
internal val stateClazz: Class<State>,
|
||||
internal val labelToOutputStateAndRefs: HashMap<String, StateAndRef<State>> = HashMap(),
|
||||
private val transactionWithLocations: HashMap<SecureHash, WireTransactionWithLocation> = HashMap(),
|
||||
private val nonVerifiedTransactionWithLocations: HashMap<SecureHash, WireTransactionWithLocation> = HashMap(),
|
||||
private val attachments: HashMap<SecureHash, Attachment> = HashMap()
|
||||
) : LedgerDslInterpreter<State, TestTransactionDslInterpreter<State>> {
|
||||
|
||||
// We specify [labelToOutputStateAndRefs] just so that Kotlin picks the primary constructor instead of cycling
|
||||
constructor(stateClazz: Class<State>) : this(stateClazz, labelToOutputStateAndRefs = HashMap())
|
||||
|
||||
companion object {
|
||||
/**
|
||||
* Convenience factory to avoid having to pass in the Class
|
||||
*/
|
||||
inline fun <reified State: ContractState> create() = TestLedgerDslInterpreter(State::class.java)
|
||||
|
||||
private fun getCallerLocation(offset: Int): String {
|
||||
val stackTraceElement = Thread.currentThread().stackTrace[3 + offset]
|
||||
return stackTraceElement.toString()
|
||||
}
|
||||
}
|
||||
|
||||
private data class WireTransactionWithLocation(val transaction: WireTransaction, val location: String)
|
||||
private class VerifiesFailed(transactionLocation: String, cause: Throwable) :
|
||||
Exception("Transaction defined at ($transactionLocation) didn't verify: $cause", cause)
|
||||
|
||||
internal fun copy(): TestLedgerDslInterpreter<State> =
|
||||
TestLedgerDslInterpreter(
|
||||
stateClazz = stateClazz,
|
||||
labelToOutputStateAndRefs = HashMap(labelToOutputStateAndRefs),
|
||||
transactionWithLocations = HashMap(transactionWithLocations),
|
||||
nonVerifiedTransactionWithLocations = HashMap(nonVerifiedTransactionWithLocations),
|
||||
attachments = HashMap(attachments)
|
||||
)
|
||||
|
||||
fun resolveWireTransaction(wireTransaction: WireTransaction, identityService: IdentityService): TransactionForVerification {
|
||||
return wireTransaction.run {
|
||||
val authenticatedCommands = commands.map {
|
||||
AuthenticatedObject(it.signers, it.signers.mapNotNull { identityService.partyFromKey(it) }, it.value)
|
||||
}
|
||||
val resolvedInputStates = inputs.map { resolveStateRef(it) }
|
||||
val resolvedAttachments = attachments.map { resolveAttachment(it) }
|
||||
TransactionForVerification(
|
||||
inputs = resolvedInputStates,
|
||||
outputs = outputs,
|
||||
commands = authenticatedCommands,
|
||||
origHash = wireTransaction.serialized.hash,
|
||||
attachments = resolvedAttachments,
|
||||
signers = signers.toList(),
|
||||
type = type
|
||||
)
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
fun resolveStateRef(stateRef: StateRef): TransactionState<State> {
|
||||
val transactionWithLocation =
|
||||
transactionWithLocations[stateRef.txhash] ?:
|
||||
nonVerifiedTransactionWithLocations[stateRef.txhash] ?:
|
||||
throw TransactionResolutionException(stateRef.txhash)
|
||||
val output = transactionWithLocation.transaction.outputs[stateRef.index]
|
||||
return if (stateClazz.isInstance(output.data)) @Suppress("UNCHECKED_CAST") {
|
||||
output as TransactionState<State>
|
||||
} else {
|
||||
throw IllegalArgumentException("Referenced state is of another type than requested")
|
||||
}
|
||||
}
|
||||
|
||||
fun resolveAttachment(attachmentId: SecureHash): Attachment =
|
||||
attachments[attachmentId] ?: throw AttachmentResolutionException(attachmentId)
|
||||
|
||||
private fun interpretTransactionDsl(dsl: TransactionDsl<State, TestTransactionDslInterpreter<State>>.() -> Unit):
|
||||
TestTransactionDslInterpreter<State> {
|
||||
val transactionInterpreter = TestTransactionDslInterpreter(this)
|
||||
dsl(TransactionDsl(transactionInterpreter))
|
||||
return transactionInterpreter
|
||||
}
|
||||
|
||||
private fun toTransactionGroup(identityService: IdentityService, storageService: StorageService): TransactionGroup {
|
||||
val ledgerTransactions = transactionWithLocations.map {
|
||||
it.value.transaction.toLedgerTransaction(identityService, storageService.attachments)
|
||||
}
|
||||
val nonVerifiedLedgerTransactions = nonVerifiedTransactionWithLocations.map {
|
||||
it.value.transaction.toLedgerTransaction(identityService, storageService.attachments)
|
||||
}
|
||||
return TransactionGroup(ledgerTransactions.toSet(), nonVerifiedLedgerTransactions.toSet())
|
||||
}
|
||||
|
||||
private fun recordTransactionWithTransactionMap(
|
||||
dsl: TransactionDsl<State, TestTransactionDslInterpreter<State>>.() -> Unit,
|
||||
transactionMap: HashMap<SecureHash, WireTransactionWithLocation> = HashMap()
|
||||
) {
|
||||
val transactionLocation = getCallerLocation(3)
|
||||
val transactionInterpreter = interpretTransactionDsl(dsl)
|
||||
// Create the WireTransaction
|
||||
val wireTransaction = transactionInterpreter.toWireTransaction()
|
||||
// Record the output states
|
||||
transactionInterpreter.outputStates.forEachIndexed { index, labeledOutput ->
|
||||
if (labeledOutput.label != null) {
|
||||
labelToOutputStateAndRefs[labeledOutput.label] = wireTransaction.outRef(index)
|
||||
}
|
||||
}
|
||||
|
||||
transactionMap[wireTransaction.serialized.hash] =
|
||||
WireTransactionWithLocation(wireTransaction, transactionLocation)
|
||||
|
||||
}
|
||||
|
||||
override fun transaction(dsl: TransactionDsl<State, TestTransactionDslInterpreter<State>>.() -> Unit) =
|
||||
recordTransactionWithTransactionMap(dsl, transactionWithLocations)
|
||||
|
||||
override fun nonVerifiedTransaction(dsl: TransactionDsl<State, TestTransactionDslInterpreter<State>>.() -> Unit) =
|
||||
recordTransactionWithTransactionMap(dsl, nonVerifiedTransactionWithLocations)
|
||||
|
||||
override fun tweak(
|
||||
dsl: LedgerDsl<State, TestTransactionDslInterpreter<State>,
|
||||
LedgerDslInterpreter<State, TestTransactionDslInterpreter<State>>>.() -> Unit) =
|
||||
dsl(LedgerDsl(copy()))
|
||||
|
||||
override fun attachment(attachment: Attachment): SecureHash {
|
||||
attachments[attachment.id] = attachment
|
||||
return attachment.id
|
||||
}
|
||||
|
||||
override fun _verifies(identityService: IdentityService, storageService: StorageService) {
|
||||
val transactionGroup = toTransactionGroup(identityService, storageService)
|
||||
try {
|
||||
transactionGroup.verify()
|
||||
} catch (exception: TransactionVerificationException) {
|
||||
throw VerifiesFailed(transactionWithLocations[exception.tx.origHash]?.location ?: "<unknown>", exception)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fun main(args: Array<String>) {
|
||||
ledger<ContractState> {
|
||||
nonVerifiedTransaction {
|
||||
output("hello") { DummyLinearState() }
|
||||
}
|
||||
|
||||
transaction {
|
||||
input("hello")
|
||||
tweak {
|
||||
timestamp(TEST_TX_TIME, MEGA_CORP_PUBKEY)
|
||||
fails()
|
||||
}
|
||||
}
|
||||
|
||||
tweak {
|
||||
|
||||
transaction {
|
||||
input("hello")
|
||||
timestamp(TEST_TX_TIME, MEGA_CORP_PUBKEY)
|
||||
fails()
|
||||
}
|
||||
}
|
||||
|
||||
verifies()
|
||||
}
|
||||
}
|
@ -0,0 +1,64 @@
|
||||
package com.r3corda.core.testing
|
||||
|
||||
import com.r3corda.core.contracts.*
|
||||
import com.r3corda.core.crypto.Party
|
||||
import com.r3corda.core.crypto.SecureHash
|
||||
import com.r3corda.core.node.services.IdentityService
|
||||
import com.r3corda.core.seconds
|
||||
import java.security.PublicKey
|
||||
import java.time.Instant
|
||||
|
||||
|
||||
/**
|
||||
* [State] is bound at the top level. This allows the definition of e.g. [String.output], however it also means that we
|
||||
* cannot mix different types of states in the same transaction.
|
||||
* TODO: Move the [State] binding to the primitives' level to allow different State types, use reflection to check types
|
||||
* dynamically, come up with a substitute for primitives relying on early bind
|
||||
*/
|
||||
interface TransactionDslInterpreter<State: ContractState> {
|
||||
fun input(stateLabel: String)
|
||||
fun input(stateRef: StateRef)
|
||||
fun output(label: String?, notary: Party, contractState: State)
|
||||
fun attachment(attachmentId: SecureHash)
|
||||
fun _command(signers: List<PublicKey>, commandData: CommandData)
|
||||
fun _verifies(identityService: IdentityService)
|
||||
fun failsWith(expectedMessage: String?, identityService: IdentityService)
|
||||
fun tweak(dsl: TransactionDsl<State, TransactionDslInterpreter<State>>.() -> Unit)
|
||||
fun retrieveOutputStateAndRef(label: String): StateAndRef<State>?
|
||||
|
||||
val String.outputStateAndRef: StateAndRef<State>
|
||||
get() = retrieveOutputStateAndRef(this) ?: throw IllegalArgumentException("State with label '$this' was not found")
|
||||
val String.output: TransactionState<State>
|
||||
get() = outputStateAndRef.state
|
||||
val String.outputRef: StateRef
|
||||
get() = outputStateAndRef.ref
|
||||
}
|
||||
|
||||
|
||||
class TransactionDsl<
|
||||
State: ContractState,
|
||||
out TransactionInterpreter: TransactionDslInterpreter<State>
|
||||
> (val interpreter: TransactionInterpreter)
|
||||
: TransactionDslInterpreter<State> by interpreter {
|
||||
|
||||
// Convenience functions
|
||||
fun output(label: String? = null, notary: Party = DUMMY_NOTARY, contractStateClosure: () -> State) =
|
||||
output(label, notary, contractStateClosure())
|
||||
@JvmOverloads
|
||||
fun output(label: String? = null, contractState: State) = output(label, DUMMY_NOTARY, contractState)
|
||||
|
||||
fun command(vararg signers: PublicKey, commandDataClosure: () -> CommandData) =
|
||||
_command(listOf(*signers), commandDataClosure())
|
||||
fun command(signer: PublicKey, commandData: CommandData) = _command(listOf(signer), commandData)
|
||||
|
||||
fun verifies(identityService: IdentityService = MOCK_IDENTITY_SERVICE) = _verifies(identityService)
|
||||
|
||||
@JvmOverloads
|
||||
fun timestamp(time: Instant, notary: PublicKey = DUMMY_NOTARY.owningKey) =
|
||||
timestamp(TimestampCommand(time, 30.seconds), notary)
|
||||
@JvmOverloads
|
||||
fun timestamp(data: TimestampCommand, notary: PublicKey = DUMMY_NOTARY.owningKey) = command(notary, data)
|
||||
|
||||
fun fails(identityService: IdentityService = MOCK_IDENTITY_SERVICE) = failsWith(null, identityService)
|
||||
infix fun `fails with`(msg: String) = failsWith(msg, MOCK_IDENTITY_SERVICE)
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user