mirror of
https://github.com/corda/corda.git
synced 2024-12-19 04:57:58 +00:00
CORDA-3879 - query with OR combinator returns too many results (#6456)
* fix suggestion and tests * detekt suppress * making sure the forced join works with IndirectStatePersistable and removing unnecessary joinPredicates from parse with sorting * remove joinPredicates and add tests * rename sorting * revert deleting joinPredicates and modify the force join to use `OR` instead of `AND` * add system property switch
This commit is contained in:
parent
38cad333c8
commit
a6b2a3159d
@ -35,7 +35,6 @@ import java.util.*
|
||||
import javax.persistence.Tuple
|
||||
import javax.persistence.criteria.*
|
||||
|
||||
|
||||
abstract class AbstractQueryCriteriaParser<Q : GenericQueryCriteria<Q,P>, in P: BaseQueryCriteriaParser<Q, P, S>, in S: BaseSort> : BaseQueryCriteriaParser<Q, P, S> {
|
||||
|
||||
abstract val criteriaBuilder: CriteriaBuilder
|
||||
@ -277,6 +276,7 @@ class HibernateQueryCriteriaParser(val contractStateType: Class<out ContractStat
|
||||
val vaultStates: Root<VaultSchemaV1.VaultStates>) : AbstractQueryCriteriaParser<QueryCriteria, IQueryCriteriaParser, Sort>(), IQueryCriteriaParser {
|
||||
private companion object {
|
||||
private val log = contextLogger()
|
||||
private val disableCorda3879 = System.getProperty("net.corda.vault.query.disable.corda3879")?.toBoolean() ?: false
|
||||
}
|
||||
|
||||
// incrementally build list of join predicates
|
||||
@ -550,7 +550,6 @@ class HibernateQueryCriteriaParser(val contractStateType: Class<out ContractStat
|
||||
|
||||
// ensure we re-use any existing instance of the same root entity
|
||||
val vaultLinearStatesRoot = getVaultLinearStatesRoot()
|
||||
|
||||
val joinPredicate = criteriaBuilder.equal(vaultStates.get<PersistentStateRef>("stateRef"),
|
||||
vaultLinearStatesRoot.get<PersistentStateRef>("stateRef"))
|
||||
predicateSet.add(joinPredicate)
|
||||
@ -613,8 +612,8 @@ class HibernateQueryCriteriaParser(val contractStateType: Class<out ContractStat
|
||||
}
|
||||
|
||||
val joinPredicate = if(IndirectStatePersistable::class.java.isAssignableFrom(entityRoot.javaType)) {
|
||||
criteriaBuilder.equal(vaultStates.get<PersistentStateRef>("stateRef"), entityRoot.get<IndirectStatePersistable<*>>("compositeKey").get<PersistentStateRef>("stateRef"))
|
||||
} else {
|
||||
criteriaBuilder.equal(vaultStates.get<PersistentStateRef>("stateRef"), entityRoot.get<IndirectStatePersistable<*>>("compositeKey").get<PersistentStateRef>("stateRef"))
|
||||
} else {
|
||||
criteriaBuilder.equal(vaultStates.get<PersistentStateRef>("stateRef"), entityRoot.get<PersistentStateRef>("stateRef"))
|
||||
}
|
||||
predicateSet.add(joinPredicate)
|
||||
@ -633,6 +632,7 @@ class HibernateQueryCriteriaParser(val contractStateType: Class<out ContractStat
|
||||
return predicateSet
|
||||
}
|
||||
|
||||
@Suppress("SpreadOperator")
|
||||
override fun parse(criteria: QueryCriteria, sorting: Sort?): Collection<Predicate> {
|
||||
val predicateSet = criteria.visit(this)
|
||||
|
||||
@ -647,12 +647,37 @@ class HibernateQueryCriteriaParser(val contractStateType: Class<out ContractStat
|
||||
else
|
||||
aggregateExpressions
|
||||
criteriaQuery.multiselect(selections)
|
||||
val combinedPredicates = commonPredicates.values.plus(predicateSet).plus(constraintPredicates).plus(joinPredicates)
|
||||
criteriaQuery.where(*combinedPredicates.toTypedArray())
|
||||
val combinedPredicates = commonPredicates.values.plus(predicateSet)
|
||||
.plus(constraintPredicates)
|
||||
.plus(joinPredicates)
|
||||
|
||||
val forceJoinPredicates = joinStateRefPredicate()
|
||||
|
||||
if(forceJoinPredicates.isEmpty() || disableCorda3879) {
|
||||
criteriaQuery.where(*combinedPredicates.toTypedArray())
|
||||
} else {
|
||||
criteriaQuery.where(*combinedPredicates.toTypedArray(), criteriaBuilder.or(*forceJoinPredicates.toTypedArray()))
|
||||
}
|
||||
|
||||
return predicateSet
|
||||
}
|
||||
|
||||
private fun joinStateRefPredicate(): Set<Predicate> {
|
||||
val returnSet = mutableSetOf<Predicate>()
|
||||
|
||||
rootEntities.values.forEach {
|
||||
if (it != vaultStates) {
|
||||
if(IndirectStatePersistable::class.java.isAssignableFrom(it.javaType)) {
|
||||
returnSet.add(criteriaBuilder.equal(vaultStates.get<PersistentStateRef>("stateRef"), it.get<IndirectStatePersistable<*>>("compositeKey").get<PersistentStateRef>("stateRef")))
|
||||
} else {
|
||||
returnSet.add(criteriaBuilder.equal(vaultStates.get<PersistentStateRef>("stateRef"), it.get<PersistentStateRef>("stateRef")))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return returnSet
|
||||
}
|
||||
|
||||
override fun parseCriteria(criteria: CommonQueryCriteria): Collection<Predicate> {
|
||||
log.trace { "Parsing CommonQueryCriteria: $criteria" }
|
||||
|
||||
@ -849,8 +874,6 @@ class HibernateQueryCriteriaParser(val contractStateType: Class<out ContractStat
|
||||
// scenario where sorting on attributes not parsed as criteria
|
||||
val entityRoot = criteriaQuery.from(entityStateClass)
|
||||
rootEntities[entityStateClass] = entityRoot
|
||||
val joinPredicate = criteriaBuilder.equal(vaultStates.get<PersistentStateRef>("stateRef"), entityRoot.get<PersistentStateRef>("stateRef"))
|
||||
joinPredicates.add(joinPredicate)
|
||||
entityRoot
|
||||
}
|
||||
when (direction) {
|
||||
@ -869,7 +892,6 @@ class HibernateQueryCriteriaParser(val contractStateType: Class<out ContractStat
|
||||
}
|
||||
if (orderCriteria.isNotEmpty()) {
|
||||
criteriaQuery.orderBy(orderCriteria)
|
||||
criteriaQuery.where(*joinPredicates.toTypedArray())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -0,0 +1,157 @@
|
||||
package net.corda.node.services.vault
|
||||
|
||||
import net.corda.core.contracts.BelongsToContract
|
||||
import net.corda.core.contracts.CommandData
|
||||
import net.corda.core.contracts.Contract
|
||||
import net.corda.core.contracts.ContractState
|
||||
import net.corda.core.contracts.StateRef
|
||||
import net.corda.core.identity.AbstractParty
|
||||
import net.corda.core.node.services.Vault
|
||||
import net.corda.core.node.services.queryBy
|
||||
import net.corda.core.node.services.vault.DEFAULT_PAGE_SIZE
|
||||
import net.corda.core.node.services.vault.QueryCriteria
|
||||
import net.corda.core.node.services.vault.Sort
|
||||
import net.corda.core.node.services.vault.SortAttribute
|
||||
import net.corda.core.node.services.vault.builder
|
||||
import net.corda.core.schemas.MappedSchema
|
||||
import net.corda.core.schemas.PersistentState
|
||||
import net.corda.core.schemas.QueryableState
|
||||
import net.corda.core.serialization.CordaSerializable
|
||||
import net.corda.core.transactions.LedgerTransaction
|
||||
import net.corda.core.transactions.TransactionBuilder
|
||||
import net.corda.testing.core.ALICE_NAME
|
||||
import net.corda.testing.node.MockNetwork
|
||||
import net.corda.testing.node.MockNetworkParameters
|
||||
import net.corda.testing.node.internal.cordappsForPackages
|
||||
import org.junit.BeforeClass
|
||||
import org.junit.Test
|
||||
import javax.persistence.Column
|
||||
import javax.persistence.Entity
|
||||
import javax.persistence.Index
|
||||
import javax.persistence.Table
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
class VaultQueryJoinTest {
|
||||
companion object {
|
||||
private val mockNetwork = MockNetwork(
|
||||
MockNetworkParameters(
|
||||
cordappsForAllNodes = cordappsForPackages(
|
||||
listOf(
|
||||
"net.corda.node.services.vault"
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
private val aliceNode = mockNetwork.createPartyNode(ALICE_NAME)
|
||||
private val notaryNode = mockNetwork.defaultNotaryNode
|
||||
private val serviceHubHandle = aliceNode.services
|
||||
private val createdStateRefs = mutableListOf<StateRef>()
|
||||
private const val numObjectsInLedger = DEFAULT_PAGE_SIZE + 1
|
||||
|
||||
@BeforeClass
|
||||
@JvmStatic
|
||||
fun setup() {
|
||||
repeat(numObjectsInLedger) { index ->
|
||||
createdStateRefs.add(addSimpleObjectToLedger(DummyData(index)))
|
||||
}
|
||||
|
||||
System.setProperty("net.corda.vault.query.disable.corda3879", "false");
|
||||
}
|
||||
|
||||
private fun addSimpleObjectToLedger(dummyObject: DummyData): StateRef {
|
||||
val tx = TransactionBuilder(notaryNode.info.legalIdentities.first())
|
||||
tx.addOutputState(
|
||||
DummyState(dummyObject, listOf(aliceNode.info.identityFromX500Name(ALICE_NAME)))
|
||||
)
|
||||
tx.addCommand(DummyContract.Commands.AddDummy(), aliceNode.info.legalIdentitiesAndCerts.first().owningKey)
|
||||
tx.verify(serviceHubHandle)
|
||||
val stx = serviceHubHandle.signInitialTransaction(tx)
|
||||
serviceHubHandle.recordTransactions(listOf(stx))
|
||||
return StateRef(stx.id, 0)
|
||||
}
|
||||
}
|
||||
|
||||
private val queryToCheckId = builder {
|
||||
val conditionToCheckId =
|
||||
DummySchema.DummyState::id
|
||||
.equal(0)
|
||||
QueryCriteria.VaultCustomQueryCriteria(conditionToCheckId, Vault.StateStatus.UNCONSUMED)
|
||||
}
|
||||
|
||||
private val queryToCheckStateRef =
|
||||
QueryCriteria.VaultQueryCriteria(Vault.StateStatus.UNCONSUMED, stateRefs = listOf(createdStateRefs[numObjectsInLedger-1]))
|
||||
|
||||
@Test(timeout = 300_000)
|
||||
fun `filter query with OR operator`() {
|
||||
val results = serviceHubHandle.vaultService.queryBy<DummyState>(
|
||||
queryToCheckId.or(queryToCheckStateRef)
|
||||
)
|
||||
assertEquals(2, results.states.size)
|
||||
assertEquals(2, results.statesMetadata.size)
|
||||
}
|
||||
|
||||
@Test(timeout = 300_000)
|
||||
fun `filter query with sorting`() {
|
||||
val sorting = Sort(listOf(Sort.SortColumn(SortAttribute.Custom(DummySchema.DummyState::class.java, "stateRef"), Sort.Direction.DESC)))
|
||||
|
||||
val results = serviceHubHandle.vaultService.queryBy<DummyState>(
|
||||
queryToCheckStateRef, sorting = sorting
|
||||
)
|
||||
|
||||
assertEquals(1, results.states.size)
|
||||
assertEquals(1, results.statesMetadata.size)
|
||||
}
|
||||
|
||||
@Test(timeout = 300_000)
|
||||
fun `filter query with OR operator and sorting`() {
|
||||
val sorting = Sort(listOf(Sort.SortColumn(SortAttribute.Custom(DummySchema.DummyState::class.java, "stateRef"), Sort.Direction.DESC)))
|
||||
|
||||
val results = serviceHubHandle.vaultService.queryBy<DummyState>(
|
||||
queryToCheckId.or(queryToCheckStateRef), sorting = sorting
|
||||
)
|
||||
|
||||
assertEquals(2, results.states.size)
|
||||
assertEquals(2, results.statesMetadata.size)
|
||||
}
|
||||
}
|
||||
|
||||
object DummyStatesV
|
||||
|
||||
@Suppress("MagicNumber") // SQL column length
|
||||
@CordaSerializable
|
||||
object DummySchema : MappedSchema(schemaFamily = DummyStatesV.javaClass, version = 1, mappedTypes = listOf(DummyState::class.java)){
|
||||
|
||||
@Entity
|
||||
@Table(name = "dummy_states", indexes = [Index(name = "dummy_id_index", columnList = "id")])
|
||||
class DummyState (
|
||||
@Column(name = "id", length = 4, nullable = false)
|
||||
var id: Int
|
||||
) : PersistentState()
|
||||
}
|
||||
|
||||
@CordaSerializable
|
||||
data class DummyData(
|
||||
val id: Int
|
||||
)
|
||||
|
||||
@BelongsToContract(DummyContract::class)
|
||||
data class DummyState(val dummyData: DummyData, override val participants: List<AbstractParty>) :
|
||||
ContractState, QueryableState {
|
||||
override fun supportedSchemas(): Iterable<MappedSchema> = listOf(DummySchema)
|
||||
|
||||
|
||||
override fun generateMappedObject(schema: MappedSchema) =
|
||||
when (schema) {
|
||||
is DummySchema -> DummySchema.DummyState(
|
||||
dummyData.id
|
||||
)
|
||||
else -> throw IllegalArgumentException("Unsupported Schema")
|
||||
}
|
||||
}
|
||||
|
||||
class DummyContract : Contract {
|
||||
override fun verify(tx: LedgerTransaction) { }
|
||||
interface Commands : CommandData {
|
||||
class AddDummy : Commands
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user