Applied contract type filtering to ALL Query Criteria types (#1159)

* Applied contract type filtering to ALL Query Criteria types (was only applied in general)

* Do not filter when all contract state types specified.
This commit is contained in:
josecoll 2017-08-03 10:46:45 +01:00 committed by GitHub
parent 78ccbe7d57
commit 3c8368e25d
2 changed files with 117 additions and 8 deletions

View File

@ -43,14 +43,9 @@ class HibernateQueryCriteriaParser(val contractType: Class<out ContractState>,
val predicateSet = mutableSetOf<Predicate>()
// contract State Types
val combinedContractTypeTypes = criteria.contractStateTypes?.plus(contractType) ?: setOf(contractType)
combinedContractTypeTypes.filter { it.name != ContractState::class.java.name }.let {
val interfaces = it.flatMap { contractTypeMappings[it.name] ?: emptyList() }
val concrete = it.filter { !it.isInterface }.map { it.name }
val all = interfaces.plus(concrete)
if (all.isNotEmpty())
predicateSet.add(criteriaBuilder.and(vaultStates.get<String>("contractStateClassName").`in`(all)))
}
val contractTypes = deriveContractTypes(criteria.contractStateTypes)
if (contractTypes.isNotEmpty())
predicateSet.add(criteriaBuilder.and(vaultStates.get<String>("contractStateClassName").`in`(contractTypes)))
// soft locking
if (!criteria.includeSoftlockedStates)
@ -83,6 +78,15 @@ class HibernateQueryCriteriaParser(val contractType: Class<out ContractState>,
return predicateSet
}
private fun deriveContractTypes(contractStateTypes: Set<Class<out ContractState>>? = null): List<String> {
val combinedContractStateTypes = contractStateTypes?.plus(contractType) ?: setOf(contractType)
combinedContractStateTypes.filter { it.name != ContractState::class.java.name }.let {
val interfaces = it.flatMap { contractTypeMappings[it.name] ?: emptyList() }
val concrete = it.filter { !it.isInterface }.map { it.name }
return interfaces.plus(concrete)
}
}
private fun columnPredicateToPredicate(column: Path<out Any?>, columnPredicate: ColumnPredicate<*>): Predicate {
return when (columnPredicate) {
is ColumnPredicate.EqualityComparison -> {
@ -216,6 +220,11 @@ class HibernateQueryCriteriaParser(val contractType: Class<out ContractState>,
val joinPredicate = criteriaBuilder.equal(vaultStates.get<PersistentStateRef>("stateRef"), vaultFungibleStates.get<PersistentStateRef>("stateRef"))
predicateSet.add(joinPredicate)
// contract State Types
val contractTypes = deriveContractTypes()
if (contractTypes.isNotEmpty())
predicateSet.add(criteriaBuilder.and(vaultStates.get<String>("contractStateClassName").`in`(contractTypes)))
// owner
criteria.owner?.let {
val ownerKeys = criteria.owner as List<AbstractParty>
@ -265,6 +274,11 @@ class HibernateQueryCriteriaParser(val contractType: Class<out ContractState>,
val joinPredicate = criteriaBuilder.equal(vaultStates.get<PersistentStateRef>("stateRef"), vaultLinearStates.get<PersistentStateRef>("stateRef"))
joinPredicates.add(joinPredicate)
// contract State Types
val contractTypes = deriveContractTypes()
if (contractTypes.isNotEmpty())
predicateSet.add(criteriaBuilder.and(vaultStates.get<String>("contractStateClassName").`in`(contractTypes)))
// linear ids
criteria.linearId?.let {
val uniqueIdentifiers = criteria.linearId as List<UniqueIdentifier>
@ -304,6 +318,11 @@ class HibernateQueryCriteriaParser(val contractType: Class<out ContractState>,
val joinPredicate = criteriaBuilder.equal(vaultStates.get<PersistentStateRef>("stateRef"), entityRoot.get<PersistentStateRef>("stateRef"))
joinPredicates.add(joinPredicate)
// contract State Types
val contractTypes = deriveContractTypes()
if (contractTypes.isNotEmpty())
predicateSet.add(criteriaBuilder.and(vaultStates.get<String>("contractStateClassName").`in`(contractTypes)))
// resolve general criteria expressions
parseExpression(entityRoot, criteria.expression, predicateSet)
}

View File

@ -807,6 +807,96 @@ class VaultQueryTests : TestDependencyInjectionBase() {
}
}
@Test
fun `aggregate functions count by contract type`() {
database.transaction {
// create new states
services.fillWithSomeTestCash(100.DOLLARS, CASH_NOTARY, 10, 10, Random(0L))
services.fillWithSomeTestLinearStates(1, "XYZ")
services.fillWithSomeTestLinearStates(2, "JKL")
services.fillWithSomeTestLinearStates(3, "ABC")
services.fillWithSomeTestDeals(listOf("123", "456", "789"))
// count fungible assets
val count = builder { VaultSchemaV1.VaultStates::recordedTime.count() }
val countCriteria = QueryCriteria.VaultCustomQueryCriteria(count)
val fungibleStateCount = vaultQuerySvc.queryBy<FungibleAsset<*>>(countCriteria).otherResults.single() as Long
assertThat(fungibleStateCount).isEqualTo(10L)
// count linear states
val linearStateCount = vaultQuerySvc.queryBy<LinearState>(countCriteria).otherResults.single() as Long
assertThat(linearStateCount).isEqualTo(9L)
// count deal states
val dealStateCount = vaultQuerySvc.queryBy<DealState>(countCriteria).otherResults.single() as Long
assertThat(dealStateCount).isEqualTo(3L)
}
}
@Test
fun `aggregate functions count by contract type and state status`() {
database.transaction {
// create new states
services.fillWithSomeTestCash(100.DOLLARS, CASH_NOTARY, 10, 10, Random(0L))
val linearStatesXYZ = services.fillWithSomeTestLinearStates(1, "XYZ")
val linearStatesJKL = services.fillWithSomeTestLinearStates(2, "JKL")
services.fillWithSomeTestLinearStates(3, "ABC")
val dealStates = services.fillWithSomeTestDeals(listOf("123", "456", "789"))
// ALL states
// count fungible assets
val count = builder { VaultSchemaV1.VaultStates::recordedTime.count() }
val countCriteria = QueryCriteria.VaultCustomQueryCriteria(count, Vault.StateStatus.ALL)
val fungibleStateCount = vaultQuerySvc.queryBy<FungibleAsset<*>>(countCriteria).otherResults.single() as Long
assertThat(fungibleStateCount).isEqualTo(10L)
// count linear states
val linearStateCount = vaultQuerySvc.queryBy<LinearState>(countCriteria).otherResults.single() as Long
assertThat(linearStateCount).isEqualTo(9L)
// count deal states
val dealStateCount = vaultQuerySvc.queryBy<DealState>(countCriteria).otherResults.single() as Long
assertThat(dealStateCount).isEqualTo(3L)
// consume some states
services.consumeLinearStates(linearStatesXYZ.states.toList())
services.consumeLinearStates(linearStatesJKL.states.toList())
services.consumeDeals(dealStates.states.filter { it.state.data.ref == "456" })
services.consumeCash(50.DOLLARS)
// UNCONSUMED states (default)
// count fungible assets
val countCriteriaUnconsumed = QueryCriteria.VaultCustomQueryCriteria(count, Vault.StateStatus.UNCONSUMED)
val fungibleStateCountUnconsumed = vaultQuerySvc.queryBy<FungibleAsset<*>>(countCriteriaUnconsumed).otherResults.single() as Long
assertThat(fungibleStateCountUnconsumed).isEqualTo(5L)
// count linear states
val linearStateCountUnconsumed = vaultQuerySvc.queryBy<LinearState>(countCriteriaUnconsumed).otherResults.single() as Long
assertThat(linearStateCountUnconsumed).isEqualTo(5L)
// count deal states
val dealStateCountUnconsumed = vaultQuerySvc.queryBy<DealState>(countCriteriaUnconsumed).otherResults.single() as Long
assertThat(dealStateCountUnconsumed).isEqualTo(2L)
// CONSUMED states
// count fungible assets
val countCriteriaConsumed = QueryCriteria.VaultCustomQueryCriteria(count, Vault.StateStatus.CONSUMED)
val fungibleStateCountConsumed = vaultQuerySvc.queryBy<FungibleAsset<*>>(countCriteriaConsumed).otherResults.single() as Long
assertThat(fungibleStateCountConsumed).isEqualTo(6L)
// count linear states
val linearStateCountConsumed = vaultQuerySvc.queryBy<LinearState>(countCriteriaConsumed).otherResults.single() as Long
assertThat(linearStateCountConsumed).isEqualTo(4L)
// count deal states
val dealStateCountConsumed = vaultQuerySvc.queryBy<DealState>(countCriteriaConsumed).otherResults.single() as Long
assertThat(dealStateCountConsumed).isEqualTo(1L)
}
}
private val TODAY = LocalDate.now().atStartOfDay().toInstant(ZoneOffset.UTC)
@Test