diff --git a/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt b/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt index a5a0b9b7e0..b734859b2c 100644 --- a/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt +++ b/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt @@ -772,7 +772,7 @@ abstract class AbstractNode(open val configuration: NodeConfiguration, override val networkMapCache by lazy { InMemoryNetworkMapCache(this) } override val vaultService by lazy { NodeVaultService(this, configuration.dataSourceProperties, configuration.database) } override val vaultQueryService by lazy { - HibernateVaultQueryImpl(HibernateConfiguration(schemaService, configuration.database ?: Properties(), { identityService }), vaultService.updatesPublisher) + HibernateVaultQueryImpl(HibernateConfiguration(schemaService, configuration.database ?: Properties(), { identityService }), vaultService) } // Place the long term identity key in the KMS. Eventually, this is likely going to be separated again because // the KMS is meant for derived temporary keys used in transactions, and we're not supposed to sign things with diff --git a/node/src/main/kotlin/net/corda/node/services/vault/HibernateQueryCriteriaParser.kt b/node/src/main/kotlin/net/corda/node/services/vault/HibernateQueryCriteriaParser.kt index 2ee2218160..114fc8b792 100644 --- a/node/src/main/kotlin/net/corda/node/services/vault/HibernateQueryCriteriaParser.kt +++ b/node/src/main/kotlin/net/corda/node/services/vault/HibernateQueryCriteriaParser.kt @@ -21,7 +21,7 @@ import javax.persistence.criteria.* class HibernateQueryCriteriaParser(val contractType: Class, - val contractTypeMappings: Map>, + val contractTypeMappings: Map>, val criteriaBuilder: CriteriaBuilder, val criteriaQuery: CriteriaQuery, val vaultStates: Root) : IQueryCriteriaParser { @@ -97,7 +97,7 @@ class HibernateQueryCriteriaParser(val contractType: Class, private fun deriveContractTypes(contractStateTypes: Set>? = null): List { val combinedContractStateTypes = contractStateTypes?.plus(contractType) ?: setOf(contractType) combinedContractStateTypes.filter { it.name != ContractState::class.java.name }.let { - val interfaces = it.flatMap { contractTypeMappings[it.name] ?: emptyList() } + val interfaces = it.flatMap { contractTypeMappings[it.name] ?: listOf(it.name) } val concrete = it.filter { !it.isInterface }.map { it.name } return interfaces.plus(concrete) } diff --git a/node/src/main/kotlin/net/corda/node/services/vault/HibernateVaultQueryImpl.kt b/node/src/main/kotlin/net/corda/node/services/vault/HibernateVaultQueryImpl.kt index 3bccc8597c..1b0156364b 100644 --- a/node/src/main/kotlin/net/corda/node/services/vault/HibernateVaultQueryImpl.kt +++ b/node/src/main/kotlin/net/corda/node/services/vault/HibernateVaultQueryImpl.kt @@ -1,16 +1,17 @@ package net.corda.node.services.vault import net.corda.core.internal.ThreadBox -import net.corda.core.internal.bufferUntilSubscribed import net.corda.core.contracts.ContractState import net.corda.core.contracts.StateAndRef import net.corda.core.contracts.StateRef import net.corda.core.contracts.TransactionState import net.corda.core.crypto.SecureHash +import net.corda.core.internal.bufferUntilSubscribed import net.corda.core.messaging.DataFeed import net.corda.core.node.services.Vault import net.corda.core.node.services.VaultQueryException import net.corda.core.node.services.VaultQueryService +import net.corda.core.node.services.VaultService import net.corda.core.node.services.vault.* import net.corda.core.node.services.vault.QueryCriteria.VaultCustomQueryCriteria import net.corda.core.serialization.SerializationDefaults.STORAGE_CONTEXT @@ -18,18 +19,18 @@ import net.corda.core.serialization.SingletonSerializeAsToken import net.corda.core.serialization.deserialize import net.corda.core.utilities.debug import net.corda.core.utilities.loggerFor +import net.corda.core.utilities.trace import net.corda.node.services.database.HibernateConfiguration +import org.hibernate.Session import org.jetbrains.exposed.sql.transactions.TransactionManager -import rx.subjects.PublishSubject import rx.Observable import java.lang.Exception import java.util.* -import javax.persistence.EntityManager import javax.persistence.Tuple class HibernateVaultQueryImpl(hibernateConfig: HibernateConfiguration, - val updatesPublisher: PublishSubject>) : SingletonSerializeAsToken(), VaultQueryService { + val vault: VaultService) : SingletonSerializeAsToken(), VaultQueryService { companion object { val log = loggerFor() } @@ -37,6 +38,29 @@ class HibernateVaultQueryImpl(hibernateConfig: HibernateConfiguration, private val sessionFactory = hibernateConfig.sessionFactoryForRegisteredSchemas() private val criteriaBuilder = sessionFactory.criteriaBuilder + /** + * Maintain a list of contract state interfaces to concrete types stored in the vault + * for usage in generic queries of type queryBy or queryBy> + */ + private val contractTypeMappings = bootstrapContractStateTypes() + + init { + vault.rawUpdates.subscribe { update -> + update.produced.forEach { + val concreteType = it.state.data.javaClass + log.trace { "State update of type: $concreteType" } + val seen = contractTypeMappings.any { it.value.contains(concreteType.name) } + if (!seen) { + val contractInterfaces = deriveContractInterfaces(concreteType) + contractInterfaces.map { + val contractInterface = contractTypeMappings.getOrPut(it.name, { mutableSetOf() }) + contractInterface.add(concreteType.name) + } + } + } + } + } + @Throws(VaultQueryException::class) override fun _queryBy(criteria: QueryCriteria, paging: PageSpecification, sorting: Sort, contractType: Class): Vault.Page { log.info("Vault Query for contract type: $contractType, criteria: $criteria, pagination: $paging, sorting: $sorting") @@ -50,15 +74,12 @@ class HibernateVaultQueryImpl(hibernateConfig: HibernateConfiguration, totalStates = results.otherResults[0] as Long } - val session = sessionFactory.withOptions(). - connection(TransactionManager.current().connection). - openSession() + val session = getSession() session.use { val criteriaQuery = criteriaBuilder.createQuery(Tuple::class.java) val queryRootVaultStates = criteriaQuery.from(VaultSchemaV1.VaultStates::class.java) - val contractTypeMappings = resolveUniqueContractStateTypes(session) // TODO: revisit (use single instance of parser for all queries) val criteriaParser = HibernateQueryCriteriaParser(contractType, contractTypeMappings, criteriaBuilder, criteriaQuery, queryRootVaultStates) @@ -116,41 +137,49 @@ class HibernateVaultQueryImpl(hibernateConfig: HibernateConfiguration, } } - private val mutex = ThreadBox({ updatesPublisher }) + private val mutex = ThreadBox({ vault.updatesPublisher }) @Throws(VaultQueryException::class) override fun _trackBy(criteria: QueryCriteria, paging: PageSpecification, sorting: Sort, contractType: Class): DataFeed, Vault.Update> { return mutex.locked { val snapshotResults = _queryBy(criteria, paging, sorting, contractType) @Suppress("UNCHECKED_CAST") - val updates = updatesPublisher.bufferUntilSubscribed().filter { it.containsType(contractType, snapshotResults.stateTypes) } as Observable> + val updates = vault.updatesPublisher.bufferUntilSubscribed().filter { it.containsType(contractType, snapshotResults.stateTypes) } as Observable> DataFeed(snapshotResults, updates) } } + private fun getSession(): Session { + return sessionFactory.withOptions(). + connection(TransactionManager.current().connection). + openSession() + } + /** - * Maintain a list of contract state interfaces to concrete types stored in the vault - * for usage in generic queries of type queryBy or queryBy> + * Derive list from existing vault states and then incrementally update using vault observables */ - fun resolveUniqueContractStateTypes(session: EntityManager): Map> { + fun bootstrapContractStateTypes(): MutableMap> { val criteria = criteriaBuilder.createQuery(String::class.java) val vaultStates = criteria.from(VaultSchemaV1.VaultStates::class.java) criteria.select(vaultStates.get("contractStateClassName")).distinct(true) - val query = session.createQuery(criteria) - val results = query.resultList - val distinctTypes = results.map { it } + val session = getSession() + session.use { + val query = session.createQuery(criteria) + val results = query.resultList + val distinctTypes = results.map { it } - val contractInterfaceToConcreteTypes = mutableMapOf>() - distinctTypes.forEach { it -> - @Suppress("UNCHECKED_CAST") - val concreteType = Class.forName(it) as Class - val contractInterfaces = deriveContractInterfaces(concreteType) - contractInterfaces.map { - val contractInterface = contractInterfaceToConcreteTypes.getOrPut(it.name, { mutableListOf() }) - contractInterface.add(concreteType.name) + val contractInterfaceToConcreteTypes = mutableMapOf>() + distinctTypes.forEach { type -> + @Suppress("UNCHECKED_CAST") + val concreteType = Class.forName(type) as Class + val contractInterfaces = deriveContractInterfaces(concreteType) + contractInterfaces.map { + val contractInterface = contractInterfaceToConcreteTypes.getOrPut(it.name, { mutableSetOf() }) + contractInterface.add(concreteType.name) + } } + return contractInterfaceToConcreteTypes } - return contractInterfaceToConcreteTypes } private fun deriveContractInterfaces(clazz: Class): Set> { diff --git a/test-utils/src/main/kotlin/net/corda/testing/node/MockServices.kt b/test-utils/src/main/kotlin/net/corda/testing/node/MockServices.kt index 43db1c6550..0b4ff61add 100644 --- a/test-utils/src/main/kotlin/net/corda/testing/node/MockServices.kt +++ b/test-utils/src/main/kotlin/net/corda/testing/node/MockServices.kt @@ -235,7 +235,7 @@ fun makeTestDatabaseAndMockServices(customSchemas: Set = setOf(Com vaultService.notifyAll(txs.map { it.tx }) } - override val vaultQueryService: VaultQueryService = HibernateVaultQueryImpl(hibernateConfig, vaultService.updatesPublisher) + override val vaultQueryService: VaultQueryService = HibernateVaultQueryImpl(hibernateConfig, vaultService) override fun jdbcSession(): Connection = database.createSession() }