From 3c8368e25d00db84ac42ebcd4d3b560817aa1196 Mon Sep 17 00:00:00 2001 From: josecoll <jose.coll@r3cev.com> Date: Thu, 3 Aug 2017 10:46:45 +0100 Subject: [PATCH] Applied contract type filtering to ALL Query Criteria types (#1159) * Applied contract type filtering to ALL Query Criteria types (was only applied in general) * Do not filter when all contract state types specified. --- .../vault/HibernateQueryCriteriaParser.kt | 35 ++++++-- .../node/services/vault/VaultQueryTests.kt | 90 +++++++++++++++++++ 2 files changed, 117 insertions(+), 8 deletions(-) diff --git a/node/src/main/kotlin/net/corda/node/services/vault/HibernateQueryCriteriaParser.kt b/node/src/main/kotlin/net/corda/node/services/vault/HibernateQueryCriteriaParser.kt index 6f383f81c2..e704eaac3c 100644 --- a/node/src/main/kotlin/net/corda/node/services/vault/HibernateQueryCriteriaParser.kt +++ b/node/src/main/kotlin/net/corda/node/services/vault/HibernateQueryCriteriaParser.kt @@ -43,14 +43,9 @@ class HibernateQueryCriteriaParser(val contractType: Class<out ContractState>, val predicateSet = mutableSetOf<Predicate>() // contract State Types - val combinedContractTypeTypes = criteria.contractStateTypes?.plus(contractType) ?: setOf(contractType) - combinedContractTypeTypes.filter { it.name != ContractState::class.java.name }.let { - val interfaces = it.flatMap { contractTypeMappings[it.name] ?: emptyList() } - val concrete = it.filter { !it.isInterface }.map { it.name } - val all = interfaces.plus(concrete) - if (all.isNotEmpty()) - predicateSet.add(criteriaBuilder.and(vaultStates.get<String>("contractStateClassName").`in`(all))) - } + val contractTypes = deriveContractTypes(criteria.contractStateTypes) + if (contractTypes.isNotEmpty()) + predicateSet.add(criteriaBuilder.and(vaultStates.get<String>("contractStateClassName").`in`(contractTypes))) // soft locking if (!criteria.includeSoftlockedStates) @@ -83,6 +78,15 @@ class HibernateQueryCriteriaParser(val contractType: Class<out ContractState>, return predicateSet } + private fun deriveContractTypes(contractStateTypes: Set<Class<out ContractState>>? = null): List<String> { + val combinedContractStateTypes = contractStateTypes?.plus(contractType) ?: setOf(contractType) + combinedContractStateTypes.filter { it.name != ContractState::class.java.name }.let { + val interfaces = it.flatMap { contractTypeMappings[it.name] ?: emptyList() } + val concrete = it.filter { !it.isInterface }.map { it.name } + return interfaces.plus(concrete) + } + } + private fun columnPredicateToPredicate(column: Path<out Any?>, columnPredicate: ColumnPredicate<*>): Predicate { return when (columnPredicate) { is ColumnPredicate.EqualityComparison -> { @@ -216,6 +220,11 @@ class HibernateQueryCriteriaParser(val contractType: Class<out ContractState>, val joinPredicate = criteriaBuilder.equal(vaultStates.get<PersistentStateRef>("stateRef"), vaultFungibleStates.get<PersistentStateRef>("stateRef")) predicateSet.add(joinPredicate) + // contract State Types + val contractTypes = deriveContractTypes() + if (contractTypes.isNotEmpty()) + predicateSet.add(criteriaBuilder.and(vaultStates.get<String>("contractStateClassName").`in`(contractTypes))) + // owner criteria.owner?.let { val ownerKeys = criteria.owner as List<AbstractParty> @@ -265,6 +274,11 @@ class HibernateQueryCriteriaParser(val contractType: Class<out ContractState>, val joinPredicate = criteriaBuilder.equal(vaultStates.get<PersistentStateRef>("stateRef"), vaultLinearStates.get<PersistentStateRef>("stateRef")) joinPredicates.add(joinPredicate) + // contract State Types + val contractTypes = deriveContractTypes() + if (contractTypes.isNotEmpty()) + predicateSet.add(criteriaBuilder.and(vaultStates.get<String>("contractStateClassName").`in`(contractTypes))) + // linear ids criteria.linearId?.let { val uniqueIdentifiers = criteria.linearId as List<UniqueIdentifier> @@ -304,6 +318,11 @@ class HibernateQueryCriteriaParser(val contractType: Class<out ContractState>, val joinPredicate = criteriaBuilder.equal(vaultStates.get<PersistentStateRef>("stateRef"), entityRoot.get<PersistentStateRef>("stateRef")) joinPredicates.add(joinPredicate) + // contract State Types + val contractTypes = deriveContractTypes() + if (contractTypes.isNotEmpty()) + predicateSet.add(criteriaBuilder.and(vaultStates.get<String>("contractStateClassName").`in`(contractTypes))) + // resolve general criteria expressions parseExpression(entityRoot, criteria.expression, predicateSet) } diff --git a/node/src/test/kotlin/net/corda/node/services/vault/VaultQueryTests.kt b/node/src/test/kotlin/net/corda/node/services/vault/VaultQueryTests.kt index c5accd024d..9fe7922c9a 100644 --- a/node/src/test/kotlin/net/corda/node/services/vault/VaultQueryTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/vault/VaultQueryTests.kt @@ -807,6 +807,96 @@ class VaultQueryTests : TestDependencyInjectionBase() { } } + @Test + fun `aggregate functions count by contract type`() { + database.transaction { + // create new states + services.fillWithSomeTestCash(100.DOLLARS, CASH_NOTARY, 10, 10, Random(0L)) + services.fillWithSomeTestLinearStates(1, "XYZ") + services.fillWithSomeTestLinearStates(2, "JKL") + services.fillWithSomeTestLinearStates(3, "ABC") + services.fillWithSomeTestDeals(listOf("123", "456", "789")) + + // count fungible assets + val count = builder { VaultSchemaV1.VaultStates::recordedTime.count() } + val countCriteria = QueryCriteria.VaultCustomQueryCriteria(count) + val fungibleStateCount = vaultQuerySvc.queryBy<FungibleAsset<*>>(countCriteria).otherResults.single() as Long + assertThat(fungibleStateCount).isEqualTo(10L) + + // count linear states + val linearStateCount = vaultQuerySvc.queryBy<LinearState>(countCriteria).otherResults.single() as Long + assertThat(linearStateCount).isEqualTo(9L) + + // count deal states + val dealStateCount = vaultQuerySvc.queryBy<DealState>(countCriteria).otherResults.single() as Long + assertThat(dealStateCount).isEqualTo(3L) + } + } + + @Test + fun `aggregate functions count by contract type and state status`() { + database.transaction { + // create new states + services.fillWithSomeTestCash(100.DOLLARS, CASH_NOTARY, 10, 10, Random(0L)) + val linearStatesXYZ = services.fillWithSomeTestLinearStates(1, "XYZ") + val linearStatesJKL = services.fillWithSomeTestLinearStates(2, "JKL") + services.fillWithSomeTestLinearStates(3, "ABC") + val dealStates = services.fillWithSomeTestDeals(listOf("123", "456", "789")) + + // ALL states + + // count fungible assets + val count = builder { VaultSchemaV1.VaultStates::recordedTime.count() } + val countCriteria = QueryCriteria.VaultCustomQueryCriteria(count, Vault.StateStatus.ALL) + val fungibleStateCount = vaultQuerySvc.queryBy<FungibleAsset<*>>(countCriteria).otherResults.single() as Long + assertThat(fungibleStateCount).isEqualTo(10L) + + // count linear states + val linearStateCount = vaultQuerySvc.queryBy<LinearState>(countCriteria).otherResults.single() as Long + assertThat(linearStateCount).isEqualTo(9L) + + // count deal states + val dealStateCount = vaultQuerySvc.queryBy<DealState>(countCriteria).otherResults.single() as Long + assertThat(dealStateCount).isEqualTo(3L) + + // consume some states + services.consumeLinearStates(linearStatesXYZ.states.toList()) + services.consumeLinearStates(linearStatesJKL.states.toList()) + services.consumeDeals(dealStates.states.filter { it.state.data.ref == "456" }) + services.consumeCash(50.DOLLARS) + + // UNCONSUMED states (default) + + // count fungible assets + val countCriteriaUnconsumed = QueryCriteria.VaultCustomQueryCriteria(count, Vault.StateStatus.UNCONSUMED) + val fungibleStateCountUnconsumed = vaultQuerySvc.queryBy<FungibleAsset<*>>(countCriteriaUnconsumed).otherResults.single() as Long + assertThat(fungibleStateCountUnconsumed).isEqualTo(5L) + + // count linear states + val linearStateCountUnconsumed = vaultQuerySvc.queryBy<LinearState>(countCriteriaUnconsumed).otherResults.single() as Long + assertThat(linearStateCountUnconsumed).isEqualTo(5L) + + // count deal states + val dealStateCountUnconsumed = vaultQuerySvc.queryBy<DealState>(countCriteriaUnconsumed).otherResults.single() as Long + assertThat(dealStateCountUnconsumed).isEqualTo(2L) + + // CONSUMED states + + // count fungible assets + val countCriteriaConsumed = QueryCriteria.VaultCustomQueryCriteria(count, Vault.StateStatus.CONSUMED) + val fungibleStateCountConsumed = vaultQuerySvc.queryBy<FungibleAsset<*>>(countCriteriaConsumed).otherResults.single() as Long + assertThat(fungibleStateCountConsumed).isEqualTo(6L) + + // count linear states + val linearStateCountConsumed = vaultQuerySvc.queryBy<LinearState>(countCriteriaConsumed).otherResults.single() as Long + assertThat(linearStateCountConsumed).isEqualTo(4L) + + // count deal states + val dealStateCountConsumed = vaultQuerySvc.queryBy<DealState>(countCriteriaConsumed).otherResults.single() as Long + assertThat(dealStateCountConsumed).isEqualTo(1L) + } + } + private val TODAY = LocalDate.now().atStartOfDay().toInstant(ZoneOffset.UTC) @Test