diff --git a/core/src/test/kotlin/net/corda/core/flows/ReferencedStatesFlowTests.kt b/core/src/test/kotlin/net/corda/core/flows/ReferencedStatesFlowTests.kt index 5505fd3b3f..6c2401041b 100644 --- a/core/src/test/kotlin/net/corda/core/flows/ReferencedStatesFlowTests.kt +++ b/core/src/test/kotlin/net/corda/core/flows/ReferencedStatesFlowTests.kt @@ -131,9 +131,8 @@ class ReferencedStatesFlowTests { // Wait until node 1 stores the new tx. nodes[1].services.validatedTransactions.updates.filter { it.id == newTx.id }.toFuture().getOrThrow() // Check that nodes[1] has finished recording the transaction (and updating the vault.. hopefully!). - val allRefStates = nodes[1].services.vaultService.queryBy() - // nodes[1] should have two states. The newly created output and the reference state created by nodes[0]. - assertEquals(2, allRefStates.states.size) + // nodes[1] should have two states. The newly created output of type "Regular.State" and the reference state created by nodes[0]. + assertEquals(2, nodes[1].services.vaultService.queryBy().states.size) // Now let's find the specific reference state on nodes[1]. val refStateLinearId = newRefState.state.data.linearId val query = QueryCriteria.LinearStateQueryCriteria(linearId = listOf(refStateLinearId)) @@ -145,14 +144,27 @@ class ReferencedStatesFlowTests { val nodeZeroQuery = QueryCriteria.LinearStateQueryCriteria(linearId = listOf(refStateLinearId)) val theReferencedStateOnNodeZero = nodes[0].services.vaultService.queryBy(nodeZeroQuery) assertEquals(newRefState, theReferencedStateOnNodeZero.states.single()) - println(theReferencedStateOnNodeZero.statesMetadata.single()) // nodes[0] sends the tx that created the reference state to nodes[1]. nodes[0].services.startFlow(Initiator(newRefState)).resultFuture.getOrThrow() // Query again. val theReferencedStateAgain = nodes[1].services.vaultService.queryBy(query) // There should be one result - the reference state. assertEquals(newRefState, theReferencedStateAgain.states.single()) - println(theReferencedStateAgain.statesMetadata.single()) + } + + @Test + fun `check schema mappings are updated for reference states`() { + // 1. Create a state to be used as a reference state. Don't share it. + val newRefTx = nodes[0].services.startFlow(CreateRefState()).resultFuture.getOrThrow() + val newRefState = newRefTx.tx.outRefsOfType().single() + // 2. Use the "newRefState" a transaction involving another party (nodes[1]) which creates a new state. They should store the new state and the reference state. + val newTx = nodes[0].services.startFlow(UseRefState(nodes[1].info.legalIdentities.first(), newRefState.state.data.linearId)).resultFuture.getOrThrow() + // Wait until node 1 stores the new tx. + nodes[1].services.validatedTransactions.updates.filter { it.id == newTx.id }.toFuture().getOrThrow() + // Check that nodes[1] has finished recording the transaction (and updating the vault.. hopefully!). + val allRefStates = nodes[1].services.vaultService.queryBy() + // nodes[1] should have two states. The newly created output and the reference state created by nodes[0]. + assertEquals(2, allRefStates.states.size) } // A dummy reference state contract. @@ -172,6 +184,20 @@ class ReferencedStatesFlowTests { class Update : CommandData } + class RegularState : Contract { + companion object { + val CONTRACT_ID: String = RegularState::class.java.name + } + + override fun verify(tx: LedgerTransaction) = Unit + + data class State(val owner: Party, override val linearId: UniqueIdentifier = UniqueIdentifier()) : LinearState { + override val participants: List get() = listOf(owner) + } + + class Create : CommandData + } + // A flow to create a reference state. class CreateRefState : FlowLogic() { @Suspendable @@ -239,7 +265,7 @@ class ReferencedStatesFlowTests { val stx = serviceHub.signInitialTransaction(TransactionBuilder(notary = notary).apply { addReferenceState(referenceState.referenced()) - addOutputState(RefState.State(participant), RefState.CONTRACT_ID) + addOutputState(RegularState.State(participant), RefState.CONTRACT_ID) addCommand(RefState.Create(), listOf(ourIdentity.owningKey)) }) return if (participant != ourIdentity) { diff --git a/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt b/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt index 67dd1d1a32..495b49daaf 100644 --- a/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt +++ b/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt @@ -81,12 +81,12 @@ class NodeVaultService( * Maintain a list of contract state interfaces to concrete types stored in the vault * for usage in generic queries of type queryBy or queryBy> */ - private val contractStateTypeMappings = mutableMapOf>() + private val contractStateTypeMappings = mutableMapOf>().toSynchronised() override fun start() { bootstrapContractStateTypes() rawUpdates.subscribe { update -> - update.produced.forEach { + (update.produced + update.references).forEach { val concreteType = it.state.data.javaClass log.trace { "State update of type: $concreteType" } val seen = contractStateTypeMappings.any { it.value.contains(concreteType.name) }