CORDA-3879 - query with OR combinator returns too many results (#6456)

* fix suggestion and tests

* detekt suppress

* making sure the forced join works with IndirectStatePersistable and removing unnecessary joinPredicates from parse with sorting

* remove joinPredicates and add tests

* rename sorting

* revert deleting joinPredicates and modify the force join to use `OR` instead of `AND`

* add system property switch
This commit is contained in:
Nikolett Nagy 2020-08-13 10:04:53 +01:00 committed by GitHub
parent 38cad333c8
commit a6b2a3159d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 188 additions and 9 deletions

View File

@ -35,7 +35,6 @@ import java.util.*
import javax.persistence.Tuple
import javax.persistence.criteria.*
abstract class AbstractQueryCriteriaParser<Q : GenericQueryCriteria<Q,P>, in P: BaseQueryCriteriaParser<Q, P, S>, in S: BaseSort> : BaseQueryCriteriaParser<Q, P, S> {
abstract val criteriaBuilder: CriteriaBuilder
@ -277,6 +276,7 @@ class HibernateQueryCriteriaParser(val contractStateType: Class<out ContractStat
val vaultStates: Root<VaultSchemaV1.VaultStates>) : AbstractQueryCriteriaParser<QueryCriteria, IQueryCriteriaParser, Sort>(), IQueryCriteriaParser {
private companion object {
private val log = contextLogger()
private val disableCorda3879 = System.getProperty("net.corda.vault.query.disable.corda3879")?.toBoolean() ?: false
}
// incrementally build list of join predicates
@ -550,7 +550,6 @@ class HibernateQueryCriteriaParser(val contractStateType: Class<out ContractStat
// ensure we re-use any existing instance of the same root entity
val vaultLinearStatesRoot = getVaultLinearStatesRoot()
val joinPredicate = criteriaBuilder.equal(vaultStates.get<PersistentStateRef>("stateRef"),
vaultLinearStatesRoot.get<PersistentStateRef>("stateRef"))
predicateSet.add(joinPredicate)
@ -613,8 +612,8 @@ class HibernateQueryCriteriaParser(val contractStateType: Class<out ContractStat
}
val joinPredicate = if(IndirectStatePersistable::class.java.isAssignableFrom(entityRoot.javaType)) {
criteriaBuilder.equal(vaultStates.get<PersistentStateRef>("stateRef"), entityRoot.get<IndirectStatePersistable<*>>("compositeKey").get<PersistentStateRef>("stateRef"))
} else {
criteriaBuilder.equal(vaultStates.get<PersistentStateRef>("stateRef"), entityRoot.get<IndirectStatePersistable<*>>("compositeKey").get<PersistentStateRef>("stateRef"))
} else {
criteriaBuilder.equal(vaultStates.get<PersistentStateRef>("stateRef"), entityRoot.get<PersistentStateRef>("stateRef"))
}
predicateSet.add(joinPredicate)
@ -633,6 +632,7 @@ class HibernateQueryCriteriaParser(val contractStateType: Class<out ContractStat
return predicateSet
}
@Suppress("SpreadOperator")
override fun parse(criteria: QueryCriteria, sorting: Sort?): Collection<Predicate> {
val predicateSet = criteria.visit(this)
@ -647,12 +647,37 @@ class HibernateQueryCriteriaParser(val contractStateType: Class<out ContractStat
else
aggregateExpressions
criteriaQuery.multiselect(selections)
val combinedPredicates = commonPredicates.values.plus(predicateSet).plus(constraintPredicates).plus(joinPredicates)
criteriaQuery.where(*combinedPredicates.toTypedArray())
val combinedPredicates = commonPredicates.values.plus(predicateSet)
.plus(constraintPredicates)
.plus(joinPredicates)
val forceJoinPredicates = joinStateRefPredicate()
if(forceJoinPredicates.isEmpty() || disableCorda3879) {
criteriaQuery.where(*combinedPredicates.toTypedArray())
} else {
criteriaQuery.where(*combinedPredicates.toTypedArray(), criteriaBuilder.or(*forceJoinPredicates.toTypedArray()))
}
return predicateSet
}
private fun joinStateRefPredicate(): Set<Predicate> {
val returnSet = mutableSetOf<Predicate>()
rootEntities.values.forEach {
if (it != vaultStates) {
if(IndirectStatePersistable::class.java.isAssignableFrom(it.javaType)) {
returnSet.add(criteriaBuilder.equal(vaultStates.get<PersistentStateRef>("stateRef"), it.get<IndirectStatePersistable<*>>("compositeKey").get<PersistentStateRef>("stateRef")))
} else {
returnSet.add(criteriaBuilder.equal(vaultStates.get<PersistentStateRef>("stateRef"), it.get<PersistentStateRef>("stateRef")))
}
}
}
return returnSet
}
override fun parseCriteria(criteria: CommonQueryCriteria): Collection<Predicate> {
log.trace { "Parsing CommonQueryCriteria: $criteria" }
@ -849,8 +874,6 @@ class HibernateQueryCriteriaParser(val contractStateType: Class<out ContractStat
// scenario where sorting on attributes not parsed as criteria
val entityRoot = criteriaQuery.from(entityStateClass)
rootEntities[entityStateClass] = entityRoot
val joinPredicate = criteriaBuilder.equal(vaultStates.get<PersistentStateRef>("stateRef"), entityRoot.get<PersistentStateRef>("stateRef"))
joinPredicates.add(joinPredicate)
entityRoot
}
when (direction) {
@ -869,7 +892,6 @@ class HibernateQueryCriteriaParser(val contractStateType: Class<out ContractStat
}
if (orderCriteria.isNotEmpty()) {
criteriaQuery.orderBy(orderCriteria)
criteriaQuery.where(*joinPredicates.toTypedArray())
}
}

View File

@ -0,0 +1,157 @@
package net.corda.node.services.vault
import net.corda.core.contracts.BelongsToContract
import net.corda.core.contracts.CommandData
import net.corda.core.contracts.Contract
import net.corda.core.contracts.ContractState
import net.corda.core.contracts.StateRef
import net.corda.core.identity.AbstractParty
import net.corda.core.node.services.Vault
import net.corda.core.node.services.queryBy
import net.corda.core.node.services.vault.DEFAULT_PAGE_SIZE
import net.corda.core.node.services.vault.QueryCriteria
import net.corda.core.node.services.vault.Sort
import net.corda.core.node.services.vault.SortAttribute
import net.corda.core.node.services.vault.builder
import net.corda.core.schemas.MappedSchema
import net.corda.core.schemas.PersistentState
import net.corda.core.schemas.QueryableState
import net.corda.core.serialization.CordaSerializable
import net.corda.core.transactions.LedgerTransaction
import net.corda.core.transactions.TransactionBuilder
import net.corda.testing.core.ALICE_NAME
import net.corda.testing.node.MockNetwork
import net.corda.testing.node.MockNetworkParameters
import net.corda.testing.node.internal.cordappsForPackages
import org.junit.BeforeClass
import org.junit.Test
import javax.persistence.Column
import javax.persistence.Entity
import javax.persistence.Index
import javax.persistence.Table
import kotlin.test.assertEquals
class VaultQueryJoinTest {
companion object {
private val mockNetwork = MockNetwork(
MockNetworkParameters(
cordappsForAllNodes = cordappsForPackages(
listOf(
"net.corda.node.services.vault"
)
)
)
)
private val aliceNode = mockNetwork.createPartyNode(ALICE_NAME)
private val notaryNode = mockNetwork.defaultNotaryNode
private val serviceHubHandle = aliceNode.services
private val createdStateRefs = mutableListOf<StateRef>()
private const val numObjectsInLedger = DEFAULT_PAGE_SIZE + 1
@BeforeClass
@JvmStatic
fun setup() {
repeat(numObjectsInLedger) { index ->
createdStateRefs.add(addSimpleObjectToLedger(DummyData(index)))
}
System.setProperty("net.corda.vault.query.disable.corda3879", "false");
}
private fun addSimpleObjectToLedger(dummyObject: DummyData): StateRef {
val tx = TransactionBuilder(notaryNode.info.legalIdentities.first())
tx.addOutputState(
DummyState(dummyObject, listOf(aliceNode.info.identityFromX500Name(ALICE_NAME)))
)
tx.addCommand(DummyContract.Commands.AddDummy(), aliceNode.info.legalIdentitiesAndCerts.first().owningKey)
tx.verify(serviceHubHandle)
val stx = serviceHubHandle.signInitialTransaction(tx)
serviceHubHandle.recordTransactions(listOf(stx))
return StateRef(stx.id, 0)
}
}
private val queryToCheckId = builder {
val conditionToCheckId =
DummySchema.DummyState::id
.equal(0)
QueryCriteria.VaultCustomQueryCriteria(conditionToCheckId, Vault.StateStatus.UNCONSUMED)
}
private val queryToCheckStateRef =
QueryCriteria.VaultQueryCriteria(Vault.StateStatus.UNCONSUMED, stateRefs = listOf(createdStateRefs[numObjectsInLedger-1]))
@Test(timeout = 300_000)
fun `filter query with OR operator`() {
val results = serviceHubHandle.vaultService.queryBy<DummyState>(
queryToCheckId.or(queryToCheckStateRef)
)
assertEquals(2, results.states.size)
assertEquals(2, results.statesMetadata.size)
}
@Test(timeout = 300_000)
fun `filter query with sorting`() {
val sorting = Sort(listOf(Sort.SortColumn(SortAttribute.Custom(DummySchema.DummyState::class.java, "stateRef"), Sort.Direction.DESC)))
val results = serviceHubHandle.vaultService.queryBy<DummyState>(
queryToCheckStateRef, sorting = sorting
)
assertEquals(1, results.states.size)
assertEquals(1, results.statesMetadata.size)
}
@Test(timeout = 300_000)
fun `filter query with OR operator and sorting`() {
val sorting = Sort(listOf(Sort.SortColumn(SortAttribute.Custom(DummySchema.DummyState::class.java, "stateRef"), Sort.Direction.DESC)))
val results = serviceHubHandle.vaultService.queryBy<DummyState>(
queryToCheckId.or(queryToCheckStateRef), sorting = sorting
)
assertEquals(2, results.states.size)
assertEquals(2, results.statesMetadata.size)
}
}
object DummyStatesV
@Suppress("MagicNumber") // SQL column length
@CordaSerializable
object DummySchema : MappedSchema(schemaFamily = DummyStatesV.javaClass, version = 1, mappedTypes = listOf(DummyState::class.java)){
@Entity
@Table(name = "dummy_states", indexes = [Index(name = "dummy_id_index", columnList = "id")])
class DummyState (
@Column(name = "id", length = 4, nullable = false)
var id: Int
) : PersistentState()
}
@CordaSerializable
data class DummyData(
val id: Int
)
@BelongsToContract(DummyContract::class)
data class DummyState(val dummyData: DummyData, override val participants: List<AbstractParty>) :
ContractState, QueryableState {
override fun supportedSchemas(): Iterable<MappedSchema> = listOf(DummySchema)
override fun generateMappedObject(schema: MappedSchema) =
when (schema) {
is DummySchema -> DummySchema.DummyState(
dummyData.id
)
else -> throw IllegalArgumentException("Unsupported Schema")
}
}
class DummyContract : Contract {
override fun verify(tx: LedgerTransaction) { }
interface Commands : CommandData {
class AddDummy : Commands
}
}