diff --git a/docs/source/changelog.rst b/docs/source/changelog.rst index 63854dda4a..38fdb4c132 100644 --- a/docs/source/changelog.rst +++ b/docs/source/changelog.rst @@ -6,6 +6,9 @@ release, see :doc:`upgrade-notes`. Unreleased ========== + +* Vault query fix: support query by parent classes of Contract State classes (see https://github.com/corda/corda/issues/3714) + * Fixed an issue preventing Shell from returning control to the user when CTRL+C is pressed in the terminal. * Fixed a problem that sometimes prevented nodes from starting in presence of custom state types in the database without a corresponding type from installed CorDapps. diff --git a/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt b/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt index bc1e4eaa7d..5b6c3bb9e5 100644 --- a/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt +++ b/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt @@ -371,9 +371,10 @@ class NodeVaultService( * Maintain a list of contract state interfaces to concrete types stored in the vault * for usage in generic queries of type queryBy or queryBy> */ - private val contractStateTypeMappings = bootstrapContractStateTypes() + private val contractStateTypeMappings = mutableMapOf>() init { + bootstrapContractStateTypes() rawUpdates.subscribe { update -> update.produced.forEach { val concreteType = it.state.data.javaClass @@ -479,7 +480,7 @@ class NodeVaultService( /** * Derive list from existing vault states and then incrementally update using vault observables */ - private fun bootstrapContractStateTypes(): MutableMap> { + private fun bootstrapContractStateTypes() { val criteria = criteriaBuilder.createQuery(String::class.java) val vaultStates = criteria.from(VaultSchemaV1.VaultStates::class.java) criteria.select(vaultStates.get("contractStateClassName")).distinct(true) @@ -491,25 +492,19 @@ class NodeVaultService( val contractInterfaceToConcreteTypes = mutableMapOf>() val unknownTypes = mutableSetOf() - distinctTypes.forEach { type -> - val concreteType: Class? = try { - uncheckedCast(Class.forName(type)) - } catch (e: ClassNotFoundException) { - unknownTypes += type - null - } + distinctTypes.forEach { type -> + val concreteType: Class = uncheckedCast(Class.forName(type)) concreteType?.let { val contractTypes = deriveContractTypes(it) contractTypes.map { - val contractStateType = contractInterfaceToConcreteTypes.getOrPut(it.name) { mutableSetOf() } - contractStateType.add(concreteType.name) + val contractStateType = contractStateTypeMappings.getOrPut(it.name) { mutableSetOf() } + contractStateType.add(it.name) } } } if (unknownTypes.isNotEmpty()) { log.warn("There are unknown contract state types in the vault, which will prevent these states from being used. The relevant CorDapps must be loaded for these states to be used. The types not on the classpath are ${unknownTypes.joinToString(", ", "[", "]")}.") } - return contractInterfaceToConcreteTypes } private fun deriveContractTypes(clazz: Class): Set> { diff --git a/node/src/test/kotlin/net/corda/node/services/vault/VaultQueryTests.kt b/node/src/test/kotlin/net/corda/node/services/vault/VaultQueryTests.kt index 11988f5524..0bfcb2c820 100644 --- a/node/src/test/kotlin/net/corda/node/services/vault/VaultQueryTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/vault/VaultQueryTests.kt @@ -9,6 +9,8 @@ import net.corda.core.internal.packageName import net.corda.core.node.services.* import net.corda.core.node.services.vault.* import net.corda.core.node.services.vault.QueryCriteria.* +import net.corda.core.transactions.LedgerTransaction +import net.corda.core.transactions.TransactionBuilder import net.corda.core.utilities.NonEmptySet import net.corda.core.utilities.days import net.corda.core.utilities.seconds @@ -23,11 +25,13 @@ import net.corda.finance.schemas.CashSchemaV1.PersistentCashState import net.corda.finance.schemas.CommercialPaperSchemaV1 import net.corda.finance.schemas.SampleCashSchemaV2 import net.corda.finance.schemas.SampleCashSchemaV3 +import net.corda.core.identity.AbstractParty import net.corda.node.internal.configureDatabase import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.nodeapi.internal.persistence.DatabaseConfig import net.corda.testing.core.* import net.corda.testing.internal.TEST_TX_TIME +import net.corda.testing.internal.chooseIdentity import net.corda.testing.internal.rigorousMock import net.corda.testing.internal.vault.DUMMY_LINEAR_CONTRACT_PROGRAM_ID import net.corda.testing.internal.vault.DummyLinearContract @@ -98,7 +102,8 @@ class VaultQueryTests { "net.corda.testing.contracts", "net.corda.finance.contracts", CashSchemaV1::class.packageName, - DummyLinearStateSchemaV1::class.packageName) + DummyLinearStateSchemaV1::class.packageName, + VaultQueryTests.MyContractClass::class.packageName) private lateinit var services: MockServices private lateinit var vaultFiller: VaultFiller private lateinit var vaultFillerCashNotary: VaultFiller @@ -106,6 +111,7 @@ class VaultQueryTests { private val vaultService: VaultService get() = services.vaultService private lateinit var identitySvc: IdentityService private lateinit var database: CordaPersistence + @Before fun setUp() { // register additional identities @@ -209,6 +215,43 @@ class VaultQueryTests { } } + @Test + fun `query by interface for a contract class extending a parent contract class`() { + database.transaction { + + // build custom contract and store in vault + val me = services.myInfo.chooseIdentity() + val state = MyState("myState", listOf(me)) + val stateAndContract = StateAndContract(state, MYCONTRACT_ID) + val utx = TransactionBuilder(notary = notaryServices.myInfo.singleIdentity()).withItems(stateAndContract).withItems(dummyCommand()) + services.recordTransactions(services.signInitialTransaction(utx)) + + // query vault by Child class + val criteria = VaultQueryCriteria() // default is UNCONSUMED + val queryByMyState = vaultService.queryBy(criteria) + assertThat(queryByMyState.states).hasSize(1) + + // query vault by Parent class + val queryByBaseState = vaultService.queryBy(criteria) + assertThat(queryByBaseState.states).hasSize(1) + + // query vault by extended Contract Interface + val queryByContract = vaultService.queryBy(criteria) + assertThat(queryByContract.states).hasSize(1) + } + } + + // Beware: do not use `MyContractClass::class.qualifiedName` as this returns a fully qualified name using "dot" notation for enclosed class + val MYCONTRACT_ID = "net.corda.node.services.vault.VaultQueryTests\$MyContractClass" + + open class MyContractClass : Contract { + override fun verify(tx: LedgerTransaction) {} + } + + interface MyContractInterface : ContractState + open class BaseState(override val participants: List = emptyList()) : MyContractInterface + data class MyState(val name: String, override val participants: List = emptyList()) : BaseState(participants) + @Test fun `unconsumed states simple`() { database.transaction {