CORDA-599 Fix circular dependency between vault and SH (#1630)

Fix circular dependency between the 2 vault objects and SH.
This commit is contained in:
Andrzej Cichocki 2017-10-09 12:49:07 +01:00 committed by Mike Hearn
parent 0c2289de8c
commit f83f1b7010
17 changed files with 148 additions and 122 deletions

View File

@ -7,10 +7,6 @@ import net.corda.core.crypto.SignableData
import net.corda.core.crypto.SignatureMetadata
import net.corda.core.crypto.TransactionSignature
import net.corda.core.flows.ContractUpgradeFlow
import net.corda.core.identity.AbstractParty
import net.corda.core.identity.AnonymousParty
import net.corda.core.identity.Party
import net.corda.core.internal.toMultiMap
import net.corda.core.node.services.*
import net.corda.core.serialization.SerializeAsToken
import net.corda.core.transactions.FilteredTransaction
@ -20,11 +16,24 @@ import java.security.PublicKey
import java.sql.Connection
import java.time.Clock
/**
* Part of [ServiceHub].
*/
interface StateLoader {
/**
* Given a [StateRef] loads the referenced transaction and looks up the specified output [ContractState].
*
* @throws TransactionResolutionException if [stateRef] points to a non-existent transaction.
*/
@Throws(TransactionResolutionException::class)
fun loadState(stateRef: StateRef): TransactionState<*>
}
/**
* Subset of node services that are used for loading transactions from the wire into fully resolved, looked up
* forms ready for verification.
*/
interface ServicesForResolution {
interface ServicesForResolution : StateLoader {
/**
* An identity service maintains a directory of parties by their associated distinguished name/public keys and thus
* supports lookup of a party given its key, or name. The service also manages the certificates linking confidential
@ -37,14 +46,6 @@ interface ServicesForResolution {
/** Provides access to anything relating to cordapps including contract attachment resolution and app context */
val cordappProvider: CordappProvider
/**
* Given a [StateRef] loads the referenced transaction and looks up the specified output [ContractState].
*
* @throws TransactionResolutionException if the [StateRef] points to a non-existent transaction.
*/
@Throws(TransactionResolutionException::class)
fun loadState(stateRef: StateRef): TransactionState<*>
}
/**
@ -155,19 +156,6 @@ interface ServiceHub : ServicesForResolution {
recordTransactions(true, txs)
}
/**
* Given a [StateRef] loads the referenced transaction and looks up the specified output [ContractState].
*
* @throws TransactionResolutionException if [stateRef] points to a non-existent transaction.
*/
@Throws(TransactionResolutionException::class)
override fun loadState(stateRef: StateRef): TransactionState<*> {
val stx = validatedTransactions.getTransaction(stateRef.txhash) ?: throw TransactionResolutionException(stateRef.txhash)
return if (stx.isNotaryChangeTransaction()) {
stx.resolveNotaryChangeTransaction(this).outputs[stateRef.index]
} else stx.tx.outputs[stateRef.index]
}
/**
* Converts the given [StateRef] into a [StateAndRef] object.
*

View File

@ -7,6 +7,7 @@ import net.corda.core.crypto.serializedHash
import net.corda.core.utilities.toBase58String
import net.corda.core.identity.Party
import net.corda.core.node.ServiceHub
import net.corda.core.node.StateLoader
import net.corda.core.serialization.CordaSerializable
import java.security.PublicKey
@ -39,9 +40,10 @@ data class NotaryChangeWireTransaction(
*/
override val id: SecureHash by lazy { serializedHash(inputs + notary + newNotary) }
fun resolve(services: ServiceHub, sigs: List<TransactionSignature>): NotaryChangeLedgerTransaction {
fun resolve(services: ServiceHub, sigs: List<TransactionSignature>) = resolve(services as StateLoader, sigs)
fun resolve(stateLoader: StateLoader, sigs: List<TransactionSignature>): NotaryChangeLedgerTransaction {
val resolvedInputs = inputs.map { ref ->
services.loadState(ref).let { StateAndRef(it, ref) }
stateLoader.loadState(ref).let { StateAndRef(it, ref) }
}
return NotaryChangeLedgerTransaction(resolvedInputs, notary, newNotary, id, sigs)
}

View File

@ -7,6 +7,7 @@ import net.corda.core.crypto.*
import net.corda.core.identity.Party
import net.corda.core.internal.VisibleForTesting
import net.corda.core.node.ServiceHub
import net.corda.core.node.StateLoader
import net.corda.core.serialization.CordaSerializable
import net.corda.core.serialization.SerializedBytes
import net.corda.core.serialization.deserialize
@ -181,10 +182,12 @@ data class SignedTransaction(val txBits: SerializedBytes<CoreTransaction>,
* If [transaction] is a [NotaryChangeWireTransaction], loads the input states and resolves it to a
* [NotaryChangeLedgerTransaction] so the signatures can be verified.
*/
fun resolveNotaryChangeTransaction(services: ServiceHub): NotaryChangeLedgerTransaction {
fun resolveNotaryChangeTransaction(services: ServiceHub) = resolveNotaryChangeTransaction(services as StateLoader)
fun resolveNotaryChangeTransaction(stateLoader: StateLoader): NotaryChangeLedgerTransaction {
val ntx = transaction as? NotaryChangeWireTransaction
?: throw IllegalStateException("Expected a ${NotaryChangeWireTransaction::class.simpleName} but found ${transaction::class.simpleName}")
return ntx.resolve(services, sigs)
return ntx.resolve(stateLoader, sigs)
}
override fun toString(): String = "${javaClass.simpleName}(id=$id)"

View File

@ -22,7 +22,7 @@ class LedgerTransactionQueryTests : TestDependencyInjectionBase() {
@Before
fun setup() {
services.mockCordappProvider.addMockCordapp(DummyContract.PROGRAM_ID, services)
services.mockCordappProvider.addMockCordapp(DummyContract.PROGRAM_ID, services.attachments)
}
interface Commands {

View File

@ -19,6 +19,10 @@ import net.corda.core.internal.concurrent.flatMap
import net.corda.core.internal.concurrent.openFuture
import net.corda.core.internal.toX509CertHolder
import net.corda.core.internal.uncheckedCast
import net.corda.core.messaging.CordaRPCOps
import net.corda.core.messaging.RPCOps
import net.corda.core.messaging.SingleMessageRecipient
import net.corda.core.node.*
import net.corda.core.messaging.*
import net.corda.core.node.AppServiceHub
import net.corda.core.node.NodeInfo
@ -453,7 +457,8 @@ abstract class AbstractNode(config: NodeConfiguration,
private fun makeServices(schemaService: SchemaService): MutableList<Any> {
checkpointStorage = DBCheckpointStorage()
cordappProvider = CordappProviderImpl(cordappLoader)
_services = ServiceHubInternalImpl(schemaService)
val transactionStorage = makeTransactionStorage()
_services = ServiceHubInternalImpl(schemaService, transactionStorage, StateLoaderImpl(transactionStorage))
attachments = NodeAttachmentService(services.monitoringService.metrics)
cordappProvider.start(attachments)
legalIdentity = obtainIdentity(notaryConfig = null)
@ -752,15 +757,18 @@ abstract class AbstractNode(config: NodeConfiguration,
protected open fun generateKeyPair() = cryptoGenerateKeyPair()
private inner class ServiceHubInternalImpl(override val schemaService: SchemaService) : ServiceHubInternal, SingletonSerializeAsToken() {
private inner class ServiceHubInternalImpl(
override val schemaService: SchemaService,
override val validatedTransactions: WritableTransactionStorage,
private val stateLoader: StateLoader
) : SingletonSerializeAsToken(), ServiceHubInternal, StateLoader by stateLoader {
override val rpcFlows = ArrayList<Class<out FlowLogic<*>>>()
override val stateMachineRecordedTransactionMapping = DBTransactionMappingStorage()
override val auditService = DummyAuditService()
override val monitoringService = MonitoringService(MetricRegistry())
override val validatedTransactions = makeTransactionStorage()
override val transactionVerifierService by lazy { makeTransactionVerifierService() }
override val networkMapCache by lazy { PersistentNetworkMapCache(this) }
override val vaultService by lazy { NodeVaultService(this, database.hibernateConfig) }
override val vaultService by lazy { NodeVaultService(platformClock, keyManagementService, stateLoader, this@AbstractNode.database.hibernateConfig) }
override val contractUpgradeService by lazy { ContractUpgradeServiceImpl() }
// Place the long term identity key in the KMS. Eventually, this is likely going to be separated again because

View File

@ -1,8 +1,13 @@
package net.corda.node.internal
import net.corda.core.contracts.StateRef
import net.corda.core.contracts.TransactionResolutionException
import net.corda.core.contracts.TransactionState
import net.corda.core.flows.FlowLogic
import net.corda.core.messaging.CordaRPCOps
import net.corda.core.node.NodeInfo
import net.corda.core.node.StateLoader
import net.corda.core.node.services.TransactionStorage
import net.corda.node.services.api.CheckpointStorage
import net.corda.node.services.api.ServiceHubInternal
import net.corda.node.services.messaging.MessagingService
@ -25,3 +30,13 @@ interface StartedNode<out N : AbstractNode> {
fun dispose() = internals.stop()
fun <T : FlowLogic<*>> registerInitiatedFlow(initiatedFlowClass: Class<T>) = internals.registerInitiatedFlow(initiatedFlowClass)
}
class StateLoaderImpl(private val validatedTransactions: TransactionStorage) : StateLoader {
@Throws(TransactionResolutionException::class)
override fun loadState(stateRef: StateRef): TransactionState<*> {
val stx = validatedTransactions.getTransaction(stateRef.txhash) ?: throw TransactionResolutionException(stateRef.txhash)
return if (stx.isNotaryChangeTransaction()) {
stx.resolveNotaryChangeTransaction(this).outputs[stateRef.index]
} else stx.tx.outputs[stateRef.index]
}
}

View File

@ -25,7 +25,6 @@ import net.corda.node.services.config.NodeConfiguration
import net.corda.node.services.messaging.MessagingService
import net.corda.node.services.statemachine.FlowLogicRefFactoryImpl
import net.corda.node.services.statemachine.FlowStateMachineImpl
import net.corda.node.services.vault.NodeVaultService
import net.corda.node.utilities.CordaPersistence
interface NetworkMapCacheInternal : NetworkMapCache {
@ -73,6 +72,7 @@ interface ServiceHubInternal : ServiceHub {
private val log = loggerFor<ServiceHubInternal>()
}
override val vaultService: VaultServiceInternal
/**
* A map of hash->tx where tx has been signature/contract validated and the states are known to be correct.
* The signatures aren't technically needed after that point, but we keep them around so that we can relay
@ -104,7 +104,7 @@ interface ServiceHubInternal : ServiceHub {
if (notifyVault) {
val toNotify = recordedTransactions.map { if (it.isNotaryChangeTransaction()) it.notaryChangeTx else it.tx }
(vaultService as NodeVaultService).notifyAll(toNotify)
vaultService.notifyAll(toNotify)
}
}

View File

@ -0,0 +1,19 @@
package net.corda.node.services.api
import net.corda.core.node.services.VaultService
import net.corda.core.transactions.CoreTransaction
import net.corda.core.transactions.NotaryChangeWireTransaction
import net.corda.core.transactions.WireTransaction
interface VaultServiceInternal : VaultService {
/**
* Splits the provided [txns] into batches of [WireTransaction] and [NotaryChangeWireTransaction].
* This is required because the batches get aggregated into single updates, and we want to be able to
* indicate whether an update consists entirely of regular or notary change transactions, which may require
* different processing logic.
*/
fun notifyAll(txns: Iterable<CoreTransaction>)
/** Same as notifyAll but with a single transaction. */
fun notify(tx: CoreTransaction) = notifyAll(listOf(tx))
}

View File

@ -4,16 +4,16 @@ import co.paralleluniverse.fibers.Suspendable
import co.paralleluniverse.strands.Strand
import net.corda.core.contracts.*
import net.corda.core.crypto.SecureHash
import net.corda.core.internal.ThreadBox
import net.corda.core.internal.VisibleForTesting
import net.corda.core.internal.bufferUntilSubscribed
import net.corda.core.internal.tee
import net.corda.core.messaging.DataFeed
import net.corda.core.node.ServiceHub
import net.corda.core.internal.*
import net.corda.core.node.StateLoader
import net.corda.core.node.services.*
import net.corda.core.node.services.StatesNotAvailableException
import net.corda.core.node.services.Vault
import net.corda.core.node.services.vault.QueryCriteria
import net.corda.core.node.services.vault.Sort
import net.corda.core.node.services.vault.SortAttribute
import net.corda.core.messaging.DataFeed
import net.corda.core.node.services.VaultQueryException
import net.corda.core.node.services.VaultService
import net.corda.core.node.services.vault.*
import net.corda.core.schemas.PersistentStateRef
import net.corda.core.serialization.SerializationDefaults.STORAGE_CONTEXT
@ -24,6 +24,7 @@ import net.corda.core.transactions.CoreTransaction
import net.corda.core.transactions.NotaryChangeWireTransaction
import net.corda.core.transactions.WireTransaction
import net.corda.core.utilities.*
import net.corda.node.services.api.VaultServiceInternal
import net.corda.node.services.persistence.HibernateConfiguration
import net.corda.node.services.statemachine.FlowStateMachineImpl
import net.corda.node.utilities.DatabaseTransactionManager
@ -33,6 +34,7 @@ import org.hibernate.Session
import rx.Observable
import rx.subjects.PublishSubject
import java.security.PublicKey
import java.time.Clock
import java.time.Instant
import java.util.*
import javax.persistence.Tuple
@ -47,7 +49,7 @@ import javax.persistence.Tuple
* TODO: keep an audit trail with time stamps of previously unconsumed states "as of" a particular point in time.
* TODO: have transaction storage do some caching.
*/
class NodeVaultService(private val services: ServiceHub, private val hibernateConfig: HibernateConfiguration) : SingletonSerializeAsToken(), VaultService {
class NodeVaultService(private val clock: Clock, private val keyManagementService: KeyManagementService, private val stateLoader: StateLoader, private val hibernateConfig: HibernateConfiguration) : SingletonSerializeAsToken(), VaultServiceInternal {
private companion object {
val log = loggerFor<NodeVaultService>()
@ -78,7 +80,7 @@ class NodeVaultService(private val services: ServiceHub, private val hibernateCo
contractStateClassName = stateAndRef.value.state.data.javaClass.name,
contractState = stateAndRef.value.state.serialize(context = STORAGE_CONTEXT).bytes,
stateStatus = Vault.StateStatus.UNCONSUMED,
recordedTime = services.clock.instant())
recordedTime = clock.instant())
state.stateRef = PersistentStateRef(stateAndRef.key)
session.save(state)
}
@ -86,11 +88,11 @@ class NodeVaultService(private val services: ServiceHub, private val hibernateCo
val state = session.get<VaultSchemaV1.VaultStates>(VaultSchemaV1.VaultStates::class.java, PersistentStateRef(stateRef))
state?.run {
stateStatus = Vault.StateStatus.CONSUMED
consumedTime = services.clock.instant()
consumedTime = clock.instant()
// remove lock (if held)
if (lockId != null) {
lockId = null
lockUpdateTime = services.clock.instant()
lockUpdateTime = clock.instant()
log.trace("Releasing soft lock on consumed state: $stateRef")
}
session.save(state)
@ -106,13 +108,7 @@ class NodeVaultService(private val services: ServiceHub, private val hibernateCo
override val updates: Observable<Vault.Update<ContractState>>
get() = mutex.locked { _updatesInDbTx }
/**
* Splits the provided [txns] into batches of [WireTransaction] and [NotaryChangeWireTransaction].
* This is required because the batches get aggregated into single updates, and we want to be able to
* indicate whether an update consists entirely of regular or notary change transactions, which may require
* different processing logic.
*/
fun notifyAll(txns: Iterable<CoreTransaction>) {
override fun notifyAll(txns: Iterable<CoreTransaction>) {
// It'd be easier to just group by type, but then we'd lose ordering.
val regularTxns = mutableListOf<WireTransaction>()
val notaryChangeTxns = mutableListOf<NotaryChangeWireTransaction>()
@ -140,12 +136,9 @@ class NodeVaultService(private val services: ServiceHub, private val hibernateCo
if (notaryChangeTxns.isNotEmpty()) notifyNotaryChange(notaryChangeTxns.toList())
}
/** Same as notifyAll but with a single transaction. */
fun notify(tx: CoreTransaction) = notifyAll(listOf(tx))
private fun notifyRegular(txns: Iterable<WireTransaction>) {
fun makeUpdate(tx: WireTransaction): Vault.Update<ContractState> {
val myKeys = services.keyManagementService.filterMyKeys(tx.outputs.flatMap { it.data.participants.map { it.owningKey } })
val myKeys = keyManagementService.filterMyKeys(tx.outputs.flatMap { it.data.participants.map { it.owningKey } })
val ourNewStates = tx.outputs.
filter { isRelevant(it.data, myKeys.toSet()) }.
map { tx.outRef<ContractState>(it.data) }
@ -171,8 +164,8 @@ class NodeVaultService(private val services: ServiceHub, private val hibernateCo
// We need to resolve the full transaction here because outputs are calculated from inputs
// We also can't do filtering beforehand, since output encumbrance pointers get recalculated based on
// input positions
val ltx = tx.resolve(services, emptyList())
val myKeys = services.keyManagementService.filterMyKeys(ltx.outputs.flatMap { it.data.participants.map { it.owningKey } })
val ltx = tx.resolve(stateLoader, emptyList())
val myKeys = keyManagementService.filterMyKeys(ltx.outputs.flatMap { it.data.participants.map { it.owningKey } })
val (consumedStateAndRefs, producedStates) = ltx.inputs.
zip(ltx.outputs).
filter { (_, output) -> isRelevant(output.data, myKeys.toSet()) }.
@ -246,7 +239,7 @@ class NodeVaultService(private val services: ServiceHub, private val hibernateCo
@Throws(StatesNotAvailableException::class)
override fun softLockReserve(lockId: UUID, stateRefs: NonEmptySet<StateRef>) {
val softLockTimestamp = services.clock.instant()
val softLockTimestamp = clock.instant()
try {
val session = DatabaseTransactionManager.current().session
val criteriaBuilder = session.criteriaBuilder
@ -292,7 +285,7 @@ class NodeVaultService(private val services: ServiceHub, private val hibernateCo
}
override fun softLockRelease(lockId: UUID, stateRefs: NonEmptySet<StateRef>?) {
val softLockTimestamp = services.clock.instant()
val softLockTimestamp = clock.instant()
val session = DatabaseTransactionManager.current().session
val criteriaBuilder = session.criteriaBuilder
if (stateRefs == null) {
@ -440,8 +433,8 @@ class NodeVaultService(private val services: ServiceHub, private val hibernateCo
// pagination checks
if (!paging.isDefault) {
// pagination
if (paging.pageNumber < DEFAULT_PAGE_NUM) throw VaultQueryException("Page specification: invalid page number ${paging.pageNumber} [page numbers start from ${DEFAULT_PAGE_NUM}]")
if (paging.pageSize < 1) throw VaultQueryException("Page specification: invalid page size ${paging.pageSize} [must be a value between 1 and ${MAX_PAGE_SIZE}]")
if (paging.pageNumber < DEFAULT_PAGE_NUM) throw VaultQueryException("Page specification: invalid page number ${paging.pageNumber} [page numbers start from $DEFAULT_PAGE_NUM]")
if (paging.pageSize < 1) throw VaultQueryException("Page specification: invalid page size ${paging.pageSize} [must be a value between 1 and $MAX_PAGE_SIZE]")
}
query.firstResult = (paging.pageNumber - 1) * paging.pageSize
@ -452,7 +445,7 @@ class NodeVaultService(private val services: ServiceHub, private val hibernateCo
// final pagination check (fail-fast on too many results when no pagination specified)
if (paging.isDefault && results.size > DEFAULT_PAGE_SIZE)
throw VaultQueryException("Please specify a `PageSpecification` as there are more results [${results.size}] than the default page size [${DEFAULT_PAGE_SIZE}]")
throw VaultQueryException("Please specify a `PageSpecification` as there are more results [${results.size}] than the default page size [$DEFAULT_PAGE_SIZE]")
val statesAndRefs: MutableList<StateAndRef<T>> = mutableListOf()
val statesMeta: MutableList<Vault.StateMetadata> = mutableListOf()
@ -495,8 +488,7 @@ class NodeVaultService(private val services: ServiceHub, private val hibernateCo
override fun <T : ContractState> _trackBy(criteria: QueryCriteria, paging: PageSpecification, sorting: Sort, contractStateType: Class<out T>): DataFeed<Vault.Page<T>, Vault.Update<T>> {
return mutex.locked {
val snapshotResults = _queryBy(criteria, paging, sorting, contractStateType)
@Suppress("UNCHECKED_CAST")
val updates = _updatesPublisher.bufferUntilSubscribed().filter { it.containsType(contractStateType, snapshotResults.stateTypes) } as Observable<Vault.Update<T>>
val updates: Observable<Vault.Update<T>> = uncheckedCast(_updatesPublisher.bufferUntilSubscribed().filter { it.containsType(contractStateType, snapshotResults.stateTypes) })
DataFeed(snapshotResults, updates)
}
}
@ -522,8 +514,7 @@ class NodeVaultService(private val services: ServiceHub, private val hibernateCo
val contractInterfaceToConcreteTypes = mutableMapOf<String, MutableSet<String>>()
distinctTypes.forEach { type ->
@Suppress("UNCHECKED_CAST")
val concreteType = Class.forName(type) as Class<ContractState>
val concreteType: Class<ContractState> = uncheckedCast(Class.forName(type))
val contractInterfaces = deriveContractInterfaces(concreteType)
contractInterfaces.map {
val contractInterface = contractInterfaceToConcreteTypes.getOrPut(it.name, { mutableSetOf() })
@ -537,10 +528,9 @@ class NodeVaultService(private val services: ServiceHub, private val hibernateCo
private fun <T : ContractState> deriveContractInterfaces(clazz: Class<T>): Set<Class<T>> {
val myInterfaces: MutableSet<Class<T>> = mutableSetOf()
clazz.interfaces.forEach {
if (!it.equals(ContractState::class.java)) {
@Suppress("UNCHECKED_CAST")
myInterfaces.add(it as Class<T>)
myInterfaces.addAll(deriveContractInterfaces(it))
if (it != ContractState::class.java) {
myInterfaces.add(uncheckedCast(it))
myInterfaces.addAll(deriveContractInterfaces(uncheckedCast(it)))
}
}
return myInterfaces

View File

@ -9,12 +9,12 @@ import net.corda.core.identity.AbstractParty
import net.corda.core.identity.CordaX500Name
import net.corda.core.identity.Party
import net.corda.core.node.ServiceHub
import net.corda.core.node.services.VaultService
import net.corda.core.serialization.SingletonSerializeAsToken
import net.corda.core.transactions.TransactionBuilder
import net.corda.core.utilities.days
import net.corda.node.internal.cordapp.CordappLoader
import net.corda.node.internal.cordapp.CordappProviderImpl
import net.corda.node.services.api.VaultServiceInternal
import net.corda.node.services.identity.InMemoryIdentityService
import net.corda.node.services.persistence.DBCheckpointStorage
import net.corda.node.services.statemachine.FlowLogicRefFactoryImpl
@ -96,7 +96,7 @@ class NodeSchedulerServiceTest : SingletonSerializeAsToken() {
overrideClock = testClock,
keyManagement = kms,
network = mockMessagingService), TestReference {
override val vaultService: VaultService = NodeVaultService(this, database.hibernateConfig)
override val vaultService: VaultServiceInternal = NodeVaultService(testClock, kms, stateLoader, database.hibernateConfig)
override val testReference = this@NodeSchedulerServiceTest
override val cordappProvider: CordappProviderImpl = CordappProviderImpl(CordappLoader.createWithTestPackages()).start(attachments)
}

View File

@ -6,18 +6,14 @@ import net.corda.core.crypto.SecureHash
import net.corda.core.crypto.SignatureMetadata
import net.corda.core.crypto.TransactionSignature
import net.corda.core.node.services.VaultService
import net.corda.core.schemas.MappedSchema
import net.corda.core.toFuture
import net.corda.core.transactions.SignedTransaction
import net.corda.core.transactions.WireTransaction
import net.corda.finance.schemas.CashSchemaV1
import net.corda.finance.schemas.SampleCashSchemaV2
import net.corda.finance.schemas.SampleCashSchemaV3
import net.corda.node.services.api.VaultServiceInternal
import net.corda.node.services.schema.HibernateObserver
import net.corda.node.services.schema.NodeSchemaService
import net.corda.node.services.transactions.PersistentUniquenessProvider
import net.corda.node.services.vault.NodeVaultService
import net.corda.node.services.vault.VaultSchemaV1
import net.corda.node.utilities.CordaPersistence
import net.corda.node.utilities.configureDatabase
import net.corda.testing.*
@ -46,8 +42,8 @@ class DBTransactionStorageTests : TestDependencyInjectionBase() {
database.transaction {
services = object : MockServices(BOB_KEY) {
override val vaultService: VaultService get() {
val vaultService = NodeVaultService(this, database.hibernateConfig)
override val vaultService: VaultServiceInternal get() {
val vaultService = NodeVaultService(clock, keyManagementService, stateLoader, database.hibernateConfig)
hibernatePersister = HibernateObserver(vaultService.rawUpdates, database.hibernateConfig)
return vaultService
}
@ -57,7 +53,7 @@ class DBTransactionStorageTests : TestDependencyInjectionBase() {
validatedTransactions.addTransaction(stx)
}
// Refactored to use notifyAll() as we have no other unit test for that method with multiple transactions.
(vaultService as NodeVaultService).notifyAll(txs.map { it.tx })
vaultService.notifyAll(txs.map { it.tx })
}
}
}

View File

@ -26,7 +26,6 @@ import net.corda.finance.schemas.SampleCashSchemaV3
import net.corda.finance.utils.sumCash
import net.corda.node.services.schema.HibernateObserver
import net.corda.node.services.schema.NodeSchemaService
import net.corda.node.services.vault.NodeVaultService
import net.corda.node.services.vault.VaultSchemaV1
import net.corda.node.utilities.CordaPersistence
import net.corda.node.utilities.configureDatabase
@ -82,14 +81,13 @@ class HibernateConfigurationTest : TestDependencyInjectionBase() {
database.transaction {
hibernateConfig = database.hibernateConfig
services = object : MockServices(BOB_KEY, BOC_KEY, DUMMY_NOTARY_KEY) {
override val vaultService: VaultService = makeVaultService(database.hibernateConfig)
override val vaultService = makeVaultService(database.hibernateConfig)
override fun recordTransactions(notifyVault: Boolean, txs: Iterable<SignedTransaction>) {
for (stx in txs) {
validatedTransactions.addTransaction(stx)
}
// Refactored to use notifyAll() as we have no other unit test for that method with multiple transactions.
(vaultService as NodeVaultService).notifyAll(txs.map { it.tx })
vaultService.notifyAll(txs.map { it.tx })
}
override fun jdbcSession() = database.createSession()
}

View File

@ -47,7 +47,7 @@ import kotlin.test.assertTrue
class NodeVaultServiceTest : TestDependencyInjectionBase() {
lateinit var services: MockServices
lateinit var issuerServices: MockServices
val vaultService: VaultService get() = services.vaultService
val vaultService get() = services.vaultService as NodeVaultService
lateinit var database: CordaPersistence
@Before
@ -98,7 +98,7 @@ class NodeVaultServiceTest : TestDependencyInjectionBase() {
val originalVault = vaultService
val services2 = object : MockServices() {
override val vaultService: NodeVaultService get() = originalVault as NodeVaultService
override val vaultService: NodeVaultService get() = originalVault
override fun recordTransactions(notifyVault: Boolean, txs: Iterable<SignedTransaction>) {
for (stx in txs) {
validatedTransactions.addTransaction(stx)
@ -473,7 +473,7 @@ class NodeVaultServiceTest : TestDependencyInjectionBase() {
@Test
fun `is ownable state relevant`() {
val service = (services.vaultService as NodeVaultService)
val service = vaultService
val amount = Amount(1000, Issued(BOC.ref(1), GBP))
val wellKnownCash = Cash.State(amount, services.myInfo.chooseIdentity())
val myKeys = services.keyManagementService.filterMyKeys(listOf(wellKnownCash.owner.owningKey))
@ -494,7 +494,7 @@ class NodeVaultServiceTest : TestDependencyInjectionBase() {
@Test
fun `correct updates are generated for general transactions`() {
val service = (services.vaultService as NodeVaultService)
val service = vaultService
val vaultSubscriber = TestSubscriber<Vault.Update<*>>().apply {
service.updates.subscribe(this)
}
@ -527,7 +527,7 @@ class NodeVaultServiceTest : TestDependencyInjectionBase() {
@Test
fun `correct updates are generated when changing notaries`() {
val service = (services.vaultService as NodeVaultService)
val service = vaultService
val notary = services.myInfo.chooseIdentity()
val vaultSubscriber = TestSubscriber<Vault.Update<*>>().apply {

View File

@ -6,9 +6,11 @@ import net.corda.core.flows.FlowInitiator
import net.corda.core.flows.FlowLogic
import net.corda.core.identity.Party
import net.corda.core.node.NodeInfo
import net.corda.core.node.StateLoader
import net.corda.core.node.services.*
import net.corda.core.serialization.SerializeAsToken
import net.corda.node.internal.InitiatedFlowFactory
import net.corda.node.internal.StateLoaderImpl
import net.corda.node.internal.cordapp.CordappLoader
import net.corda.node.internal.cordapp.CordappProviderImpl
import net.corda.node.serialization.NodeClock
@ -34,7 +36,7 @@ import java.time.Clock
open class MockServiceHubInternal(
override val database: CordaPersistence,
override val configuration: NodeConfiguration,
val customVault: VaultService? = null,
val customVault: VaultServiceInternal? = null,
val keyManagement: KeyManagementService? = null,
val network: MessagingService? = null,
val identity: IdentityService? = MOCK_IDENTITY_SERVICE,
@ -47,11 +49,12 @@ open class MockServiceHubInternal(
val schemas: SchemaService? = NodeSchemaService(),
val customContractUpgradeService: ContractUpgradeService? = null,
val customTransactionVerifierService: TransactionVerifierService? = InMemoryTransactionVerifierService(2),
override val cordappProvider: CordappProvider = CordappProviderImpl(CordappLoader.createDefault(Paths.get("."))).start(attachments)
) : ServiceHubInternal {
override val cordappProvider: CordappProvider = CordappProviderImpl(CordappLoader.createDefault(Paths.get("."))).start(attachments),
protected val stateLoader: StateLoaderImpl = StateLoaderImpl(validatedTransactions)
) : ServiceHubInternal, StateLoader by stateLoader {
override val transactionVerifierService: TransactionVerifierService
get() = customTransactionVerifierService ?: throw UnsupportedOperationException()
override val vaultService: VaultService
override val vaultService: VaultServiceInternal
get() = customVault ?: throw UnsupportedOperationException()
override val contractUpgradeService: ContractUpgradeService
get() = customContractUpgradeService ?: throw UnsupportedOperationException()

View File

@ -7,14 +7,17 @@ import net.corda.core.identity.PartyAndCertificate
import net.corda.core.messaging.DataFeed
import net.corda.core.node.NodeInfo
import net.corda.core.node.ServiceHub
import net.corda.core.node.StateLoader
import net.corda.core.node.services.*
import net.corda.core.schemas.MappedSchema
import net.corda.core.serialization.SerializeAsToken
import net.corda.core.serialization.SingletonSerializeAsToken
import net.corda.core.transactions.SignedTransaction
import net.corda.node.VersionInfo
import net.corda.node.internal.StateLoaderImpl
import net.corda.node.internal.cordapp.CordappLoader
import net.corda.node.services.api.StateMachineRecordedTransactionMappingStorage
import net.corda.node.services.api.VaultServiceInternal
import net.corda.node.services.api.WritableTransactionStorage
import net.corda.node.services.identity.InMemoryIdentityService
import net.corda.node.services.keys.freshCertificate
@ -42,7 +45,12 @@ import java.util.*
* A singleton utility that only provides a mock identity, key and storage service. However, this is sufficient for
* building chains of transactions and verifying them. It isn't sufficient for testing flows however.
*/
open class MockServices(cordappPackages: List<String> = emptyList(), vararg val keys: KeyPair) : ServiceHub {
open class MockServices(
cordappPackages: List<String>,
override val validatedTransactions: WritableTransactionStorage,
protected val stateLoader: StateLoaderImpl = StateLoaderImpl(validatedTransactions),
vararg val keys: KeyPair
) : ServiceHub, StateLoader by stateLoader {
companion object {
@JvmStatic
@ -105,14 +113,14 @@ open class MockServices(cordappPackages: List<String> = emptyList(), vararg val
val mockService = database.transaction {
object : MockServices(cordappPackages, *(keys.toTypedArray())) {
override val identityService: IdentityService = database.transaction { identityServiceRef }
override val vaultService: VaultService = makeVaultService(database.hibernateConfig)
override val vaultService = makeVaultService(database.hibernateConfig)
override fun recordTransactions(notifyVault: Boolean, txs: Iterable<SignedTransaction>) {
for (stx in txs) {
validatedTransactions.addTransaction(stx)
}
// Refactored to use notifyAll() as we have no other unit test for that method with multiple transactions.
(vaultService as NodeVaultService).notifyAll(txs.map { it.tx })
vaultService.notifyAll(txs.map { it.tx })
}
override fun jdbcSession(): Connection = database.createSession()
@ -122,9 +130,9 @@ open class MockServices(cordappPackages: List<String> = emptyList(), vararg val
}
}
constructor(cordappPackages: List<String>, vararg keys: KeyPair) : this(cordappPackages, MockTransactionStorage(), keys = *keys)
constructor(vararg keys: KeyPair) : this(emptyList(), *keys)
constructor() : this(emptyList(), generateKeyPair())
constructor() : this(generateKeyPair())
val key: KeyPair get() = keys.first()
@ -137,8 +145,7 @@ open class MockServices(cordappPackages: List<String> = emptyList(), vararg val
}
}
final override val attachments: AttachmentStorage = MockAttachmentStorage()
override val validatedTransactions: WritableTransactionStorage = MockTransactionStorage()
final override val attachments = MockAttachmentStorage()
val stateMachineRecordedTransactionMapping: StateMachineRecordedTransactionMappingStorage = MockStateMachineRecordedTransactionMappingStorage()
override val identityService: IdentityService = InMemoryIdentityService(MOCK_IDENTITIES, trustRoot = DEV_TRUST_ROOT)
override val keyManagementService: KeyManagementService by lazy { MockKeyManagementService(identityService, *keys) }
@ -157,8 +164,8 @@ open class MockServices(cordappPackages: List<String> = emptyList(), vararg val
lateinit var hibernatePersister: HibernateObserver
fun makeVaultService(hibernateConfig: HibernateConfiguration = HibernateConfiguration(NodeSchemaService(), makeTestDatabaseProperties(), { identityService })): VaultService {
val vaultService = NodeVaultService(this, hibernateConfig)
fun makeVaultService(hibernateConfig: HibernateConfiguration = HibernateConfiguration(NodeSchemaService(), makeTestDatabaseProperties(), { identityService })): VaultServiceInternal {
val vaultService = NodeVaultService(Clock.systemUTC(), keyManagementService, stateLoader, hibernateConfig)
hibernatePersister = HibernateObserver(vaultService.rawUpdates, hibernateConfig)
return vaultService
}

View File

@ -8,10 +8,12 @@ import net.corda.core.flows.FlowException
import net.corda.core.identity.Party
import net.corda.core.internal.uncheckedCast
import net.corda.core.node.ServiceHub
import net.corda.core.node.ServicesForResolution
import net.corda.core.transactions.SignedTransaction
import net.corda.core.transactions.TransactionBuilder
import net.corda.core.transactions.WireTransaction
import net.corda.testing.contracts.DummyContract
import net.corda.testing.node.MockAttachmentStorage
import net.corda.testing.node.MockCordappProvider
import java.io.InputStream
import java.security.KeyPair
@ -72,7 +74,7 @@ data class TestTransactionDSLInterpreter private constructor(
transactionBuilder: TransactionBuilder
) : this(ledgerInterpreter, transactionBuilder, HashMap())
val services = object : ServiceHub by ledgerInterpreter.services {
val services = object : ServicesForResolution by ledgerInterpreter.services {
override fun loadState(stateRef: StateRef) = ledgerInterpreter.resolveStateRef<ContractState>(stateRef)
override val cordappProvider: CordappProvider = ledgerInterpreter.services.cordappProvider
}
@ -136,7 +138,7 @@ data class TestTransactionDSLInterpreter private constructor(
) = dsl(TransactionDSL(copy()))
override fun _attachment(contractClassName: ContractClassName) {
(services.cordappProvider as MockCordappProvider).addMockCordapp(contractClassName, services)
(services.cordappProvider as MockCordappProvider).addMockCordapp(contractClassName, services.attachments as MockAttachmentStorage)
}
}

View File

@ -3,7 +3,6 @@ package net.corda.testing.node
import net.corda.core.contracts.ContractClassName
import net.corda.core.cordapp.Cordapp
import net.corda.core.internal.cordapp.CordappImpl
import net.corda.core.node.ServiceHub
import net.corda.core.node.services.AttachmentId
import net.corda.node.internal.cordapp.CordappLoader
import net.corda.node.internal.cordapp.CordappProviderImpl
@ -13,27 +12,23 @@ import java.util.*
class MockCordappProvider(cordappLoader: CordappLoader) : CordappProviderImpl(cordappLoader) {
val cordappRegistry = mutableListOf<Pair<Cordapp, AttachmentId>>()
fun addMockCordapp(contractClassName: ContractClassName, services: ServiceHub) {
fun addMockCordapp(contractClassName: ContractClassName, attachments: MockAttachmentStorage) {
val cordapp = CordappImpl(listOf(contractClassName), emptyList(), emptyList(), emptyList(), emptyList(), emptyList(), emptyList(), emptySet(), Paths.get(".").toUri().toURL())
if (cordappRegistry.none { it.first.contractClassNames.contains(contractClassName) }) {
cordappRegistry.add(Pair(cordapp, findOrImportAttachment(contractClassName.toByteArray(), services)))
cordappRegistry.add(Pair(cordapp, findOrImportAttachment(contractClassName.toByteArray(), attachments)))
}
}
override fun getContractAttachmentID(contractClassName: ContractClassName): AttachmentId? = cordappRegistry.find { it.first.contractClassNames.contains(contractClassName) }?.second ?: super.getContractAttachmentID(contractClassName)
private fun findOrImportAttachment(data: ByteArray, services: ServiceHub): AttachmentId {
return if (services.attachments is MockAttachmentStorage) {
val existingAttachment = (services.attachments as MockAttachmentStorage).files.filter {
Arrays.equals(it.value, data)
}
if (!existingAttachment.isEmpty()) {
existingAttachment.keys.first()
} else {
services.attachments.importAttachment(data.inputStream())
}
private fun findOrImportAttachment(data: ByteArray, attachments: MockAttachmentStorage): AttachmentId {
val existingAttachment = attachments.files.filter {
Arrays.equals(it.value, data)
}
return if (!existingAttachment.isEmpty()) {
existingAttachment.keys.first()
} else {
throw Exception("MockCordappService only requires MockAttachmentStorage")
attachments.importAttachment(data.inputStream())
}
}
}