mirror of
https://github.com/corda/corda.git
synced 2025-05-03 09:12:55 +00:00
Merge pull request #6632 from corda/nnagy-os-4.5-os-4.6-20200813
NOTICK - OS 4.5 to OS 4.6 merge 20200813
This commit is contained in:
commit
2fb21373a4
@ -35,7 +35,6 @@ import java.util.*
|
|||||||
import javax.persistence.Tuple
|
import javax.persistence.Tuple
|
||||||
import javax.persistence.criteria.*
|
import javax.persistence.criteria.*
|
||||||
|
|
||||||
|
|
||||||
abstract class AbstractQueryCriteriaParser<Q : GenericQueryCriteria<Q,P>, in P: BaseQueryCriteriaParser<Q, P, S>, in S: BaseSort> : BaseQueryCriteriaParser<Q, P, S> {
|
abstract class AbstractQueryCriteriaParser<Q : GenericQueryCriteria<Q,P>, in P: BaseQueryCriteriaParser<Q, P, S>, in S: BaseSort> : BaseQueryCriteriaParser<Q, P, S> {
|
||||||
|
|
||||||
abstract val criteriaBuilder: CriteriaBuilder
|
abstract val criteriaBuilder: CriteriaBuilder
|
||||||
@ -277,6 +276,7 @@ class HibernateQueryCriteriaParser(val contractStateType: Class<out ContractStat
|
|||||||
val vaultStates: Root<VaultSchemaV1.VaultStates>) : AbstractQueryCriteriaParser<QueryCriteria, IQueryCriteriaParser, Sort>(), IQueryCriteriaParser {
|
val vaultStates: Root<VaultSchemaV1.VaultStates>) : AbstractQueryCriteriaParser<QueryCriteria, IQueryCriteriaParser, Sort>(), IQueryCriteriaParser {
|
||||||
private companion object {
|
private companion object {
|
||||||
private val log = contextLogger()
|
private val log = contextLogger()
|
||||||
|
private val disableCorda3879 = System.getProperty("net.corda.vault.query.disable.corda3879")?.toBoolean() ?: false
|
||||||
}
|
}
|
||||||
|
|
||||||
// incrementally build list of join predicates
|
// incrementally build list of join predicates
|
||||||
@ -550,7 +550,6 @@ class HibernateQueryCriteriaParser(val contractStateType: Class<out ContractStat
|
|||||||
|
|
||||||
// ensure we re-use any existing instance of the same root entity
|
// ensure we re-use any existing instance of the same root entity
|
||||||
val vaultLinearStatesRoot = getVaultLinearStatesRoot()
|
val vaultLinearStatesRoot = getVaultLinearStatesRoot()
|
||||||
|
|
||||||
val joinPredicate = criteriaBuilder.equal(vaultStates.get<PersistentStateRef>("stateRef"),
|
val joinPredicate = criteriaBuilder.equal(vaultStates.get<PersistentStateRef>("stateRef"),
|
||||||
vaultLinearStatesRoot.get<PersistentStateRef>("stateRef"))
|
vaultLinearStatesRoot.get<PersistentStateRef>("stateRef"))
|
||||||
predicateSet.add(joinPredicate)
|
predicateSet.add(joinPredicate)
|
||||||
@ -636,6 +635,7 @@ class HibernateQueryCriteriaParser(val contractStateType: Class<out ContractStat
|
|||||||
return predicateSet
|
return predicateSet
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Suppress("SpreadOperator")
|
||||||
override fun parse(criteria: QueryCriteria, sorting: Sort?): Collection<Predicate> {
|
override fun parse(criteria: QueryCriteria, sorting: Sort?): Collection<Predicate> {
|
||||||
val predicateSet = criteria.visit(this)
|
val predicateSet = criteria.visit(this)
|
||||||
|
|
||||||
@ -650,12 +650,37 @@ class HibernateQueryCriteriaParser(val contractStateType: Class<out ContractStat
|
|||||||
else
|
else
|
||||||
aggregateExpressions
|
aggregateExpressions
|
||||||
criteriaQuery.multiselect(selections)
|
criteriaQuery.multiselect(selections)
|
||||||
val combinedPredicates = commonPredicates.values.plus(predicateSet).plus(constraintPredicates).plus(joinPredicates)
|
val combinedPredicates = commonPredicates.values.plus(predicateSet)
|
||||||
criteriaQuery.where(*combinedPredicates.toTypedArray())
|
.plus(constraintPredicates)
|
||||||
|
.plus(joinPredicates)
|
||||||
|
|
||||||
|
val forceJoinPredicates = joinStateRefPredicate()
|
||||||
|
|
||||||
|
if(forceJoinPredicates.isEmpty() || disableCorda3879) {
|
||||||
|
criteriaQuery.where(*combinedPredicates.toTypedArray())
|
||||||
|
} else {
|
||||||
|
criteriaQuery.where(*combinedPredicates.toTypedArray(), criteriaBuilder.or(*forceJoinPredicates.toTypedArray()))
|
||||||
|
}
|
||||||
|
|
||||||
return predicateSet
|
return predicateSet
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private fun joinStateRefPredicate(): Set<Predicate> {
|
||||||
|
val returnSet = mutableSetOf<Predicate>()
|
||||||
|
|
||||||
|
rootEntities.values.forEach {
|
||||||
|
if (it != vaultStates) {
|
||||||
|
if(IndirectStatePersistable::class.java.isAssignableFrom(it.javaType)) {
|
||||||
|
returnSet.add(criteriaBuilder.equal(vaultStates.get<PersistentStateRef>("stateRef"), it.get<IndirectStatePersistable<*>>("compositeKey").get<PersistentStateRef>("stateRef")))
|
||||||
|
} else {
|
||||||
|
returnSet.add(criteriaBuilder.equal(vaultStates.get<PersistentStateRef>("stateRef"), it.get<PersistentStateRef>("stateRef")))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return returnSet
|
||||||
|
}
|
||||||
|
|
||||||
override fun parseCriteria(criteria: CommonQueryCriteria): Collection<Predicate> {
|
override fun parseCriteria(criteria: CommonQueryCriteria): Collection<Predicate> {
|
||||||
log.trace { "Parsing CommonQueryCriteria: $criteria" }
|
log.trace { "Parsing CommonQueryCriteria: $criteria" }
|
||||||
|
|
||||||
@ -852,8 +877,6 @@ class HibernateQueryCriteriaParser(val contractStateType: Class<out ContractStat
|
|||||||
// scenario where sorting on attributes not parsed as criteria
|
// scenario where sorting on attributes not parsed as criteria
|
||||||
val entityRoot = criteriaQuery.from(entityStateClass)
|
val entityRoot = criteriaQuery.from(entityStateClass)
|
||||||
rootEntities[entityStateClass] = entityRoot
|
rootEntities[entityStateClass] = entityRoot
|
||||||
val joinPredicate = criteriaBuilder.equal(vaultStates.get<PersistentStateRef>("stateRef"), entityRoot.get<PersistentStateRef>("stateRef"))
|
|
||||||
joinPredicates.add(joinPredicate)
|
|
||||||
entityRoot
|
entityRoot
|
||||||
}
|
}
|
||||||
when (direction) {
|
when (direction) {
|
||||||
@ -872,7 +895,6 @@ class HibernateQueryCriteriaParser(val contractStateType: Class<out ContractStat
|
|||||||
}
|
}
|
||||||
if (orderCriteria.isNotEmpty()) {
|
if (orderCriteria.isNotEmpty()) {
|
||||||
criteriaQuery.orderBy(orderCriteria)
|
criteriaQuery.orderBy(orderCriteria)
|
||||||
criteriaQuery.where(*joinPredicates.toTypedArray())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -0,0 +1,157 @@
|
|||||||
|
package net.corda.node.services.vault
|
||||||
|
|
||||||
|
import net.corda.core.contracts.BelongsToContract
|
||||||
|
import net.corda.core.contracts.CommandData
|
||||||
|
import net.corda.core.contracts.Contract
|
||||||
|
import net.corda.core.contracts.ContractState
|
||||||
|
import net.corda.core.contracts.StateRef
|
||||||
|
import net.corda.core.identity.AbstractParty
|
||||||
|
import net.corda.core.node.services.Vault
|
||||||
|
import net.corda.core.node.services.queryBy
|
||||||
|
import net.corda.core.node.services.vault.DEFAULT_PAGE_SIZE
|
||||||
|
import net.corda.core.node.services.vault.QueryCriteria
|
||||||
|
import net.corda.core.node.services.vault.Sort
|
||||||
|
import net.corda.core.node.services.vault.SortAttribute
|
||||||
|
import net.corda.core.node.services.vault.builder
|
||||||
|
import net.corda.core.schemas.MappedSchema
|
||||||
|
import net.corda.core.schemas.PersistentState
|
||||||
|
import net.corda.core.schemas.QueryableState
|
||||||
|
import net.corda.core.serialization.CordaSerializable
|
||||||
|
import net.corda.core.transactions.LedgerTransaction
|
||||||
|
import net.corda.core.transactions.TransactionBuilder
|
||||||
|
import net.corda.testing.core.ALICE_NAME
|
||||||
|
import net.corda.testing.node.MockNetwork
|
||||||
|
import net.corda.testing.node.MockNetworkParameters
|
||||||
|
import net.corda.testing.node.internal.cordappsForPackages
|
||||||
|
import org.junit.BeforeClass
|
||||||
|
import org.junit.Test
|
||||||
|
import javax.persistence.Column
|
||||||
|
import javax.persistence.Entity
|
||||||
|
import javax.persistence.Index
|
||||||
|
import javax.persistence.Table
|
||||||
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
|
class VaultQueryJoinTest {
|
||||||
|
companion object {
|
||||||
|
private val mockNetwork = MockNetwork(
|
||||||
|
MockNetworkParameters(
|
||||||
|
cordappsForAllNodes = cordappsForPackages(
|
||||||
|
listOf(
|
||||||
|
"net.corda.node.services.vault"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
private val aliceNode = mockNetwork.createPartyNode(ALICE_NAME)
|
||||||
|
private val notaryNode = mockNetwork.defaultNotaryNode
|
||||||
|
private val serviceHubHandle = aliceNode.services
|
||||||
|
private val createdStateRefs = mutableListOf<StateRef>()
|
||||||
|
private const val numObjectsInLedger = DEFAULT_PAGE_SIZE + 1
|
||||||
|
|
||||||
|
@BeforeClass
|
||||||
|
@JvmStatic
|
||||||
|
fun setup() {
|
||||||
|
repeat(numObjectsInLedger) { index ->
|
||||||
|
createdStateRefs.add(addSimpleObjectToLedger(DummyData(index)))
|
||||||
|
}
|
||||||
|
|
||||||
|
System.setProperty("net.corda.vault.query.disable.corda3879", "false");
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun addSimpleObjectToLedger(dummyObject: DummyData): StateRef {
|
||||||
|
val tx = TransactionBuilder(notaryNode.info.legalIdentities.first())
|
||||||
|
tx.addOutputState(
|
||||||
|
DummyState(dummyObject, listOf(aliceNode.info.identityFromX500Name(ALICE_NAME)))
|
||||||
|
)
|
||||||
|
tx.addCommand(DummyContract.Commands.AddDummy(), aliceNode.info.legalIdentitiesAndCerts.first().owningKey)
|
||||||
|
tx.verify(serviceHubHandle)
|
||||||
|
val stx = serviceHubHandle.signInitialTransaction(tx)
|
||||||
|
serviceHubHandle.recordTransactions(listOf(stx))
|
||||||
|
return StateRef(stx.id, 0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private val queryToCheckId = builder {
|
||||||
|
val conditionToCheckId =
|
||||||
|
DummySchema.DummyState::id
|
||||||
|
.equal(0)
|
||||||
|
QueryCriteria.VaultCustomQueryCriteria(conditionToCheckId, Vault.StateStatus.UNCONSUMED)
|
||||||
|
}
|
||||||
|
|
||||||
|
private val queryToCheckStateRef =
|
||||||
|
QueryCriteria.VaultQueryCriteria(Vault.StateStatus.UNCONSUMED, stateRefs = listOf(createdStateRefs[numObjectsInLedger-1]))
|
||||||
|
|
||||||
|
@Test(timeout = 300_000)
|
||||||
|
fun `filter query with OR operator`() {
|
||||||
|
val results = serviceHubHandle.vaultService.queryBy<DummyState>(
|
||||||
|
queryToCheckId.or(queryToCheckStateRef)
|
||||||
|
)
|
||||||
|
assertEquals(2, results.states.size)
|
||||||
|
assertEquals(2, results.statesMetadata.size)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test(timeout = 300_000)
|
||||||
|
fun `filter query with sorting`() {
|
||||||
|
val sorting = Sort(listOf(Sort.SortColumn(SortAttribute.Custom(DummySchema.DummyState::class.java, "stateRef"), Sort.Direction.DESC)))
|
||||||
|
|
||||||
|
val results = serviceHubHandle.vaultService.queryBy<DummyState>(
|
||||||
|
queryToCheckStateRef, sorting = sorting
|
||||||
|
)
|
||||||
|
|
||||||
|
assertEquals(1, results.states.size)
|
||||||
|
assertEquals(1, results.statesMetadata.size)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test(timeout = 300_000)
|
||||||
|
fun `filter query with OR operator and sorting`() {
|
||||||
|
val sorting = Sort(listOf(Sort.SortColumn(SortAttribute.Custom(DummySchema.DummyState::class.java, "stateRef"), Sort.Direction.DESC)))
|
||||||
|
|
||||||
|
val results = serviceHubHandle.vaultService.queryBy<DummyState>(
|
||||||
|
queryToCheckId.or(queryToCheckStateRef), sorting = sorting
|
||||||
|
)
|
||||||
|
|
||||||
|
assertEquals(2, results.states.size)
|
||||||
|
assertEquals(2, results.statesMetadata.size)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
object DummyStatesV
|
||||||
|
|
||||||
|
@Suppress("MagicNumber") // SQL column length
|
||||||
|
@CordaSerializable
|
||||||
|
object DummySchema : MappedSchema(schemaFamily = DummyStatesV.javaClass, version = 1, mappedTypes = listOf(DummyState::class.java)){
|
||||||
|
|
||||||
|
@Entity
|
||||||
|
@Table(name = "dummy_states", indexes = [Index(name = "dummy_id_index", columnList = "id")])
|
||||||
|
class DummyState (
|
||||||
|
@Column(name = "id", length = 4, nullable = false)
|
||||||
|
var id: Int
|
||||||
|
) : PersistentState()
|
||||||
|
}
|
||||||
|
|
||||||
|
@CordaSerializable
|
||||||
|
data class DummyData(
|
||||||
|
val id: Int
|
||||||
|
)
|
||||||
|
|
||||||
|
@BelongsToContract(DummyContract::class)
|
||||||
|
data class DummyState(val dummyData: DummyData, override val participants: List<AbstractParty>) :
|
||||||
|
ContractState, QueryableState {
|
||||||
|
override fun supportedSchemas(): Iterable<MappedSchema> = listOf(DummySchema)
|
||||||
|
|
||||||
|
|
||||||
|
override fun generateMappedObject(schema: MappedSchema) =
|
||||||
|
when (schema) {
|
||||||
|
is DummySchema -> DummySchema.DummyState(
|
||||||
|
dummyData.id
|
||||||
|
)
|
||||||
|
else -> throw IllegalArgumentException("Unsupported Schema")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
class DummyContract : Contract {
|
||||||
|
override fun verify(tx: LedgerTransaction) { }
|
||||||
|
interface Commands : CommandData {
|
||||||
|
class AddDummy : Commands
|
||||||
|
}
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user