From a6b2a3159dbaaabaacccde16c468fb00ff9fb4f7 Mon Sep 17 00:00:00 2001 From: Nikolett Nagy <61757742+nikinagy@users.noreply.github.com> Date: Thu, 13 Aug 2020 10:04:53 +0100 Subject: [PATCH] CORDA-3879 - query with OR combinator returns too many results (#6456) * fix suggestion and tests * detekt suppress * making sure the forced join works with IndirectStatePersistable and removing unnecessary joinPredicates from parse with sorting * remove joinPredicates and add tests * rename sorting * revert deleting joinPredicates and modify the force join to use `OR` instead of `AND` * add system property switch --- .../vault/HibernateQueryCriteriaParser.kt | 40 ++++- .../node/services/vault/VaultQueryJoinTest.kt | 157 ++++++++++++++++++ 2 files changed, 188 insertions(+), 9 deletions(-) create mode 100644 node/src/test/kotlin/net/corda/node/services/vault/VaultQueryJoinTest.kt diff --git a/node/src/main/kotlin/net/corda/node/services/vault/HibernateQueryCriteriaParser.kt b/node/src/main/kotlin/net/corda/node/services/vault/HibernateQueryCriteriaParser.kt index c26c414019..26e7c195d5 100644 --- a/node/src/main/kotlin/net/corda/node/services/vault/HibernateQueryCriteriaParser.kt +++ b/node/src/main/kotlin/net/corda/node/services/vault/HibernateQueryCriteriaParser.kt @@ -35,7 +35,6 @@ import java.util.* import javax.persistence.Tuple import javax.persistence.criteria.* - abstract class AbstractQueryCriteriaParser, in P: BaseQueryCriteriaParser, in S: BaseSort> : BaseQueryCriteriaParser { abstract val criteriaBuilder: CriteriaBuilder @@ -277,6 +276,7 @@ class HibernateQueryCriteriaParser(val contractStateType: Class) : AbstractQueryCriteriaParser(), IQueryCriteriaParser { private companion object { private val log = contextLogger() + private val disableCorda3879 = System.getProperty("net.corda.vault.query.disable.corda3879")?.toBoolean() ?: false } // incrementally build list of join predicates @@ -550,7 +550,6 @@ class HibernateQueryCriteriaParser(val contractStateType: Class("stateRef"), vaultLinearStatesRoot.get("stateRef")) predicateSet.add(joinPredicate) @@ -613,8 +612,8 @@ class HibernateQueryCriteriaParser(val contractStateType: Class("stateRef"), entityRoot.get>("compositeKey").get("stateRef")) - } else { + criteriaBuilder.equal(vaultStates.get("stateRef"), entityRoot.get>("compositeKey").get("stateRef")) + } else { criteriaBuilder.equal(vaultStates.get("stateRef"), entityRoot.get("stateRef")) } predicateSet.add(joinPredicate) @@ -633,6 +632,7 @@ class HibernateQueryCriteriaParser(val contractStateType: Class { val predicateSet = criteria.visit(this) @@ -647,12 +647,37 @@ class HibernateQueryCriteriaParser(val contractStateType: Class { + val returnSet = mutableSetOf() + + rootEntities.values.forEach { + if (it != vaultStates) { + if(IndirectStatePersistable::class.java.isAssignableFrom(it.javaType)) { + returnSet.add(criteriaBuilder.equal(vaultStates.get("stateRef"), it.get>("compositeKey").get("stateRef"))) + } else { + returnSet.add(criteriaBuilder.equal(vaultStates.get("stateRef"), it.get("stateRef"))) + } + } + } + + return returnSet + } + override fun parseCriteria(criteria: CommonQueryCriteria): Collection { log.trace { "Parsing CommonQueryCriteria: $criteria" } @@ -849,8 +874,6 @@ class HibernateQueryCriteriaParser(val contractStateType: Class("stateRef"), entityRoot.get("stateRef")) - joinPredicates.add(joinPredicate) entityRoot } when (direction) { @@ -869,7 +892,6 @@ class HibernateQueryCriteriaParser(val contractStateType: Class() + private const val numObjectsInLedger = DEFAULT_PAGE_SIZE + 1 + + @BeforeClass + @JvmStatic + fun setup() { + repeat(numObjectsInLedger) { index -> + createdStateRefs.add(addSimpleObjectToLedger(DummyData(index))) + } + + System.setProperty("net.corda.vault.query.disable.corda3879", "false"); + } + + private fun addSimpleObjectToLedger(dummyObject: DummyData): StateRef { + val tx = TransactionBuilder(notaryNode.info.legalIdentities.first()) + tx.addOutputState( + DummyState(dummyObject, listOf(aliceNode.info.identityFromX500Name(ALICE_NAME))) + ) + tx.addCommand(DummyContract.Commands.AddDummy(), aliceNode.info.legalIdentitiesAndCerts.first().owningKey) + tx.verify(serviceHubHandle) + val stx = serviceHubHandle.signInitialTransaction(tx) + serviceHubHandle.recordTransactions(listOf(stx)) + return StateRef(stx.id, 0) + } + } + + private val queryToCheckId = builder { + val conditionToCheckId = + DummySchema.DummyState::id + .equal(0) + QueryCriteria.VaultCustomQueryCriteria(conditionToCheckId, Vault.StateStatus.UNCONSUMED) + } + + private val queryToCheckStateRef = + QueryCriteria.VaultQueryCriteria(Vault.StateStatus.UNCONSUMED, stateRefs = listOf(createdStateRefs[numObjectsInLedger-1])) + + @Test(timeout = 300_000) + fun `filter query with OR operator`() { + val results = serviceHubHandle.vaultService.queryBy( + queryToCheckId.or(queryToCheckStateRef) + ) + assertEquals(2, results.states.size) + assertEquals(2, results.statesMetadata.size) + } + + @Test(timeout = 300_000) + fun `filter query with sorting`() { + val sorting = Sort(listOf(Sort.SortColumn(SortAttribute.Custom(DummySchema.DummyState::class.java, "stateRef"), Sort.Direction.DESC))) + + val results = serviceHubHandle.vaultService.queryBy( + queryToCheckStateRef, sorting = sorting + ) + + assertEquals(1, results.states.size) + assertEquals(1, results.statesMetadata.size) + } + + @Test(timeout = 300_000) + fun `filter query with OR operator and sorting`() { + val sorting = Sort(listOf(Sort.SortColumn(SortAttribute.Custom(DummySchema.DummyState::class.java, "stateRef"), Sort.Direction.DESC))) + + val results = serviceHubHandle.vaultService.queryBy( + queryToCheckId.or(queryToCheckStateRef), sorting = sorting + ) + + assertEquals(2, results.states.size) + assertEquals(2, results.statesMetadata.size) + } +} + +object DummyStatesV + +@Suppress("MagicNumber") // SQL column length +@CordaSerializable +object DummySchema : MappedSchema(schemaFamily = DummyStatesV.javaClass, version = 1, mappedTypes = listOf(DummyState::class.java)){ + + @Entity + @Table(name = "dummy_states", indexes = [Index(name = "dummy_id_index", columnList = "id")]) + class DummyState ( + @Column(name = "id", length = 4, nullable = false) + var id: Int + ) : PersistentState() +} + +@CordaSerializable +data class DummyData( + val id: Int +) + +@BelongsToContract(DummyContract::class) +data class DummyState(val dummyData: DummyData, override val participants: List) : + ContractState, QueryableState { + override fun supportedSchemas(): Iterable = listOf(DummySchema) + + + override fun generateMappedObject(schema: MappedSchema) = + when (schema) { + is DummySchema -> DummySchema.DummyState( + dummyData.id + ) + else -> throw IllegalArgumentException("Unsupported Schema") + } +} + +class DummyContract : Contract { + override fun verify(tx: LedgerTransaction) { } + interface Commands : CommandData { + class AddDummy : Commands + } +} \ No newline at end of file