From 3c8368e25d00db84ac42ebcd4d3b560817aa1196 Mon Sep 17 00:00:00 2001
From: josecoll <jose.coll@r3cev.com>
Date: Thu, 3 Aug 2017 10:46:45 +0100
Subject: [PATCH] Applied contract type filtering to ALL Query Criteria types 
 (#1159)

* Applied contract type filtering to ALL Query Criteria types (was only applied in general)

* Do not filter when all contract state types specified.
---
 .../vault/HibernateQueryCriteriaParser.kt     | 35 ++++++--
 .../node/services/vault/VaultQueryTests.kt    | 90 +++++++++++++++++++
 2 files changed, 117 insertions(+), 8 deletions(-)

diff --git a/node/src/main/kotlin/net/corda/node/services/vault/HibernateQueryCriteriaParser.kt b/node/src/main/kotlin/net/corda/node/services/vault/HibernateQueryCriteriaParser.kt
index 6f383f81c2..e704eaac3c 100644
--- a/node/src/main/kotlin/net/corda/node/services/vault/HibernateQueryCriteriaParser.kt
+++ b/node/src/main/kotlin/net/corda/node/services/vault/HibernateQueryCriteriaParser.kt
@@ -43,14 +43,9 @@ class HibernateQueryCriteriaParser(val contractType: Class<out ContractState>,
         val predicateSet = mutableSetOf<Predicate>()
 
         // contract State Types
-        val combinedContractTypeTypes = criteria.contractStateTypes?.plus(contractType) ?: setOf(contractType)
-        combinedContractTypeTypes.filter { it.name != ContractState::class.java.name }.let {
-            val interfaces = it.flatMap { contractTypeMappings[it.name] ?: emptyList() }
-            val concrete = it.filter { !it.isInterface }.map { it.name }
-            val all = interfaces.plus(concrete)
-            if (all.isNotEmpty())
-                predicateSet.add(criteriaBuilder.and(vaultStates.get<String>("contractStateClassName").`in`(all)))
-        }
+        val contractTypes = deriveContractTypes(criteria.contractStateTypes)
+        if (contractTypes.isNotEmpty())
+            predicateSet.add(criteriaBuilder.and(vaultStates.get<String>("contractStateClassName").`in`(contractTypes)))
 
         // soft locking
         if (!criteria.includeSoftlockedStates)
@@ -83,6 +78,15 @@ class HibernateQueryCriteriaParser(val contractType: Class<out ContractState>,
         return predicateSet
     }
 
+    private fun deriveContractTypes(contractStateTypes: Set<Class<out ContractState>>? = null): List<String> {
+        val combinedContractStateTypes = contractStateTypes?.plus(contractType) ?: setOf(contractType)
+        combinedContractStateTypes.filter { it.name != ContractState::class.java.name }.let {
+            val interfaces = it.flatMap { contractTypeMappings[it.name] ?: emptyList() }
+            val concrete = it.filter { !it.isInterface }.map { it.name }
+            return interfaces.plus(concrete)
+        }
+    }
+
     private fun columnPredicateToPredicate(column: Path<out Any?>, columnPredicate: ColumnPredicate<*>): Predicate {
         return when (columnPredicate) {
             is ColumnPredicate.EqualityComparison -> {
@@ -216,6 +220,11 @@ class HibernateQueryCriteriaParser(val contractType: Class<out ContractState>,
         val joinPredicate = criteriaBuilder.equal(vaultStates.get<PersistentStateRef>("stateRef"), vaultFungibleStates.get<PersistentStateRef>("stateRef"))
         predicateSet.add(joinPredicate)
 
+        // contract State Types
+        val contractTypes = deriveContractTypes()
+        if (contractTypes.isNotEmpty())
+            predicateSet.add(criteriaBuilder.and(vaultStates.get<String>("contractStateClassName").`in`(contractTypes)))
+
         // owner
         criteria.owner?.let {
             val ownerKeys = criteria.owner as List<AbstractParty>
@@ -265,6 +274,11 @@ class HibernateQueryCriteriaParser(val contractType: Class<out ContractState>,
         val joinPredicate = criteriaBuilder.equal(vaultStates.get<PersistentStateRef>("stateRef"), vaultLinearStates.get<PersistentStateRef>("stateRef"))
         joinPredicates.add(joinPredicate)
 
+        // contract State Types
+        val contractTypes = deriveContractTypes()
+        if (contractTypes.isNotEmpty())
+            predicateSet.add(criteriaBuilder.and(vaultStates.get<String>("contractStateClassName").`in`(contractTypes)))
+
         // linear ids
         criteria.linearId?.let {
             val uniqueIdentifiers = criteria.linearId as List<UniqueIdentifier>
@@ -304,6 +318,11 @@ class HibernateQueryCriteriaParser(val contractType: Class<out ContractState>,
             val joinPredicate = criteriaBuilder.equal(vaultStates.get<PersistentStateRef>("stateRef"), entityRoot.get<PersistentStateRef>("stateRef"))
             joinPredicates.add(joinPredicate)
 
+            // contract State Types
+            val contractTypes = deriveContractTypes()
+            if (contractTypes.isNotEmpty())
+                predicateSet.add(criteriaBuilder.and(vaultStates.get<String>("contractStateClassName").`in`(contractTypes)))
+
             // resolve general criteria expressions
             parseExpression(entityRoot, criteria.expression, predicateSet)
         }
diff --git a/node/src/test/kotlin/net/corda/node/services/vault/VaultQueryTests.kt b/node/src/test/kotlin/net/corda/node/services/vault/VaultQueryTests.kt
index c5accd024d..9fe7922c9a 100644
--- a/node/src/test/kotlin/net/corda/node/services/vault/VaultQueryTests.kt
+++ b/node/src/test/kotlin/net/corda/node/services/vault/VaultQueryTests.kt
@@ -807,6 +807,96 @@ class VaultQueryTests : TestDependencyInjectionBase() {
         }
     }
 
+    @Test
+    fun `aggregate functions count by contract type`() {
+        database.transaction {
+            // create new states
+            services.fillWithSomeTestCash(100.DOLLARS, CASH_NOTARY, 10, 10, Random(0L))
+            services.fillWithSomeTestLinearStates(1, "XYZ")
+            services.fillWithSomeTestLinearStates(2, "JKL")
+            services.fillWithSomeTestLinearStates(3, "ABC")
+            services.fillWithSomeTestDeals(listOf("123", "456", "789"))
+
+            // count fungible assets
+            val count = builder { VaultSchemaV1.VaultStates::recordedTime.count() }
+            val countCriteria = QueryCriteria.VaultCustomQueryCriteria(count)
+            val fungibleStateCount = vaultQuerySvc.queryBy<FungibleAsset<*>>(countCriteria).otherResults.single() as Long
+            assertThat(fungibleStateCount).isEqualTo(10L)
+
+            // count linear states
+            val linearStateCount = vaultQuerySvc.queryBy<LinearState>(countCriteria).otherResults.single() as Long
+            assertThat(linearStateCount).isEqualTo(9L)
+
+            // count deal states
+            val dealStateCount = vaultQuerySvc.queryBy<DealState>(countCriteria).otherResults.single() as Long
+            assertThat(dealStateCount).isEqualTo(3L)
+        }
+    }
+
+    @Test
+    fun `aggregate functions count by contract type and state status`() {
+        database.transaction {
+            // create new states
+            services.fillWithSomeTestCash(100.DOLLARS, CASH_NOTARY, 10, 10, Random(0L))
+            val linearStatesXYZ = services.fillWithSomeTestLinearStates(1, "XYZ")
+            val linearStatesJKL = services.fillWithSomeTestLinearStates(2, "JKL")
+            services.fillWithSomeTestLinearStates(3, "ABC")
+            val dealStates = services.fillWithSomeTestDeals(listOf("123", "456", "789"))
+
+            // ALL states
+
+            // count fungible assets
+            val count = builder { VaultSchemaV1.VaultStates::recordedTime.count() }
+            val countCriteria = QueryCriteria.VaultCustomQueryCriteria(count, Vault.StateStatus.ALL)
+            val fungibleStateCount = vaultQuerySvc.queryBy<FungibleAsset<*>>(countCriteria).otherResults.single() as Long
+            assertThat(fungibleStateCount).isEqualTo(10L)
+
+            // count linear states
+            val linearStateCount = vaultQuerySvc.queryBy<LinearState>(countCriteria).otherResults.single() as Long
+            assertThat(linearStateCount).isEqualTo(9L)
+
+            // count deal states
+            val dealStateCount = vaultQuerySvc.queryBy<DealState>(countCriteria).otherResults.single() as Long
+            assertThat(dealStateCount).isEqualTo(3L)
+
+            // consume some states
+            services.consumeLinearStates(linearStatesXYZ.states.toList())
+            services.consumeLinearStates(linearStatesJKL.states.toList())
+            services.consumeDeals(dealStates.states.filter { it.state.data.ref == "456" })
+            services.consumeCash(50.DOLLARS)
+
+            // UNCONSUMED states (default)
+
+            // count fungible assets
+            val countCriteriaUnconsumed = QueryCriteria.VaultCustomQueryCriteria(count, Vault.StateStatus.UNCONSUMED)
+            val fungibleStateCountUnconsumed = vaultQuerySvc.queryBy<FungibleAsset<*>>(countCriteriaUnconsumed).otherResults.single() as Long
+            assertThat(fungibleStateCountUnconsumed).isEqualTo(5L)
+
+            // count linear states
+            val linearStateCountUnconsumed = vaultQuerySvc.queryBy<LinearState>(countCriteriaUnconsumed).otherResults.single() as Long
+            assertThat(linearStateCountUnconsumed).isEqualTo(5L)
+
+            // count deal states
+            val dealStateCountUnconsumed = vaultQuerySvc.queryBy<DealState>(countCriteriaUnconsumed).otherResults.single() as Long
+            assertThat(dealStateCountUnconsumed).isEqualTo(2L)
+
+            // CONSUMED states
+
+            // count fungible assets
+            val countCriteriaConsumed = QueryCriteria.VaultCustomQueryCriteria(count, Vault.StateStatus.CONSUMED)
+            val fungibleStateCountConsumed = vaultQuerySvc.queryBy<FungibleAsset<*>>(countCriteriaConsumed).otherResults.single() as Long
+            assertThat(fungibleStateCountConsumed).isEqualTo(6L)
+
+            // count linear states
+            val linearStateCountConsumed = vaultQuerySvc.queryBy<LinearState>(countCriteriaConsumed).otherResults.single() as Long
+            assertThat(linearStateCountConsumed).isEqualTo(4L)
+
+            // count deal states
+            val dealStateCountConsumed = vaultQuerySvc.queryBy<DealState>(countCriteriaConsumed).otherResults.single() as Long
+            assertThat(dealStateCountConsumed).isEqualTo(1L)
+        }
+    }
+
     private val TODAY = LocalDate.now().atStartOfDay().toInstant(ZoneOffset.UTC)
 
     @Test