ENT-10013: Vault service refactoring backport

This commit is contained in:
Shams Asari 2023-06-06 16:16:59 +01:00
parent a817218b08
commit 0bfce451ea
15 changed files with 253 additions and 161 deletions

View File

@ -1,3 +1,5 @@
@file:Suppress("LongParameterList")
package net.corda.core.node.services package net.corda.core.node.services
import co.paralleluniverse.fibers.Suspendable import co.paralleluniverse.fibers.Suspendable
@ -197,8 +199,7 @@ class Vault<out T : ContractState>(val states: Iterable<StateAndRef<T>>) {
* 4) Status types used in this query: [StateStatus.UNCONSUMED], [StateStatus.CONSUMED], [StateStatus.ALL]. * 4) Status types used in this query: [StateStatus.UNCONSUMED], [StateStatus.CONSUMED], [StateStatus.ALL].
* 5) Other results as a [List] of any type (eg. aggregate function results with/without group by). * 5) Other results as a [List] of any type (eg. aggregate function results with/without group by).
* *
* Note: currently otherResults are used only for Aggregate Functions (in which case, the states and statesMetadata * Note: currently [otherResults] is used only for aggregate functions (in which case, [states] and [statesMetadata] will be empty).
* results will be empty).
*/ */
@CordaSerializable @CordaSerializable
data class Page<out T : ContractState>(val states: List<StateAndRef<T>>, data class Page<out T : ContractState>(val states: List<StateAndRef<T>>,
@ -213,11 +214,11 @@ class Vault<out T : ContractState>(val states: Iterable<StateAndRef<T>>) {
val contractStateClassName: String, val contractStateClassName: String,
val recordedTime: Instant, val recordedTime: Instant,
val consumedTime: Instant?, val consumedTime: Instant?,
val status: Vault.StateStatus, val status: StateStatus,
val notary: AbstractParty?, val notary: AbstractParty?,
val lockId: String?, val lockId: String?,
val lockUpdateTime: Instant?, val lockUpdateTime: Instant?,
val relevancyStatus: Vault.RelevancyStatus? = null, val relevancyStatus: RelevancyStatus? = null,
val constraintInfo: ConstraintInfo? = null val constraintInfo: ConstraintInfo? = null
) { ) {
fun copy( fun copy(
@ -225,7 +226,7 @@ class Vault<out T : ContractState>(val states: Iterable<StateAndRef<T>>) {
contractStateClassName: String = this.contractStateClassName, contractStateClassName: String = this.contractStateClassName,
recordedTime: Instant = this.recordedTime, recordedTime: Instant = this.recordedTime,
consumedTime: Instant? = this.consumedTime, consumedTime: Instant? = this.consumedTime,
status: Vault.StateStatus = this.status, status: StateStatus = this.status,
notary: AbstractParty? = this.notary, notary: AbstractParty? = this.notary,
lockId: String? = this.lockId, lockId: String? = this.lockId,
lockUpdateTime: Instant? = this.lockUpdateTime lockUpdateTime: Instant? = this.lockUpdateTime
@ -237,11 +238,11 @@ class Vault<out T : ContractState>(val states: Iterable<StateAndRef<T>>) {
contractStateClassName: String = this.contractStateClassName, contractStateClassName: String = this.contractStateClassName,
recordedTime: Instant = this.recordedTime, recordedTime: Instant = this.recordedTime,
consumedTime: Instant? = this.consumedTime, consumedTime: Instant? = this.consumedTime,
status: Vault.StateStatus = this.status, status: StateStatus = this.status,
notary: AbstractParty? = this.notary, notary: AbstractParty? = this.notary,
lockId: String? = this.lockId, lockId: String? = this.lockId,
lockUpdateTime: Instant? = this.lockUpdateTime, lockUpdateTime: Instant? = this.lockUpdateTime,
relevancyStatus: Vault.RelevancyStatus? relevancyStatus: RelevancyStatus?
): StateMetadata { ): StateMetadata {
return StateMetadata(ref, contractStateClassName, recordedTime, consumedTime, status, notary, lockId, lockUpdateTime, relevancyStatus, ConstraintInfo(AlwaysAcceptAttachmentConstraint)) return StateMetadata(ref, contractStateClassName, recordedTime, consumedTime, status, notary, lockId, lockUpdateTime, relevancyStatus, ConstraintInfo(AlwaysAcceptAttachmentConstraint))
} }
@ -249,9 +250,9 @@ class Vault<out T : ContractState>(val states: Iterable<StateAndRef<T>>) {
companion object { companion object {
@Deprecated("No longer used. The vault does not emit empty updates") @Deprecated("No longer used. The vault does not emit empty updates")
val NoUpdate = Update(emptySet(), emptySet(), type = Vault.UpdateType.GENERAL, references = emptySet()) val NoUpdate = Update(emptySet(), emptySet(), type = UpdateType.GENERAL, references = emptySet())
@Deprecated("No longer used. The vault does not emit empty updates") @Deprecated("No longer used. The vault does not emit empty updates")
val NoNotaryUpdate = Vault.Update(emptySet(), emptySet(), type = Vault.UpdateType.NOTARY_CHANGE, references = emptySet()) val NoNotaryUpdate = Update(emptySet(), emptySet(), type = UpdateType.NOTARY_CHANGE, references = emptySet())
} }
} }
@ -302,7 +303,7 @@ interface VaultService {
fun whenConsumed(ref: StateRef): CordaFuture<Vault.Update<ContractState>> { fun whenConsumed(ref: StateRef): CordaFuture<Vault.Update<ContractState>> {
val query = QueryCriteria.VaultQueryCriteria( val query = QueryCriteria.VaultQueryCriteria(
stateRefs = listOf(ref), stateRefs = listOf(ref),
status = Vault.StateStatus.CONSUMED status = StateStatus.CONSUMED
) )
val result = trackBy<ContractState>(query) val result = trackBy<ContractState>(query)
val snapshot = result.snapshot.states val snapshot = result.snapshot.states
@ -358,8 +359,8 @@ interface VaultService {
/** /**
* Helper function to determine spendable states and soft locking them. * Helper function to determine spendable states and soft locking them.
* Currently performance will be worse than for the hand optimised version in * Currently performance will be worse than for the hand optimised version in
* [Cash.unconsumedCashStatesForSpending]. However, this is fully generic and can operate with custom [FungibleState] * [net.corda.finance.workflows.asset.selection.AbstractCashSelection.unconsumedCashStatesForSpending]. However, this is fully generic
* and [FungibleAsset] states. * and can operate with custom [FungibleState] and [FungibleAsset] states.
* @param lockId The [FlowLogic.runId]'s [UUID] of the current flow used to soft lock the states. * @param lockId The [FlowLogic.runId]'s [UUID] of the current flow used to soft lock the states.
* @param eligibleStatesQuery A custom query object that selects down to the appropriate subset of all states of the * @param eligibleStatesQuery A custom query object that selects down to the appropriate subset of all states of the
* [contractStateType]. e.g. by selecting on account, issuer, etc. The query is internally augmented with the * [contractStateType]. e.g. by selecting on account, issuer, etc. The query is internally augmented with the

View File

@ -1077,7 +1077,7 @@ abstract class AbstractNode<S>(val configuration: NodeConfiguration,
networkParameters: NetworkParameters) networkParameters: NetworkParameters)
protected open fun makeVaultService(keyManagementService: KeyManagementService, protected open fun makeVaultService(keyManagementService: KeyManagementService,
services: ServicesForResolution, services: NodeServicesForResolution,
database: CordaPersistence, database: CordaPersistence,
cordappLoader: CordappLoader): VaultServiceInternal { cordappLoader: CordappLoader): VaultServiceInternal {
return NodeVaultService(platformClock, keyManagementService, services, database, schemaService, cordappLoader.appClassLoader) return NodeVaultService(platformClock, keyManagementService, services, database, schemaService, cordappLoader.appClassLoader)

View File

@ -0,0 +1,15 @@
package net.corda.node.internal
import net.corda.core.contracts.ContractState
import net.corda.core.contracts.StateAndRef
import net.corda.core.contracts.StateRef
import net.corda.core.contracts.TransactionResolutionException
import net.corda.core.node.ServicesForResolution
import java.util.LinkedHashSet
interface NodeServicesForResolution : ServicesForResolution {
@Throws(TransactionResolutionException::class)
override fun loadStates(stateRefs: Set<StateRef>): Set<StateAndRef<ContractState>> = loadStates(stateRefs, LinkedHashSet())
fun <T : ContractState, C : MutableCollection<StateAndRef<T>>> loadStates(input: Iterable<StateRef>, output: C): C
}

View File

@ -1,11 +1,18 @@
package net.corda.node.internal package net.corda.node.internal
import net.corda.core.contracts.* import net.corda.core.contracts.Attachment
import net.corda.core.contracts.AttachmentResolutionException
import net.corda.core.contracts.ContractAttachment
import net.corda.core.contracts.ContractState
import net.corda.core.contracts.StateAndRef
import net.corda.core.contracts.StateRef
import net.corda.core.contracts.TransactionResolutionException
import net.corda.core.contracts.TransactionState
import net.corda.core.cordapp.CordappProvider import net.corda.core.cordapp.CordappProvider
import net.corda.core.crypto.SecureHash import net.corda.core.crypto.SecureHash
import net.corda.core.internal.SerializedStateAndRef import net.corda.core.internal.SerializedStateAndRef
import net.corda.core.internal.uncheckedCast
import net.corda.core.node.NetworkParameters import net.corda.core.node.NetworkParameters
import net.corda.core.node.ServicesForResolution
import net.corda.core.node.services.AttachmentStorage import net.corda.core.node.services.AttachmentStorage
import net.corda.core.node.services.IdentityService import net.corda.core.node.services.IdentityService
import net.corda.core.node.services.NetworkParametersService import net.corda.core.node.services.NetworkParametersService
@ -23,7 +30,7 @@ data class ServicesForResolutionImpl(
override val cordappProvider: CordappProvider, override val cordappProvider: CordappProvider,
override val networkParametersService: NetworkParametersService, override val networkParametersService: NetworkParametersService,
private val validatedTransactions: TransactionStorage private val validatedTransactions: TransactionStorage
) : ServicesForResolution { ) : NodeServicesForResolution {
override val networkParameters: NetworkParameters get() = networkParametersService.lookup(networkParametersService.currentHash) ?: override val networkParameters: NetworkParameters get() = networkParametersService.lookup(networkParametersService.currentHash) ?:
throw IllegalArgumentException("No current parameters in network parameters storage") throw IllegalArgumentException("No current parameters in network parameters storage")
@ -32,12 +39,11 @@ data class ServicesForResolutionImpl(
return toBaseTransaction(stateRef.txhash).outputs[stateRef.index] return toBaseTransaction(stateRef.txhash).outputs[stateRef.index]
} }
@Throws(TransactionResolutionException::class) override fun <T : ContractState, C : MutableCollection<StateAndRef<T>>> loadStates(input: Iterable<StateRef>, output: C): C {
override fun loadStates(stateRefs: Set<StateRef>): Set<StateAndRef<ContractState>> {
val baseTxs = HashMap<SecureHash, BaseTransaction>() val baseTxs = HashMap<SecureHash, BaseTransaction>()
return stateRefs.mapTo(LinkedHashSet()) { stateRef -> return input.mapTo(output) { stateRef ->
val baseTx = baseTxs.computeIfAbsent(stateRef.txhash, ::toBaseTransaction) val baseTx = baseTxs.computeIfAbsent(stateRef.txhash, ::toBaseTransaction)
StateAndRef(baseTx.outputs[stateRef.index], stateRef) StateAndRef(uncheckedCast(baseTx.outputs[stateRef.index]), stateRef)
} }
} }

View File

@ -2,7 +2,6 @@ package net.corda.node.migration
import liquibase.database.Database import liquibase.database.Database
import net.corda.core.contracts.* import net.corda.core.contracts.*
import net.corda.core.crypto.SecureHash
import net.corda.core.identity.CordaX500Name import net.corda.core.identity.CordaX500Name
import net.corda.core.node.services.Vault import net.corda.core.node.services.Vault
import net.corda.core.schemas.MappedSchema import net.corda.core.schemas.MappedSchema
@ -18,6 +17,7 @@ import net.corda.node.services.persistence.DBTransactionStorage
import net.corda.node.services.persistence.NodeAttachmentService import net.corda.node.services.persistence.NodeAttachmentService
import net.corda.node.services.vault.NodeVaultService import net.corda.node.services.vault.NodeVaultService
import net.corda.node.services.vault.VaultSchemaV1 import net.corda.node.services.vault.VaultSchemaV1
import net.corda.node.services.vault.toStateRef
import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.nodeapi.internal.persistence.DatabaseTransaction import net.corda.nodeapi.internal.persistence.DatabaseTransaction
import net.corda.nodeapi.internal.persistence.SchemaMigration import net.corda.nodeapi.internal.persistence.SchemaMigration
@ -61,8 +61,7 @@ class VaultStateMigration : CordaMigration() {
private fun getStateAndRef(persistentState: VaultSchemaV1.VaultStates): StateAndRef<ContractState> { private fun getStateAndRef(persistentState: VaultSchemaV1.VaultStates): StateAndRef<ContractState> {
val persistentStateRef = persistentState.stateRef ?: val persistentStateRef = persistentState.stateRef ?:
throw VaultStateMigrationException("Persistent state ref missing from state") throw VaultStateMigrationException("Persistent state ref missing from state")
val txHash = SecureHash.create(persistentStateRef.txId) val stateRef = persistentStateRef.toStateRef()
val stateRef = StateRef(txHash, persistentStateRef.index)
val state = try { val state = try {
servicesForResolution.loadState(stateRef) servicesForResolution.loadState(stateRef)
} catch (e: Exception) { } catch (e: Exception) {

View File

@ -2,8 +2,8 @@ package net.corda.node.services.events
import net.corda.core.contracts.ScheduledStateRef import net.corda.core.contracts.ScheduledStateRef
import net.corda.core.contracts.StateRef import net.corda.core.contracts.StateRef
import net.corda.core.crypto.SecureHash
import net.corda.core.schemas.PersistentStateRef import net.corda.core.schemas.PersistentStateRef
import net.corda.node.services.vault.toStateRef
import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.nodeapi.internal.persistence.CordaPersistence
interface ScheduledFlowRepository { interface ScheduledFlowRepository {
@ -25,9 +25,8 @@ class PersistentScheduledFlowRepository(val database: CordaPersistence) : Schedu
} }
private fun fromPersistentEntity(scheduledStateRecord: NodeSchedulerService.PersistentScheduledState): Pair<StateRef, ScheduledStateRef> { private fun fromPersistentEntity(scheduledStateRecord: NodeSchedulerService.PersistentScheduledState): Pair<StateRef, ScheduledStateRef> {
val txId = scheduledStateRecord.output.txId val stateRef = scheduledStateRecord.output.toStateRef()
val index = scheduledStateRecord.output.index return Pair(stateRef, ScheduledStateRef(stateRef, scheduledStateRecord.scheduledAt))
return Pair(StateRef(SecureHash.create(txId), index), ScheduledStateRef(StateRef(SecureHash.create(txId), index), scheduledStateRecord.scheduledAt))
} }
override fun delete(key: StateRef): Boolean { override fun delete(key: StateRef): Boolean {

View File

@ -25,6 +25,7 @@ import net.corda.core.serialization.SingletonSerializeAsToken
import net.corda.core.serialization.serialize import net.corda.core.serialization.serialize
import net.corda.core.utilities.contextLogger import net.corda.core.utilities.contextLogger
import net.corda.core.utilities.debug import net.corda.core.utilities.debug
import net.corda.node.services.vault.toStateRef
import net.corda.node.utilities.AppendOnlyPersistentMap import net.corda.node.utilities.AppendOnlyPersistentMap
import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.nodeapi.internal.persistence.NODE_DATABASE_PREFIX import net.corda.nodeapi.internal.persistence.NODE_DATABASE_PREFIX
@ -157,13 +158,7 @@ class PersistentUniquenessProvider(val clock: Clock, val database: CordaPersiste
toPersistentEntityKey = { PersistentStateRef(it.txhash.toString(), it.index) }, toPersistentEntityKey = { PersistentStateRef(it.txhash.toString(), it.index) },
fromPersistentEntity = { fromPersistentEntity = {
//TODO null check will become obsolete after making DB/JPA columns not nullable //TODO null check will become obsolete after making DB/JPA columns not nullable
val txId = it.id.txId Pair(it.id.toStateRef(), SecureHash.create(it.consumingTxHash))
val index = it.id.index
Pair(
StateRef(txhash = SecureHash.create(txId), index = index),
SecureHash.create(it.consumingTxHash)
)
}, },
toPersistentEntity = { (txHash, index): StateRef, id: SecureHash -> toPersistentEntity = { (txHash, index): StateRef, id: SecureHash ->
CommittedState( CommittedState(

View File

@ -3,28 +3,65 @@ package net.corda.node.services.vault
import co.paralleluniverse.fibers.Suspendable import co.paralleluniverse.fibers.Suspendable
import co.paralleluniverse.strands.Strand import co.paralleluniverse.strands.Strand
import net.corda.core.CordaRuntimeException import net.corda.core.CordaRuntimeException
import net.corda.core.contracts.* import net.corda.core.contracts.Amount
import net.corda.core.contracts.ContractState
import net.corda.core.contracts.FungibleAsset
import net.corda.core.contracts.FungibleState
import net.corda.core.contracts.Issued
import net.corda.core.contracts.OwnableState
import net.corda.core.contracts.StateAndRef
import net.corda.core.contracts.StateRef
import net.corda.core.contracts.TransactionState
import net.corda.core.crypto.SecureHash import net.corda.core.crypto.SecureHash
import net.corda.core.crypto.containsAny import net.corda.core.crypto.containsAny
import net.corda.core.flows.HospitalizeFlowException import net.corda.core.flows.HospitalizeFlowException
import net.corda.core.internal.* import net.corda.core.internal.ThreadBox
import net.corda.core.internal.TransactionDeserialisationException
import net.corda.core.internal.VisibleForTesting
import net.corda.core.internal.bufferUntilSubscribed
import net.corda.core.internal.tee
import net.corda.core.internal.uncheckedCast
import net.corda.core.messaging.DataFeed import net.corda.core.messaging.DataFeed
import net.corda.core.node.ServicesForResolution
import net.corda.core.node.StatesToRecord import net.corda.core.node.StatesToRecord
import net.corda.core.node.services.* import net.corda.core.node.services.KeyManagementService
import net.corda.core.node.services.Vault.ConstraintInfo.Companion.constraintInfo import net.corda.core.node.services.StatesNotAvailableException
import net.corda.core.node.services.vault.* import net.corda.core.node.services.Vault
import net.corda.core.node.services.VaultQueryException
import net.corda.core.node.services.VaultService
import net.corda.core.node.services.queryBy
import net.corda.core.node.services.vault.DEFAULT_PAGE_NUM
import net.corda.core.node.services.vault.DEFAULT_PAGE_SIZE
import net.corda.core.node.services.vault.PageSpecification
import net.corda.core.node.services.vault.QueryCriteria
import net.corda.core.node.services.vault.Sort
import net.corda.core.node.services.vault.SortAttribute
import net.corda.core.node.services.vault.builder
import net.corda.core.observable.internal.OnResilientSubscribe import net.corda.core.observable.internal.OnResilientSubscribe
import net.corda.core.schemas.PersistentStateRef import net.corda.core.schemas.PersistentStateRef
import net.corda.core.serialization.SingletonSerializeAsToken import net.corda.core.serialization.SingletonSerializeAsToken
import net.corda.core.transactions.* import net.corda.core.transactions.ContractUpgradeWireTransaction
import net.corda.core.utilities.* import net.corda.core.transactions.CoreTransaction
import net.corda.core.transactions.FullTransaction
import net.corda.core.transactions.LedgerTransaction
import net.corda.core.transactions.NotaryChangeWireTransaction
import net.corda.core.transactions.WireTransaction
import net.corda.core.utilities.NonEmptySet
import net.corda.core.utilities.contextLogger
import net.corda.core.utilities.debug
import net.corda.core.utilities.toNonEmptySet
import net.corda.core.utilities.trace
import net.corda.node.internal.NodeServicesForResolution
import net.corda.node.services.api.SchemaService import net.corda.node.services.api.SchemaService
import net.corda.node.services.api.VaultServiceInternal import net.corda.node.services.api.VaultServiceInternal
import net.corda.node.services.schema.PersistentStateService import net.corda.node.services.schema.PersistentStateService
import net.corda.node.services.statemachine.FlowStateMachineImpl import net.corda.node.services.statemachine.FlowStateMachineImpl
import net.corda.nodeapi.internal.persistence.* import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.nodeapi.internal.persistence.bufferUntilDatabaseCommit
import net.corda.nodeapi.internal.persistence.contextTransactionOrNull
import net.corda.nodeapi.internal.persistence.currentDBSession
import net.corda.nodeapi.internal.persistence.wrapWithDatabaseTransaction
import org.hibernate.Session import org.hibernate.Session
import org.hibernate.query.Query
import rx.Observable import rx.Observable
import rx.exceptions.OnErrorNotImplementedException import rx.exceptions.OnErrorNotImplementedException
import rx.subjects.PublishSubject import rx.subjects.PublishSubject
@ -32,9 +69,11 @@ import java.security.PublicKey
import java.sql.SQLException import java.sql.SQLException
import java.time.Clock import java.time.Clock
import java.time.Instant import java.time.Instant
import java.util.* import java.util.Arrays
import java.util.UUID
import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.CopyOnWriteArraySet import java.util.concurrent.CopyOnWriteArraySet
import java.util.stream.Stream
import javax.persistence.PersistenceException import javax.persistence.PersistenceException
import javax.persistence.Tuple import javax.persistence.Tuple
import javax.persistence.criteria.CriteriaBuilder import javax.persistence.criteria.CriteriaBuilder
@ -54,9 +93,9 @@ import javax.persistence.criteria.Root
class NodeVaultService( class NodeVaultService(
private val clock: Clock, private val clock: Clock,
private val keyManagementService: KeyManagementService, private val keyManagementService: KeyManagementService,
private val servicesForResolution: ServicesForResolution, private val servicesForResolution: NodeServicesForResolution,
private val database: CordaPersistence, private val database: CordaPersistence,
private val schemaService: SchemaService, schemaService: SchemaService,
private val appClassloader: ClassLoader private val appClassloader: ClassLoader
) : SingletonSerializeAsToken(), VaultServiceInternal { ) : SingletonSerializeAsToken(), VaultServiceInternal {
companion object { companion object {
@ -196,7 +235,7 @@ class NodeVaultService(
if (lockId != null) { if (lockId != null) {
lockId = null lockId = null
lockUpdateTime = clock.instant() lockUpdateTime = clock.instant()
log.trace("Releasing soft lock on consumed state: $stateRef") log.trace { "Releasing soft lock on consumed state: $stateRef" }
} }
session.save(state) session.save(state)
} }
@ -227,7 +266,7 @@ class NodeVaultService(
} }
// we are not inside a flow, we are most likely inside a CordaService; // we are not inside a flow, we are most likely inside a CordaService;
// we will expose, by default, subscribing of -non unsubscribing- rx.Observers to rawUpdates. // we will expose, by default, subscribing of -non unsubscribing- rx.Observers to rawUpdates.
return _rawUpdatesPublisher.resilientOnError() _rawUpdatesPublisher.resilientOnError()
} }
override val updates: Observable<Vault.Update<ContractState>> override val updates: Observable<Vault.Update<ContractState>>
@ -639,7 +678,23 @@ class NodeVaultService(
@Throws(VaultQueryException::class) @Throws(VaultQueryException::class)
override fun <T : ContractState> _queryBy(criteria: QueryCriteria, paging: PageSpecification, sorting: Sort, contractStateType: Class<out T>): Vault.Page<T> { override fun <T : ContractState> _queryBy(criteria: QueryCriteria, paging: PageSpecification, sorting: Sort, contractStateType: Class<out T>): Vault.Page<T> {
try { try {
return _queryBy(criteria, paging, sorting, contractStateType, false) // We decrement by one if the client requests MAX_VALUE, assuming they can not notice this because they don't have enough memory
// to request MAX_VALUE states at once.
val validPaging = if (paging.pageSize == Integer.MAX_VALUE) {
paging.copy(pageSize = Integer.MAX_VALUE - 1)
} else {
checkVaultQuery(paging.pageSize >= 1) { "Page specification: invalid page size ${paging.pageSize} [minimum is 1]" }
paging
}
if (!validPaging.isDefault) {
checkVaultQuery(validPaging.pageNumber >= DEFAULT_PAGE_NUM) {
"Page specification: invalid page number ${validPaging.pageNumber} [page numbers start from $DEFAULT_PAGE_NUM]"
}
}
log.debug { "Vault Query for contract type: $contractStateType, criteria: $criteria, pagination: $validPaging, sorting: $sorting" }
return database.transaction {
queryBy(criteria, validPaging, sorting, contractStateType)
}
} catch (e: VaultQueryException) { } catch (e: VaultQueryException) {
throw e throw e
} catch (e: Exception) { } catch (e: Exception) {
@ -647,100 +702,90 @@ class NodeVaultService(
} }
} }
@Throws(VaultQueryException::class) private fun <T : ContractState> queryBy(criteria: QueryCriteria,
private fun <T : ContractState> _queryBy(criteria: QueryCriteria, paging_: PageSpecification, sorting: Sort, contractStateType: Class<out T>, skipPagingChecks: Boolean): Vault.Page<T> { paging: PageSpecification,
// We decrement by one if the client requests MAX_PAGE_SIZE, assuming they can not notice this because they don't have enough memory sorting: Sort,
// to request `MAX_PAGE_SIZE` states at once. contractStateType: Class<out T>): Vault.Page<T> {
val paging = if (paging_.pageSize == Integer.MAX_VALUE) { // calculate total results where a page specification has been defined
paging_.copy(pageSize = Integer.MAX_VALUE - 1) val totalStatesAvailable = if (paging.isDefault) -1 else queryTotalStateCount(criteria, contractStateType)
} else {
paging_ val (query, stateTypes) = createQuery(criteria, contractStateType, sorting)
query.setResultWindow(paging)
val statesMetadata: MutableList<Vault.StateMetadata> = mutableListOf()
val otherResults: MutableList<Any> = mutableListOf()
query.resultStream(paging).use { results ->
results.forEach { result ->
val result0 = result[0]
if (result0 is VaultSchemaV1.VaultStates) {
statesMetadata.add(result0.toStateMetadata())
} else {
log.debug { "OtherResults: ${Arrays.toString(result.toArray())}" }
otherResults.addAll(result.toArray().asList())
}
}
} }
log.debug { "Vault Query for contract type: $contractStateType, criteria: $criteria, pagination: $paging, sorting: $sorting" }
return database.transaction {
// calculate total results where a page specification has been defined
var totalStates = -1L
if (!skipPagingChecks && !paging.isDefault) {
val count = builder { VaultSchemaV1.VaultStates::recordedTime.count() }
val countCriteria = QueryCriteria.VaultCustomQueryCriteria(count, Vault.StateStatus.ALL)
val results = _queryBy(criteria.and(countCriteria), PageSpecification(), Sort(emptyList()), contractStateType, true) // only skip pagination checks for total results count query
totalStates = results.otherResults.last() as Long
}
val session = getSession() val states: List<StateAndRef<T>> = servicesForResolution.loadStates(
statesMetadata.mapTo(LinkedHashSet()) { it.ref },
ArrayList()
)
val criteriaQuery = criteriaBuilder.createQuery(Tuple::class.java) return Vault.Page(states, statesMetadata, totalStatesAvailable, stateTypes, otherResults)
val queryRootVaultStates = criteriaQuery.from(VaultSchemaV1.VaultStates::class.java) }
// TODO: revisit (use single instance of parser for all queries)
val criteriaParser = HibernateQueryCriteriaParser(contractStateType, contractStateTypeMappings, criteriaBuilder, criteriaQuery, queryRootVaultStates)
// parse criteria and build where predicates
criteriaParser.parse(criteria, sorting)
// prepare query for execution
val query = session.createQuery(criteriaQuery)
// pagination checks
if (!skipPagingChecks && !paging.isDefault) {
// pagination
if (paging.pageNumber < DEFAULT_PAGE_NUM) throw VaultQueryException("Page specification: invalid page number ${paging.pageNumber} [page numbers start from $DEFAULT_PAGE_NUM]")
if (paging.pageSize < 1) throw VaultQueryException("Page specification: invalid page size ${paging.pageSize} [minimum is 1]")
if (paging.pageSize > MAX_PAGE_SIZE) throw VaultQueryException("Page specification: invalid page size ${paging.pageSize} [maximum is $MAX_PAGE_SIZE]")
}
// For both SQLServer and PostgresSQL, firstResult must be >= 0. So we set a floor at 0.
// TODO: This is a catch-all solution. But why is the default pageNumber set to be -1 in the first place?
// Even if we set the default pageNumber to be 1 instead, that may not cover the non-default cases.
// So the floor may be necessary anyway.
query.firstResult = maxOf(0, (paging.pageNumber - 1) * paging.pageSize)
val pageSize = paging.pageSize + 1
query.maxResults = if (pageSize > 0) pageSize else Integer.MAX_VALUE // detection too many results, protected against overflow
// execution
val results = query.resultList
private fun <R> Query<R>.resultStream(paging: PageSpecification): Stream<R> {
return if (paging.isDefault) {
val allResults = resultList
// final pagination check (fail-fast on too many results when no pagination specified) // final pagination check (fail-fast on too many results when no pagination specified)
if (!skipPagingChecks && paging.isDefault && results.size > DEFAULT_PAGE_SIZE) { checkVaultQuery(allResults.size != paging.pageSize + 1) {
throw VaultQueryException("There are ${results.size} results, which exceeds the limit of $DEFAULT_PAGE_SIZE for queries that do not specify paging. In order to retrieve these results, provide a `PageSpecification(pageNumber, pageSize)` to the method invoked.") "There are more results than the limit of $DEFAULT_PAGE_SIZE for queries that do not specify paging. " +
"In order to retrieve these results, provide a PageSpecification to the method invoked."
} }
val statesAndRefs: MutableList<StateAndRef<T>> = mutableListOf() allResults.stream()
val statesMeta: MutableList<Vault.StateMetadata> = mutableListOf() } else {
val otherResults: MutableList<Any> = mutableListOf() stream()
val stateRefs = mutableSetOf<StateRef>()
results.asSequence()
.forEachIndexed { index, result ->
if (result[0] is VaultSchemaV1.VaultStates) {
if (!paging.isDefault && index == paging.pageSize) // skip last result if paged
return@forEachIndexed
val vaultState = result[0] as VaultSchemaV1.VaultStates
val stateRef = StateRef(SecureHash.create(vaultState.stateRef!!.txId), vaultState.stateRef!!.index)
stateRefs.add(stateRef)
statesMeta.add(Vault.StateMetadata(stateRef,
vaultState.contractStateClassName,
vaultState.recordedTime,
vaultState.consumedTime,
vaultState.stateStatus,
vaultState.notary,
vaultState.lockId,
vaultState.lockUpdateTime,
vaultState.relevancyStatus,
constraintInfo(vaultState.constraintType, vaultState.constraintData)
))
} else {
// TODO: improve typing of returned other results
log.debug { "OtherResults: ${Arrays.toString(result.toArray())}" }
otherResults.addAll(result.toArray().asList())
}
}
if (stateRefs.isNotEmpty())
statesAndRefs.addAll(uncheckedCast(servicesForResolution.loadStates(stateRefs)))
Vault.Page(states = statesAndRefs, statesMetadata = statesMeta, stateTypes = criteriaParser.stateTypes, totalStatesAvailable = totalStates, otherResults = otherResults)
} }
} }
private fun Query<*>.setResultWindow(paging: PageSpecification) {
if (paging.isDefault) {
// For both SQLServer and PostgresSQL, firstResult must be >= 0.
firstResult = 0
// Peek ahead and see if there are more results in case pagination should be done
maxResults = paging.pageSize + 1
} else {
firstResult = (paging.pageNumber - 1) * paging.pageSize
maxResults = paging.pageSize
}
}
private fun <T : ContractState> queryTotalStateCount(baseCriteria: QueryCriteria, contractStateType: Class<out T>): Long {
val count = builder { VaultSchemaV1.VaultStates::recordedTime.count() }
val countCriteria = QueryCriteria.VaultCustomQueryCriteria(count, Vault.StateStatus.ALL)
val criteria = baseCriteria.and(countCriteria)
val (query) = createQuery(criteria, contractStateType, null)
val results = query.resultList
return results.last().toArray().last() as Long
}
private fun <T : ContractState> createQuery(criteria: QueryCriteria,
contractStateType: Class<out T>,
sorting: Sort?): Pair<Query<Tuple>, Vault.StateStatus> {
val criteriaQuery = criteriaBuilder.createQuery(Tuple::class.java)
val criteriaParser = HibernateQueryCriteriaParser(
contractStateType,
contractStateTypeMappings,
criteriaBuilder,
criteriaQuery,
criteriaQuery.from(VaultSchemaV1.VaultStates::class.java)
)
criteriaParser.parse(criteria, sorting)
val query = getSession().createQuery(criteriaQuery)
return Pair(query, criteriaParser.stateTypes)
}
/** /**
* Returns a [DataFeed] containing the results of the provided query, along with the associated observable, containing any subsequent updates. * Returns a [DataFeed] containing the results of the provided query, along with the associated observable, containing any subsequent updates.
* *
@ -775,6 +820,12 @@ class NodeVaultService(
} }
} }
private inline fun checkVaultQuery(value: Boolean, lazyMessage: () -> Any) {
if (!value) {
throw VaultQueryException(lazyMessage().toString())
}
}
private fun <T : ContractState> filterContractStates(update: Vault.Update<T>, contractStateType: Class<out T>) = private fun <T : ContractState> filterContractStates(update: Vault.Update<T>, contractStateType: Class<out T>) =
update.copy(consumed = filterByContractState(contractStateType, update.consumed), update.copy(consumed = filterByContractState(contractStateType, update.consumed),
produced = filterByContractState(contractStateType, update.produced)) produced = filterByContractState(contractStateType, update.produced))
@ -802,6 +853,7 @@ class NodeVaultService(
} }
private fun getSession() = database.currentOrNew().session private fun getSession() = database.currentOrNew().session
/** /**
* Derive list from existing vault states and then incrementally update using vault observables * Derive list from existing vault states and then incrementally update using vault observables
*/ */

View File

@ -2,7 +2,9 @@ package net.corda.node.services.vault
import net.corda.core.contracts.ContractState import net.corda.core.contracts.ContractState
import net.corda.core.contracts.MAX_ISSUER_REF_SIZE import net.corda.core.contracts.MAX_ISSUER_REF_SIZE
import net.corda.core.contracts.StateRef
import net.corda.core.contracts.UniqueIdentifier import net.corda.core.contracts.UniqueIdentifier
import net.corda.core.crypto.SecureHash
import net.corda.core.crypto.toStringShort import net.corda.core.crypto.toStringShort
import net.corda.core.identity.AbstractParty import net.corda.core.identity.AbstractParty
import net.corda.core.identity.Party import net.corda.core.identity.Party
@ -192,3 +194,19 @@ object VaultSchemaV1 : MappedSchema(
) : IndirectStatePersistable<PersistentStateRefAndKey> ) : IndirectStatePersistable<PersistentStateRefAndKey>
} }
fun PersistentStateRef.toStateRef(): StateRef = StateRef(SecureHash.create(txId), index)
fun VaultSchemaV1.VaultStates.toStateMetadata(): Vault.StateMetadata {
return Vault.StateMetadata(
stateRef!!.toStateRef(),
contractStateClassName,
recordedTime,
consumedTime,
stateStatus,
notary,
lockId,
lockUpdateTime,
relevancyStatus,
Vault.ConstraintInfo.constraintInfo(constraintType, constraintData)
)
}

View File

@ -21,6 +21,7 @@ import net.corda.core.utilities.getOrThrow
import net.corda.core.utilities.unwrap import net.corda.core.utilities.unwrap
import net.corda.node.services.api.ServiceHubInternal import net.corda.node.services.api.ServiceHubInternal
import net.corda.node.services.transactions.PersistentUniquenessProvider import net.corda.node.services.transactions.PersistentUniquenessProvider
import net.corda.node.services.vault.toStateRef
import net.corda.node.utilities.AppendOnlyPersistentMap import net.corda.node.utilities.AppendOnlyPersistentMap
import net.corda.nodeapi.internal.persistence.NODE_DATABASE_PREFIX import net.corda.nodeapi.internal.persistence.NODE_DATABASE_PREFIX
import java.security.PublicKey import java.security.PublicKey
@ -41,6 +42,8 @@ class BFTSmartNotaryService(
) : NotaryService() { ) : NotaryService() {
companion object { companion object {
private val log = contextLogger() private val log = contextLogger()
@Suppress("unused") // Used by NotaryLoader via reflection
@JvmStatic @JvmStatic
val serializationFilter val serializationFilter
get() = { clazz: Class<*> -> get() = { clazz: Class<*> ->
@ -147,12 +150,7 @@ class BFTSmartNotaryService(
toPersistentEntityKey = { PersistentStateRef(it.txhash.toString(), it.index) }, toPersistentEntityKey = { PersistentStateRef(it.txhash.toString(), it.index) },
fromPersistentEntity = { fromPersistentEntity = {
//TODO null check will become obsolete after making DB/JPA columns not nullable //TODO null check will become obsolete after making DB/JPA columns not nullable
val txId = it.id.txId Pair(it.id.toStateRef(), SecureHash.create(it.consumingTxHash))
val index = it.id.index
Pair(
StateRef(txhash = SecureHash.create(txId), index = index),
SecureHash.create(it.consumingTxHash)
)
}, },
toPersistentEntity = { (txHash, index): StateRef, id: SecureHash -> toPersistentEntity = { (txHash, index): StateRef, id: SecureHash ->
CommittedState( CommittedState(

View File

@ -24,6 +24,7 @@ import net.corda.core.serialization.SingletonSerializeAsToken
import net.corda.core.serialization.serialize import net.corda.core.serialization.serialize
import net.corda.core.utilities.contextLogger import net.corda.core.utilities.contextLogger
import net.corda.core.utilities.debug import net.corda.core.utilities.debug
import net.corda.node.services.vault.toStateRef
import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.nodeapi.internal.persistence.NODE_DATABASE_PREFIX import net.corda.nodeapi.internal.persistence.NODE_DATABASE_PREFIX
import net.corda.notary.common.InternalResult import net.corda.notary.common.InternalResult
@ -142,10 +143,6 @@ class JPAUniquenessProvider(
fun encodeStateRef(s: StateRef): PersistentStateRef { fun encodeStateRef(s: StateRef): PersistentStateRef {
return PersistentStateRef(s.txhash.toString(), s.index) return PersistentStateRef(s.txhash.toString(), s.index)
} }
fun decodeStateRef(s: PersistentStateRef): StateRef {
return StateRef(txhash = SecureHash.create(s.txId), index = s.index)
}
} }
/** /**
@ -215,15 +212,15 @@ class JPAUniquenessProvider(
committedStates.addAll(existing) committedStates.addAll(existing)
} }
return committedStates.map { return committedStates.associate {
val stateRef = StateRef(txhash = SecureHash.create(it.id.txId), index = it.id.index) val stateRef = it.id.toStateRef()
val consumingTxId = SecureHash.create(it.consumingTxHash) val consumingTxId = SecureHash.create(it.consumingTxHash)
if (stateRef in references) { if (stateRef in references) {
stateRef to StateConsumptionDetails(consumingTxId.reHash(), type = StateConsumptionDetails.ConsumedStateType.REFERENCE_INPUT_STATE) stateRef to StateConsumptionDetails(consumingTxId.reHash(), type = StateConsumptionDetails.ConsumedStateType.REFERENCE_INPUT_STATE)
} else { } else {
stateRef to StateConsumptionDetails(consumingTxId.reHash()) stateRef to StateConsumptionDetails(consumingTxId.reHash())
} }
}.toMap() }
} }
private fun<T> withRetry(block: () -> T): T { private fun<T> withRetry(block: () -> T): T {

View File

@ -28,12 +28,14 @@ import net.corda.finance.schemas.CashSchemaV1
import net.corda.finance.test.SampleCashSchemaV1 import net.corda.finance.test.SampleCashSchemaV1
import net.corda.finance.test.SampleCashSchemaV2 import net.corda.finance.test.SampleCashSchemaV2
import net.corda.finance.test.SampleCashSchemaV3 import net.corda.finance.test.SampleCashSchemaV3
import net.corda.node.internal.NodeServicesForResolution
import net.corda.node.services.api.WritableTransactionStorage import net.corda.node.services.api.WritableTransactionStorage
import net.corda.node.services.schema.ContractStateAndRef import net.corda.node.services.schema.ContractStateAndRef
import net.corda.node.services.schema.NodeSchemaService import net.corda.node.services.schema.NodeSchemaService
import net.corda.node.services.schema.PersistentStateService import net.corda.node.services.schema.PersistentStateService
import net.corda.node.services.vault.NodeVaultService import net.corda.node.services.vault.NodeVaultService
import net.corda.node.services.vault.VaultSchemaV1 import net.corda.node.services.vault.VaultSchemaV1
import net.corda.node.services.vault.toStateRef
import net.corda.node.testing.DummyFungibleContract import net.corda.node.testing.DummyFungibleContract
import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.nodeapi.internal.persistence.DatabaseConfig import net.corda.nodeapi.internal.persistence.DatabaseConfig
@ -48,7 +50,6 @@ import net.corda.testing.internal.vault.VaultFiller
import net.corda.testing.node.MockServices import net.corda.testing.node.MockServices
import net.corda.testing.node.MockServices.Companion.makeTestDataSourceProperties import net.corda.testing.node.MockServices.Companion.makeTestDataSourceProperties
import org.assertj.core.api.Assertions import org.assertj.core.api.Assertions
import org.assertj.core.api.Assertions.`in`
import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThat
import org.assertj.core.api.Assertions.assertThatThrownBy import org.assertj.core.api.Assertions.assertThatThrownBy
import org.hibernate.SessionFactory import org.hibernate.SessionFactory
@ -122,7 +123,14 @@ class HibernateConfigurationTest {
services = object : MockServices(cordappPackages, BOB_NAME, mock<IdentityService>().also { services = object : MockServices(cordappPackages, BOB_NAME, mock<IdentityService>().also {
doReturn(null).whenever(it).verifyAndRegisterIdentity(argThat { name == BOB_NAME }) doReturn(null).whenever(it).verifyAndRegisterIdentity(argThat { name == BOB_NAME })
}, generateKeyPair(), dummyNotary.keyPair) { }, generateKeyPair(), dummyNotary.keyPair) {
override val vaultService = NodeVaultService(Clock.systemUTC(), keyManagementService, servicesForResolution, database, schemaService, cordappClassloader).apply { start() } override val vaultService = NodeVaultService(
Clock.systemUTC(),
keyManagementService,
servicesForResolution as NodeServicesForResolution,
database,
schemaService,
cordappClassloader
).apply { start() }
override fun recordTransactions(statesToRecord: StatesToRecord, txs: Iterable<SignedTransaction>) { override fun recordTransactions(statesToRecord: StatesToRecord, txs: Iterable<SignedTransaction>) {
for (stx in txs) { for (stx in txs) {
(validatedTransactions as WritableTransactionStorage).addTransaction(stx) (validatedTransactions as WritableTransactionStorage).addTransaction(stx)
@ -183,7 +191,7 @@ class HibernateConfigurationTest {
// execute query // execute query
val queryResults = entityManager.createQuery(criteriaQuery).resultList val queryResults = entityManager.createQuery(criteriaQuery).resultList
val coins = queryResults.map { val coins = queryResults.map {
services.loadState(toStateRef(it.stateRef!!)).data services.loadState(it.stateRef!!.toStateRef()).data
}.sumCash() }.sumCash()
assertThat(coins.toDecimal() >= BigDecimal("50.00")) assertThat(coins.toDecimal() >= BigDecimal("50.00"))
} }
@ -739,7 +747,7 @@ class HibernateConfigurationTest {
val queryResults = entityManager.createQuery(criteriaQuery).resultList val queryResults = entityManager.createQuery(criteriaQuery).resultList
queryResults.forEach { queryResults.forEach {
val cashState = services.loadState(toStateRef(it.stateRef!!)).data as Cash.State val cashState = services.loadState(it.stateRef!!.toStateRef()).data as Cash.State
println("${it.stateRef} with owner: ${cashState.owner.owningKey.toBase58String()}") println("${it.stateRef} with owner: ${cashState.owner.owningKey.toBase58String()}")
} }
@ -823,7 +831,7 @@ class HibernateConfigurationTest {
// execute query // execute query
val queryResults = entityManager.createQuery(criteriaQuery).resultList val queryResults = entityManager.createQuery(criteriaQuery).resultList
queryResults.forEach { queryResults.forEach {
val cashState = services.loadState(toStateRef(it.stateRef!!)).data as Cash.State val cashState = services.loadState(it.stateRef!!.toStateRef()).data as Cash.State
println("${it.stateRef} with owner ${cashState.owner.owningKey.toBase58String()} and participants ${cashState.participants.map { it.owningKey.toBase58String() }}") println("${it.stateRef} with owner ${cashState.owner.owningKey.toBase58String()} and participants ${cashState.participants.map { it.owningKey.toBase58String() }}")
} }
@ -961,10 +969,6 @@ class HibernateConfigurationTest {
} }
} }
private fun toStateRef(pStateRef: PersistentStateRef): StateRef {
return StateRef(SecureHash.create(pStateRef.txId), pStateRef.index)
}
@Test(timeout=300_000) @Test(timeout=300_000)
fun `schema change`() { fun `schema change`() {
fun createNewDB(schemas: Set<MappedSchema>, initialiseSchema: Boolean = true): CordaPersistence { fun createNewDB(schemas: Set<MappedSchema>, initialiseSchema: Boolean = true): CordaPersistence {

View File

@ -1674,7 +1674,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
// pagination: last page // pagination: last page
@Test(timeout=300_000) @Test(timeout=300_000)
fun `all states with paging specification - last`() { fun `all states with paging specification - last`() {
database.transaction { database.transaction {
vaultFiller.fillWithSomeTestCash(95.DOLLARS, notaryServices, 95, DUMMY_CASH_ISSUER) vaultFiller.fillWithSomeTestCash(95.DOLLARS, notaryServices, 95, DUMMY_CASH_ISSUER)
// Last page implies we need to perform a row count for the Query first, // Last page implies we need to perform a row count for the Query first,
@ -1723,7 +1723,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
@Test(timeout=300_000) @Test(timeout=300_000)
fun `pagination not specified but more than default results available`() { fun `pagination not specified but more than default results available`() {
expectedEx.expect(VaultQueryException::class.java) expectedEx.expect(VaultQueryException::class.java)
expectedEx.expectMessage("provide a `PageSpecification(pageNumber, pageSize)`") expectedEx.expectMessage("provide a PageSpecification")
database.transaction { database.transaction {
vaultFiller.fillWithSomeTestCash(201.DOLLARS, notaryServices, 201, DUMMY_CASH_ISSUER) vaultFiller.fillWithSomeTestCash(201.DOLLARS, notaryServices, 201, DUMMY_CASH_ISSUER)

View File

@ -10,7 +10,6 @@ import net.corda.core.flows.InitiatingFlow
import net.corda.core.identity.AbstractParty import net.corda.core.identity.AbstractParty
import net.corda.core.internal.FlowStateMachine import net.corda.core.internal.FlowStateMachine
import net.corda.core.internal.uncheckedCast import net.corda.core.internal.uncheckedCast
import net.corda.core.node.ServicesForResolution
import net.corda.core.node.services.KeyManagementService import net.corda.core.node.services.KeyManagementService
import net.corda.core.node.services.queryBy import net.corda.core.node.services.queryBy
import net.corda.core.node.services.vault.QueryCriteria.SoftLockingCondition import net.corda.core.node.services.vault.QueryCriteria.SoftLockingCondition
@ -29,6 +28,7 @@ import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.testing.core.singleIdentity import net.corda.testing.core.singleIdentity
import net.corda.testing.flows.registerCoreFlowFactory import net.corda.testing.flows.registerCoreFlowFactory
import net.corda.coretesting.internal.rigorousMock import net.corda.coretesting.internal.rigorousMock
import net.corda.node.internal.NodeServicesForResolution
import net.corda.testing.node.internal.InternalMockNetwork import net.corda.testing.node.internal.InternalMockNetwork
import net.corda.testing.node.internal.enclosedCordapp import net.corda.testing.node.internal.enclosedCordapp
import net.corda.testing.node.internal.startFlow import net.corda.testing.node.internal.startFlow
@ -86,7 +86,7 @@ class VaultSoftLockManagerTest {
private val mockNet = InternalMockNetwork(cordappsForAllNodes = listOf(enclosedCordapp()), defaultFactory = { args -> private val mockNet = InternalMockNetwork(cordappsForAllNodes = listOf(enclosedCordapp()), defaultFactory = { args ->
object : InternalMockNetwork.MockNode(args) { object : InternalMockNetwork.MockNode(args) {
override fun makeVaultService(keyManagementService: KeyManagementService, override fun makeVaultService(keyManagementService: KeyManagementService,
services: ServicesForResolution, services: NodeServicesForResolution,
database: CordaPersistence, database: CordaPersistence,
cordappLoader: CordappLoader): VaultServiceInternal { cordappLoader: CordappLoader): VaultServiceInternal {
val node = this val node = this

View File

@ -26,6 +26,7 @@ import net.corda.core.transactions.SignedTransaction
import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.NetworkHostAndPort
import net.corda.node.VersionInfo import net.corda.node.VersionInfo
import net.corda.node.internal.ServicesForResolutionImpl import net.corda.node.internal.ServicesForResolutionImpl
import net.corda.node.internal.NodeServicesForResolution
import net.corda.node.internal.cordapp.JarScanningCordappLoader import net.corda.node.internal.cordapp.JarScanningCordappLoader
import net.corda.node.services.api.* import net.corda.node.services.api.*
import net.corda.node.services.diagnostics.NodeDiagnosticsService import net.corda.node.services.diagnostics.NodeDiagnosticsService
@ -460,7 +461,14 @@ open class MockServices private constructor(
get() = ServicesForResolutionImpl(identityService, attachments, cordappProvider, networkParametersService, validatedTransactions) get() = ServicesForResolutionImpl(identityService, attachments, cordappProvider, networkParametersService, validatedTransactions)
internal fun makeVaultService(schemaService: SchemaService, database: CordaPersistence, cordappLoader: CordappLoader): VaultServiceInternal { internal fun makeVaultService(schemaService: SchemaService, database: CordaPersistence, cordappLoader: CordappLoader): VaultServiceInternal {
return NodeVaultService(clock, keyManagementService, servicesForResolution, database, schemaService, cordappLoader.appClassLoader).apply { start() } return NodeVaultService(
clock,
keyManagementService,
servicesForResolution as NodeServicesForResolution,
database,
schemaService,
cordappLoader.appClassLoader
).apply { start() }
} }
// This needs to be internal as MutableClassToInstanceMap is a guava type and shouldn't be part of our public API // This needs to be internal as MutableClassToInstanceMap is a guava type and shouldn't be part of our public API