CORDA-1858 - Vault query fails to find a state if it extends from class (#3722)

* Included Contract State parent classes in list of queryable types.

* Added changelog entry.
This commit is contained in:
josecoll 2018-07-31 17:16:27 +01:00 committed by GitHub
parent b7f7dcc510
commit 8b501b1b80
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 65 additions and 16 deletions

View File

@ -6,6 +6,8 @@ release, see :doc:`upgrade-notes`.
Unreleased Unreleased
---------- ----------
* Vault query fix: support query by parent classes of Contract State classes (see https://github.com/corda/corda/issues/3714)
* Added ``registerResponderFlow`` method to ``StartedMockNode``, to support isolated testing of responder flow behaviour. * Added ``registerResponderFlow`` method to ``StartedMockNode``, to support isolated testing of responder flow behaviour.
* "app", "rpc", "p2p" and "unknown" are no longer allowed as uploader values when importing attachments. These are used * "app", "rpc", "p2p" and "unknown" are no longer allowed as uploader values when importing attachments. These are used

View File

@ -85,10 +85,10 @@ class NodeVaultService(
log.trace { "State update of type: $concreteType" } log.trace { "State update of type: $concreteType" }
val seen = contractStateTypeMappings.any { it.value.contains(concreteType.name) } val seen = contractStateTypeMappings.any { it.value.contains(concreteType.name) }
if (!seen) { if (!seen) {
val contractInterfaces = deriveContractInterfaces(concreteType) val contractTypes = deriveContractTypes(concreteType)
contractInterfaces.map { contractTypes.map {
val contractInterface = contractStateTypeMappings.getOrPut(it.name) { mutableSetOf() } val contractStateType = contractStateTypeMappings.getOrPut(it.name) { mutableSetOf() }
contractInterface.add(concreteType.name) contractStateType.add(concreteType.name)
} }
} }
} }
@ -532,10 +532,10 @@ class NodeVaultService(
null null
} }
concreteType?.let { concreteType?.let {
val contractInterfaces = deriveContractInterfaces(it) val contractTypes = deriveContractTypes(it)
contractInterfaces.map { contractTypes.map {
val contractInterface = contractStateTypeMappings.getOrPut(it.name) { mutableSetOf() } val contractStateType = contractStateTypeMappings.getOrPut(it.name) { mutableSetOf() }
contractInterface.add(it.name) contractStateType.add(it.name)
} }
} }
} }
@ -544,14 +544,20 @@ class NodeVaultService(
} }
} }
private fun <T : ContractState> deriveContractInterfaces(clazz: Class<T>): Set<Class<T>> { private fun <T : ContractState> deriveContractTypes(clazz: Class<T>): Set<Class<T>> {
val myInterfaces: MutableSet<Class<T>> = mutableSetOf() val myTypes : MutableSet<Class<T>> = mutableSetOf()
clazz.interfaces.forEach { clazz.superclass?.let {
if (it != ContractState::class.java) { if (!it.isInstance(Any::class)) {
myInterfaces.add(uncheckedCast(it)) myTypes.add(uncheckedCast(it))
myInterfaces.addAll(deriveContractInterfaces(uncheckedCast(it))) myTypes.addAll(deriveContractTypes(uncheckedCast(it)))
} }
} }
return myInterfaces clazz.interfaces.forEach {
if (it != ContractState::class.java) {
myTypes.add(uncheckedCast(it))
myTypes.addAll(deriveContractTypes(uncheckedCast(it)))
}
}
return myTypes
} }
} }

View File

@ -2,12 +2,14 @@ package net.corda.node.services.vault
import net.corda.core.contracts.* import net.corda.core.contracts.*
import net.corda.core.crypto.* import net.corda.core.crypto.*
import net.corda.core.identity.AbstractParty
import net.corda.core.identity.CordaX500Name import net.corda.core.identity.CordaX500Name
import net.corda.core.identity.Party import net.corda.core.identity.Party
import net.corda.core.internal.packageName import net.corda.core.internal.packageName
import net.corda.core.node.services.* import net.corda.core.node.services.*
import net.corda.core.node.services.vault.* import net.corda.core.node.services.vault.*
import net.corda.core.node.services.vault.QueryCriteria.* import net.corda.core.node.services.vault.QueryCriteria.*
import net.corda.core.transactions.LedgerTransaction
import net.corda.core.transactions.TransactionBuilder import net.corda.core.transactions.TransactionBuilder
import net.corda.core.utilities.* import net.corda.core.utilities.*
import net.corda.finance.* import net.corda.finance.*
@ -27,6 +29,7 @@ import net.corda.nodeapi.internal.persistence.DatabaseConfig
import net.corda.nodeapi.internal.persistence.DatabaseTransaction import net.corda.nodeapi.internal.persistence.DatabaseTransaction
import net.corda.testing.core.* import net.corda.testing.core.*
import net.corda.testing.internal.TEST_TX_TIME import net.corda.testing.internal.TEST_TX_TIME
import net.corda.testing.internal.chooseIdentity
import net.corda.testing.internal.rigorousMock import net.corda.testing.internal.rigorousMock
import net.corda.testing.internal.vault.* import net.corda.testing.internal.vault.*
import net.corda.testing.node.MockServices import net.corda.testing.node.MockServices
@ -115,7 +118,8 @@ open class VaultQueryTestRule : ExternalResource(), VaultQueryParties {
"net.corda.finance.contracts", "net.corda.finance.contracts",
CashSchemaV1::class.packageName, CashSchemaV1::class.packageName,
DummyLinearStateSchemaV1::class.packageName, DummyLinearStateSchemaV1::class.packageName,
SampleCashSchemaV3::class.packageName) SampleCashSchemaV3::class.packageName,
VaultQueryTestsBase.MyContractClass::class.packageName)
override lateinit var services: MockServices override lateinit var services: MockServices
override lateinit var vaultFiller: VaultFiller override lateinit var vaultFiller: VaultFiller
@ -253,6 +257,43 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
} }
} }
@Test
fun `query by interface for a contract class extending a parent contract class`() {
database.transaction {
// build custom contract and store in vault
val me = services.myInfo.chooseIdentity()
val state = MyState("myState", listOf(me))
val stateAndContract = StateAndContract(state, MYCONTRACT_ID)
val utx = TransactionBuilder(notary = notaryServices.myInfo.singleIdentity()).withItems(stateAndContract).withItems(dummyCommand())
services.recordTransactions(services.signInitialTransaction(utx))
// query vault by Child class
val criteria = VaultQueryCriteria() // default is UNCONSUMED
val queryByMyState = vaultService.queryBy<MyState>(criteria)
assertThat(queryByMyState.states).hasSize(1)
// query vault by Parent class
val queryByBaseState = vaultService.queryBy<BaseState>(criteria)
assertThat(queryByBaseState.states).hasSize(1)
// query vault by extended Contract Interface
val queryByContract = vaultService.queryBy<MyContractInterface>(criteria)
assertThat(queryByContract.states).hasSize(1)
}
}
// Beware: do not use `MyContractClass::class.qualifiedName` as this returns a fully qualified name using "dot" notation for enclosed class
val MYCONTRACT_ID = "net.corda.node.services.vault.VaultQueryTestsBase\$MyContractClass"
open class MyContractClass : Contract {
override fun verify(tx: LedgerTransaction) {}
}
interface MyContractInterface : ContractState
open class BaseState(override val participants: List<AbstractParty> = emptyList()) : MyContractInterface
data class MyState(val name: String, override val participants: List<AbstractParty> = emptyList()) : BaseState(participants)
@Test @Test
fun `unconsumed states simple`() { fun `unconsumed states simple`() {
database.transaction { database.transaction {