From 5ceb61606a7c36bb1f59213c09cffccb57e65f55 Mon Sep 17 00:00:00 2001 From: Dan Newton Date: Mon, 11 Jun 2018 17:53:31 +0100 Subject: [PATCH] VaultTrack returns undesired states #3276 (#3336) * filter by contract state in _trackBy * write tests to check that _trackBy is filtering the states correct and tidy up filtering functions * remove un needed function * add change log message for filtering unrelated ContractStates from trackBy --- docs/source/changelog.rst | 3 + .../node/services/vault/NodeVaultService.kt | 11 ++- .../node/services/vault/VaultQueryTests.kt | 74 ++++++++++++++++++- .../testing/internal/vault/VaultFiller.kt | 37 +++++++++- 4 files changed, 119 insertions(+), 6 deletions(-) diff --git a/docs/source/changelog.rst b/docs/source/changelog.rst index 22df9ee1b8..f673e74ee8 100644 --- a/docs/source/changelog.rst +++ b/docs/source/changelog.rst @@ -7,6 +7,9 @@ release, see :doc:`upgrade-notes`. Unreleased ========== +* Fixed an issue where ``trackBy`` was returning ``ContractStates`` from a transaction that were not being tracked. The + unrelated ``ContractStates`` will now be filtered out from the returned ``Vault.Update``. + * Introducing the flow hospital - a component of the node that manages flows that have errored and whether they should be retried from their previous checkpoints or have their errors propagate. Currently it will respond to any error that occurs during the resolution of a received transaction as part of ``FinalityFlow``. In such a scenerio the receiving diff --git a/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt b/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt index 06ffc886b2..0826ef094b 100644 --- a/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt +++ b/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt @@ -489,12 +489,21 @@ class NodeVaultService( return database.transaction { mutex.locked { val snapshotResults = _queryBy(criteria, paging, sorting, contractStateType) - val updates: Observable> = uncheckedCast(_updatesPublisher.bufferUntilSubscribed().filter { it.containsType(contractStateType, snapshotResults.stateTypes) }) + val updates: Observable> = uncheckedCast(_updatesPublisher.bufferUntilSubscribed() + .filter { it.containsType(contractStateType, snapshotResults.stateTypes) } + .map { filterContractStates(it, contractStateType) }) DataFeed(snapshotResults, updates) } } } + private fun filterContractStates(update: Vault.Update, contractStateType: Class) = + update.copy(consumed = filterByContractState(contractStateType, update.consumed), + produced = filterByContractState(contractStateType, update.produced)) + + private fun filterByContractState(contractStateType: Class, stateAndRefs: Set>) = + stateAndRefs.filter { contractStateType.isAssignableFrom(it.state.data.javaClass) }.toSet() + private fun getSession() = database.currentOrNew().session /** * Derive list from existing vault states and then incrementally update using vault observables diff --git a/node/src/test/kotlin/net/corda/node/services/vault/VaultQueryTests.kt b/node/src/test/kotlin/net/corda/node/services/vault/VaultQueryTests.kt index 3bc59abe3f..980f79b29d 100644 --- a/node/src/test/kotlin/net/corda/node/services/vault/VaultQueryTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/vault/VaultQueryTests.kt @@ -28,10 +28,7 @@ import net.corda.nodeapi.internal.persistence.DatabaseTransaction import net.corda.testing.core.* import net.corda.testing.internal.TEST_TX_TIME import net.corda.testing.internal.rigorousMock -import net.corda.testing.internal.vault.DUMMY_LINEAR_CONTRACT_PROGRAM_ID -import net.corda.testing.internal.vault.DummyLinearContract -import net.corda.testing.internal.vault.DummyLinearStateSchemaV1 -import net.corda.testing.internal.vault.VaultFiller +import net.corda.testing.internal.vault.* import net.corda.testing.node.MockServices import net.corda.testing.node.MockServices.Companion.makeTestDatabaseAndMockServices import net.corda.testing.node.makeTestIdentityService @@ -2282,4 +2279,73 @@ class VaultQueryTests : VaultQueryTestsBase(), VaultQueryParties by delegate { ) } } + + @Test + fun `track by only returns updates of tracked type`() { + val updates = database.transaction { + val (snapshot, updates) = vaultService.trackBy() + assertThat(snapshot.states).hasSize(0) + val states = vaultFiller.fillWithSomeTestLinearAndDealStates(10).states + this.session.flush() + vaultFiller.consumeLinearStates(states.toList()) + updates + } + + updates.expectEvents { + sequence( + expect { (consumed, produced, flowId) -> + require(flowId == null) {} + require(consumed.isEmpty()) {} + require(produced.size == 10) {} + require(produced.filter { DummyDealContract.State::class.java.isAssignableFrom(it.state.data::class.java) }.size == 10) {} + } + ) + } + } + + @Test + fun `track by of super class only returns updates of sub classes of tracked type`() { + val updates = database.transaction { + val (snapshot, updates) = vaultService.trackBy() + assertThat(snapshot.states).hasSize(0) + val states = vaultFiller.fillWithSomeTestLinearAndDealStates(10).states + this.session.flush() + vaultFiller.consumeLinearStates(states.toList()) + updates + } + + updates.expectEvents { + sequence( + expect { (consumed, produced, flowId) -> + require(flowId == null) {} + require(consumed.isEmpty()) {} + require(produced.size == 10) {} + require(produced.filter { DealState::class.java.isAssignableFrom(it.state.data::class.java) }.size == 10) {} + } + ) + } + } + + @Test + fun `track by of contract state interface returns updates of all states`() { + val updates = database.transaction { + val (snapshot, updates) = vaultService.trackBy() + assertThat(snapshot.states).hasSize(0) + val states = vaultFiller.fillWithSomeTestLinearAndDealStates(10).states + this.session.flush() + vaultFiller.consumeLinearStates(states.toList()) + updates + } + + updates.expectEvents { + sequence( + expect { (consumed, produced, flowId) -> + require(flowId == null) {} + require(consumed.isEmpty()) {} + require(produced.size == 20) {} + require(produced.filter { ContractState::class.java.isAssignableFrom(it.state.data::class.java) }.size == 20) {} + } + ) + } + } } \ No newline at end of file diff --git a/testing/test-utils/src/main/kotlin/net/corda/testing/internal/vault/VaultFiller.kt b/testing/test-utils/src/main/kotlin/net/corda/testing/internal/vault/VaultFiller.kt index 6308223bbb..a3cd85d564 100644 --- a/testing/test-utils/src/main/kotlin/net/corda/testing/internal/vault/VaultFiller.kt +++ b/testing/test-utils/src/main/kotlin/net/corda/testing/internal/vault/VaultFiller.kt @@ -127,6 +127,42 @@ class VaultFiller @JvmOverloads constructor( return Vault(states) } + @JvmOverloads + fun fillWithSomeTestLinearAndDealStates(numberToCreate: Int, + externalId: String? = null, + participants: List = emptyList(), + linearString: String = "", + linearNumber: Long = 0L, + linearBoolean: Boolean = false, + linearTimestamp: Instant = now()): Vault { + val myKey: PublicKey = services.myInfo.chooseIdentity().owningKey + val me = AnonymousParty(myKey) + val issuerKey = defaultNotary.keyPair + val signatureMetadata = SignatureMetadata(services.myInfo.platformVersion, Crypto.findSignatureScheme(issuerKey.public).schemeNumberID) + val transactions: List = (1..numberToCreate).map { + val dummyIssue = TransactionBuilder(notary = defaultNotary.party).apply { + // Issue a Linear state + addOutputState(DummyLinearContract.State( + linearId = UniqueIdentifier(externalId), + participants = participants.plus(me), + linearString = linearString, + linearNumber = linearNumber, + linearBoolean = linearBoolean, + linearTimestamp = linearTimestamp), DUMMY_LINEAR_CONTRACT_PROGRAM_ID) + // Issue a Deal state + addOutputState(DummyDealContract.State(ref = "test ref", participants = participants.plus(me)), DUMMY_DEAL_PROGRAM_ID) + addCommand(dummyCommand()) + } + return@map services.signInitialTransaction(dummyIssue).withAdditionalSignature(issuerKey, signatureMetadata) + } + services.recordTransactions(transactions) + // Get all the StateAndRefs of all the generated transactions. + val states = transactions.flatMap { stx -> + stx.tx.outputs.indices.map { i -> stx.tx.outRef(i) } + } + return Vault(states) + } + @JvmOverloads fun fillWithSomeTestCash(howMuch: Amount, issuerServices: ServiceHub, @@ -167,7 +203,6 @@ class VaultFiller @JvmOverloads constructor( return Vault(states) } - /** * Puts together an issuance transaction for the specified amount that starts out being owned by the given pubkey. */