diff --git a/tests/contracts/CommercialPaperTests.kt b/tests/contracts/CommercialPaperTests.kt index be52ed6a3f..11a82b9967 100644 --- a/tests/contracts/CommercialPaperTests.kt +++ b/tests/contracts/CommercialPaperTests.kt @@ -88,9 +88,9 @@ class CommercialPaperTests { // Generate a trade lifecycle with various parameters. private fun trade(redemptionTime: Instant = TEST_TX_TIME + 8.days, aliceGetsBack: Amount = 1000.DOLLARS, - destroyPaperAtRedemption: Boolean = true): TransactionGroupForTest { + destroyPaperAtRedemption: Boolean = true): TransactionGroupForTest { val someProfits = 1200.DOLLARS - return transactionGroup { + return transactionGroupFor() { roots { transaction(900.DOLLARS.CASH `owned by` ALICE label "alice's $900") transaction(someProfits.CASH `owned by` MEGA_CORP_KEY label "some profits") @@ -108,7 +108,7 @@ class CommercialPaperTests { input("paper") input("alice's $900") output { 900.DOLLARS.CASH `owned by` MEGA_CORP_KEY } - output("alice's paper") { PAPER_1 `owned by` ALICE } + output("alice's paper") { "paper".output `owned by` ALICE } arg(ALICE) { Cash.Commands.Move } arg(MEGA_CORP_KEY) { CommercialPaper.Commands.Move } } @@ -122,7 +122,7 @@ class CommercialPaperTests { output { aliceGetsBack.CASH `owned by` ALICE } output { (someProfits - aliceGetsBack).CASH `owned by` MEGA_CORP_KEY } if (!destroyPaperAtRedemption) - output { PAPER_1 `owned by` ALICE } + output { "paper".output } arg(MEGA_CORP_KEY) { Cash.Commands.Move } arg(ALICE) { CommercialPaper.Commands.Redeem } diff --git a/tests/core/testutils/TestUtils.kt b/tests/core/testutils/TestUtils.kt index ec0513023f..1b2ba02a09 100644 --- a/tests/core/testutils/TestUtils.kt +++ b/tests/core/testutils/TestUtils.kt @@ -169,7 +169,7 @@ open class TransactionForTest : AbstractTransactionForTest() { fun transaction(body: TransactionForTest.() -> Unit) = TransactionForTest().apply { body() } -class TransactionGroupForTest { +class TransactionGroupForTest(private val stateType: Class) { open inner class LedgerTransactionForTest : AbstractTransactionForTest() { private val inStates = ArrayList() @@ -177,18 +177,26 @@ class TransactionGroupForTest { inStates.add(labelToRefs[label] ?: throw IllegalArgumentException("Unknown label \"$label\"")) } + fun toLedgerTransaction(time: Instant): LedgerTransaction { val wireCmds = commands.map { WireCommand(it.value, it.signers) } return WireTransaction(inStates, outStates.map { it.state }, wireCmds).toLedgerTransaction(time, TEST_KEYS_TO_CORP_MAP) } } + val String.output: T get() = labelToOutputs[this] ?: throw IllegalArgumentException("State with label '$this' was not found") + private inner class InternalLedgerTransactionForTest : LedgerTransactionForTest() { fun finaliseAndInsertLabels(time: Instant): LedgerTransaction { val ltx = toLedgerTransaction(time) - for ((index, state) in outStates.withIndex()) { - if (state.label != null) - labelToRefs[state.label] = ContractStateRef(ltx.hash, index) + for ((index, labelledState) in outStates.withIndex()) { + if (labelledState.label != null) { + labelToRefs[labelledState.label] = ContractStateRef(ltx.hash, index) + if (stateType.isInstance(labelledState.state)) { + @Suppress("UNCHECKED_CAST") + labelToOutputs[labelledState.label] = labelledState.state as T + } + } } return ltx } @@ -196,6 +204,7 @@ class TransactionGroupForTest { private val rootTxns = ArrayList() private val labelToRefs = HashMap() + private val labelToOutputs = HashMap() inner class Roots { fun transaction(vararg outputStates: LabeledOutput) { val outs = outputStates.map { it.state } @@ -223,7 +232,7 @@ class TransactionGroupForTest { } @Deprecated("Does not nest ", level = DeprecationLevel.ERROR) - fun transactionGroup(body: TransactionGroupForTest.() -> Unit) {} + fun transactionGroup(body: TransactionGroupForTest.() -> Unit) {} fun toTransactionGroup() = TransactionGroup(txns.map { it }.toSet(), rootTxns.toSet()) @@ -251,4 +260,5 @@ class TransactionGroupForTest { } } -fun transactionGroup(body: TransactionGroupForTest.() -> Unit) = TransactionGroupForTest().apply { this.body() } +inline fun transactionGroupFor(body: TransactionGroupForTest.() -> Unit) = TransactionGroupForTest(T::class.java).apply { this.body() } +fun transactionGroup(body: TransactionGroupForTest.() -> Unit) = TransactionGroupForTest(ContractState::class.java).apply { this.body() }