CORDA-599 Fix circular dependency between vault and SH (#1630)

Fix circular dependency between the 2 vault objects and SH.
This commit is contained in:
Andrzej Cichocki
2017-10-09 12:49:07 +01:00
committed by Mike Hearn
parent 0c2289de8c
commit f83f1b7010
17 changed files with 148 additions and 122 deletions

View File

@ -7,10 +7,6 @@ import net.corda.core.crypto.SignableData
import net.corda.core.crypto.SignatureMetadata import net.corda.core.crypto.SignatureMetadata
import net.corda.core.crypto.TransactionSignature import net.corda.core.crypto.TransactionSignature
import net.corda.core.flows.ContractUpgradeFlow import net.corda.core.flows.ContractUpgradeFlow
import net.corda.core.identity.AbstractParty
import net.corda.core.identity.AnonymousParty
import net.corda.core.identity.Party
import net.corda.core.internal.toMultiMap
import net.corda.core.node.services.* import net.corda.core.node.services.*
import net.corda.core.serialization.SerializeAsToken import net.corda.core.serialization.SerializeAsToken
import net.corda.core.transactions.FilteredTransaction import net.corda.core.transactions.FilteredTransaction
@ -20,11 +16,24 @@ import java.security.PublicKey
import java.sql.Connection import java.sql.Connection
import java.time.Clock import java.time.Clock
/**
* Part of [ServiceHub].
*/
interface StateLoader {
/**
* Given a [StateRef] loads the referenced transaction and looks up the specified output [ContractState].
*
* @throws TransactionResolutionException if [stateRef] points to a non-existent transaction.
*/
@Throws(TransactionResolutionException::class)
fun loadState(stateRef: StateRef): TransactionState<*>
}
/** /**
* Subset of node services that are used for loading transactions from the wire into fully resolved, looked up * Subset of node services that are used for loading transactions from the wire into fully resolved, looked up
* forms ready for verification. * forms ready for verification.
*/ */
interface ServicesForResolution { interface ServicesForResolution : StateLoader {
/** /**
* An identity service maintains a directory of parties by their associated distinguished name/public keys and thus * An identity service maintains a directory of parties by their associated distinguished name/public keys and thus
* supports lookup of a party given its key, or name. The service also manages the certificates linking confidential * supports lookup of a party given its key, or name. The service also manages the certificates linking confidential
@ -37,14 +46,6 @@ interface ServicesForResolution {
/** Provides access to anything relating to cordapps including contract attachment resolution and app context */ /** Provides access to anything relating to cordapps including contract attachment resolution and app context */
val cordappProvider: CordappProvider val cordappProvider: CordappProvider
/**
* Given a [StateRef] loads the referenced transaction and looks up the specified output [ContractState].
*
* @throws TransactionResolutionException if the [StateRef] points to a non-existent transaction.
*/
@Throws(TransactionResolutionException::class)
fun loadState(stateRef: StateRef): TransactionState<*>
} }
/** /**
@ -155,19 +156,6 @@ interface ServiceHub : ServicesForResolution {
recordTransactions(true, txs) recordTransactions(true, txs)
} }
/**
* Given a [StateRef] loads the referenced transaction and looks up the specified output [ContractState].
*
* @throws TransactionResolutionException if [stateRef] points to a non-existent transaction.
*/
@Throws(TransactionResolutionException::class)
override fun loadState(stateRef: StateRef): TransactionState<*> {
val stx = validatedTransactions.getTransaction(stateRef.txhash) ?: throw TransactionResolutionException(stateRef.txhash)
return if (stx.isNotaryChangeTransaction()) {
stx.resolveNotaryChangeTransaction(this).outputs[stateRef.index]
} else stx.tx.outputs[stateRef.index]
}
/** /**
* Converts the given [StateRef] into a [StateAndRef] object. * Converts the given [StateRef] into a [StateAndRef] object.
* *

View File

@ -7,6 +7,7 @@ import net.corda.core.crypto.serializedHash
import net.corda.core.utilities.toBase58String import net.corda.core.utilities.toBase58String
import net.corda.core.identity.Party import net.corda.core.identity.Party
import net.corda.core.node.ServiceHub import net.corda.core.node.ServiceHub
import net.corda.core.node.StateLoader
import net.corda.core.serialization.CordaSerializable import net.corda.core.serialization.CordaSerializable
import java.security.PublicKey import java.security.PublicKey
@ -39,9 +40,10 @@ data class NotaryChangeWireTransaction(
*/ */
override val id: SecureHash by lazy { serializedHash(inputs + notary + newNotary) } override val id: SecureHash by lazy { serializedHash(inputs + notary + newNotary) }
fun resolve(services: ServiceHub, sigs: List<TransactionSignature>): NotaryChangeLedgerTransaction { fun resolve(services: ServiceHub, sigs: List<TransactionSignature>) = resolve(services as StateLoader, sigs)
fun resolve(stateLoader: StateLoader, sigs: List<TransactionSignature>): NotaryChangeLedgerTransaction {
val resolvedInputs = inputs.map { ref -> val resolvedInputs = inputs.map { ref ->
services.loadState(ref).let { StateAndRef(it, ref) } stateLoader.loadState(ref).let { StateAndRef(it, ref) }
} }
return NotaryChangeLedgerTransaction(resolvedInputs, notary, newNotary, id, sigs) return NotaryChangeLedgerTransaction(resolvedInputs, notary, newNotary, id, sigs)
} }

View File

@ -7,6 +7,7 @@ import net.corda.core.crypto.*
import net.corda.core.identity.Party import net.corda.core.identity.Party
import net.corda.core.internal.VisibleForTesting import net.corda.core.internal.VisibleForTesting
import net.corda.core.node.ServiceHub import net.corda.core.node.ServiceHub
import net.corda.core.node.StateLoader
import net.corda.core.serialization.CordaSerializable import net.corda.core.serialization.CordaSerializable
import net.corda.core.serialization.SerializedBytes import net.corda.core.serialization.SerializedBytes
import net.corda.core.serialization.deserialize import net.corda.core.serialization.deserialize
@ -181,10 +182,12 @@ data class SignedTransaction(val txBits: SerializedBytes<CoreTransaction>,
* If [transaction] is a [NotaryChangeWireTransaction], loads the input states and resolves it to a * If [transaction] is a [NotaryChangeWireTransaction], loads the input states and resolves it to a
* [NotaryChangeLedgerTransaction] so the signatures can be verified. * [NotaryChangeLedgerTransaction] so the signatures can be verified.
*/ */
fun resolveNotaryChangeTransaction(services: ServiceHub): NotaryChangeLedgerTransaction { fun resolveNotaryChangeTransaction(services: ServiceHub) = resolveNotaryChangeTransaction(services as StateLoader)
fun resolveNotaryChangeTransaction(stateLoader: StateLoader): NotaryChangeLedgerTransaction {
val ntx = transaction as? NotaryChangeWireTransaction val ntx = transaction as? NotaryChangeWireTransaction
?: throw IllegalStateException("Expected a ${NotaryChangeWireTransaction::class.simpleName} but found ${transaction::class.simpleName}") ?: throw IllegalStateException("Expected a ${NotaryChangeWireTransaction::class.simpleName} but found ${transaction::class.simpleName}")
return ntx.resolve(services, sigs) return ntx.resolve(stateLoader, sigs)
} }
override fun toString(): String = "${javaClass.simpleName}(id=$id)" override fun toString(): String = "${javaClass.simpleName}(id=$id)"

View File

@ -22,7 +22,7 @@ class LedgerTransactionQueryTests : TestDependencyInjectionBase() {
@Before @Before
fun setup() { fun setup() {
services.mockCordappProvider.addMockCordapp(DummyContract.PROGRAM_ID, services) services.mockCordappProvider.addMockCordapp(DummyContract.PROGRAM_ID, services.attachments)
} }
interface Commands { interface Commands {

View File

@ -19,6 +19,10 @@ import net.corda.core.internal.concurrent.flatMap
import net.corda.core.internal.concurrent.openFuture import net.corda.core.internal.concurrent.openFuture
import net.corda.core.internal.toX509CertHolder import net.corda.core.internal.toX509CertHolder
import net.corda.core.internal.uncheckedCast import net.corda.core.internal.uncheckedCast
import net.corda.core.messaging.CordaRPCOps
import net.corda.core.messaging.RPCOps
import net.corda.core.messaging.SingleMessageRecipient
import net.corda.core.node.*
import net.corda.core.messaging.* import net.corda.core.messaging.*
import net.corda.core.node.AppServiceHub import net.corda.core.node.AppServiceHub
import net.corda.core.node.NodeInfo import net.corda.core.node.NodeInfo
@ -453,7 +457,8 @@ abstract class AbstractNode(config: NodeConfiguration,
private fun makeServices(schemaService: SchemaService): MutableList<Any> { private fun makeServices(schemaService: SchemaService): MutableList<Any> {
checkpointStorage = DBCheckpointStorage() checkpointStorage = DBCheckpointStorage()
cordappProvider = CordappProviderImpl(cordappLoader) cordappProvider = CordappProviderImpl(cordappLoader)
_services = ServiceHubInternalImpl(schemaService) val transactionStorage = makeTransactionStorage()
_services = ServiceHubInternalImpl(schemaService, transactionStorage, StateLoaderImpl(transactionStorage))
attachments = NodeAttachmentService(services.monitoringService.metrics) attachments = NodeAttachmentService(services.monitoringService.metrics)
cordappProvider.start(attachments) cordappProvider.start(attachments)
legalIdentity = obtainIdentity(notaryConfig = null) legalIdentity = obtainIdentity(notaryConfig = null)
@ -752,15 +757,18 @@ abstract class AbstractNode(config: NodeConfiguration,
protected open fun generateKeyPair() = cryptoGenerateKeyPair() protected open fun generateKeyPair() = cryptoGenerateKeyPair()
private inner class ServiceHubInternalImpl(override val schemaService: SchemaService) : ServiceHubInternal, SingletonSerializeAsToken() { private inner class ServiceHubInternalImpl(
override val schemaService: SchemaService,
override val validatedTransactions: WritableTransactionStorage,
private val stateLoader: StateLoader
) : SingletonSerializeAsToken(), ServiceHubInternal, StateLoader by stateLoader {
override val rpcFlows = ArrayList<Class<out FlowLogic<*>>>() override val rpcFlows = ArrayList<Class<out FlowLogic<*>>>()
override val stateMachineRecordedTransactionMapping = DBTransactionMappingStorage() override val stateMachineRecordedTransactionMapping = DBTransactionMappingStorage()
override val auditService = DummyAuditService() override val auditService = DummyAuditService()
override val monitoringService = MonitoringService(MetricRegistry()) override val monitoringService = MonitoringService(MetricRegistry())
override val validatedTransactions = makeTransactionStorage()
override val transactionVerifierService by lazy { makeTransactionVerifierService() } override val transactionVerifierService by lazy { makeTransactionVerifierService() }
override val networkMapCache by lazy { PersistentNetworkMapCache(this) } override val networkMapCache by lazy { PersistentNetworkMapCache(this) }
override val vaultService by lazy { NodeVaultService(this, database.hibernateConfig) } override val vaultService by lazy { NodeVaultService(platformClock, keyManagementService, stateLoader, this@AbstractNode.database.hibernateConfig) }
override val contractUpgradeService by lazy { ContractUpgradeServiceImpl() } override val contractUpgradeService by lazy { ContractUpgradeServiceImpl() }
// Place the long term identity key in the KMS. Eventually, this is likely going to be separated again because // Place the long term identity key in the KMS. Eventually, this is likely going to be separated again because

View File

@ -1,8 +1,13 @@
package net.corda.node.internal package net.corda.node.internal
import net.corda.core.contracts.StateRef
import net.corda.core.contracts.TransactionResolutionException
import net.corda.core.contracts.TransactionState
import net.corda.core.flows.FlowLogic import net.corda.core.flows.FlowLogic
import net.corda.core.messaging.CordaRPCOps import net.corda.core.messaging.CordaRPCOps
import net.corda.core.node.NodeInfo import net.corda.core.node.NodeInfo
import net.corda.core.node.StateLoader
import net.corda.core.node.services.TransactionStorage
import net.corda.node.services.api.CheckpointStorage import net.corda.node.services.api.CheckpointStorage
import net.corda.node.services.api.ServiceHubInternal import net.corda.node.services.api.ServiceHubInternal
import net.corda.node.services.messaging.MessagingService import net.corda.node.services.messaging.MessagingService
@ -25,3 +30,13 @@ interface StartedNode<out N : AbstractNode> {
fun dispose() = internals.stop() fun dispose() = internals.stop()
fun <T : FlowLogic<*>> registerInitiatedFlow(initiatedFlowClass: Class<T>) = internals.registerInitiatedFlow(initiatedFlowClass) fun <T : FlowLogic<*>> registerInitiatedFlow(initiatedFlowClass: Class<T>) = internals.registerInitiatedFlow(initiatedFlowClass)
} }
class StateLoaderImpl(private val validatedTransactions: TransactionStorage) : StateLoader {
@Throws(TransactionResolutionException::class)
override fun loadState(stateRef: StateRef): TransactionState<*> {
val stx = validatedTransactions.getTransaction(stateRef.txhash) ?: throw TransactionResolutionException(stateRef.txhash)
return if (stx.isNotaryChangeTransaction()) {
stx.resolveNotaryChangeTransaction(this).outputs[stateRef.index]
} else stx.tx.outputs[stateRef.index]
}
}

View File

@ -25,7 +25,6 @@ import net.corda.node.services.config.NodeConfiguration
import net.corda.node.services.messaging.MessagingService import net.corda.node.services.messaging.MessagingService
import net.corda.node.services.statemachine.FlowLogicRefFactoryImpl import net.corda.node.services.statemachine.FlowLogicRefFactoryImpl
import net.corda.node.services.statemachine.FlowStateMachineImpl import net.corda.node.services.statemachine.FlowStateMachineImpl
import net.corda.node.services.vault.NodeVaultService
import net.corda.node.utilities.CordaPersistence import net.corda.node.utilities.CordaPersistence
interface NetworkMapCacheInternal : NetworkMapCache { interface NetworkMapCacheInternal : NetworkMapCache {
@ -73,6 +72,7 @@ interface ServiceHubInternal : ServiceHub {
private val log = loggerFor<ServiceHubInternal>() private val log = loggerFor<ServiceHubInternal>()
} }
override val vaultService: VaultServiceInternal
/** /**
* A map of hash->tx where tx has been signature/contract validated and the states are known to be correct. * A map of hash->tx where tx has been signature/contract validated and the states are known to be correct.
* The signatures aren't technically needed after that point, but we keep them around so that we can relay * The signatures aren't technically needed after that point, but we keep them around so that we can relay
@ -104,7 +104,7 @@ interface ServiceHubInternal : ServiceHub {
if (notifyVault) { if (notifyVault) {
val toNotify = recordedTransactions.map { if (it.isNotaryChangeTransaction()) it.notaryChangeTx else it.tx } val toNotify = recordedTransactions.map { if (it.isNotaryChangeTransaction()) it.notaryChangeTx else it.tx }
(vaultService as NodeVaultService).notifyAll(toNotify) vaultService.notifyAll(toNotify)
} }
} }

View File

@ -0,0 +1,19 @@
package net.corda.node.services.api
import net.corda.core.node.services.VaultService
import net.corda.core.transactions.CoreTransaction
import net.corda.core.transactions.NotaryChangeWireTransaction
import net.corda.core.transactions.WireTransaction
interface VaultServiceInternal : VaultService {
/**
* Splits the provided [txns] into batches of [WireTransaction] and [NotaryChangeWireTransaction].
* This is required because the batches get aggregated into single updates, and we want to be able to
* indicate whether an update consists entirely of regular or notary change transactions, which may require
* different processing logic.
*/
fun notifyAll(txns: Iterable<CoreTransaction>)
/** Same as notifyAll but with a single transaction. */
fun notify(tx: CoreTransaction) = notifyAll(listOf(tx))
}

View File

@ -4,16 +4,16 @@ import co.paralleluniverse.fibers.Suspendable
import co.paralleluniverse.strands.Strand import co.paralleluniverse.strands.Strand
import net.corda.core.contracts.* import net.corda.core.contracts.*
import net.corda.core.crypto.SecureHash import net.corda.core.crypto.SecureHash
import net.corda.core.internal.ThreadBox import net.corda.core.internal.*
import net.corda.core.internal.VisibleForTesting import net.corda.core.node.StateLoader
import net.corda.core.internal.bufferUntilSubscribed import net.corda.core.node.services.*
import net.corda.core.internal.tee
import net.corda.core.messaging.DataFeed
import net.corda.core.node.ServiceHub
import net.corda.core.node.services.StatesNotAvailableException import net.corda.core.node.services.StatesNotAvailableException
import net.corda.core.node.services.Vault import net.corda.core.node.services.Vault
import net.corda.core.node.services.vault.QueryCriteria
import net.corda.core.node.services.vault.Sort
import net.corda.core.node.services.vault.SortAttribute
import net.corda.core.messaging.DataFeed
import net.corda.core.node.services.VaultQueryException import net.corda.core.node.services.VaultQueryException
import net.corda.core.node.services.VaultService
import net.corda.core.node.services.vault.* import net.corda.core.node.services.vault.*
import net.corda.core.schemas.PersistentStateRef import net.corda.core.schemas.PersistentStateRef
import net.corda.core.serialization.SerializationDefaults.STORAGE_CONTEXT import net.corda.core.serialization.SerializationDefaults.STORAGE_CONTEXT
@ -24,6 +24,7 @@ import net.corda.core.transactions.CoreTransaction
import net.corda.core.transactions.NotaryChangeWireTransaction import net.corda.core.transactions.NotaryChangeWireTransaction
import net.corda.core.transactions.WireTransaction import net.corda.core.transactions.WireTransaction
import net.corda.core.utilities.* import net.corda.core.utilities.*
import net.corda.node.services.api.VaultServiceInternal
import net.corda.node.services.persistence.HibernateConfiguration import net.corda.node.services.persistence.HibernateConfiguration
import net.corda.node.services.statemachine.FlowStateMachineImpl import net.corda.node.services.statemachine.FlowStateMachineImpl
import net.corda.node.utilities.DatabaseTransactionManager import net.corda.node.utilities.DatabaseTransactionManager
@ -33,6 +34,7 @@ import org.hibernate.Session
import rx.Observable import rx.Observable
import rx.subjects.PublishSubject import rx.subjects.PublishSubject
import java.security.PublicKey import java.security.PublicKey
import java.time.Clock
import java.time.Instant import java.time.Instant
import java.util.* import java.util.*
import javax.persistence.Tuple import javax.persistence.Tuple
@ -47,7 +49,7 @@ import javax.persistence.Tuple
* TODO: keep an audit trail with time stamps of previously unconsumed states "as of" a particular point in time. * TODO: keep an audit trail with time stamps of previously unconsumed states "as of" a particular point in time.
* TODO: have transaction storage do some caching. * TODO: have transaction storage do some caching.
*/ */
class NodeVaultService(private val services: ServiceHub, private val hibernateConfig: HibernateConfiguration) : SingletonSerializeAsToken(), VaultService { class NodeVaultService(private val clock: Clock, private val keyManagementService: KeyManagementService, private val stateLoader: StateLoader, private val hibernateConfig: HibernateConfiguration) : SingletonSerializeAsToken(), VaultServiceInternal {
private companion object { private companion object {
val log = loggerFor<NodeVaultService>() val log = loggerFor<NodeVaultService>()
@ -78,7 +80,7 @@ class NodeVaultService(private val services: ServiceHub, private val hibernateCo
contractStateClassName = stateAndRef.value.state.data.javaClass.name, contractStateClassName = stateAndRef.value.state.data.javaClass.name,
contractState = stateAndRef.value.state.serialize(context = STORAGE_CONTEXT).bytes, contractState = stateAndRef.value.state.serialize(context = STORAGE_CONTEXT).bytes,
stateStatus = Vault.StateStatus.UNCONSUMED, stateStatus = Vault.StateStatus.UNCONSUMED,
recordedTime = services.clock.instant()) recordedTime = clock.instant())
state.stateRef = PersistentStateRef(stateAndRef.key) state.stateRef = PersistentStateRef(stateAndRef.key)
session.save(state) session.save(state)
} }
@ -86,11 +88,11 @@ class NodeVaultService(private val services: ServiceHub, private val hibernateCo
val state = session.get<VaultSchemaV1.VaultStates>(VaultSchemaV1.VaultStates::class.java, PersistentStateRef(stateRef)) val state = session.get<VaultSchemaV1.VaultStates>(VaultSchemaV1.VaultStates::class.java, PersistentStateRef(stateRef))
state?.run { state?.run {
stateStatus = Vault.StateStatus.CONSUMED stateStatus = Vault.StateStatus.CONSUMED
consumedTime = services.clock.instant() consumedTime = clock.instant()
// remove lock (if held) // remove lock (if held)
if (lockId != null) { if (lockId != null) {
lockId = null lockId = null
lockUpdateTime = services.clock.instant() lockUpdateTime = clock.instant()
log.trace("Releasing soft lock on consumed state: $stateRef") log.trace("Releasing soft lock on consumed state: $stateRef")
} }
session.save(state) session.save(state)
@ -106,13 +108,7 @@ class NodeVaultService(private val services: ServiceHub, private val hibernateCo
override val updates: Observable<Vault.Update<ContractState>> override val updates: Observable<Vault.Update<ContractState>>
get() = mutex.locked { _updatesInDbTx } get() = mutex.locked { _updatesInDbTx }
/** override fun notifyAll(txns: Iterable<CoreTransaction>) {
* Splits the provided [txns] into batches of [WireTransaction] and [NotaryChangeWireTransaction].
* This is required because the batches get aggregated into single updates, and we want to be able to
* indicate whether an update consists entirely of regular or notary change transactions, which may require
* different processing logic.
*/
fun notifyAll(txns: Iterable<CoreTransaction>) {
// It'd be easier to just group by type, but then we'd lose ordering. // It'd be easier to just group by type, but then we'd lose ordering.
val regularTxns = mutableListOf<WireTransaction>() val regularTxns = mutableListOf<WireTransaction>()
val notaryChangeTxns = mutableListOf<NotaryChangeWireTransaction>() val notaryChangeTxns = mutableListOf<NotaryChangeWireTransaction>()
@ -140,12 +136,9 @@ class NodeVaultService(private val services: ServiceHub, private val hibernateCo
if (notaryChangeTxns.isNotEmpty()) notifyNotaryChange(notaryChangeTxns.toList()) if (notaryChangeTxns.isNotEmpty()) notifyNotaryChange(notaryChangeTxns.toList())
} }
/** Same as notifyAll but with a single transaction. */
fun notify(tx: CoreTransaction) = notifyAll(listOf(tx))
private fun notifyRegular(txns: Iterable<WireTransaction>) { private fun notifyRegular(txns: Iterable<WireTransaction>) {
fun makeUpdate(tx: WireTransaction): Vault.Update<ContractState> { fun makeUpdate(tx: WireTransaction): Vault.Update<ContractState> {
val myKeys = services.keyManagementService.filterMyKeys(tx.outputs.flatMap { it.data.participants.map { it.owningKey } }) val myKeys = keyManagementService.filterMyKeys(tx.outputs.flatMap { it.data.participants.map { it.owningKey } })
val ourNewStates = tx.outputs. val ourNewStates = tx.outputs.
filter { isRelevant(it.data, myKeys.toSet()) }. filter { isRelevant(it.data, myKeys.toSet()) }.
map { tx.outRef<ContractState>(it.data) } map { tx.outRef<ContractState>(it.data) }
@ -171,8 +164,8 @@ class NodeVaultService(private val services: ServiceHub, private val hibernateCo
// We need to resolve the full transaction here because outputs are calculated from inputs // We need to resolve the full transaction here because outputs are calculated from inputs
// We also can't do filtering beforehand, since output encumbrance pointers get recalculated based on // We also can't do filtering beforehand, since output encumbrance pointers get recalculated based on
// input positions // input positions
val ltx = tx.resolve(services, emptyList()) val ltx = tx.resolve(stateLoader, emptyList())
val myKeys = services.keyManagementService.filterMyKeys(ltx.outputs.flatMap { it.data.participants.map { it.owningKey } }) val myKeys = keyManagementService.filterMyKeys(ltx.outputs.flatMap { it.data.participants.map { it.owningKey } })
val (consumedStateAndRefs, producedStates) = ltx.inputs. val (consumedStateAndRefs, producedStates) = ltx.inputs.
zip(ltx.outputs). zip(ltx.outputs).
filter { (_, output) -> isRelevant(output.data, myKeys.toSet()) }. filter { (_, output) -> isRelevant(output.data, myKeys.toSet()) }.
@ -246,7 +239,7 @@ class NodeVaultService(private val services: ServiceHub, private val hibernateCo
@Throws(StatesNotAvailableException::class) @Throws(StatesNotAvailableException::class)
override fun softLockReserve(lockId: UUID, stateRefs: NonEmptySet<StateRef>) { override fun softLockReserve(lockId: UUID, stateRefs: NonEmptySet<StateRef>) {
val softLockTimestamp = services.clock.instant() val softLockTimestamp = clock.instant()
try { try {
val session = DatabaseTransactionManager.current().session val session = DatabaseTransactionManager.current().session
val criteriaBuilder = session.criteriaBuilder val criteriaBuilder = session.criteriaBuilder
@ -292,7 +285,7 @@ class NodeVaultService(private val services: ServiceHub, private val hibernateCo
} }
override fun softLockRelease(lockId: UUID, stateRefs: NonEmptySet<StateRef>?) { override fun softLockRelease(lockId: UUID, stateRefs: NonEmptySet<StateRef>?) {
val softLockTimestamp = services.clock.instant() val softLockTimestamp = clock.instant()
val session = DatabaseTransactionManager.current().session val session = DatabaseTransactionManager.current().session
val criteriaBuilder = session.criteriaBuilder val criteriaBuilder = session.criteriaBuilder
if (stateRefs == null) { if (stateRefs == null) {
@ -440,8 +433,8 @@ class NodeVaultService(private val services: ServiceHub, private val hibernateCo
// pagination checks // pagination checks
if (!paging.isDefault) { if (!paging.isDefault) {
// pagination // pagination
if (paging.pageNumber < DEFAULT_PAGE_NUM) throw VaultQueryException("Page specification: invalid page number ${paging.pageNumber} [page numbers start from ${DEFAULT_PAGE_NUM}]") if (paging.pageNumber < DEFAULT_PAGE_NUM) throw VaultQueryException("Page specification: invalid page number ${paging.pageNumber} [page numbers start from $DEFAULT_PAGE_NUM]")
if (paging.pageSize < 1) throw VaultQueryException("Page specification: invalid page size ${paging.pageSize} [must be a value between 1 and ${MAX_PAGE_SIZE}]") if (paging.pageSize < 1) throw VaultQueryException("Page specification: invalid page size ${paging.pageSize} [must be a value between 1 and $MAX_PAGE_SIZE]")
} }
query.firstResult = (paging.pageNumber - 1) * paging.pageSize query.firstResult = (paging.pageNumber - 1) * paging.pageSize
@ -452,7 +445,7 @@ class NodeVaultService(private val services: ServiceHub, private val hibernateCo
// final pagination check (fail-fast on too many results when no pagination specified) // final pagination check (fail-fast on too many results when no pagination specified)
if (paging.isDefault && results.size > DEFAULT_PAGE_SIZE) if (paging.isDefault && results.size > DEFAULT_PAGE_SIZE)
throw VaultQueryException("Please specify a `PageSpecification` as there are more results [${results.size}] than the default page size [${DEFAULT_PAGE_SIZE}]") throw VaultQueryException("Please specify a `PageSpecification` as there are more results [${results.size}] than the default page size [$DEFAULT_PAGE_SIZE]")
val statesAndRefs: MutableList<StateAndRef<T>> = mutableListOf() val statesAndRefs: MutableList<StateAndRef<T>> = mutableListOf()
val statesMeta: MutableList<Vault.StateMetadata> = mutableListOf() val statesMeta: MutableList<Vault.StateMetadata> = mutableListOf()
@ -495,8 +488,7 @@ class NodeVaultService(private val services: ServiceHub, private val hibernateCo
override fun <T : ContractState> _trackBy(criteria: QueryCriteria, paging: PageSpecification, sorting: Sort, contractStateType: Class<out T>): DataFeed<Vault.Page<T>, Vault.Update<T>> { override fun <T : ContractState> _trackBy(criteria: QueryCriteria, paging: PageSpecification, sorting: Sort, contractStateType: Class<out T>): DataFeed<Vault.Page<T>, Vault.Update<T>> {
return mutex.locked { return mutex.locked {
val snapshotResults = _queryBy(criteria, paging, sorting, contractStateType) val snapshotResults = _queryBy(criteria, paging, sorting, contractStateType)
@Suppress("UNCHECKED_CAST") val updates: Observable<Vault.Update<T>> = uncheckedCast(_updatesPublisher.bufferUntilSubscribed().filter { it.containsType(contractStateType, snapshotResults.stateTypes) })
val updates = _updatesPublisher.bufferUntilSubscribed().filter { it.containsType(contractStateType, snapshotResults.stateTypes) } as Observable<Vault.Update<T>>
DataFeed(snapshotResults, updates) DataFeed(snapshotResults, updates)
} }
} }
@ -522,8 +514,7 @@ class NodeVaultService(private val services: ServiceHub, private val hibernateCo
val contractInterfaceToConcreteTypes = mutableMapOf<String, MutableSet<String>>() val contractInterfaceToConcreteTypes = mutableMapOf<String, MutableSet<String>>()
distinctTypes.forEach { type -> distinctTypes.forEach { type ->
@Suppress("UNCHECKED_CAST") val concreteType: Class<ContractState> = uncheckedCast(Class.forName(type))
val concreteType = Class.forName(type) as Class<ContractState>
val contractInterfaces = deriveContractInterfaces(concreteType) val contractInterfaces = deriveContractInterfaces(concreteType)
contractInterfaces.map { contractInterfaces.map {
val contractInterface = contractInterfaceToConcreteTypes.getOrPut(it.name, { mutableSetOf() }) val contractInterface = contractInterfaceToConcreteTypes.getOrPut(it.name, { mutableSetOf() })
@ -537,10 +528,9 @@ class NodeVaultService(private val services: ServiceHub, private val hibernateCo
private fun <T : ContractState> deriveContractInterfaces(clazz: Class<T>): Set<Class<T>> { private fun <T : ContractState> deriveContractInterfaces(clazz: Class<T>): Set<Class<T>> {
val myInterfaces: MutableSet<Class<T>> = mutableSetOf() val myInterfaces: MutableSet<Class<T>> = mutableSetOf()
clazz.interfaces.forEach { clazz.interfaces.forEach {
if (!it.equals(ContractState::class.java)) { if (it != ContractState::class.java) {
@Suppress("UNCHECKED_CAST") myInterfaces.add(uncheckedCast(it))
myInterfaces.add(it as Class<T>) myInterfaces.addAll(deriveContractInterfaces(uncheckedCast(it)))
myInterfaces.addAll(deriveContractInterfaces(it))
} }
} }
return myInterfaces return myInterfaces

View File

@ -9,12 +9,12 @@ import net.corda.core.identity.AbstractParty
import net.corda.core.identity.CordaX500Name import net.corda.core.identity.CordaX500Name
import net.corda.core.identity.Party import net.corda.core.identity.Party
import net.corda.core.node.ServiceHub import net.corda.core.node.ServiceHub
import net.corda.core.node.services.VaultService
import net.corda.core.serialization.SingletonSerializeAsToken import net.corda.core.serialization.SingletonSerializeAsToken
import net.corda.core.transactions.TransactionBuilder import net.corda.core.transactions.TransactionBuilder
import net.corda.core.utilities.days import net.corda.core.utilities.days
import net.corda.node.internal.cordapp.CordappLoader import net.corda.node.internal.cordapp.CordappLoader
import net.corda.node.internal.cordapp.CordappProviderImpl import net.corda.node.internal.cordapp.CordappProviderImpl
import net.corda.node.services.api.VaultServiceInternal
import net.corda.node.services.identity.InMemoryIdentityService import net.corda.node.services.identity.InMemoryIdentityService
import net.corda.node.services.persistence.DBCheckpointStorage import net.corda.node.services.persistence.DBCheckpointStorage
import net.corda.node.services.statemachine.FlowLogicRefFactoryImpl import net.corda.node.services.statemachine.FlowLogicRefFactoryImpl
@ -96,7 +96,7 @@ class NodeSchedulerServiceTest : SingletonSerializeAsToken() {
overrideClock = testClock, overrideClock = testClock,
keyManagement = kms, keyManagement = kms,
network = mockMessagingService), TestReference { network = mockMessagingService), TestReference {
override val vaultService: VaultService = NodeVaultService(this, database.hibernateConfig) override val vaultService: VaultServiceInternal = NodeVaultService(testClock, kms, stateLoader, database.hibernateConfig)
override val testReference = this@NodeSchedulerServiceTest override val testReference = this@NodeSchedulerServiceTest
override val cordappProvider: CordappProviderImpl = CordappProviderImpl(CordappLoader.createWithTestPackages()).start(attachments) override val cordappProvider: CordappProviderImpl = CordappProviderImpl(CordappLoader.createWithTestPackages()).start(attachments)
} }

View File

@ -6,18 +6,14 @@ import net.corda.core.crypto.SecureHash
import net.corda.core.crypto.SignatureMetadata import net.corda.core.crypto.SignatureMetadata
import net.corda.core.crypto.TransactionSignature import net.corda.core.crypto.TransactionSignature
import net.corda.core.node.services.VaultService import net.corda.core.node.services.VaultService
import net.corda.core.schemas.MappedSchema
import net.corda.core.toFuture import net.corda.core.toFuture
import net.corda.core.transactions.SignedTransaction import net.corda.core.transactions.SignedTransaction
import net.corda.core.transactions.WireTransaction import net.corda.core.transactions.WireTransaction
import net.corda.finance.schemas.CashSchemaV1 import net.corda.node.services.api.VaultServiceInternal
import net.corda.finance.schemas.SampleCashSchemaV2
import net.corda.finance.schemas.SampleCashSchemaV3
import net.corda.node.services.schema.HibernateObserver import net.corda.node.services.schema.HibernateObserver
import net.corda.node.services.schema.NodeSchemaService import net.corda.node.services.schema.NodeSchemaService
import net.corda.node.services.transactions.PersistentUniquenessProvider import net.corda.node.services.transactions.PersistentUniquenessProvider
import net.corda.node.services.vault.NodeVaultService import net.corda.node.services.vault.NodeVaultService
import net.corda.node.services.vault.VaultSchemaV1
import net.corda.node.utilities.CordaPersistence import net.corda.node.utilities.CordaPersistence
import net.corda.node.utilities.configureDatabase import net.corda.node.utilities.configureDatabase
import net.corda.testing.* import net.corda.testing.*
@ -46,8 +42,8 @@ class DBTransactionStorageTests : TestDependencyInjectionBase() {
database.transaction { database.transaction {
services = object : MockServices(BOB_KEY) { services = object : MockServices(BOB_KEY) {
override val vaultService: VaultService get() { override val vaultService: VaultServiceInternal get() {
val vaultService = NodeVaultService(this, database.hibernateConfig) val vaultService = NodeVaultService(clock, keyManagementService, stateLoader, database.hibernateConfig)
hibernatePersister = HibernateObserver(vaultService.rawUpdates, database.hibernateConfig) hibernatePersister = HibernateObserver(vaultService.rawUpdates, database.hibernateConfig)
return vaultService return vaultService
} }
@ -57,7 +53,7 @@ class DBTransactionStorageTests : TestDependencyInjectionBase() {
validatedTransactions.addTransaction(stx) validatedTransactions.addTransaction(stx)
} }
// Refactored to use notifyAll() as we have no other unit test for that method with multiple transactions. // Refactored to use notifyAll() as we have no other unit test for that method with multiple transactions.
(vaultService as NodeVaultService).notifyAll(txs.map { it.tx }) vaultService.notifyAll(txs.map { it.tx })
} }
} }
} }

View File

@ -26,7 +26,6 @@ import net.corda.finance.schemas.SampleCashSchemaV3
import net.corda.finance.utils.sumCash import net.corda.finance.utils.sumCash
import net.corda.node.services.schema.HibernateObserver import net.corda.node.services.schema.HibernateObserver
import net.corda.node.services.schema.NodeSchemaService import net.corda.node.services.schema.NodeSchemaService
import net.corda.node.services.vault.NodeVaultService
import net.corda.node.services.vault.VaultSchemaV1 import net.corda.node.services.vault.VaultSchemaV1
import net.corda.node.utilities.CordaPersistence import net.corda.node.utilities.CordaPersistence
import net.corda.node.utilities.configureDatabase import net.corda.node.utilities.configureDatabase
@ -82,14 +81,13 @@ class HibernateConfigurationTest : TestDependencyInjectionBase() {
database.transaction { database.transaction {
hibernateConfig = database.hibernateConfig hibernateConfig = database.hibernateConfig
services = object : MockServices(BOB_KEY, BOC_KEY, DUMMY_NOTARY_KEY) { services = object : MockServices(BOB_KEY, BOC_KEY, DUMMY_NOTARY_KEY) {
override val vaultService: VaultService = makeVaultService(database.hibernateConfig) override val vaultService = makeVaultService(database.hibernateConfig)
override fun recordTransactions(notifyVault: Boolean, txs: Iterable<SignedTransaction>) { override fun recordTransactions(notifyVault: Boolean, txs: Iterable<SignedTransaction>) {
for (stx in txs) { for (stx in txs) {
validatedTransactions.addTransaction(stx) validatedTransactions.addTransaction(stx)
} }
// Refactored to use notifyAll() as we have no other unit test for that method with multiple transactions. // Refactored to use notifyAll() as we have no other unit test for that method with multiple transactions.
(vaultService as NodeVaultService).notifyAll(txs.map { it.tx }) vaultService.notifyAll(txs.map { it.tx })
} }
override fun jdbcSession() = database.createSession() override fun jdbcSession() = database.createSession()
} }

View File

@ -47,7 +47,7 @@ import kotlin.test.assertTrue
class NodeVaultServiceTest : TestDependencyInjectionBase() { class NodeVaultServiceTest : TestDependencyInjectionBase() {
lateinit var services: MockServices lateinit var services: MockServices
lateinit var issuerServices: MockServices lateinit var issuerServices: MockServices
val vaultService: VaultService get() = services.vaultService val vaultService get() = services.vaultService as NodeVaultService
lateinit var database: CordaPersistence lateinit var database: CordaPersistence
@Before @Before
@ -98,7 +98,7 @@ class NodeVaultServiceTest : TestDependencyInjectionBase() {
val originalVault = vaultService val originalVault = vaultService
val services2 = object : MockServices() { val services2 = object : MockServices() {
override val vaultService: NodeVaultService get() = originalVault as NodeVaultService override val vaultService: NodeVaultService get() = originalVault
override fun recordTransactions(notifyVault: Boolean, txs: Iterable<SignedTransaction>) { override fun recordTransactions(notifyVault: Boolean, txs: Iterable<SignedTransaction>) {
for (stx in txs) { for (stx in txs) {
validatedTransactions.addTransaction(stx) validatedTransactions.addTransaction(stx)
@ -473,7 +473,7 @@ class NodeVaultServiceTest : TestDependencyInjectionBase() {
@Test @Test
fun `is ownable state relevant`() { fun `is ownable state relevant`() {
val service = (services.vaultService as NodeVaultService) val service = vaultService
val amount = Amount(1000, Issued(BOC.ref(1), GBP)) val amount = Amount(1000, Issued(BOC.ref(1), GBP))
val wellKnownCash = Cash.State(amount, services.myInfo.chooseIdentity()) val wellKnownCash = Cash.State(amount, services.myInfo.chooseIdentity())
val myKeys = services.keyManagementService.filterMyKeys(listOf(wellKnownCash.owner.owningKey)) val myKeys = services.keyManagementService.filterMyKeys(listOf(wellKnownCash.owner.owningKey))
@ -494,7 +494,7 @@ class NodeVaultServiceTest : TestDependencyInjectionBase() {
@Test @Test
fun `correct updates are generated for general transactions`() { fun `correct updates are generated for general transactions`() {
val service = (services.vaultService as NodeVaultService) val service = vaultService
val vaultSubscriber = TestSubscriber<Vault.Update<*>>().apply { val vaultSubscriber = TestSubscriber<Vault.Update<*>>().apply {
service.updates.subscribe(this) service.updates.subscribe(this)
} }
@ -527,7 +527,7 @@ class NodeVaultServiceTest : TestDependencyInjectionBase() {
@Test @Test
fun `correct updates are generated when changing notaries`() { fun `correct updates are generated when changing notaries`() {
val service = (services.vaultService as NodeVaultService) val service = vaultService
val notary = services.myInfo.chooseIdentity() val notary = services.myInfo.chooseIdentity()
val vaultSubscriber = TestSubscriber<Vault.Update<*>>().apply { val vaultSubscriber = TestSubscriber<Vault.Update<*>>().apply {

View File

@ -6,9 +6,11 @@ import net.corda.core.flows.FlowInitiator
import net.corda.core.flows.FlowLogic import net.corda.core.flows.FlowLogic
import net.corda.core.identity.Party import net.corda.core.identity.Party
import net.corda.core.node.NodeInfo import net.corda.core.node.NodeInfo
import net.corda.core.node.StateLoader
import net.corda.core.node.services.* import net.corda.core.node.services.*
import net.corda.core.serialization.SerializeAsToken import net.corda.core.serialization.SerializeAsToken
import net.corda.node.internal.InitiatedFlowFactory import net.corda.node.internal.InitiatedFlowFactory
import net.corda.node.internal.StateLoaderImpl
import net.corda.node.internal.cordapp.CordappLoader import net.corda.node.internal.cordapp.CordappLoader
import net.corda.node.internal.cordapp.CordappProviderImpl import net.corda.node.internal.cordapp.CordappProviderImpl
import net.corda.node.serialization.NodeClock import net.corda.node.serialization.NodeClock
@ -34,7 +36,7 @@ import java.time.Clock
open class MockServiceHubInternal( open class MockServiceHubInternal(
override val database: CordaPersistence, override val database: CordaPersistence,
override val configuration: NodeConfiguration, override val configuration: NodeConfiguration,
val customVault: VaultService? = null, val customVault: VaultServiceInternal? = null,
val keyManagement: KeyManagementService? = null, val keyManagement: KeyManagementService? = null,
val network: MessagingService? = null, val network: MessagingService? = null,
val identity: IdentityService? = MOCK_IDENTITY_SERVICE, val identity: IdentityService? = MOCK_IDENTITY_SERVICE,
@ -47,11 +49,12 @@ open class MockServiceHubInternal(
val schemas: SchemaService? = NodeSchemaService(), val schemas: SchemaService? = NodeSchemaService(),
val customContractUpgradeService: ContractUpgradeService? = null, val customContractUpgradeService: ContractUpgradeService? = null,
val customTransactionVerifierService: TransactionVerifierService? = InMemoryTransactionVerifierService(2), val customTransactionVerifierService: TransactionVerifierService? = InMemoryTransactionVerifierService(2),
override val cordappProvider: CordappProvider = CordappProviderImpl(CordappLoader.createDefault(Paths.get("."))).start(attachments) override val cordappProvider: CordappProvider = CordappProviderImpl(CordappLoader.createDefault(Paths.get("."))).start(attachments),
) : ServiceHubInternal { protected val stateLoader: StateLoaderImpl = StateLoaderImpl(validatedTransactions)
) : ServiceHubInternal, StateLoader by stateLoader {
override val transactionVerifierService: TransactionVerifierService override val transactionVerifierService: TransactionVerifierService
get() = customTransactionVerifierService ?: throw UnsupportedOperationException() get() = customTransactionVerifierService ?: throw UnsupportedOperationException()
override val vaultService: VaultService override val vaultService: VaultServiceInternal
get() = customVault ?: throw UnsupportedOperationException() get() = customVault ?: throw UnsupportedOperationException()
override val contractUpgradeService: ContractUpgradeService override val contractUpgradeService: ContractUpgradeService
get() = customContractUpgradeService ?: throw UnsupportedOperationException() get() = customContractUpgradeService ?: throw UnsupportedOperationException()

View File

@ -7,14 +7,17 @@ import net.corda.core.identity.PartyAndCertificate
import net.corda.core.messaging.DataFeed import net.corda.core.messaging.DataFeed
import net.corda.core.node.NodeInfo import net.corda.core.node.NodeInfo
import net.corda.core.node.ServiceHub import net.corda.core.node.ServiceHub
import net.corda.core.node.StateLoader
import net.corda.core.node.services.* import net.corda.core.node.services.*
import net.corda.core.schemas.MappedSchema import net.corda.core.schemas.MappedSchema
import net.corda.core.serialization.SerializeAsToken import net.corda.core.serialization.SerializeAsToken
import net.corda.core.serialization.SingletonSerializeAsToken import net.corda.core.serialization.SingletonSerializeAsToken
import net.corda.core.transactions.SignedTransaction import net.corda.core.transactions.SignedTransaction
import net.corda.node.VersionInfo import net.corda.node.VersionInfo
import net.corda.node.internal.StateLoaderImpl
import net.corda.node.internal.cordapp.CordappLoader import net.corda.node.internal.cordapp.CordappLoader
import net.corda.node.services.api.StateMachineRecordedTransactionMappingStorage import net.corda.node.services.api.StateMachineRecordedTransactionMappingStorage
import net.corda.node.services.api.VaultServiceInternal
import net.corda.node.services.api.WritableTransactionStorage import net.corda.node.services.api.WritableTransactionStorage
import net.corda.node.services.identity.InMemoryIdentityService import net.corda.node.services.identity.InMemoryIdentityService
import net.corda.node.services.keys.freshCertificate import net.corda.node.services.keys.freshCertificate
@ -42,7 +45,12 @@ import java.util.*
* A singleton utility that only provides a mock identity, key and storage service. However, this is sufficient for * A singleton utility that only provides a mock identity, key and storage service. However, this is sufficient for
* building chains of transactions and verifying them. It isn't sufficient for testing flows however. * building chains of transactions and verifying them. It isn't sufficient for testing flows however.
*/ */
open class MockServices(cordappPackages: List<String> = emptyList(), vararg val keys: KeyPair) : ServiceHub { open class MockServices(
cordappPackages: List<String>,
override val validatedTransactions: WritableTransactionStorage,
protected val stateLoader: StateLoaderImpl = StateLoaderImpl(validatedTransactions),
vararg val keys: KeyPair
) : ServiceHub, StateLoader by stateLoader {
companion object { companion object {
@JvmStatic @JvmStatic
@ -105,14 +113,14 @@ open class MockServices(cordappPackages: List<String> = emptyList(), vararg val
val mockService = database.transaction { val mockService = database.transaction {
object : MockServices(cordappPackages, *(keys.toTypedArray())) { object : MockServices(cordappPackages, *(keys.toTypedArray())) {
override val identityService: IdentityService = database.transaction { identityServiceRef } override val identityService: IdentityService = database.transaction { identityServiceRef }
override val vaultService: VaultService = makeVaultService(database.hibernateConfig) override val vaultService = makeVaultService(database.hibernateConfig)
override fun recordTransactions(notifyVault: Boolean, txs: Iterable<SignedTransaction>) { override fun recordTransactions(notifyVault: Boolean, txs: Iterable<SignedTransaction>) {
for (stx in txs) { for (stx in txs) {
validatedTransactions.addTransaction(stx) validatedTransactions.addTransaction(stx)
} }
// Refactored to use notifyAll() as we have no other unit test for that method with multiple transactions. // Refactored to use notifyAll() as we have no other unit test for that method with multiple transactions.
(vaultService as NodeVaultService).notifyAll(txs.map { it.tx }) vaultService.notifyAll(txs.map { it.tx })
} }
override fun jdbcSession(): Connection = database.createSession() override fun jdbcSession(): Connection = database.createSession()
@ -122,9 +130,9 @@ open class MockServices(cordappPackages: List<String> = emptyList(), vararg val
} }
} }
constructor(cordappPackages: List<String>, vararg keys: KeyPair) : this(cordappPackages, MockTransactionStorage(), keys = *keys)
constructor(vararg keys: KeyPair) : this(emptyList(), *keys) constructor(vararg keys: KeyPair) : this(emptyList(), *keys)
constructor() : this(generateKeyPair())
constructor() : this(emptyList(), generateKeyPair())
val key: KeyPair get() = keys.first() val key: KeyPair get() = keys.first()
@ -137,8 +145,7 @@ open class MockServices(cordappPackages: List<String> = emptyList(), vararg val
} }
} }
final override val attachments: AttachmentStorage = MockAttachmentStorage() final override val attachments = MockAttachmentStorage()
override val validatedTransactions: WritableTransactionStorage = MockTransactionStorage()
val stateMachineRecordedTransactionMapping: StateMachineRecordedTransactionMappingStorage = MockStateMachineRecordedTransactionMappingStorage() val stateMachineRecordedTransactionMapping: StateMachineRecordedTransactionMappingStorage = MockStateMachineRecordedTransactionMappingStorage()
override val identityService: IdentityService = InMemoryIdentityService(MOCK_IDENTITIES, trustRoot = DEV_TRUST_ROOT) override val identityService: IdentityService = InMemoryIdentityService(MOCK_IDENTITIES, trustRoot = DEV_TRUST_ROOT)
override val keyManagementService: KeyManagementService by lazy { MockKeyManagementService(identityService, *keys) } override val keyManagementService: KeyManagementService by lazy { MockKeyManagementService(identityService, *keys) }
@ -157,8 +164,8 @@ open class MockServices(cordappPackages: List<String> = emptyList(), vararg val
lateinit var hibernatePersister: HibernateObserver lateinit var hibernatePersister: HibernateObserver
fun makeVaultService(hibernateConfig: HibernateConfiguration = HibernateConfiguration(NodeSchemaService(), makeTestDatabaseProperties(), { identityService })): VaultService { fun makeVaultService(hibernateConfig: HibernateConfiguration = HibernateConfiguration(NodeSchemaService(), makeTestDatabaseProperties(), { identityService })): VaultServiceInternal {
val vaultService = NodeVaultService(this, hibernateConfig) val vaultService = NodeVaultService(Clock.systemUTC(), keyManagementService, stateLoader, hibernateConfig)
hibernatePersister = HibernateObserver(vaultService.rawUpdates, hibernateConfig) hibernatePersister = HibernateObserver(vaultService.rawUpdates, hibernateConfig)
return vaultService return vaultService
} }

View File

@ -8,10 +8,12 @@ import net.corda.core.flows.FlowException
import net.corda.core.identity.Party import net.corda.core.identity.Party
import net.corda.core.internal.uncheckedCast import net.corda.core.internal.uncheckedCast
import net.corda.core.node.ServiceHub import net.corda.core.node.ServiceHub
import net.corda.core.node.ServicesForResolution
import net.corda.core.transactions.SignedTransaction import net.corda.core.transactions.SignedTransaction
import net.corda.core.transactions.TransactionBuilder import net.corda.core.transactions.TransactionBuilder
import net.corda.core.transactions.WireTransaction import net.corda.core.transactions.WireTransaction
import net.corda.testing.contracts.DummyContract import net.corda.testing.contracts.DummyContract
import net.corda.testing.node.MockAttachmentStorage
import net.corda.testing.node.MockCordappProvider import net.corda.testing.node.MockCordappProvider
import java.io.InputStream import java.io.InputStream
import java.security.KeyPair import java.security.KeyPair
@ -72,7 +74,7 @@ data class TestTransactionDSLInterpreter private constructor(
transactionBuilder: TransactionBuilder transactionBuilder: TransactionBuilder
) : this(ledgerInterpreter, transactionBuilder, HashMap()) ) : this(ledgerInterpreter, transactionBuilder, HashMap())
val services = object : ServiceHub by ledgerInterpreter.services { val services = object : ServicesForResolution by ledgerInterpreter.services {
override fun loadState(stateRef: StateRef) = ledgerInterpreter.resolveStateRef<ContractState>(stateRef) override fun loadState(stateRef: StateRef) = ledgerInterpreter.resolveStateRef<ContractState>(stateRef)
override val cordappProvider: CordappProvider = ledgerInterpreter.services.cordappProvider override val cordappProvider: CordappProvider = ledgerInterpreter.services.cordappProvider
} }
@ -136,7 +138,7 @@ data class TestTransactionDSLInterpreter private constructor(
) = dsl(TransactionDSL(copy())) ) = dsl(TransactionDSL(copy()))
override fun _attachment(contractClassName: ContractClassName) { override fun _attachment(contractClassName: ContractClassName) {
(services.cordappProvider as MockCordappProvider).addMockCordapp(contractClassName, services) (services.cordappProvider as MockCordappProvider).addMockCordapp(contractClassName, services.attachments as MockAttachmentStorage)
} }
} }

View File

@ -3,7 +3,6 @@ package net.corda.testing.node
import net.corda.core.contracts.ContractClassName import net.corda.core.contracts.ContractClassName
import net.corda.core.cordapp.Cordapp import net.corda.core.cordapp.Cordapp
import net.corda.core.internal.cordapp.CordappImpl import net.corda.core.internal.cordapp.CordappImpl
import net.corda.core.node.ServiceHub
import net.corda.core.node.services.AttachmentId import net.corda.core.node.services.AttachmentId
import net.corda.node.internal.cordapp.CordappLoader import net.corda.node.internal.cordapp.CordappLoader
import net.corda.node.internal.cordapp.CordappProviderImpl import net.corda.node.internal.cordapp.CordappProviderImpl
@ -13,27 +12,23 @@ import java.util.*
class MockCordappProvider(cordappLoader: CordappLoader) : CordappProviderImpl(cordappLoader) { class MockCordappProvider(cordappLoader: CordappLoader) : CordappProviderImpl(cordappLoader) {
val cordappRegistry = mutableListOf<Pair<Cordapp, AttachmentId>>() val cordappRegistry = mutableListOf<Pair<Cordapp, AttachmentId>>()
fun addMockCordapp(contractClassName: ContractClassName, services: ServiceHub) { fun addMockCordapp(contractClassName: ContractClassName, attachments: MockAttachmentStorage) {
val cordapp = CordappImpl(listOf(contractClassName), emptyList(), emptyList(), emptyList(), emptyList(), emptyList(), emptyList(), emptySet(), Paths.get(".").toUri().toURL()) val cordapp = CordappImpl(listOf(contractClassName), emptyList(), emptyList(), emptyList(), emptyList(), emptyList(), emptyList(), emptySet(), Paths.get(".").toUri().toURL())
if (cordappRegistry.none { it.first.contractClassNames.contains(contractClassName) }) { if (cordappRegistry.none { it.first.contractClassNames.contains(contractClassName) }) {
cordappRegistry.add(Pair(cordapp, findOrImportAttachment(contractClassName.toByteArray(), services))) cordappRegistry.add(Pair(cordapp, findOrImportAttachment(contractClassName.toByteArray(), attachments)))
} }
} }
override fun getContractAttachmentID(contractClassName: ContractClassName): AttachmentId? = cordappRegistry.find { it.first.contractClassNames.contains(contractClassName) }?.second ?: super.getContractAttachmentID(contractClassName) override fun getContractAttachmentID(contractClassName: ContractClassName): AttachmentId? = cordappRegistry.find { it.first.contractClassNames.contains(contractClassName) }?.second ?: super.getContractAttachmentID(contractClassName)
private fun findOrImportAttachment(data: ByteArray, services: ServiceHub): AttachmentId { private fun findOrImportAttachment(data: ByteArray, attachments: MockAttachmentStorage): AttachmentId {
return if (services.attachments is MockAttachmentStorage) { val existingAttachment = attachments.files.filter {
val existingAttachment = (services.attachments as MockAttachmentStorage).files.filter { Arrays.equals(it.value, data)
Arrays.equals(it.value, data) }
} return if (!existingAttachment.isEmpty()) {
if (!existingAttachment.isEmpty()) { existingAttachment.keys.first()
existingAttachment.keys.first()
} else {
services.attachments.importAttachment(data.inputStream())
}
} else { } else {
throw Exception("MockCordappService only requires MockAttachmentStorage") attachments.importAttachment(data.inputStream())
} }
} }
} }