From c0650213284c1e382f681033d8ac675962a76d39 Mon Sep 17 00:00:00 2001 From: Shams Asari Date: Thu, 18 May 2023 11:33:05 +0100 Subject: [PATCH] ENT-8827: The ordering of vault query results is clobbered by ServiceHub.loadStates --- .../kotlin/net/corda/core/node/ServiceHub.kt | 2 +- .../internal/ServicesForResolutionImpl.kt | 25 +- .../node/messaging/TwoPartyTradeFlowTests.kt | 4 +- .../statemachine/FlowSoftLocksTests.kt | 1 - .../node/services/vault/VaultQueryTests.kt | 80 ++++--- .../testing/internal/vault/VaultFiller.kt | 217 ++++++++---------- 6 files changed, 168 insertions(+), 161 deletions(-) diff --git a/core/src/main/kotlin/net/corda/core/node/ServiceHub.kt b/core/src/main/kotlin/net/corda/core/node/ServiceHub.kt index d63b63edf4..612e341a6f 100644 --- a/core/src/main/kotlin/net/corda/core/node/ServiceHub.kt +++ b/core/src/main/kotlin/net/corda/core/node/ServiceHub.kt @@ -64,7 +64,7 @@ interface ServicesForResolution { /** * Given a [Set] of [StateRef]'s loads the referenced transaction and looks up the specified output [ContractState]. * - * @throws TransactionResolutionException if [stateRef] points to a non-existent transaction. + * @throws TransactionResolutionException if any of the [stateRefs] point to a non-existent transaction. */ // TODO: future implementation to use a Vault state ref -> contract state BLOB table and perform single query bulk load // as the existing transaction store will become encrypted at some point diff --git a/node/src/main/kotlin/net/corda/node/internal/ServicesForResolutionImpl.kt b/node/src/main/kotlin/net/corda/node/internal/ServicesForResolutionImpl.kt index f5836c0cc5..06e46992d4 100644 --- a/node/src/main/kotlin/net/corda/node/internal/ServicesForResolutionImpl.kt +++ b/node/src/main/kotlin/net/corda/node/internal/ServicesForResolutionImpl.kt @@ -2,6 +2,7 @@ package net.corda.node.internal import net.corda.core.contracts.* import net.corda.core.cordapp.CordappProvider +import net.corda.core.crypto.SecureHash import net.corda.core.internal.SerializedStateAndRef import net.corda.core.node.NetworkParameters import net.corda.core.node.ServicesForResolution @@ -9,8 +10,10 @@ import net.corda.core.node.services.AttachmentStorage import net.corda.core.node.services.IdentityService import net.corda.core.node.services.NetworkParametersService import net.corda.core.node.services.TransactionStorage +import net.corda.core.transactions.BaseTransaction import net.corda.core.transactions.ContractUpgradeWireTransaction import net.corda.core.transactions.NotaryChangeWireTransaction +import net.corda.core.transactions.SignedTransaction import net.corda.core.transactions.WireTransaction import net.corda.core.transactions.WireTransaction.Companion.resolveStateRefBinaryComponent @@ -26,25 +29,23 @@ data class ServicesForResolutionImpl( @Throws(TransactionResolutionException::class) override fun loadState(stateRef: StateRef): TransactionState<*> { - val stx = validatedTransactions.getTransaction(stateRef.txhash) ?: throw TransactionResolutionException(stateRef.txhash) - return stx.resolveBaseTransaction(this).outputs[stateRef.index] + return toBaseTransaction(stateRef.txhash).outputs[stateRef.index] } @Throws(TransactionResolutionException::class) override fun loadStates(stateRefs: Set): Set> { - return stateRefs.groupBy { it.txhash }.flatMap { - val stx = validatedTransactions.getTransaction(it.key) ?: throw TransactionResolutionException(it.key) - val baseTx = stx.resolveBaseTransaction(this) - it.value.map { ref -> StateAndRef(baseTx.outputs[ref.index], ref) } - }.toSet() + val baseTxs = HashMap() + return stateRefs.mapTo(LinkedHashSet()) { stateRef -> + val baseTx = baseTxs.computeIfAbsent(stateRef.txhash, ::toBaseTransaction) + StateAndRef(baseTx.outputs[stateRef.index], stateRef) + } } @Throws(TransactionResolutionException::class, AttachmentResolutionException::class) override fun loadContractAttachment(stateRef: StateRef): Attachment { // We may need to recursively chase transactions if there are notary changes. fun inner(stateRef: StateRef, forContractClassName: String?): Attachment { - val ctx = validatedTransactions.getTransaction(stateRef.txhash)?.coreTransaction - ?: throw TransactionResolutionException(stateRef.txhash) + val ctx = getSignedTransaction(stateRef.txhash).coreTransaction when (ctx) { is WireTransaction -> { val transactionState = ctx.outRef(stateRef.index).state @@ -69,4 +70,10 @@ data class ServicesForResolutionImpl( } return inner(stateRef, null) } + + private fun toBaseTransaction(txhash: SecureHash): BaseTransaction = getSignedTransaction(txhash).resolveBaseTransaction(this) + + private fun getSignedTransaction(txhash: SecureHash): SignedTransaction { + return validatedTransactions.getTransaction(txhash) ?: throw TransactionResolutionException(txhash) + } } diff --git a/node/src/test/kotlin/net/corda/node/messaging/TwoPartyTradeFlowTests.kt b/node/src/test/kotlin/net/corda/node/messaging/TwoPartyTradeFlowTests.kt index 0c49ee44ac..28a5f3b973 100644 --- a/node/src/test/kotlin/net/corda/node/messaging/TwoPartyTradeFlowTests.kt +++ b/node/src/test/kotlin/net/corda/node/messaging/TwoPartyTradeFlowTests.kt @@ -124,7 +124,7 @@ class TwoPartyTradeFlowTests(private val anonymous: Boolean) { bobNode.internals.disableDBCloseOnStop() bobNode.database.transaction { - VaultFiller(bobNode.services, dummyNotary, notary, ::Random).fillWithSomeTestCash(2000.DOLLARS, bankNode.services, 3, 10, cashIssuer) + VaultFiller(bobNode.services, dummyNotary, notary, ::Random).fillWithSomeTestCash(2000.DOLLARS, bankNode.services, 3, cashIssuer, atMostThisManyStates = 10) } val alicesFakePaper = aliceNode.database.transaction { @@ -233,7 +233,7 @@ class TwoPartyTradeFlowTests(private val anonymous: Boolean) { val issuer = bank.ref(1, 2, 3) bobNode.database.transaction { - VaultFiller(bobNode.services, dummyNotary, notary, ::Random).fillWithSomeTestCash(2000.DOLLARS, bankNode.services, 3, 10, issuer) + VaultFiller(bobNode.services, dummyNotary, notary, ::Random).fillWithSomeTestCash(2000.DOLLARS, bankNode.services, 3, issuer, atMostThisManyStates = 10) } val alicesFakePaper = aliceNode.database.transaction { fillUpForSeller(false, issuer, alice, diff --git a/node/src/test/kotlin/net/corda/node/services/statemachine/FlowSoftLocksTests.kt b/node/src/test/kotlin/net/corda/node/services/statemachine/FlowSoftLocksTests.kt index 6f0fa3278c..1930e7ffd8 100644 --- a/node/src/test/kotlin/net/corda/node/services/statemachine/FlowSoftLocksTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/statemachine/FlowSoftLocksTests.kt @@ -244,7 +244,6 @@ class FlowSoftLocksTests { 100.DOLLARS, bankNode.services, thisManyStates, - thisManyStates, cashIssuer ) } diff --git a/node/src/test/kotlin/net/corda/node/services/vault/VaultQueryTests.kt b/node/src/test/kotlin/net/corda/node/services/vault/VaultQueryTests.kt index 09964b6602..1b139ab022 100644 --- a/node/src/test/kotlin/net/corda/node/services/vault/VaultQueryTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/vault/VaultQueryTests.kt @@ -20,14 +20,13 @@ import net.corda.finance.* import net.corda.finance.contracts.CommercialPaper import net.corda.finance.contracts.Commodity import net.corda.finance.contracts.DealState -import net.corda.finance.workflows.asset.selection.AbstractCashSelection import net.corda.finance.contracts.asset.Cash import net.corda.finance.schemas.CashSchemaV1 -import net.corda.finance.schemas.CashSchemaV1.PersistentCashState import net.corda.finance.schemas.CommercialPaperSchemaV1 import net.corda.finance.test.SampleCashSchemaV2 import net.corda.finance.test.SampleCashSchemaV3 import net.corda.finance.workflows.CommercialPaperUtils +import net.corda.finance.workflows.asset.selection.AbstractCashSelection import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.nodeapi.internal.persistence.DatabaseConfig import net.corda.nodeapi.internal.persistence.DatabaseTransaction @@ -197,8 +196,9 @@ abstract class VaultQueryTestsBase : VaultQueryParties { } protected fun consumeCash(amount: Amount) = vaultFiller.consumeCash(amount, CHARLIE) - private fun setUpDb(_database: CordaPersistence, delay: Long = 0) { - _database.transaction { + + private fun setUpDb(database: CordaPersistence, delay: Long = 0) { + database.transaction { // create new states vaultFiller.fillWithSomeTestCash(100.DOLLARS, notaryServices, 10, DUMMY_CASH_ISSUER) val linearStatesXYZ = vaultFiller.fillWithSomeTestLinearStates(1, "XYZ") @@ -444,7 +444,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties { Sort.Direction.DESC -> assertThat(allStates.sortedByDescending { it.state.data.linearNumber }.sortedBy { it.ref.txhash }.sortedBy { it.ref.index }).isEqualTo(allStates) } - (1..3).forEach { + repeat(3) { val newAllStates = vaultService.queryBy(sorting = sorting, criteria = criteria).states assertThat(newAllStates.groupBy(StateAndRef<*>::ref)).hasSameSizeAs(allStates) assertThat(newAllStates).containsExactlyElementsOf(allStates) @@ -485,7 +485,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties { Sort.Direction.DESC -> assertThat(allStates.sortedByDescending { it.ref.txhash }.sortedByDescending { it.ref.index }).isEqualTo(allStates) } - (1..3).forEach { + repeat(3) { val newAllStates = vaultService.queryBy(sorting = sorting, criteria = criteria).states assertThat(newAllStates.groupBy(StateAndRef<*>::ref)).hasSameSizeAs(allStates) assertThat(newAllStates).containsExactlyElementsOf(allStates) @@ -638,7 +638,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties { } val sorted = results.states.sortedBy { it.ref.toString() } assertThat(results.states).isEqualTo(sorted) - assertThat(results.states).allSatisfy { !consumed.contains(it.ref.txhash) } + assertThat(results.states).allSatisfy { assertThat(consumed).doesNotContain(it.ref.txhash) } } } @@ -1537,7 +1537,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties { vaultFiller.fillWithSomeTestDeals(listOf("123", "456", "789")) // count fungible assets val count = builder { VaultSchemaV1.VaultStates::recordedTime.count() } - val countCriteria = QueryCriteria.VaultCustomQueryCriteria(count) + val countCriteria = VaultCustomQueryCriteria(count) val fungibleStateCount = vaultService.queryBy>(countCriteria).otherResults.single() as Long assertThat(fungibleStateCount).isEqualTo(10L) @@ -1563,7 +1563,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties { val count = builder { VaultSchemaV1.VaultStates::recordedTime.count() } // count fungible assets - val countCriteria = QueryCriteria.VaultCustomQueryCriteria(count, Vault.StateStatus.ALL) + val countCriteria = VaultCustomQueryCriteria(count, Vault.StateStatus.ALL) val fungibleStateCount = vaultService.queryBy>(countCriteria).otherResults.single() as Long assertThat(fungibleStateCount).isEqualTo(10L) @@ -1583,7 +1583,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties { // UNCONSUMED states (default) // count fungible assets - val countCriteriaUnconsumed = QueryCriteria.VaultCustomQueryCriteria(count, Vault.StateStatus.UNCONSUMED) + val countCriteriaUnconsumed = VaultCustomQueryCriteria(count, Vault.StateStatus.UNCONSUMED) val fungibleStateCountUnconsumed = vaultService.queryBy>(countCriteriaUnconsumed).otherResults.single() as Long assertThat(fungibleStateCountUnconsumed.toInt()).isEqualTo(10 - cashUpdates.consumed.size + cashUpdates.produced.size) @@ -1598,7 +1598,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties { // CONSUMED states // count fungible assets - val countCriteriaConsumed = QueryCriteria.VaultCustomQueryCriteria(count, Vault.StateStatus.CONSUMED) + val countCriteriaConsumed = VaultCustomQueryCriteria(count, Vault.StateStatus.CONSUMED) val fungibleStateCountConsumed = vaultService.queryBy>(countCriteriaConsumed).otherResults.single() as Long assertThat(fungibleStateCountConsumed.toInt()).isEqualTo(cashUpdates.consumed.size) @@ -1622,7 +1622,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties { val start = TODAY val end = TODAY.plus(30, ChronoUnit.DAYS) val recordedBetweenExpression = TimeCondition( - QueryCriteria.TimeInstantType.RECORDED, + TimeInstantType.RECORDED, ColumnPredicate.Between(start, end)) val criteria = VaultQueryCriteria(timeCondition = recordedBetweenExpression) val results = vaultService.queryBy(criteria) @@ -1632,7 +1632,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties { // Future val startFuture = TODAY.plus(1, ChronoUnit.DAYS) val recordedBetweenExpressionFuture = TimeCondition( - QueryCriteria.TimeInstantType.RECORDED, ColumnPredicate.Between(startFuture, end)) + TimeInstantType.RECORDED, ColumnPredicate.Between(startFuture, end)) val criteriaFuture = VaultQueryCriteria(timeCondition = recordedBetweenExpressionFuture) assertThat(vaultService.queryBy(criteriaFuture).states).isEmpty() } @@ -1648,7 +1648,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties { consumeCash(100.DOLLARS) val asOfDateTime = TODAY val consumedAfterExpression = TimeCondition( - QueryCriteria.TimeInstantType.CONSUMED, ColumnPredicate.BinaryComparison(BinaryComparisonOperator.GREATER_THAN_OR_EQUAL, asOfDateTime)) + TimeInstantType.CONSUMED, ColumnPredicate.BinaryComparison(BinaryComparisonOperator.GREATER_THAN_OR_EQUAL, asOfDateTime)) val criteria = VaultQueryCriteria(status = Vault.StateStatus.CONSUMED, timeCondition = consumedAfterExpression) val results = vaultService.queryBy(criteria) @@ -1705,6 +1705,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties { } // pagination: invalid page size + @Suppress("INTEGER_OVERFLOW") @Test(timeout=300_000) fun `invalid page size`() { expectedEx.expect(VaultQueryException::class.java) @@ -1712,8 +1713,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties { database.transaction { vaultFiller.fillWithSomeTestCash(100.DOLLARS, notaryServices, 100, DUMMY_CASH_ISSUER) - @Suppress("EXPECTED_CONDITION") - val pagingSpec = PageSpecification(DEFAULT_PAGE_NUM, @Suppress("INTEGER_OVERFLOW") Integer.MAX_VALUE + 1) // overflow = -2147483648 + val pagingSpec = PageSpecification(DEFAULT_PAGE_NUM, Integer.MAX_VALUE + 1) // overflow = -2147483648 val criteria = VaultQueryCriteria(status = Vault.StateStatus.ALL) vaultService.queryBy(criteria, paging = pagingSpec) } @@ -1781,9 +1781,9 @@ abstract class VaultQueryTestsBase : VaultQueryParties { println("$index : $any") } assertThat(results.otherResults.size).isEqualTo(402) - val instants = results.otherResults.filter { it is Instant }.map { it as Instant } + val instants = results.otherResults.filterIsInstance() assertThat(instants).isSorted - val longs = results.otherResults.filter { it is Long }.map { it as Long } + val longs = results.otherResults.filterIsInstance() assertThat(longs.size).isEqualTo(201) assertThat(instants.size).isEqualTo(201) assertThat(longs.sum()).isEqualTo(20100L) @@ -1911,8 +1911,8 @@ abstract class VaultQueryTestsBase : VaultQueryParties { fun `LinearStateQueryCriteria returns empty resultset without errors if there is an empty list after the 'in' clause`() { database.transaction { val uid = UniqueIdentifier("999") - vaultFiller.fillWithSomeTestLinearStates(numberToCreate = 1, uniqueIdentifier = uid) - vaultFiller.fillWithSomeTestLinearStates(numberToCreate = 1, externalId = "1234") + vaultFiller.fillWithSomeTestLinearStates(txCount = 1, uniqueIdentifier = uid) + vaultFiller.fillWithSomeTestLinearStates(txCount = 1, externalId = "1234") val uuidCriteria = LinearStateQueryCriteria(uuid = listOf(uid.id)) val externalIdCriteria = LinearStateQueryCriteria(externalId = listOf("1234")) @@ -2061,6 +2061,26 @@ abstract class VaultQueryTestsBase : VaultQueryParties { } } + @Test(timeout = 300_000) + fun `unconsumed states which are globally unordered across multiple transactions sorted by custom attribute`() { + val linearNumbers = Array(2) { LongArray(2) } + // Make sure states from the same transaction are not given consecutive linear numbers. + linearNumbers[0][0] = 1L + linearNumbers[0][1] = 3L + linearNumbers[1][0] = 2L + linearNumbers[1][1] = 4L + + val results = database.transaction { + vaultFiller.fillWithTestStates(txCount = 2, statesPerTx = 2) { participantsToUse, txIndex, stateIndex -> + DummyLinearContract.State(participants = participantsToUse, linearNumber = linearNumbers[txIndex][stateIndex]) + } + + val sortColumn = Sort.SortColumn(SortAttribute.Custom(DummyLinearStateSchemaV1.PersistentDummyLinearState::class.java, "linearNumber")) + vaultService.queryBy(VaultQueryCriteria(), sorting = Sort(setOf(sortColumn))) + } + assertThat(results.states.map { it.state.data.linearNumber }).isEqualTo(listOf(1L, 2L, 3L, 4L)) + } + @Test(timeout=300_000) fun `return consumed linear states for a given linear id`() { database.transaction { @@ -2390,7 +2410,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties { services.recordTransactions(commercialPaper2) val ccyIndex = builder { CommercialPaperSchemaV1.PersistentCommercialPaperState::currency.equal(USD.currencyCode) } - val criteria1 = QueryCriteria.VaultCustomQueryCriteria(ccyIndex) + val criteria1 = VaultCustomQueryCriteria(ccyIndex) val result = vaultService.queryBy(criteria1) @@ -2433,9 +2453,9 @@ abstract class VaultQueryTestsBase : VaultQueryParties { val maturityIndex = CommercialPaperSchemaV1.PersistentCommercialPaperState::maturity.greaterThanOrEqual(TEST_TX_TIME + 30.days) val faceValueIndex = CommercialPaperSchemaV1.PersistentCommercialPaperState::faceValue.greaterThanOrEqual(10000L) - val criteria1 = QueryCriteria.VaultCustomQueryCriteria(ccyIndex) - val criteria2 = QueryCriteria.VaultCustomQueryCriteria(maturityIndex) - val criteria3 = QueryCriteria.VaultCustomQueryCriteria(faceValueIndex) + val criteria1 = VaultCustomQueryCriteria(ccyIndex) + val criteria2 = VaultCustomQueryCriteria(maturityIndex) + val criteria3 = VaultCustomQueryCriteria(faceValueIndex) vaultService.queryBy(criteria1.and(criteria3).and(criteria2)) } @@ -2458,8 +2478,8 @@ abstract class VaultQueryTestsBase : VaultQueryParties { val generalCriteria = VaultQueryCriteria(Vault.StateStatus.ALL) val results = builder { - val currencyIndex = PersistentCashState::currency.equal(USD.currencyCode) - val quantityIndex = PersistentCashState::pennies.greaterThanOrEqual(10L) + val currencyIndex = CashSchemaV1.PersistentCashState::currency.equal(USD.currencyCode) + val quantityIndex = CashSchemaV1.PersistentCashState::pennies.greaterThanOrEqual(10L) val customCriteria1 = VaultCustomQueryCriteria(currencyIndex) val customCriteria2 = VaultCustomQueryCriteria(quantityIndex) @@ -2710,7 +2730,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties { // Enrich and override QueryCriteria with additional default attributes (such as soft locks) val enrichedCriteria = VaultQueryCriteria(contractStateTypes = setOf(DealState::class.java), // enrich - softLockingCondition = QueryCriteria.SoftLockingCondition(QueryCriteria.SoftLockingType.UNLOCKED_AND_SPECIFIED, listOf(UUID.randomUUID())), + softLockingCondition = SoftLockingCondition(SoftLockingType.UNLOCKED_AND_SPECIFIED, listOf(UUID.randomUUID())), status = Vault.StateStatus.UNCONSUMED) // override // Sorting val sortAttribute = SortAttribute.Standard(Sort.CommonStateAttribute.STATE_REF) @@ -3056,7 +3076,7 @@ class VaultQueryTests : VaultQueryTestsBase(), VaultQueryParties by delegate { assertThat(snapshot.states).hasSize(0) val states = vaultFiller.fillWithSomeTestLinearAndDealStates(10).states this.session.flush() - vaultFiller.consumeLinearStates(states.toList()) + vaultFiller.consumeStates(states) updates } @@ -3079,7 +3099,7 @@ class VaultQueryTests : VaultQueryTestsBase(), VaultQueryParties by delegate { assertThat(snapshot.states).hasSize(0) val states = vaultFiller.fillWithSomeTestLinearAndDealStates(10).states this.session.flush() - vaultFiller.consumeLinearStates(states.toList()) + vaultFiller.consumeStates(states) updates } @@ -3102,7 +3122,7 @@ class VaultQueryTests : VaultQueryTestsBase(), VaultQueryParties by delegate { assertThat(snapshot.states).hasSize(0) val states = vaultFiller.fillWithSomeTestLinearAndDealStates(10).states this.session.flush() - vaultFiller.consumeLinearStates(states.toList()) + vaultFiller.consumeStates(states) updates } diff --git a/testing/test-utils/src/main/kotlin/net/corda/testing/internal/vault/VaultFiller.kt b/testing/test-utils/src/main/kotlin/net/corda/testing/internal/vault/VaultFiller.kt index 467b54ea22..f2775e1878 100644 --- a/testing/test-utils/src/main/kotlin/net/corda/testing/internal/vault/VaultFiller.kt +++ b/testing/test-utils/src/main/kotlin/net/corda/testing/internal/vault/VaultFiller.kt @@ -1,6 +1,20 @@ +@file:Suppress("LongParameterList") + package net.corda.testing.internal.vault -import net.corda.core.contracts.* +import net.corda.core.contracts.Amount +import net.corda.core.contracts.AttachmentConstraint +import net.corda.core.contracts.AutomaticPlaceholderConstraint +import net.corda.core.contracts.BelongsToContract +import net.corda.core.contracts.CommandAndState +import net.corda.core.contracts.ContractState +import net.corda.core.contracts.FungibleAsset +import net.corda.core.contracts.Issued +import net.corda.core.contracts.LinearState +import net.corda.core.contracts.PartyAndReference +import net.corda.core.contracts.StateAndRef +import net.corda.core.contracts.TransactionState +import net.corda.core.contracts.UniqueIdentifier import net.corda.core.crypto.Crypto import net.corda.core.crypto.SignatureMetadata import net.corda.core.identity.AbstractParty @@ -19,9 +33,7 @@ import net.corda.finance.contracts.asset.Cash import net.corda.finance.contracts.asset.Obligation import net.corda.finance.contracts.asset.OnLedgerAsset import net.corda.finance.workflows.asset.CashUtils -import net.corda.testing.contracts.DummyContract import net.corda.testing.contracts.DummyState -import net.corda.testing.core.DummyCommandData import net.corda.testing.core.TestIdentity import net.corda.testing.core.dummyCommand import net.corda.testing.core.singleIdentity @@ -32,6 +44,7 @@ import java.time.Duration import java.time.Instant import java.time.Instant.now import java.util.* +import kotlin.math.floor /** * The service hub should provide at least a key management service and a storage service. @@ -46,7 +59,7 @@ class VaultFiller @JvmOverloads constructor( private val rngFactory: () -> Random = { Random(0L) }) { companion object { fun calculateRandomlySizedAmounts(howMuch: Amount, min: Int, max: Int, rng: Random): LongArray { - val numSlots = min + Math.floor(rng.nextDouble() * (max - min)).toInt() + val numSlots = min + floor(rng.nextDouble() * (max - min)).toInt() val baseSize = howMuch.quantity / numSlots check(baseSize > 0) { baseSize } @@ -79,31 +92,18 @@ class VaultFiller @JvmOverloads constructor( issuerServices: ServiceHub = services, participants: List = emptyList(), includeMe: Boolean = true): Vault { - val myKey: PublicKey = services.myInfo.chooseIdentity().owningKey - val me = AnonymousParty(myKey) - val participantsToUse = if (includeMe) participants.plus(me) else participants - - val transactions: List = dealIds.map { - // Issue a deal state - val dummyIssue = TransactionBuilder(notary = defaultNotary.party).apply { - addOutputState(DummyDealContract.State(ref = it, participants = participantsToUse), DUMMY_DEAL_PROGRAM_ID) - addCommand(dummyCommand()) - } - val stx = issuerServices.signInitialTransaction(dummyIssue) - return@map services.addSignature(stx, defaultNotary.publicKey) + return fillWithTestStates( + txCount = dealIds.size, + participants = participants, + includeMe = includeMe, + services = issuerServices + ) { participantsToUse, txIndex, _ -> + DummyDealContract.State(ref = dealIds[txIndex], participants = participantsToUse) } - val statesToRecord = if (includeMe) StatesToRecord.ONLY_RELEVANT else StatesToRecord.ALL_VISIBLE - services.recordTransactions(statesToRecord, transactions) - // Get all the StateAndRefs of all the generated transactions. - val states = transactions.flatMap { stx -> - stx.tx.outputs.indices.map { i -> stx.tx.outRef(i) } - } - - return Vault(states) } @JvmOverloads - fun fillWithSomeTestLinearStates(numberToCreate: Int, + fun fillWithSomeTestLinearStates(txCount: Int, externalId: String? = null, participants: List = emptyList(), uniqueIdentifier: UniqueIdentifier? = null, @@ -113,81 +113,41 @@ class VaultFiller @JvmOverloads constructor( linearTimestamp: Instant = now(), constraint: AttachmentConstraint = AutomaticPlaceholderConstraint, includeMe: Boolean = true): Vault { - val myKey: PublicKey = services.myInfo.chooseIdentity().owningKey - val me = AnonymousParty(myKey) - val issuerKey = defaultNotary.keyPair - val signatureMetadata = SignatureMetadata(services.myInfo.platformVersion, Crypto.findSignatureScheme(issuerKey.public).schemeNumberID) - val participantsToUse = if (includeMe) participants.plus(me) else participants - val transactions: List = (1..numberToCreate).map { - // Issue a Linear state - val dummyIssue = TransactionBuilder(notary = defaultNotary.party).apply { - addOutputState(DummyLinearContract.State( - linearId = uniqueIdentifier ?: UniqueIdentifier(externalId), - participants = participantsToUse, - linearString = linearString, - linearNumber = linearNumber, - linearBoolean = linearBoolean, - linearTimestamp = linearTimestamp), DUMMY_LINEAR_CONTRACT_PROGRAM_ID, - constraint = constraint) - addCommand(dummyCommand()) - } - return@map services.signInitialTransaction(dummyIssue).withAdditionalSignature(issuerKey, signatureMetadata) + return fillWithTestStates(txCount, 1, participants, constraint, includeMe) { participantsToUse, _, _ -> + DummyLinearContract.State( + linearId = uniqueIdentifier ?: UniqueIdentifier(externalId), + participants = participantsToUse, + linearString = linearString, + linearNumber = linearNumber, + linearBoolean = linearBoolean, + linearTimestamp = linearTimestamp + ) } - val statesToRecord = if (includeMe) StatesToRecord.ONLY_RELEVANT else StatesToRecord.ALL_VISIBLE - services.recordTransactions(statesToRecord, transactions) - // Get all the StateAndRefs of all the generated transactions. - val states = transactions.flatMap { stx -> - stx.tx.outputs.indices.map { i -> stx.tx.outRef(i) } - } - - return Vault(states) } @JvmOverloads - fun fillWithSomeTestLinearAndDealStates(numberToCreate: Int, + fun fillWithSomeTestLinearAndDealStates(txCount: Int, externalId: String? = null, participants: List = emptyList(), linearString: String = "", linearNumber: Long = 0L, linearBoolean: Boolean = false, - linearTimestamp: Instant = now()): Vault { - val myKey: PublicKey = services.myInfo.chooseIdentity().owningKey - val me = AnonymousParty(myKey) - val issuerKey = defaultNotary.keyPair - val signatureMetadata = SignatureMetadata(services.myInfo.platformVersion, Crypto.findSignatureScheme(issuerKey.public).schemeNumberID) - val transactions: List = (1..numberToCreate).map { - val dummyIssue = TransactionBuilder(notary = defaultNotary.party).apply { - // Issue a Linear state - addOutputState(DummyLinearContract.State( + linearTimestamp: Instant = now()): Vault { + return fillWithTestStates(txCount, 2, participants) { participantsToUse, _, stateIndex -> + when (stateIndex) { + 0 -> DummyLinearContract.State( linearId = UniqueIdentifier(externalId), - participants = participants.plus(me), + participants = participantsToUse, linearString = linearString, linearNumber = linearNumber, linearBoolean = linearBoolean, - linearTimestamp = linearTimestamp), DUMMY_LINEAR_CONTRACT_PROGRAM_ID) - // Issue a Deal state - addOutputState(DummyDealContract.State(ref = "test ref", participants = participants.plus(me)), DUMMY_DEAL_PROGRAM_ID) - addCommand(dummyCommand()) + linearTimestamp = linearTimestamp + ) + else -> DummyDealContract.State(ref = "test ref", participants = participantsToUse) } - return@map services.signInitialTransaction(dummyIssue).withAdditionalSignature(issuerKey, signatureMetadata) } - services.recordTransactions(transactions) - // Get all the StateAndRefs of all the generated transactions. - val states = transactions.flatMap { stx -> - stx.tx.outputs.indices.map { i -> stx.tx.outRef(i) } - } - return Vault(states) } - @JvmOverloads - fun fillWithSomeTestCash(howMuch: Amount, - issuerServices: ServiceHub, - thisManyStates: Int, - issuedBy: PartyAndReference, - owner: AbstractParty? = null, - rng: Random? = null, - statesToRecord: StatesToRecord = StatesToRecord.ONLY_RELEVANT) = fillWithSomeTestCash(howMuch, issuerServices, thisManyStates, thisManyStates, issuedBy, owner, rng, statesToRecord) - /** * Creates a random set of between (by default) 3 and 10 cash states that add up to the given amount and adds them * to the vault. This is intended for unit tests. By default the cash is owned by the legal @@ -196,14 +156,15 @@ class VaultFiller @JvmOverloads constructor( * @param issuerServices service hub of the issuer node, which will be used to sign the transaction. * @return a vault object that represents the generated states (it will NOT be the full vault from the service hub!). */ + @JvmOverloads fun fillWithSomeTestCash(howMuch: Amount, issuerServices: ServiceHub, atLeastThisManyStates: Int, - atMostThisManyStates: Int, issuedBy: PartyAndReference, owner: AbstractParty? = null, rng: Random? = null, - statesToRecord: StatesToRecord = StatesToRecord.ONLY_RELEVANT): Vault { + statesToRecord: StatesToRecord = StatesToRecord.ONLY_RELEVANT, + atMostThisManyStates: Int = atLeastThisManyStates): Vault { val amounts = calculateRandomlySizedAmounts(howMuch, atLeastThisManyStates, atMostThisManyStates, rng ?: rngFactory()) // We will allocate one state to one transaction, for simplicities sake. val cash = Cash() @@ -212,39 +173,46 @@ class VaultFiller @JvmOverloads constructor( cash.generateIssue(issuance, Amount(pennies, Issued(issuedBy, howMuch.token)), owner ?: services.myInfo.singleIdentity(), altNotary) return@map issuerServices.signInitialTransaction(issuance, issuedBy.party.owningKey) } - services.recordTransactions(statesToRecord, transactions) - // Get all the StateRefs of all the generated transactions. - val states = transactions.flatMap { stx -> - stx.tx.outputs.indices.map { i -> stx.tx.outRef(i) } - } - - return Vault(states) + return recordTransactions(transactions, statesToRecord) } /** * Records a dummy state in the Vault (useful for creating random states when testing vault queries) */ - fun fillWithDummyState(participants: List = listOf(services.myInfo.singleIdentity())) : Vault { - val outputState = TransactionState( - data = DummyState(Random().nextInt(), participants = participants), - contract = DummyContract.PROGRAM_ID, - notary = defaultNotary.party - ) - val participantKeys : List = participants.map { it.owningKey } - val builder = TransactionBuilder() - .addOutputState(outputState) - .addCommand(DummyCommandData, participantKeys) - val stxn = services.signInitialTransaction(builder) - services.recordTransactions(stxn) - return Vault(setOf(stxn.tx.outRef(0))) + fun fillWithDummyState(participants: List = listOf(services.myInfo.singleIdentity())): Vault { + return fillWithTestStates(participants = participants) { participantsToUse, _, _ -> + DummyState(Random().nextInt(), participants = participantsToUse) + } } - /** - * Puts together an issuance transaction for the specified amount that starts out being owned by the given pubkey. - */ - fun generateCommoditiesIssue(tx: TransactionBuilder, amount: Amount>, owner: AbstractParty, notary: Party) - = OnLedgerAsset.generateIssue(tx, TransactionState(CommodityState(amount, owner), Obligation.PROGRAM_ID, notary), Obligation.Commands.Issue()) - + fun fillWithTestStates(txCount: Int = 1, + statesPerTx: Int = 1, + participants: List = emptyList(), + constraint: AttachmentConstraint = AutomaticPlaceholderConstraint, + includeMe: Boolean = true, + services: ServiceHub = this.services, + genOutputState: (participantsToUse: List, txIndex: Int, stateIndex: Int) -> T): Vault { + val issuerKey = defaultNotary.keyPair + val signatureMetadata = SignatureMetadata( + services.myInfo.platformVersion, + Crypto.findSignatureScheme(issuerKey.public).schemeNumberID + ) + val participantsToUse = if (includeMe) { + participants + AnonymousParty(this.services.myInfo.chooseIdentity().owningKey) + } else { + participants + } + val transactions = Array(txCount) { txIndex -> + val builder = TransactionBuilder(notary = defaultNotary.party) + repeat(statesPerTx) { stateIndex -> + builder.addOutputState(genOutputState(participantsToUse, txIndex, stateIndex), constraint) + } + builder.addCommand(dummyCommand()) + services.signInitialTransaction(builder).withAdditionalSignature(issuerKey, signatureMetadata) + } + val statesToRecord = if (includeMe) StatesToRecord.ONLY_RELEVANT else StatesToRecord.ALL_VISIBLE + return recordTransactions(transactions.asList(), statesToRecord) + } /** * @@ -257,13 +225,16 @@ class VaultFiller @JvmOverloads constructor( val me = AnonymousParty(myKey) val issuance = TransactionBuilder(null as Party?) - generateCommoditiesIssue(issuance, Amount(amount.quantity, Issued(issuedBy, amount.token)), me, altNotary) + OnLedgerAsset.generateIssue( + issuance, + TransactionState(CommodityState(Amount(amount.quantity, Issued(issuedBy, amount.token)), me), Obligation.PROGRAM_ID, altNotary), + Obligation.Commands.Issue() + ) val transaction = issuerServices.signInitialTransaction(issuance, issuedBy.party.owningKey) - services.recordTransactions(transaction) - return Vault(setOf(transaction.tx.outRef(0))) + return recordTransactions(listOf(transaction)) } - private fun consume(states: List>) { + fun consumeStates(states: Iterable>) { // Create a txn consuming different contract types states.forEach { val builder = TransactionBuilder(notary = altNotary).apply { @@ -300,10 +271,11 @@ class VaultFiller @JvmOverloads constructor( } } - fun consumeDeals(dealStates: List>) = consume(dealStates) - fun consumeLinearStates(linearStates: List>) = consume(linearStates) + fun consumeDeals(dealStates: List>) = consumeStates(dealStates) + fun consumeLinearStates(linearStates: List>) = consumeStates(linearStates) fun evolveLinearStates(linearStates: List>) = consumeAndProduce(linearStates) fun evolveLinearState(linearState: StateAndRef): StateAndRef = consumeAndProduce(linearState) + /** * Consume cash, sending any change to the default identity for this node. Only suitable for use in test scenarios, * where nodes have a default identity. @@ -319,6 +291,16 @@ class VaultFiller @JvmOverloads constructor( services.recordTransactions(spendTx) return update.getOrThrow(Duration.ofSeconds(3)) } + + private fun recordTransactions(transactions: Iterable, + statesToRecord: StatesToRecord = StatesToRecord.ONLY_RELEVANT): Vault { + services.recordTransactions(statesToRecord, transactions) + // Get all the StateAndRefs of all the generated transactions. + val states = transactions.flatMap { stx -> + stx.tx.outputs.indices.map { i -> stx.tx.outRef(i) } + } + return Vault(states) + } } @@ -344,4 +326,3 @@ data class CommodityState( override fun withNewOwner(newOwner: AbstractParty) = CommandAndState(Obligation.Commands.Move(), copy(owner = newOwner)) } -