From 4b7e2a399502cb8cefc6a18487355041c1753f25 Mon Sep 17 00:00:00 2001 From: Shams Asari Date: Wed, 26 Jul 2023 10:46:25 +0100 Subject: [PATCH 1/2] ENT-10045: Fix vault query bug on externalId and mapping with multiple keys --- .../vault/HibernateQueryCriteriaParser.kt | 2 +- .../node/services/vault/NodeVaultService.kt | 37 +++++------ .../vault/VaultQueryExceptionsTests.kt | 2 +- .../node/services/vault/VaultQueryTests.kt | 61 ++++++++++++++++--- 4 files changed, 72 insertions(+), 30 deletions(-) diff --git a/node/src/main/kotlin/net/corda/node/services/vault/HibernateQueryCriteriaParser.kt b/node/src/main/kotlin/net/corda/node/services/vault/HibernateQueryCriteriaParser.kt index 3e8387803f..b5f9b327c2 100644 --- a/node/src/main/kotlin/net/corda/node/services/vault/HibernateQueryCriteriaParser.kt +++ b/node/src/main/kotlin/net/corda/node/services/vault/HibernateQueryCriteriaParser.kt @@ -272,7 +272,7 @@ class HibernateAttachmentQueryCriteriaParser(override val criteriaBuilder: class HibernateQueryCriteriaParser(val contractStateType: Class, val contractStateTypeMappings: Map>, override val criteriaBuilder: CriteriaBuilder, - val criteriaQuery: CriteriaQuery, + val criteriaQuery: CriteriaQuery<*>, val vaultStates: Root) : AbstractQueryCriteriaParser(), IQueryCriteriaParser { private companion object { private val log = contextLogger() diff --git a/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt b/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt index ac0913604c..cccac84910 100644 --- a/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt +++ b/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt @@ -35,7 +35,6 @@ import net.corda.core.node.services.vault.PageSpecification import net.corda.core.node.services.vault.QueryCriteria import net.corda.core.node.services.vault.Sort import net.corda.core.node.services.vault.SortAttribute -import net.corda.core.node.services.vault.builder import net.corda.core.observable.internal.OnResilientSubscribe import net.corda.core.schemas.PersistentStateRef import net.corda.core.serialization.SingletonSerializeAsToken @@ -69,17 +68,21 @@ import java.security.PublicKey import java.sql.SQLException import java.time.Clock import java.time.Instant -import java.util.Arrays -import java.util.UUID +import java.util.* import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.CopyOnWriteArraySet import java.util.stream.Stream import javax.persistence.PersistenceException import javax.persistence.Tuple import javax.persistence.criteria.CriteriaBuilder +import javax.persistence.criteria.CriteriaQuery import javax.persistence.criteria.CriteriaUpdate import javax.persistence.criteria.Predicate import javax.persistence.criteria.Root +import kotlin.collections.ArrayList +import kotlin.collections.LinkedHashSet +import kotlin.collections.component1 +import kotlin.collections.component2 /** * The vault service handles storage, retrieval and querying of states. @@ -709,7 +712,8 @@ class NodeVaultService( // calculate total results where a page specification has been defined val totalStatesAvailable = if (paging.isDefault) -1 else queryTotalStateCount(criteria, contractStateType) - val (query, stateTypes) = createQuery(criteria, contractStateType, sorting) + val (criteriaQuery, criteriaParser) = buildCriteriaQuery(criteria, contractStateType, sorting) + val query = getSession().createQuery(criteriaQuery) query.setResultWindow(paging) val statesMetadata: MutableList = mutableListOf() @@ -732,7 +736,7 @@ class NodeVaultService( ArrayList() ) - return Vault.Page(states, statesMetadata, totalStatesAvailable, stateTypes, otherResults) + return Vault.Page(states, statesMetadata, totalStatesAvailable, criteriaParser.stateTypes, otherResults) } private fun Query.resultStream(paging: PageSpecification): Stream { @@ -761,19 +765,17 @@ class NodeVaultService( } } - private fun queryTotalStateCount(baseCriteria: QueryCriteria, contractStateType: Class): Long { - val count = builder { VaultSchemaV1.VaultStates::recordedTime.count() } - val countCriteria = QueryCriteria.VaultCustomQueryCriteria(count, Vault.StateStatus.ALL) - val criteria = baseCriteria.and(countCriteria) - val (query) = createQuery(criteria, contractStateType, null) - val results = query.resultList - return results.last().toArray().last() as Long + private fun queryTotalStateCount(criteria: QueryCriteria, contractStateType: Class): Long { + val (criteriaQuery, criteriaParser) = buildCriteriaQuery(criteria, contractStateType, null) + criteriaQuery.select(criteriaBuilder.countDistinct(criteriaParser.vaultStates)) + val query = getSession().createQuery(criteriaQuery) + return query.singleResult } - private fun createQuery(criteria: QueryCriteria, - contractStateType: Class, - sorting: Sort?): Pair, Vault.StateStatus> { - val criteriaQuery = criteriaBuilder.createQuery(Tuple::class.java) + private inline fun buildCriteriaQuery(criteria: QueryCriteria, + contractStateType: Class, + sorting: Sort?): Pair, HibernateQueryCriteriaParser> { + val criteriaQuery = criteriaBuilder.createQuery(T::class.java) val criteriaParser = HibernateQueryCriteriaParser( contractStateType, contractStateTypeMappings, @@ -782,8 +784,7 @@ class NodeVaultService( criteriaQuery.from(VaultSchemaV1.VaultStates::class.java) ) criteriaParser.parse(criteria, sorting) - val query = getSession().createQuery(criteriaQuery) - return Pair(query, criteriaParser.stateTypes) + return Pair(criteriaQuery, criteriaParser) } /** diff --git a/node/src/test/kotlin/net/corda/node/services/vault/VaultQueryExceptionsTests.kt b/node/src/test/kotlin/net/corda/node/services/vault/VaultQueryExceptionsTests.kt index d1a96ccda5..2dbd77b3ed 100644 --- a/node/src/test/kotlin/net/corda/node/services/vault/VaultQueryExceptionsTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/vault/VaultQueryExceptionsTests.kt @@ -22,7 +22,7 @@ class VaultQueryExceptionsTests : VaultQueryParties by rule { @ClassRule @JvmField - val rule = object : VaultQueryTestRule() { + val rule = object : VaultQueryTestRule(persistentServices = false) { override val cordappPackages = listOf( "net.corda.testing.contracts", "net.corda.finance.contracts", diff --git a/node/src/test/kotlin/net/corda/node/services/vault/VaultQueryTests.kt b/node/src/test/kotlin/net/corda/node/services/vault/VaultQueryTests.kt index b06518667c..94a6eda019 100644 --- a/node/src/test/kotlin/net/corda/node/services/vault/VaultQueryTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/vault/VaultQueryTests.kt @@ -4,6 +4,7 @@ import com.nhaarman.mockito_kotlin.mock import net.corda.core.contracts.* import net.corda.core.crypto.* import net.corda.core.identity.AbstractParty +import net.corda.core.identity.AnonymousParty import net.corda.core.identity.CordaX500Name import net.corda.core.identity.Party import net.corda.core.internal.packageName @@ -37,6 +38,7 @@ import net.corda.testing.internal.configureDatabase import net.corda.testing.internal.vault.* import net.corda.testing.node.MockServices import net.corda.testing.node.MockServices.Companion.makeTestDatabaseAndMockServices +import net.corda.testing.node.MockServices.Companion.makeTestDatabaseAndPersistentServices import net.corda.testing.node.makeTestIdentityService import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThatCode @@ -102,7 +104,7 @@ interface VaultQueryParties { val cordappPackages: List } -open class VaultQueryTestRule : ExternalResource(), VaultQueryParties { +open class VaultQueryTestRule(private val persistentServices: Boolean) : ExternalResource(), VaultQueryParties { override val alice = TestIdentity(ALICE_NAME, 70) override val bankOfCorda = TestIdentity(BOC_NAME) override val bigCorp = TestIdentity(CordaX500Name("BigCorporation", "New York", "US")) @@ -135,12 +137,22 @@ open class VaultQueryTestRule : ExternalResource(), VaultQueryParties { override fun before() { - // register additional identities - val databaseAndServices = makeTestDatabaseAndMockServices( - cordappPackages, - makeTestIdentityService(MEGA_CORP_IDENTITY, MINI_CORP_IDENTITY, dummyCashIssuer.identity, dummyNotary.identity), - megaCorp, - moreKeys = *arrayOf(DUMMY_NOTARY_KEY)) + val databaseAndServices = if (persistentServices) { + makeTestDatabaseAndPersistentServices( + cordappPackages, + megaCorp, + moreKeys = setOf(DUMMY_NOTARY_KEY), + moreIdentities = setOf(MEGA_CORP_IDENTITY, MINI_CORP_IDENTITY, dummyCashIssuer.identity, dummyNotary.identity) + ) + } else { + @Suppress("SpreadOperator") + makeTestDatabaseAndMockServices( + cordappPackages, + makeTestIdentityService(MEGA_CORP_IDENTITY, MINI_CORP_IDENTITY, dummyCashIssuer.identity, dummyNotary.identity), + megaCorp, + moreKeys = *arrayOf(DUMMY_NOTARY_KEY) + ) + } database = databaseAndServices.first services = databaseAndServices.second vaultFiller = VaultFiller(services, dummyNotary) @@ -2832,9 +2844,8 @@ abstract class VaultQueryTestsBase : VaultQueryParties { } class VaultQueryTests : VaultQueryTestsBase(), VaultQueryParties by delegate { - companion object { - val delegate = VaultQueryTestRule() + val delegate = VaultQueryTestRule(persistentServices = false) } @Rule @@ -3137,4 +3148,34 @@ class VaultQueryTests : VaultQueryTestsBase(), VaultQueryParties by delegate { ) } } -} \ No newline at end of file +} + + +class PersistentServicesVaultQueryTests : VaultQueryParties by delegate { + companion object { + val delegate = VaultQueryTestRule(persistentServices = true) + + @ClassRule + @JvmField + val testSerialization = SerializationEnvironmentRule() + } + + @Rule + @JvmField + val vaultQueryTestRule = delegate + + @Test(timeout = 300_000) + fun `query on externalId which maps to multiple keys`() { + val externalId = UUID.randomUUID() + val page = database.transaction { + val keys = Array(2) { services.keyManagementService.freshKey(externalId) } + vaultFiller.fillWithDummyState(participants = keys.map(::AnonymousParty)) + services.vaultService.queryBy( + VaultQueryCriteria(externalIds = listOf(externalId)), + paging = PageSpecification(DEFAULT_PAGE_NUM, 10) + ) + } + assertThat(page.states).hasSize(1) + assertThat(page.totalStatesAvailable).isEqualTo(1) + } +} From 5cdbec9ddf8236d4f381fc5b654e487da31a4d94 Mon Sep 17 00:00:00 2001 From: Shams Asari Date: Mon, 21 Aug 2023 10:30:42 +0100 Subject: [PATCH 2/2] ENT-6876: Optimised vault query to not query for total state count if the first page isn't full (#7449) --- .../net/corda/node/services/vault/NodeVaultService.kt | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt b/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt index ac0913604c..9c90c36dd3 100644 --- a/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt +++ b/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt @@ -706,9 +706,6 @@ class NodeVaultService( paging: PageSpecification, sorting: Sort, contractStateType: Class): Vault.Page { - // calculate total results where a page specification has been defined - val totalStatesAvailable = if (paging.isDefault) -1 else queryTotalStateCount(criteria, contractStateType) - val (query, stateTypes) = createQuery(criteria, contractStateType, sorting) query.setResultWindow(paging) @@ -732,6 +729,13 @@ class NodeVaultService( ArrayList() ) + val totalStatesAvailable = when { + paging.isDefault -> -1L + // If the first page isn't full then we know that's all the states that are available + paging.pageNumber == DEFAULT_PAGE_NUM && states.size < paging.pageSize -> states.size.toLong() + else -> queryTotalStateCount(criteria, contractStateType) + } + return Vault.Page(states, statesMetadata, totalStatesAvailable, stateTypes, otherResults) }