mirror of
https://github.com/corda/corda.git
synced 2025-01-31 08:25:50 +00:00
core: Add convenience input(), remove TransactionGroupDSL
This commit is contained in:
parent
cde315aca9
commit
cb47e00feb
@ -38,4 +38,12 @@ class LedgerDsl<
|
|||||||
inline fun <reified State: ContractState> String.output(): TransactionState<State> =
|
inline fun <reified State: ContractState> String.output(): TransactionState<State> =
|
||||||
outputStateAndRef<State>().state
|
outputStateAndRef<State>().state
|
||||||
fun String.outputRef(): StateRef = outputStateAndRef<ContractState>().ref
|
fun String.outputRef(): StateRef = outputStateAndRef<ContractState>().ref
|
||||||
|
|
||||||
|
fun TransactionDslInterpreter.input(state: ContractState) {
|
||||||
|
val transaction = nonVerifiedTransaction {
|
||||||
|
output { state }
|
||||||
|
}
|
||||||
|
input(transaction.outRef<ContractState>(0).ref)
|
||||||
|
}
|
||||||
|
fun TransactionDslInterpreter.input(stateClosure: () -> ContractState) = input(stateClosure())
|
||||||
}
|
}
|
||||||
|
@ -58,7 +58,7 @@ data class TestTransactionDslInterpreter(
|
|||||||
) : TransactionDslInterpreter, OutputStateLookup {
|
) : TransactionDslInterpreter, OutputStateLookup {
|
||||||
private fun copy(): TestTransactionDslInterpreter =
|
private fun copy(): TestTransactionDslInterpreter =
|
||||||
TestTransactionDslInterpreter(
|
TestTransactionDslInterpreter(
|
||||||
ledgerInterpreter = ledgerInterpreter.copy(),
|
ledgerInterpreter = ledgerInterpreter,
|
||||||
inputStateRefs = ArrayList(inputStateRefs),
|
inputStateRefs = ArrayList(inputStateRefs),
|
||||||
outputStates = ArrayList(outputStates),
|
outputStates = ArrayList(outputStates),
|
||||||
attachments = ArrayList(attachments),
|
attachments = ArrayList(attachments),
|
||||||
@ -77,12 +77,6 @@ data class TestTransactionDslInterpreter(
|
|||||||
type = transactionType
|
type = transactionType
|
||||||
)
|
)
|
||||||
|
|
||||||
override fun input(stateLabel: String) {
|
|
||||||
val stateAndRef = retrieveOutputStateAndRef(ContractState::class.java, stateLabel)
|
|
||||||
signers.add(stateAndRef.state.notary.owningKey)
|
|
||||||
inputStateRefs.add(stateAndRef.ref)
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun input(stateRef: StateRef) {
|
override fun input(stateRef: StateRef) {
|
||||||
val notary = ledgerInterpreter.resolveStateRef<ContractState>(stateRef).notary
|
val notary = ledgerInterpreter.resolveStateRef<ContractState>(stateRef).notary
|
||||||
signers.add(notary.owningKey)
|
signers.add(notary.owningKey)
|
||||||
@ -109,7 +103,7 @@ data class TestTransactionDslInterpreter(
|
|||||||
|
|
||||||
override fun failsWith(expectedMessage: String?) {
|
override fun failsWith(expectedMessage: String?) {
|
||||||
val exceptionThrown = try {
|
val exceptionThrown = try {
|
||||||
verifies()
|
this.verifies()
|
||||||
false
|
false
|
||||||
} catch (exception: Exception) {
|
} catch (exception: Exception) {
|
||||||
if (expectedMessage != null) {
|
if (expectedMessage != null) {
|
||||||
@ -149,6 +143,8 @@ data class TestLedgerDslInterpreter private constructor (
|
|||||||
private val nonVerifiedTransactionWithLocations: HashMap<SecureHash, WireTransactionWithLocation> = HashMap()
|
private val nonVerifiedTransactionWithLocations: HashMap<SecureHash, WireTransactionWithLocation> = HashMap()
|
||||||
) : LedgerDslInterpreter<TestTransactionDslInterpreter> {
|
) : LedgerDslInterpreter<TestTransactionDslInterpreter> {
|
||||||
|
|
||||||
|
val wireTransactions: List<WireTransaction> get() = transactionWithLocations.values.map { it.transaction }
|
||||||
|
|
||||||
// We specify [labelToOutputStateAndRefs] just so that Kotlin picks the primary constructor instead of cycling
|
// We specify [labelToOutputStateAndRefs] just so that Kotlin picks the primary constructor instead of cycling
|
||||||
constructor(identityService: IdentityService, storageService: StorageService) : this(
|
constructor(identityService: IdentityService, storageService: StorageService) : this(
|
||||||
identityService, storageService, labelToOutputStateAndRefs = HashMap()
|
identityService, storageService, labelToOutputStateAndRefs = HashMap()
|
||||||
@ -344,6 +340,6 @@ fun main(args: Array<String>) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
verifies()
|
this.verifies()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -297,152 +297,3 @@ open class TransactionForTest : AbstractTransactionForTest() {
|
|||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
class TransactionGroupDSL<out T : ContractState>(private val stateType: Class<T>) {
|
|
||||||
open inner class WireTransactionDSL : AbstractTransactionForTest() {
|
|
||||||
private val inStates = ArrayList<StateRef>()
|
|
||||||
|
|
||||||
fun input(label: String) {
|
|
||||||
val notaryKey = label.output.notary.owningKey
|
|
||||||
signers.add(notaryKey)
|
|
||||||
inStates.add(label.outputRef)
|
|
||||||
}
|
|
||||||
|
|
||||||
fun toWireTransaction() = WireTransaction(inStates, attachments, outStates.map { it.state }, commands, signers.toList(), type)
|
|
||||||
}
|
|
||||||
|
|
||||||
val String.output: TransactionState<T>
|
|
||||||
get() = labelToOutputs[this] ?: throw IllegalArgumentException("State with label '$this' was not found")
|
|
||||||
val String.outputRef: StateRef get() = labelToRefs[this] ?: throw IllegalArgumentException("Unknown label \"$this\"")
|
|
||||||
|
|
||||||
fun <C : ContractState> lookup(label: String): StateAndRef<C> {
|
|
||||||
val output = label.output
|
|
||||||
val newOutput = TransactionState(output.data as C, output.notary)
|
|
||||||
return StateAndRef(newOutput, label.outputRef)
|
|
||||||
}
|
|
||||||
|
|
||||||
private inner class InternalWireTransactionDSL : WireTransactionDSL() {
|
|
||||||
fun finaliseAndInsertLabels(): WireTransaction {
|
|
||||||
val wtx = toWireTransaction()
|
|
||||||
for ((index, labelledState) in outStates.withIndex()) {
|
|
||||||
if (labelledState.label != null) {
|
|
||||||
labelToRefs[labelledState.label] = StateRef(wtx.id, index)
|
|
||||||
if (stateType.isInstance(labelledState.state.data)) {
|
|
||||||
labelToOutputs[labelledState.label] = labelledState.state as TransactionState<T>
|
|
||||||
}
|
|
||||||
outputsToLabels[labelledState.state] = labelledState.label
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return wtx
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private val rootTxns = ArrayList<WireTransaction>()
|
|
||||||
private val labelToRefs = HashMap<String, StateRef>()
|
|
||||||
private val labelToOutputs = HashMap<String, TransactionState<T>>()
|
|
||||||
private val outputsToLabels = HashMap<TransactionState<*>, String>()
|
|
||||||
|
|
||||||
fun labelForState(output: TransactionState<*>): String? = outputsToLabels[output]
|
|
||||||
|
|
||||||
inner class Roots {
|
|
||||||
fun transaction(vararg outputStates: LabeledOutput): Roots {
|
|
||||||
val outs = outputStates.map { it.state }
|
|
||||||
val wtx = WireTransaction(emptyList(), emptyList(), outs, emptyList(), emptyList(), TransactionType.General())
|
|
||||||
for ((index, state) in outputStates.withIndex()) {
|
|
||||||
val label = state.label!!
|
|
||||||
labelToRefs[label] = StateRef(wtx.id, index)
|
|
||||||
outputsToLabels[state.state] = label
|
|
||||||
labelToOutputs[label] = state.state as TransactionState<T>
|
|
||||||
}
|
|
||||||
rootTxns.add(wtx)
|
|
||||||
return this
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Note: Don't delete, this is intended to trigger compiler diagnostic when the DSL primitive is used in the wrong place
|
|
||||||
*/
|
|
||||||
@Deprecated("Does not nest ", level = DeprecationLevel.ERROR)
|
|
||||||
fun roots(body: Roots.() -> Unit) {
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Note: Don't delete, this is intended to trigger compiler diagnostic when the DSL primitive is used in the wrong place
|
|
||||||
*/
|
|
||||||
@Deprecated("Use the vararg form of transaction inside roots", level = DeprecationLevel.ERROR)
|
|
||||||
fun transaction(body: WireTransactionDSL.() -> Unit) {
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fun roots(body: Roots.() -> Unit) = Roots().apply { body() }
|
|
||||||
|
|
||||||
val txns = ArrayList<WireTransaction>()
|
|
||||||
private val txnToLabelMap = HashMap<SecureHash, String>()
|
|
||||||
|
|
||||||
@JvmOverloads
|
|
||||||
fun transaction(label: String? = null, body: WireTransactionDSL.() -> Unit): WireTransaction {
|
|
||||||
val forTest = InternalWireTransactionDSL()
|
|
||||||
forTest.body()
|
|
||||||
val wtx = forTest.finaliseAndInsertLabels()
|
|
||||||
txns.add(wtx)
|
|
||||||
if (label != null)
|
|
||||||
txnToLabelMap[wtx.id] = label
|
|
||||||
return wtx
|
|
||||||
}
|
|
||||||
|
|
||||||
fun labelForTransaction(tx: WireTransaction): String? = txnToLabelMap[tx.id]
|
|
||||||
fun labelForTransaction(tx: LedgerTransaction): String? = txnToLabelMap[tx.id]
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Note: Don't delete, this is intended to trigger compiler diagnostic when the DSL primitive is used in the wrong place
|
|
||||||
*/
|
|
||||||
@Deprecated("Does not nest ", level = DeprecationLevel.ERROR)
|
|
||||||
fun transactionGroup(body: TransactionGroupDSL<T>.() -> Unit) {
|
|
||||||
}
|
|
||||||
|
|
||||||
fun toTransactionGroup() = TransactionGroup(
|
|
||||||
txns.map { it.toLedgerTransaction(MOCK_IDENTITY_SERVICE, MockStorageService().attachments) }.toSet(),
|
|
||||||
rootTxns.map { it.toLedgerTransaction(MOCK_IDENTITY_SERVICE, MockStorageService().attachments) }.toSet()
|
|
||||||
)
|
|
||||||
|
|
||||||
class Failed(val index: Int, cause: Throwable) : Exception("Transaction $index didn't verify", cause)
|
|
||||||
|
|
||||||
fun verify() {
|
|
||||||
val group = toTransactionGroup()
|
|
||||||
try {
|
|
||||||
group.verify()
|
|
||||||
} catch (e: TransactionVerificationException) {
|
|
||||||
// Let the developer know the index of the transaction that failed.
|
|
||||||
val wtx: WireTransaction = txns.find { it.id == e.tx.origHash }!!
|
|
||||||
throw Failed(txns.indexOf(wtx) + 1, e)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fun expectFailureOfTx(index: Int, message: String): Exception {
|
|
||||||
val e = assertFailsWith(Failed::class) {
|
|
||||||
verify()
|
|
||||||
}
|
|
||||||
assertEquals(index, e.index)
|
|
||||||
if (!(e.cause?.message ?: "") .contains(message))
|
|
||||||
throw AssertionError("Exception should have said '$message' but was actually: ${e.cause?.message}", e.cause)
|
|
||||||
return e
|
|
||||||
}
|
|
||||||
|
|
||||||
fun signAll(txnsToSign: List<WireTransaction> = txns, vararg extraKeys: KeyPair): List<SignedTransaction> {
|
|
||||||
return txnsToSign.map { wtx ->
|
|
||||||
val allPubKeys = wtx.signers.toMutableSet()
|
|
||||||
val bits = wtx.serialize()
|
|
||||||
require(bits == wtx.serialized)
|
|
||||||
val sigs = ArrayList<DigitalSignature.WithKey>()
|
|
||||||
for (key in ALL_TEST_KEYS + extraKeys) {
|
|
||||||
if (allPubKeys.contains(key.public)) {
|
|
||||||
sigs += key.signWithECDSA(bits)
|
|
||||||
allPubKeys -= key.public
|
|
||||||
}
|
|
||||||
}
|
|
||||||
SignedTransaction(bits, sigs)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
inline fun <reified T : ContractState> transactionGroupFor(body: TransactionGroupDSL<T>.() -> Unit) = TransactionGroupDSL<T>(T::class.java).apply { this.body() }
|
|
||||||
fun transactionGroup(body: TransactionGroupDSL<ContractState>.() -> Unit) = TransactionGroupDSL(ContractState::class.java).apply { this.body() }
|
|
||||||
|
@ -16,7 +16,6 @@ import java.time.Instant
|
|||||||
* dynamically, come up with a substitute for primitives relying on early bind
|
* dynamically, come up with a substitute for primitives relying on early bind
|
||||||
*/
|
*/
|
||||||
interface TransactionDslInterpreter : OutputStateLookup {
|
interface TransactionDslInterpreter : OutputStateLookup {
|
||||||
fun input(stateLabel: String)
|
|
||||||
fun input(stateRef: StateRef)
|
fun input(stateRef: StateRef)
|
||||||
fun output(label: String?, notary: Party, contractState: ContractState)
|
fun output(label: String?, notary: Party, contractState: ContractState)
|
||||||
fun attachment(attachmentId: SecureHash)
|
fun attachment(attachmentId: SecureHash)
|
||||||
@ -32,6 +31,8 @@ class TransactionDsl<
|
|||||||
> (val interpreter: TransactionInterpreter)
|
> (val interpreter: TransactionInterpreter)
|
||||||
: TransactionDslInterpreter by interpreter {
|
: TransactionDslInterpreter by interpreter {
|
||||||
|
|
||||||
|
fun input(stateLabel: String) = input(retrieveOutputStateAndRef(ContractState::class.java, stateLabel).ref)
|
||||||
|
|
||||||
// Convenience functions
|
// Convenience functions
|
||||||
fun output(label: String? = null, notary: Party = DUMMY_NOTARY, contractStateClosure: () -> ContractState) =
|
fun output(label: String? = null, notary: Party = DUMMY_NOTARY, contractStateClosure: () -> ContractState) =
|
||||||
output(label, notary, contractStateClosure())
|
output(label, notary, contractStateClosure())
|
||||||
|
@ -120,19 +120,14 @@ class TransactionGroupTests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// We have to do this manually without the DSL because transactionGroup { } won't let us create a tx that
|
|
||||||
// points nowhere.
|
|
||||||
val input = StateAndRef(A_THOUSAND_POUNDS `with notary` DUMMY_NOTARY, generateStateRef())
|
val input = StateAndRef(A_THOUSAND_POUNDS `with notary` DUMMY_NOTARY, generateStateRef())
|
||||||
tg.txns += TransactionType.General.Builder().apply {
|
tg.apply {
|
||||||
addInputState(input)
|
transaction {
|
||||||
addOutputState(A_THOUSAND_POUNDS `with notary` DUMMY_NOTARY)
|
assertFailsWith(TransactionResolutionException::class) {
|
||||||
addCommand(TestCash.Commands.Move(), BOB_PUBKEY)
|
input(input.ref)
|
||||||
}.toWireTransaction()
|
}
|
||||||
|
}
|
||||||
val e = assertFailsWith(TransactionResolutionException::class) {
|
|
||||||
tg.verify()
|
|
||||||
}
|
}
|
||||||
assertEquals(e.hash, input.ref.txhash)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
Loading…
x
Reference in New Issue
Block a user