diff --git a/docs/source/changelog.rst b/docs/source/changelog.rst index 426b192264..e852b26f77 100644 --- a/docs/source/changelog.rst +++ b/docs/source/changelog.rst @@ -6,6 +6,8 @@ release, see :doc:`upgrade-notes`. Unreleased ---------- +* Vault query fix: support query by parent classes of Contract State classes (see https://github.com/corda/corda/issues/3714) + * Added ``registerResponderFlow`` method to ``StartedMockNode``, to support isolated testing of responder flow behaviour. * "app", "rpc", "p2p" and "unknown" are no longer allowed as uploader values when importing attachments. These are used diff --git a/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt b/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt index 2062762711..17f905313d 100644 --- a/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt +++ b/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt @@ -85,10 +85,10 @@ class NodeVaultService( log.trace { "State update of type: $concreteType" } val seen = contractStateTypeMappings.any { it.value.contains(concreteType.name) } if (!seen) { - val contractInterfaces = deriveContractInterfaces(concreteType) - contractInterfaces.map { - val contractInterface = contractStateTypeMappings.getOrPut(it.name) { mutableSetOf() } - contractInterface.add(concreteType.name) + val contractTypes = deriveContractTypes(concreteType) + contractTypes.map { + val contractStateType = contractStateTypeMappings.getOrPut(it.name) { mutableSetOf() } + contractStateType.add(concreteType.name) } } } @@ -532,10 +532,10 @@ class NodeVaultService( null } concreteType?.let { - val contractInterfaces = deriveContractInterfaces(it) - contractInterfaces.map { - val contractInterface = contractStateTypeMappings.getOrPut(it.name) { mutableSetOf() } - contractInterface.add(it.name) + val contractTypes = deriveContractTypes(it) + contractTypes.map { + val contractStateType = contractStateTypeMappings.getOrPut(it.name) { mutableSetOf() } + contractStateType.add(it.name) } } } @@ -544,14 +544,20 @@ class NodeVaultService( } } - private fun deriveContractInterfaces(clazz: Class): Set> { - val myInterfaces: MutableSet> = mutableSetOf() - clazz.interfaces.forEach { - if (it != ContractState::class.java) { - myInterfaces.add(uncheckedCast(it)) - myInterfaces.addAll(deriveContractInterfaces(uncheckedCast(it))) + private fun deriveContractTypes(clazz: Class): Set> { + val myTypes : MutableSet> = mutableSetOf() + clazz.superclass?.let { + if (!it.isInstance(Any::class)) { + myTypes.add(uncheckedCast(it)) + myTypes.addAll(deriveContractTypes(uncheckedCast(it))) } } - return myInterfaces + clazz.interfaces.forEach { + if (it != ContractState::class.java) { + myTypes.add(uncheckedCast(it)) + myTypes.addAll(deriveContractTypes(uncheckedCast(it))) + } + } + return myTypes } } \ No newline at end of file diff --git a/node/src/test/kotlin/net/corda/node/services/vault/VaultQueryTests.kt b/node/src/test/kotlin/net/corda/node/services/vault/VaultQueryTests.kt index 54e47d3f08..3f8f4ef7d2 100644 --- a/node/src/test/kotlin/net/corda/node/services/vault/VaultQueryTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/vault/VaultQueryTests.kt @@ -2,12 +2,14 @@ package net.corda.node.services.vault import net.corda.core.contracts.* import net.corda.core.crypto.* +import net.corda.core.identity.AbstractParty import net.corda.core.identity.CordaX500Name import net.corda.core.identity.Party import net.corda.core.internal.packageName import net.corda.core.node.services.* import net.corda.core.node.services.vault.* import net.corda.core.node.services.vault.QueryCriteria.* +import net.corda.core.transactions.LedgerTransaction import net.corda.core.transactions.TransactionBuilder import net.corda.core.utilities.* import net.corda.finance.* @@ -27,6 +29,7 @@ import net.corda.nodeapi.internal.persistence.DatabaseConfig import net.corda.nodeapi.internal.persistence.DatabaseTransaction import net.corda.testing.core.* import net.corda.testing.internal.TEST_TX_TIME +import net.corda.testing.internal.chooseIdentity import net.corda.testing.internal.rigorousMock import net.corda.testing.internal.vault.* import net.corda.testing.node.MockServices @@ -115,7 +118,8 @@ open class VaultQueryTestRule : ExternalResource(), VaultQueryParties { "net.corda.finance.contracts", CashSchemaV1::class.packageName, DummyLinearStateSchemaV1::class.packageName, - SampleCashSchemaV3::class.packageName) + SampleCashSchemaV3::class.packageName, + VaultQueryTestsBase.MyContractClass::class.packageName) override lateinit var services: MockServices override lateinit var vaultFiller: VaultFiller @@ -253,6 +257,43 @@ abstract class VaultQueryTestsBase : VaultQueryParties { } } + @Test + fun `query by interface for a contract class extending a parent contract class`() { + database.transaction { + + // build custom contract and store in vault + val me = services.myInfo.chooseIdentity() + val state = MyState("myState", listOf(me)) + val stateAndContract = StateAndContract(state, MYCONTRACT_ID) + val utx = TransactionBuilder(notary = notaryServices.myInfo.singleIdentity()).withItems(stateAndContract).withItems(dummyCommand()) + services.recordTransactions(services.signInitialTransaction(utx)) + + // query vault by Child class + val criteria = VaultQueryCriteria() // default is UNCONSUMED + val queryByMyState = vaultService.queryBy(criteria) + assertThat(queryByMyState.states).hasSize(1) + + // query vault by Parent class + val queryByBaseState = vaultService.queryBy(criteria) + assertThat(queryByBaseState.states).hasSize(1) + + // query vault by extended Contract Interface + val queryByContract = vaultService.queryBy(criteria) + assertThat(queryByContract.states).hasSize(1) + } + } + + // Beware: do not use `MyContractClass::class.qualifiedName` as this returns a fully qualified name using "dot" notation for enclosed class + val MYCONTRACT_ID = "net.corda.node.services.vault.VaultQueryTestsBase\$MyContractClass" + + open class MyContractClass : Contract { + override fun verify(tx: LedgerTransaction) {} + } + + interface MyContractInterface : ContractState + open class BaseState(override val participants: List = emptyList()) : MyContractInterface + data class MyState(val name: String, override val participants: List = emptyList()) : BaseState(participants) + @Test fun `unconsumed states simple`() { database.transaction {