diff --git a/node/src/main/kotlin/net/corda/node/services/vault/HibernateQueryCriteriaParser.kt b/node/src/main/kotlin/net/corda/node/services/vault/HibernateQueryCriteriaParser.kt index a601d90aab..04ba81567a 100644 --- a/node/src/main/kotlin/net/corda/node/services/vault/HibernateQueryCriteriaParser.kt +++ b/node/src/main/kotlin/net/corda/node/services/vault/HibernateQueryCriteriaParser.kt @@ -299,29 +299,36 @@ class HibernateQueryCriteriaParser(val contractStateType: Class criteriaBuilder.asc(aggregateExpression) - Sort.Direction.DESC -> criteriaBuilder.desc(aggregateExpression) - } - criteriaQuery.orderBy(orderCriteria) - } + // Some databases may not support aggregate expression in 'group by' clause e.g. 'group by sum(col)', + // Hibernate Criteria Builder can't produce alias 'group by col_alias', and the only solution is to use a positional parameter 'group by 1' + val orderByColumnPosition = aggregateExpressions.size + var shiftLeft = 0 // add optional group by clauses expression.groupByColumns?.let { columns -> val groupByExpressions = columns.map { _column -> val path = root.get(getColumnName(_column)) + val columnNumberBeforeRemoval = aggregateExpressions.size if (path is SingularAttributePath) //remove the same columns from different joins to match the single column in 'group by' only (from the last join) aggregateExpressions.removeAll { elem -> if (elem is SingularAttributePath) elem.attribute.javaMember == path.attribute.javaMember else false } + shiftLeft += columnNumberBeforeRemoval - aggregateExpressions.size //record how many times a duplicated column was removed (from the previous 'parseAggregateFunction' run) aggregateExpressions.add(path) path } criteriaQuery.groupBy(groupByExpressions) } + // optionally order by this aggregate function + expression.orderBy?.let { + val orderCriteria = + when (expression.orderBy!!) { + // when adding column position of 'group by' shift in case columns were removed + Sort.Direction.ASC -> criteriaBuilder.asc(criteriaBuilder.literal(orderByColumnPosition - shiftLeft)) + Sort.Direction.DESC -> criteriaBuilder.desc(criteriaBuilder.literal(orderByColumnPosition - shiftLeft)) + } + criteriaQuery.orderBy(orderCriteria) + } return aggregateExpression } else -> throw VaultQueryException("Not expecting $columnPredicate") diff --git a/node/src/test/kotlin/net/corda/node/services/vault/VaultQueryTests.kt b/node/src/test/kotlin/net/corda/node/services/vault/VaultQueryTests.kt index 8704a7c871..06c680ad41 100644 --- a/node/src/test/kotlin/net/corda/node/services/vault/VaultQueryTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/vault/VaultQueryTests.kt @@ -767,6 +767,72 @@ open class VaultQueryTests { } } + @Test + fun `aggregate functions with single group clause desc first column`() { + database.transaction { + listOf(100.DOLLARS, 200.DOLLARS, 300.DOLLARS, 400.POUNDS, 500.SWISS_FRANCS).zip(1..5).forEach { (howMuch, states) -> + vaultFiller.fillWithSomeTestCash(howMuch, notaryServices, states, DUMMY_CASH_ISSUER) + } + val sum = builder { CashSchemaV1.PersistentCashState::pennies.sum(groupByColumns = listOf(CashSchemaV1.PersistentCashState::currency), orderBy = Sort.Direction.DESC) } + val max = builder { CashSchemaV1.PersistentCashState::pennies.max(groupByColumns = listOf(CashSchemaV1.PersistentCashState::currency)) } + val min = builder { CashSchemaV1.PersistentCashState::pennies.min(groupByColumns = listOf(CashSchemaV1.PersistentCashState::currency)) } + + val results = vaultService.queryBy>(VaultCustomQueryCriteria(sum) + .and(VaultCustomQueryCriteria(max)) + .and(VaultCustomQueryCriteria(min))) + + assertThat(results.otherResults).hasSize(12) + + assertThat(results.otherResults.subList(0,4)).isEqualTo(listOf(60000L, 11298L, 8702L, "USD")) + assertThat(results.otherResults.subList(4,8)).isEqualTo(listOf(50000L, 10274L, 9481L, "CHF")) + assertThat(results.otherResults.subList(8,12)).isEqualTo(listOf(40000L, 10343L, 9351L, "GBP")) + } + } + + @Test + fun `aggregate functions with single group clause desc mid column`() { + database.transaction { + listOf(100.DOLLARS, 200.DOLLARS, 300.DOLLARS, 400.POUNDS, 500.SWISS_FRANCS).zip(1..5).forEach { (howMuch, states) -> + vaultFiller.fillWithSomeTestCash(howMuch, notaryServices, states, DUMMY_CASH_ISSUER) + } + val sum = builder { CashSchemaV1.PersistentCashState::pennies.sum(groupByColumns = listOf(CashSchemaV1.PersistentCashState::currency)) } + val max = builder { CashSchemaV1.PersistentCashState::pennies.max(groupByColumns = listOf(CashSchemaV1.PersistentCashState::currency), orderBy = Sort.Direction.DESC) } + val min = builder { CashSchemaV1.PersistentCashState::pennies.min(groupByColumns = listOf(CashSchemaV1.PersistentCashState::currency)) } + + val results = vaultService.queryBy>(VaultCustomQueryCriteria(sum) + .and(VaultCustomQueryCriteria(max)) + .and(VaultCustomQueryCriteria(min))) + + assertThat(results.otherResults).hasSize(12) + + assertThat(results.otherResults.subList(0,4)).isEqualTo(listOf(60000L, 11298L, 8702L, "USD")) + assertThat(results.otherResults.subList(4,8)).isEqualTo(listOf(40000L, 10343L, 9351L, "GBP")) + assertThat(results.otherResults.subList(8,12)).isEqualTo(listOf(50000L, 10274L, 9481L, "CHF")) + } + } + + @Test + fun `aggregate functions with single group clause desc last column`() { + database.transaction { + listOf(100.DOLLARS, 200.DOLLARS, 300.DOLLARS, 400.POUNDS, 500.SWISS_FRANCS).zip(1..5).forEach { (howMuch, states) -> + vaultFiller.fillWithSomeTestCash(howMuch, notaryServices, states, DUMMY_CASH_ISSUER) + } + val sum = builder { CashSchemaV1.PersistentCashState::pennies.sum(groupByColumns = listOf(CashSchemaV1.PersistentCashState::currency)) } + val max = builder { CashSchemaV1.PersistentCashState::pennies.max(groupByColumns = listOf(CashSchemaV1.PersistentCashState::currency)) } + val min = builder { CashSchemaV1.PersistentCashState::pennies.min(groupByColumns = listOf(CashSchemaV1.PersistentCashState::currency), orderBy = Sort.Direction.DESC) } + + val results = vaultService.queryBy>(VaultCustomQueryCriteria(sum) + .and(VaultCustomQueryCriteria(max)) + .and(VaultCustomQueryCriteria(min))) + + assertThat(results.otherResults).hasSize(12) + + assertThat(results.otherResults.subList(0,4)).isEqualTo(listOf(50000L, 10274L, 9481L, "CHF")) + assertThat(results.otherResults.subList(4,8)).isEqualTo(listOf(40000L, 10343L, 9351L, "GBP")) + assertThat(results.otherResults.subList(8,12)).isEqualTo(listOf(60000L, 11298L, 8702L, "USD")) + } + } + @Test fun `aggregate functions sum by issuer and currency and sort by aggregate sum`() { database.transaction {