mirror of
https://github.com/corda/corda.git
synced 2025-04-07 11:27:01 +00:00
Applied contract type filtering to ALL Query Criteria types (#1159)
* Applied contract type filtering to ALL Query Criteria types (was only applied in general) * Do not filter when all contract state types specified.
This commit is contained in:
parent
78ccbe7d57
commit
3c8368e25d
@ -43,14 +43,9 @@ class HibernateQueryCriteriaParser(val contractType: Class<out ContractState>,
|
||||
val predicateSet = mutableSetOf<Predicate>()
|
||||
|
||||
// contract State Types
|
||||
val combinedContractTypeTypes = criteria.contractStateTypes?.plus(contractType) ?: setOf(contractType)
|
||||
combinedContractTypeTypes.filter { it.name != ContractState::class.java.name }.let {
|
||||
val interfaces = it.flatMap { contractTypeMappings[it.name] ?: emptyList() }
|
||||
val concrete = it.filter { !it.isInterface }.map { it.name }
|
||||
val all = interfaces.plus(concrete)
|
||||
if (all.isNotEmpty())
|
||||
predicateSet.add(criteriaBuilder.and(vaultStates.get<String>("contractStateClassName").`in`(all)))
|
||||
}
|
||||
val contractTypes = deriveContractTypes(criteria.contractStateTypes)
|
||||
if (contractTypes.isNotEmpty())
|
||||
predicateSet.add(criteriaBuilder.and(vaultStates.get<String>("contractStateClassName").`in`(contractTypes)))
|
||||
|
||||
// soft locking
|
||||
if (!criteria.includeSoftlockedStates)
|
||||
@ -83,6 +78,15 @@ class HibernateQueryCriteriaParser(val contractType: Class<out ContractState>,
|
||||
return predicateSet
|
||||
}
|
||||
|
||||
private fun deriveContractTypes(contractStateTypes: Set<Class<out ContractState>>? = null): List<String> {
|
||||
val combinedContractStateTypes = contractStateTypes?.plus(contractType) ?: setOf(contractType)
|
||||
combinedContractStateTypes.filter { it.name != ContractState::class.java.name }.let {
|
||||
val interfaces = it.flatMap { contractTypeMappings[it.name] ?: emptyList() }
|
||||
val concrete = it.filter { !it.isInterface }.map { it.name }
|
||||
return interfaces.plus(concrete)
|
||||
}
|
||||
}
|
||||
|
||||
private fun columnPredicateToPredicate(column: Path<out Any?>, columnPredicate: ColumnPredicate<*>): Predicate {
|
||||
return when (columnPredicate) {
|
||||
is ColumnPredicate.EqualityComparison -> {
|
||||
@ -216,6 +220,11 @@ class HibernateQueryCriteriaParser(val contractType: Class<out ContractState>,
|
||||
val joinPredicate = criteriaBuilder.equal(vaultStates.get<PersistentStateRef>("stateRef"), vaultFungibleStates.get<PersistentStateRef>("stateRef"))
|
||||
predicateSet.add(joinPredicate)
|
||||
|
||||
// contract State Types
|
||||
val contractTypes = deriveContractTypes()
|
||||
if (contractTypes.isNotEmpty())
|
||||
predicateSet.add(criteriaBuilder.and(vaultStates.get<String>("contractStateClassName").`in`(contractTypes)))
|
||||
|
||||
// owner
|
||||
criteria.owner?.let {
|
||||
val ownerKeys = criteria.owner as List<AbstractParty>
|
||||
@ -265,6 +274,11 @@ class HibernateQueryCriteriaParser(val contractType: Class<out ContractState>,
|
||||
val joinPredicate = criteriaBuilder.equal(vaultStates.get<PersistentStateRef>("stateRef"), vaultLinearStates.get<PersistentStateRef>("stateRef"))
|
||||
joinPredicates.add(joinPredicate)
|
||||
|
||||
// contract State Types
|
||||
val contractTypes = deriveContractTypes()
|
||||
if (contractTypes.isNotEmpty())
|
||||
predicateSet.add(criteriaBuilder.and(vaultStates.get<String>("contractStateClassName").`in`(contractTypes)))
|
||||
|
||||
// linear ids
|
||||
criteria.linearId?.let {
|
||||
val uniqueIdentifiers = criteria.linearId as List<UniqueIdentifier>
|
||||
@ -304,6 +318,11 @@ class HibernateQueryCriteriaParser(val contractType: Class<out ContractState>,
|
||||
val joinPredicate = criteriaBuilder.equal(vaultStates.get<PersistentStateRef>("stateRef"), entityRoot.get<PersistentStateRef>("stateRef"))
|
||||
joinPredicates.add(joinPredicate)
|
||||
|
||||
// contract State Types
|
||||
val contractTypes = deriveContractTypes()
|
||||
if (contractTypes.isNotEmpty())
|
||||
predicateSet.add(criteriaBuilder.and(vaultStates.get<String>("contractStateClassName").`in`(contractTypes)))
|
||||
|
||||
// resolve general criteria expressions
|
||||
parseExpression(entityRoot, criteria.expression, predicateSet)
|
||||
}
|
||||
|
@ -807,6 +807,96 @@ class VaultQueryTests : TestDependencyInjectionBase() {
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `aggregate functions count by contract type`() {
|
||||
database.transaction {
|
||||
// create new states
|
||||
services.fillWithSomeTestCash(100.DOLLARS, CASH_NOTARY, 10, 10, Random(0L))
|
||||
services.fillWithSomeTestLinearStates(1, "XYZ")
|
||||
services.fillWithSomeTestLinearStates(2, "JKL")
|
||||
services.fillWithSomeTestLinearStates(3, "ABC")
|
||||
services.fillWithSomeTestDeals(listOf("123", "456", "789"))
|
||||
|
||||
// count fungible assets
|
||||
val count = builder { VaultSchemaV1.VaultStates::recordedTime.count() }
|
||||
val countCriteria = QueryCriteria.VaultCustomQueryCriteria(count)
|
||||
val fungibleStateCount = vaultQuerySvc.queryBy<FungibleAsset<*>>(countCriteria).otherResults.single() as Long
|
||||
assertThat(fungibleStateCount).isEqualTo(10L)
|
||||
|
||||
// count linear states
|
||||
val linearStateCount = vaultQuerySvc.queryBy<LinearState>(countCriteria).otherResults.single() as Long
|
||||
assertThat(linearStateCount).isEqualTo(9L)
|
||||
|
||||
// count deal states
|
||||
val dealStateCount = vaultQuerySvc.queryBy<DealState>(countCriteria).otherResults.single() as Long
|
||||
assertThat(dealStateCount).isEqualTo(3L)
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `aggregate functions count by contract type and state status`() {
|
||||
database.transaction {
|
||||
// create new states
|
||||
services.fillWithSomeTestCash(100.DOLLARS, CASH_NOTARY, 10, 10, Random(0L))
|
||||
val linearStatesXYZ = services.fillWithSomeTestLinearStates(1, "XYZ")
|
||||
val linearStatesJKL = services.fillWithSomeTestLinearStates(2, "JKL")
|
||||
services.fillWithSomeTestLinearStates(3, "ABC")
|
||||
val dealStates = services.fillWithSomeTestDeals(listOf("123", "456", "789"))
|
||||
|
||||
// ALL states
|
||||
|
||||
// count fungible assets
|
||||
val count = builder { VaultSchemaV1.VaultStates::recordedTime.count() }
|
||||
val countCriteria = QueryCriteria.VaultCustomQueryCriteria(count, Vault.StateStatus.ALL)
|
||||
val fungibleStateCount = vaultQuerySvc.queryBy<FungibleAsset<*>>(countCriteria).otherResults.single() as Long
|
||||
assertThat(fungibleStateCount).isEqualTo(10L)
|
||||
|
||||
// count linear states
|
||||
val linearStateCount = vaultQuerySvc.queryBy<LinearState>(countCriteria).otherResults.single() as Long
|
||||
assertThat(linearStateCount).isEqualTo(9L)
|
||||
|
||||
// count deal states
|
||||
val dealStateCount = vaultQuerySvc.queryBy<DealState>(countCriteria).otherResults.single() as Long
|
||||
assertThat(dealStateCount).isEqualTo(3L)
|
||||
|
||||
// consume some states
|
||||
services.consumeLinearStates(linearStatesXYZ.states.toList())
|
||||
services.consumeLinearStates(linearStatesJKL.states.toList())
|
||||
services.consumeDeals(dealStates.states.filter { it.state.data.ref == "456" })
|
||||
services.consumeCash(50.DOLLARS)
|
||||
|
||||
// UNCONSUMED states (default)
|
||||
|
||||
// count fungible assets
|
||||
val countCriteriaUnconsumed = QueryCriteria.VaultCustomQueryCriteria(count, Vault.StateStatus.UNCONSUMED)
|
||||
val fungibleStateCountUnconsumed = vaultQuerySvc.queryBy<FungibleAsset<*>>(countCriteriaUnconsumed).otherResults.single() as Long
|
||||
assertThat(fungibleStateCountUnconsumed).isEqualTo(5L)
|
||||
|
||||
// count linear states
|
||||
val linearStateCountUnconsumed = vaultQuerySvc.queryBy<LinearState>(countCriteriaUnconsumed).otherResults.single() as Long
|
||||
assertThat(linearStateCountUnconsumed).isEqualTo(5L)
|
||||
|
||||
// count deal states
|
||||
val dealStateCountUnconsumed = vaultQuerySvc.queryBy<DealState>(countCriteriaUnconsumed).otherResults.single() as Long
|
||||
assertThat(dealStateCountUnconsumed).isEqualTo(2L)
|
||||
|
||||
// CONSUMED states
|
||||
|
||||
// count fungible assets
|
||||
val countCriteriaConsumed = QueryCriteria.VaultCustomQueryCriteria(count, Vault.StateStatus.CONSUMED)
|
||||
val fungibleStateCountConsumed = vaultQuerySvc.queryBy<FungibleAsset<*>>(countCriteriaConsumed).otherResults.single() as Long
|
||||
assertThat(fungibleStateCountConsumed).isEqualTo(6L)
|
||||
|
||||
// count linear states
|
||||
val linearStateCountConsumed = vaultQuerySvc.queryBy<LinearState>(countCriteriaConsumed).otherResults.single() as Long
|
||||
assertThat(linearStateCountConsumed).isEqualTo(4L)
|
||||
|
||||
// count deal states
|
||||
val dealStateCountConsumed = vaultQuerySvc.queryBy<DealState>(countCriteriaConsumed).otherResults.single() as Long
|
||||
assertThat(dealStateCountConsumed).isEqualTo(1L)
|
||||
}
|
||||
}
|
||||
|
||||
private val TODAY = LocalDate.now().atStartOfDay().toInstant(ZoneOffset.UTC)
|
||||
|
||||
@Test
|
||||
|
Loading…
x
Reference in New Issue
Block a user