From ce5b7de71801a557bb57602e8af2e5ec8ba63959 Mon Sep 17 00:00:00 2001 From: Andrzej Cichocki Date: Fri, 13 Oct 2017 12:15:52 +0100 Subject: [PATCH] CORDA-530 Unduplicate code (#1791) --- .../net/corda/core/flows/AttachmentTests.kt | 3 +- .../AttachmentSerializationTest.kt | 4 +- .../network/PersistentNetworkMapCache.kt | 132 ++++++++---------- .../persistence/DBCheckpointStorage.kt | 9 +- .../persistence/NodeAttachmentService.kt | 8 +- .../node/services/vault/NodeVaultService.kt | 100 ++++++------- .../node/utilities/AppendOnlyPersistentMap.kt | 16 ++- .../utilities/DatabaseTransactionManager.kt | 1 + .../net/corda/node/utilities/PersistentMap.kt | 22 +-- .../persistence/NodeAttachmentStorageTest.kt | 24 ++-- 10 files changed, 152 insertions(+), 167 deletions(-) diff --git a/core/src/test/kotlin/net/corda/core/flows/AttachmentTests.kt b/core/src/test/kotlin/net/corda/core/flows/AttachmentTests.kt index ab4936d6fc..c9974d5dc3 100644 --- a/core/src/test/kotlin/net/corda/core/flows/AttachmentTests.kt +++ b/core/src/test/kotlin/net/corda/core/flows/AttachmentTests.kt @@ -12,7 +12,6 @@ import net.corda.core.utilities.getOrThrow import net.corda.node.internal.StartedNode import net.corda.node.services.config.NodeConfiguration import net.corda.node.services.persistence.NodeAttachmentService -import net.corda.node.utilities.DatabaseTransactionManager import net.corda.nodeapi.internal.ServiceInfo import net.corda.testing.ALICE import net.corda.testing.ALICE_NAME @@ -147,7 +146,7 @@ class AttachmentTests { val corruptAttachment = NodeAttachmentService.DBAttachment(attId = id.toString(), content = attachment) aliceNode.database.transaction { - DatabaseTransactionManager.current().session.update(corruptAttachment) + session.update(corruptAttachment) } // Get n1 to fetch the attachment. Should receive corrupted bytes. diff --git a/core/src/test/kotlin/net/corda/core/serialization/AttachmentSerializationTest.kt b/core/src/test/kotlin/net/corda/core/serialization/AttachmentSerializationTest.kt index 89bc81dad9..f6d1eec6b2 100644 --- a/core/src/test/kotlin/net/corda/core/serialization/AttachmentSerializationTest.kt +++ b/core/src/test/kotlin/net/corda/core/serialization/AttachmentSerializationTest.kt @@ -16,7 +16,7 @@ import net.corda.node.internal.InitiatedFlowFactory import net.corda.node.internal.StartedNode import net.corda.node.services.config.NodeConfiguration import net.corda.node.services.persistence.NodeAttachmentService -import net.corda.node.utilities.DatabaseTransactionManager +import net.corda.node.utilities.currentDBSession import net.corda.nodeapi.internal.ServiceInfo import net.corda.testing.chooseIdentity import net.corda.testing.node.MockNetwork @@ -54,7 +54,7 @@ private fun StartedNode<*>.hackAttachment(attachmentId: SecureHash, content: Str * @see NodeAttachmentService.importAttachment */ private fun updateAttachment(attachmentId: SecureHash, data: ByteArray) { - val session = DatabaseTransactionManager.current().session + val session = currentDBSession() val attachment = session.get(NodeAttachmentService.DBAttachment::class.java, attachmentId.toString()) attachment?.let { attachment.content = data diff --git a/node/src/main/kotlin/net/corda/node/services/network/PersistentNetworkMapCache.kt b/node/src/main/kotlin/net/corda/node/services/network/PersistentNetworkMapCache.kt index 1f24e27526..eb45e7a0df 100644 --- a/node/src/main/kotlin/net/corda/node/services/network/PersistentNetworkMapCache.kt +++ b/node/src/main/kotlin/net/corda/node/services/network/PersistentNetworkMapCache.kt @@ -31,7 +31,6 @@ import net.corda.node.services.messaging.sendRequest import net.corda.node.services.network.NetworkMapService.FetchMapResponse import net.corda.node.services.network.NetworkMapService.SubscribeResponse import net.corda.node.utilities.AddOrRemove -import net.corda.node.utilities.DatabaseTransactionManager import net.corda.node.utilities.bufferUntilDatabaseCommit import net.corda.node.utilities.wrapWithDatabaseTransaction import org.hibernate.Session @@ -39,7 +38,6 @@ import rx.Observable import rx.subjects.PublishSubject import java.security.PublicKey import java.security.SignatureException -import java.time.Duration import java.util.* import javax.annotation.concurrent.ThreadSafe import kotlin.collections.HashMap @@ -93,7 +91,7 @@ open class PersistentNetworkMapCache(private val serviceHub: ServiceHubInternal) init { loadFromFiles() - serviceHub.database.transaction { loadFromDB() } + serviceHub.database.transaction { loadFromDB(session) } } private fun loadFromFiles() { @@ -102,7 +100,7 @@ open class PersistentNetworkMapCache(private val serviceHub: ServiceHubInternal) } override fun getPartyInfo(party: Party): PartyInfo? { - val nodes = serviceHub.database.transaction { queryByIdentityKey(party.owningKey) } + val nodes = serviceHub.database.transaction { queryByIdentityKey(session, party.owningKey) } if (nodes.size == 1 && nodes[0].isLegalIdentity(party)) { return PartyInfo.SingleNode(party, nodes[0].addresses) } @@ -117,9 +115,9 @@ open class PersistentNetworkMapCache(private val serviceHub: ServiceHubInternal) } override fun getNodeByLegalName(name: CordaX500Name): NodeInfo? = getNodesByLegalName(name).firstOrNull() - override fun getNodesByLegalName(name: CordaX500Name): List = serviceHub.database.transaction { queryByLegalName(name) } + override fun getNodesByLegalName(name: CordaX500Name): List = serviceHub.database.transaction { queryByLegalName(session, name) } override fun getNodesByLegalIdentityKey(identityKey: PublicKey): List = - serviceHub.database.transaction { queryByIdentityKey(identityKey) } + serviceHub.database.transaction { queryByIdentityKey(session, identityKey) } override fun getNodeByLegalIdentity(party: AbstractParty): NodeInfo? { val wellKnownParty = serviceHub.identityService.wellKnownPartyFromAnonymous(party) @@ -128,9 +126,9 @@ open class PersistentNetworkMapCache(private val serviceHub: ServiceHubInternal) } } - override fun getNodeByAddress(address: NetworkHostAndPort): NodeInfo? = serviceHub.database.transaction { queryByAddress(address) } + override fun getNodeByAddress(address: NetworkHostAndPort): NodeInfo? = serviceHub.database.transaction { queryByAddress(session, address) } - override fun getPeerCertificateByLegalName(name: CordaX500Name): PartyAndCertificate? = serviceHub.database.transaction { queryIdentityByLegalName(name) } + override fun getPeerCertificateByLegalName(name: CordaX500Name): PartyAndCertificate? = serviceHub.database.transaction { queryIdentityByLegalName(session, name) } override fun track(): DataFeed, MapChange> { synchronized(_changed) { @@ -204,7 +202,7 @@ open class PersistentNetworkMapCache(private val serviceHub: ServiceHubInternal) synchronized(_changed) { registeredNodes.remove(node.legalIdentities.first().owningKey) serviceHub.database.transaction { - removeInfoDB(node) + removeInfoDB(session, node) changePublisher.onNext(MapChange.Removed(node)) } } @@ -238,10 +236,8 @@ open class PersistentNetworkMapCache(private val serviceHub: ServiceHubInternal) } override val allNodes: List - get () = serviceHub.database.transaction { - createSession { - getAllInfos(it).map { it.toNodeInfo() } - } + get() = serviceHub.database.transaction { + getAllInfos(session).map { it.toNodeInfo() } } private fun processRegistration(reg: NodeRegistration) { @@ -259,10 +255,6 @@ open class PersistentNetworkMapCache(private val serviceHub: ServiceHubInternal) // Changes related to NetworkMap redesign // TODO It will be properly merged into network map cache after services removal. - private inline fun createSession(block: (Session) -> T): T { - return DatabaseTransactionManager.current().session.let { block(it) } - } - private fun getAllInfos(session: Session): List { val criteria = session.criteriaBuilder.createQuery(NodeInfoSchemaV1.PersistentNodeInfo::class.java) criteria.select(criteria.from(NodeInfoSchemaV1.PersistentNodeInfo::class.java)) @@ -272,24 +264,22 @@ open class PersistentNetworkMapCache(private val serviceHub: ServiceHubInternal) /** * Load NetworkMap data from the database if present. Node can start without having NetworkMapService configured. */ - private fun loadFromDB() { + private fun loadFromDB(session: Session) { logger.info("Loading network map from database...") - createSession { - val result = getAllInfos(it) - for (nodeInfo in result) { - try { - logger.info("Loaded node info: $nodeInfo") - val node = nodeInfo.toNodeInfo() - addNode(node) - _loadDBSuccess = true // This is used in AbstractNode to indicate that node is ready. - } catch (e: Exception) { - logger.warn("Exception parsing network map from the database.", e) - } - } - if (loadDBSuccess) { - _registrationFuture.set(null) // Useful only if we don't have NetworkMapService configured so StateMachineManager can start. + val result = getAllInfos(session) + for (nodeInfo in result) { + try { + logger.info("Loaded node info: $nodeInfo") + val node = nodeInfo.toNodeInfo() + addNode(node) + _loadDBSuccess = true // This is used in AbstractNode to indicate that node is ready. + } catch (e: Exception) { + logger.warn("Exception parsing network map from the database.", e) } } + if (loadDBSuccess) { + _registrationFuture.set(null) // Useful only if we don't have NetworkMapService configured so StateMachineManager can start. + } } private fun updateInfoDB(nodeInfo: NodeInfo) { @@ -313,11 +303,9 @@ open class PersistentNetworkMapCache(private val serviceHub: ServiceHubInternal) } } - private fun removeInfoDB(nodeInfo: NodeInfo) { - createSession { - val info = findByIdentityKey(it, nodeInfo.legalIdentitiesAndCerts.first().owningKey).single() - it.remove(info) - } + private fun removeInfoDB(session: Session, nodeInfo: NodeInfo) { + val info = findByIdentityKey(session, nodeInfo.legalIdentitiesAndCerts.first().owningKey).single() + session.remove(info) } private fun findByIdentityKey(session: Session, identityKey: PublicKey): List { @@ -328,48 +316,40 @@ open class PersistentNetworkMapCache(private val serviceHub: ServiceHubInternal) return query.resultList } - private fun queryByIdentityKey(identityKey: PublicKey): List { - createSession { - val result = findByIdentityKey(it, identityKey) - return result.map { it.toNodeInfo() } - } + private fun queryByIdentityKey(session: Session, identityKey: PublicKey): List { + val result = findByIdentityKey(session, identityKey) + return result.map { it.toNodeInfo() } } - private fun queryIdentityByLegalName(name: CordaX500Name): PartyAndCertificate? { - createSession { - val query = it.createQuery( - // We do the JOIN here to restrict results to those present in the network map - "SELECT DISTINCT l FROM ${NodeInfoSchemaV1.PersistentNodeInfo::class.java.name} n JOIN n.legalIdentitiesAndCerts l WHERE l.name = :name", - NodeInfoSchemaV1.DBPartyAndCertificate::class.java) - query.setParameter("name", name.toString()) - val candidates = query.resultList.map { it.toLegalIdentityAndCert() } - // The map is restricted to holding a single identity for any X.500 name, so firstOrNull() is correct here. - return candidates.firstOrNull() - } + private fun queryIdentityByLegalName(session: Session, name: CordaX500Name): PartyAndCertificate? { + val query = session.createQuery( + // We do the JOIN here to restrict results to those present in the network map + "SELECT DISTINCT l FROM ${NodeInfoSchemaV1.PersistentNodeInfo::class.java.name} n JOIN n.legalIdentitiesAndCerts l WHERE l.name = :name", + NodeInfoSchemaV1.DBPartyAndCertificate::class.java) + query.setParameter("name", name.toString()) + val candidates = query.resultList.map { it.toLegalIdentityAndCert() } + // The map is restricted to holding a single identity for any X.500 name, so firstOrNull() is correct here. + return candidates.firstOrNull() } - private fun queryByLegalName(name: CordaX500Name): List { - createSession { - val query = it.createQuery( - "SELECT n FROM ${NodeInfoSchemaV1.PersistentNodeInfo::class.java.name} n JOIN n.legalIdentitiesAndCerts l WHERE l.name = :name", - NodeInfoSchemaV1.PersistentNodeInfo::class.java) - query.setParameter("name", name.toString()) - val result = query.resultList - return result.map { it.toNodeInfo() } - } + private fun queryByLegalName(session: Session, name: CordaX500Name): List { + val query = session.createQuery( + "SELECT n FROM ${NodeInfoSchemaV1.PersistentNodeInfo::class.java.name} n JOIN n.legalIdentitiesAndCerts l WHERE l.name = :name", + NodeInfoSchemaV1.PersistentNodeInfo::class.java) + query.setParameter("name", name.toString()) + val result = query.resultList + return result.map { it.toNodeInfo() } } - private fun queryByAddress(hostAndPort: NetworkHostAndPort): NodeInfo? { - createSession { - val query = it.createQuery( - "SELECT n FROM ${NodeInfoSchemaV1.PersistentNodeInfo::class.java.name} n JOIN n.addresses a WHERE a.pk.host = :host AND a.pk.port = :port", - NodeInfoSchemaV1.PersistentNodeInfo::class.java) - query.setParameter("host", hostAndPort.host) - query.setParameter("port", hostAndPort.port) - val result = query.resultList - return if (result.isEmpty()) null - else result.map { it.toNodeInfo() }.singleOrNull() ?: throw IllegalStateException("More than one node with the same host and port") - } + private fun queryByAddress(session: Session, hostAndPort: NetworkHostAndPort): NodeInfo? { + val query = session.createQuery( + "SELECT n FROM ${NodeInfoSchemaV1.PersistentNodeInfo::class.java.name} n JOIN n.addresses a WHERE a.pk.host = :host AND a.pk.port = :port", + NodeInfoSchemaV1.PersistentNodeInfo::class.java) + query.setParameter("host", hostAndPort.host) + query.setParameter("port", hostAndPort.port) + val result = query.resultList + return if (result.isEmpty()) null + else result.map { it.toNodeInfo() }.singleOrNull() ?: throw IllegalStateException("More than one node with the same host and port") } /** Object Relational Mapping support. */ @@ -388,10 +368,8 @@ open class PersistentNetworkMapCache(private val serviceHub: ServiceHubInternal) override fun clearNetworkMapCache() { serviceHub.database.transaction { - createSession { - val result = getAllInfos(it) - for (nodeInfo in result) it.remove(nodeInfo) - } + val result = getAllInfos(session) + for (nodeInfo in result) session.remove(nodeInfo) } } } diff --git a/node/src/main/kotlin/net/corda/node/services/persistence/DBCheckpointStorage.kt b/node/src/main/kotlin/net/corda/node/services/persistence/DBCheckpointStorage.kt index a82642adf7..b9b5f0bbdc 100644 --- a/node/src/main/kotlin/net/corda/node/services/persistence/DBCheckpointStorage.kt +++ b/node/src/main/kotlin/net/corda/node/services/persistence/DBCheckpointStorage.kt @@ -3,8 +3,8 @@ package net.corda.node.services.persistence import net.corda.core.serialization.SerializedBytes import net.corda.node.services.api.Checkpoint import net.corda.node.services.api.CheckpointStorage -import net.corda.node.utilities.DatabaseTransactionManager import net.corda.node.utilities.NODE_DATABASE_PREFIX +import net.corda.node.utilities.currentDBSession import javax.persistence.Column import javax.persistence.Entity import javax.persistence.Id @@ -28,15 +28,14 @@ class DBCheckpointStorage : CheckpointStorage { ) override fun addCheckpoint(checkpoint: Checkpoint) { - val session = DatabaseTransactionManager.current().session - session.save(DBCheckpoint().apply { + currentDBSession().save(DBCheckpoint().apply { checkpointId = checkpoint.id.toString() this.checkpoint = checkpoint.serializedFiber.bytes }) } override fun removeCheckpoint(checkpoint: Checkpoint) { - val session = DatabaseTransactionManager.current().session + val session = currentDBSession() val criteriaBuilder = session.criteriaBuilder val delete = criteriaBuilder.createCriteriaDelete(DBCheckpoint::class.java) val root = delete.from(DBCheckpoint::class.java) @@ -45,7 +44,7 @@ class DBCheckpointStorage : CheckpointStorage { } override fun forEach(block: (Checkpoint) -> Boolean) { - val session = DatabaseTransactionManager.current().session + val session = currentDBSession() val criteriaQuery = session.criteriaBuilder.createQuery(DBCheckpoint::class.java) val root = criteriaQuery.from(DBCheckpoint::class.java) criteriaQuery.select(root) diff --git a/node/src/main/kotlin/net/corda/node/services/persistence/NodeAttachmentService.kt b/node/src/main/kotlin/net/corda/node/services/persistence/NodeAttachmentService.kt index c864521d7b..c95fe2bf94 100644 --- a/node/src/main/kotlin/net/corda/node/services/persistence/NodeAttachmentService.kt +++ b/node/src/main/kotlin/net/corda/node/services/persistence/NodeAttachmentService.kt @@ -13,8 +13,8 @@ import net.corda.core.crypto.SecureHash import net.corda.core.node.services.AttachmentStorage import net.corda.core.serialization.* import net.corda.core.utilities.loggerFor -import net.corda.node.utilities.DatabaseTransactionManager import net.corda.node.utilities.NODE_DATABASE_PREFIX +import net.corda.node.utilities.currentDBSession import java.io.* import java.nio.file.Paths import java.util.jar.JarInputStream @@ -50,7 +50,7 @@ class NodeAttachmentService(metrics: MetricRegistry) : AttachmentStorage, Single private val attachmentCount = metrics.counter("Attachments") init { - val session = DatabaseTransactionManager.current().session + val session = currentDBSession() val criteriaBuilder = session.criteriaBuilder val criteriaQuery = criteriaBuilder.createQuery(Long::class.java) criteriaQuery.select(criteriaBuilder.count(criteriaQuery.from(NodeAttachmentService.DBAttachment::class.java))) @@ -140,7 +140,7 @@ class NodeAttachmentService(metrics: MetricRegistry) : AttachmentStorage, Single } override fun openAttachment(id: SecureHash): Attachment? { - val attachment = DatabaseTransactionManager.current().session.get(NodeAttachmentService.DBAttachment::class.java, id.toString()) + val attachment = currentDBSession().get(NodeAttachmentService.DBAttachment::class.java, id.toString()) attachment?.let { return AttachmentImpl(id, { attachment.content }, checkAttachmentsOnLoad) } @@ -161,7 +161,7 @@ class NodeAttachmentService(metrics: MetricRegistry) : AttachmentStorage, Single checkIsAValidJAR(ByteArrayInputStream(bytes)) val id = SecureHash.SHA256(hs.hash().asBytes()) - val session = DatabaseTransactionManager.current().session + val session = currentDBSession() val criteriaBuilder = session.criteriaBuilder val criteriaQuery = criteriaBuilder.createQuery(Long::class.java) val attachments = criteriaQuery.from(NodeAttachmentService.DBAttachment::class.java) diff --git a/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt b/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt index 3b5236d97a..864da2bd50 100644 --- a/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt +++ b/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt @@ -29,6 +29,7 @@ import net.corda.node.services.persistence.HibernateConfiguration import net.corda.node.services.statemachine.FlowStateMachineImpl import net.corda.node.utilities.DatabaseTransactionManager import net.corda.node.utilities.bufferUntilDatabaseCommit +import net.corda.node.utilities.currentDBSession import net.corda.node.utilities.wrapWithDatabaseTransaction import org.hibernate.Session import rx.Observable @@ -38,6 +39,15 @@ import java.time.Clock import java.time.Instant import java.util.* import javax.persistence.Tuple +import javax.persistence.criteria.CriteriaBuilder +import javax.persistence.criteria.CriteriaUpdate +import javax.persistence.criteria.Predicate +import javax.persistence.criteria.Root + +private fun CriteriaBuilder.executeUpdate(session: Session, configure: Root<*>.(CriteriaUpdate<*>) -> Any?) = createCriteriaUpdate(VaultSchemaV1.VaultStates::class.java).let { update -> + update.from(VaultSchemaV1.VaultStates::class.java).run { configure(update) } + session.createQuery(update).executeUpdate() +} /** * Currently, the node vault service is a very simple RDBMS backed implementation. It will change significantly when @@ -73,7 +83,7 @@ class NodeVaultService(private val clock: Clock, private val keyManagementServic val consumedStateRefs = update.consumed.map { it.ref } log.trace { "Removing $consumedStateRefs consumed contract states and adding $producedStateRefs produced contract states to the database." } - val session = DatabaseTransactionManager.current().session + val session = currentDBSession() producedStateRefsMap.forEach { stateAndRef -> val state = VaultSchemaV1.VaultStates( notary = stateAndRef.value.state.notary, @@ -189,7 +199,7 @@ class NodeVaultService(private val clock: Clock, private val keyManagementServic private fun loadStates(refs: Collection): HashSet> { val states = HashSet>() if (refs.isNotEmpty()) { - val session = DatabaseTransactionManager.current().session + val session = currentDBSession() val criteriaBuilder = session.criteriaBuilder val criteriaQuery = criteriaBuilder.createQuery(VaultSchemaV1.VaultStates::class.java) val vaultStates = criteriaQuery.from(VaultSchemaV1.VaultStates::class.java) @@ -223,11 +233,11 @@ class NodeVaultService(private val clock: Clock, private val keyManagementServic override fun addNoteToTransaction(txnId: SecureHash, noteText: String) { val txnNoteEntity = VaultSchemaV1.VaultTxnNote(txnId.toString(), noteText) - DatabaseTransactionManager.current().session.save(txnNoteEntity) + currentDBSession().save(txnNoteEntity) } override fun getTransactionNotes(txnId: SecureHash): Iterable { - val session = DatabaseTransactionManager.current().session + val session = currentDBSession() val criteriaBuilder = session.criteriaBuilder val criteriaQuery = criteriaBuilder.createQuery(VaultSchemaV1.VaultTxnNote::class.java) val vaultStates = criteriaQuery.from(VaultSchemaV1.VaultTxnNote::class.java) @@ -241,35 +251,34 @@ class NodeVaultService(private val clock: Clock, private val keyManagementServic override fun softLockReserve(lockId: UUID, stateRefs: NonEmptySet) { val softLockTimestamp = clock.instant() try { - val session = DatabaseTransactionManager.current().session + val session = currentDBSession() val criteriaBuilder = session.criteriaBuilder - val criteriaUpdate = criteriaBuilder.createCriteriaUpdate(VaultSchemaV1.VaultStates::class.java) - val vaultStates = criteriaUpdate.from(VaultSchemaV1.VaultStates::class.java) - val stateStatusPredication = criteriaBuilder.equal(vaultStates.get(VaultSchemaV1.VaultStates::stateStatus.name), Vault.StateStatus.UNCONSUMED) - val lockIdPredicate = criteriaBuilder.or(vaultStates.get(VaultSchemaV1.VaultStates::lockId.name).isNull, - criteriaBuilder.equal(vaultStates.get(VaultSchemaV1.VaultStates::lockId.name), lockId.toString())) - val persistentStateRefs = stateRefs.map { PersistentStateRef(it.txhash.bytes.toHexString(), it.index) } - val compositeKey = vaultStates.get(VaultSchemaV1.VaultStates::stateRef.name) - val stateRefsPredicate = criteriaBuilder.and(compositeKey.`in`(persistentStateRefs)) - criteriaUpdate.set(vaultStates.get(VaultSchemaV1.VaultStates::lockId.name), lockId.toString()) - criteriaUpdate.set(vaultStates.get(VaultSchemaV1.VaultStates::lockUpdateTime.name), softLockTimestamp) - criteriaUpdate.where(stateStatusPredication, lockIdPredicate, stateRefsPredicate) - val updatedRows = session.createQuery(criteriaUpdate).executeUpdate() + fun execute(configure: Root<*>.(CriteriaUpdate<*>, Array) -> Any?) = criteriaBuilder.executeUpdate(session) { update -> + val persistentStateRefs = stateRefs.map { PersistentStateRef(it.txhash.bytes.toHexString(), it.index) } + val compositeKey = get(VaultSchemaV1.VaultStates::stateRef.name) + val stateRefsPredicate = criteriaBuilder.and(compositeKey.`in`(persistentStateRefs)) + configure(update, arrayOf(stateRefsPredicate)) + } + + val updatedRows = execute { update, commonPredicates -> + val stateStatusPredication = criteriaBuilder.equal(get(VaultSchemaV1.VaultStates::stateStatus.name), Vault.StateStatus.UNCONSUMED) + val lockIdPredicate = criteriaBuilder.or(get(VaultSchemaV1.VaultStates::lockId.name).isNull, + criteriaBuilder.equal(get(VaultSchemaV1.VaultStates::lockId.name), lockId.toString())) + update.set(get(VaultSchemaV1.VaultStates::lockId.name), lockId.toString()) + update.set(get(VaultSchemaV1.VaultStates::lockUpdateTime.name), softLockTimestamp) + update.where(stateStatusPredication, lockIdPredicate, *commonPredicates) + } if (updatedRows > 0 && updatedRows == stateRefs.size) { log.trace("Reserving soft lock states for $lockId: $stateRefs") FlowStateMachineImpl.currentStateMachine()?.hasSoftLockedStates = true } else { // revert partial soft locks - val criteriaRevertUpdate = criteriaBuilder.createCriteriaUpdate(VaultSchemaV1.VaultStates::class.java) - val vaultStatesRevert = criteriaRevertUpdate.from(VaultSchemaV1.VaultStates::class.java) - val lockIdPredicateRevert = criteriaBuilder.equal(vaultStatesRevert.get(VaultSchemaV1.VaultStates::lockId.name), lockId.toString()) - val lockUpdateTime = criteriaBuilder.equal(vaultStatesRevert.get(VaultSchemaV1.VaultStates::lockUpdateTime.name), softLockTimestamp) - val persistentStateRefsRevert = stateRefs.map { PersistentStateRef(it.txhash.bytes.toHexString(), it.index) } - val compositeKeyRevert = vaultStatesRevert.get(VaultSchemaV1.VaultStates::stateRef.name) - val stateRefsPredicateRevert = criteriaBuilder.and(compositeKeyRevert.`in`(persistentStateRefsRevert)) - criteriaRevertUpdate.set(vaultStatesRevert.get(VaultSchemaV1.VaultStates::lockId.name), criteriaBuilder.nullLiteral(String::class.java)) - criteriaRevertUpdate.where(lockUpdateTime, lockIdPredicateRevert, stateRefsPredicateRevert) - val revertUpdatedRows = session.createQuery(criteriaRevertUpdate).executeUpdate() + val revertUpdatedRows = execute { update, commonPredicates -> + val lockIdPredicate = criteriaBuilder.equal(get(VaultSchemaV1.VaultStates::lockId.name), lockId.toString()) + val lockUpdateTime = criteriaBuilder.equal(get(VaultSchemaV1.VaultStates::lockUpdateTime.name), softLockTimestamp) + update.set(get(VaultSchemaV1.VaultStates::lockId.name), criteriaBuilder.nullLiteral(String::class.java)) + update.where(lockUpdateTime, lockIdPredicate, *commonPredicates) + } if (revertUpdatedRows > 0) { log.trace("Reverting $revertUpdatedRows partially soft locked states for $lockId") } @@ -286,33 +295,30 @@ class NodeVaultService(private val clock: Clock, private val keyManagementServic override fun softLockRelease(lockId: UUID, stateRefs: NonEmptySet?) { val softLockTimestamp = clock.instant() - val session = DatabaseTransactionManager.current().session + val session = currentDBSession() val criteriaBuilder = session.criteriaBuilder + fun execute(configure: Root<*>.(CriteriaUpdate<*>, Array) -> Any?) = criteriaBuilder.executeUpdate(session) { update -> + val stateStatusPredication = criteriaBuilder.equal(get(VaultSchemaV1.VaultStates::stateStatus.name), Vault.StateStatus.UNCONSUMED) + val lockIdPredicate = criteriaBuilder.equal(get(VaultSchemaV1.VaultStates::lockId.name), lockId.toString()) + update.set(get(VaultSchemaV1.VaultStates::lockId.name), criteriaBuilder.nullLiteral(String::class.java)) + update.set(get(VaultSchemaV1.VaultStates::lockUpdateTime.name), softLockTimestamp) + configure(update, arrayOf(stateStatusPredication, lockIdPredicate)) + } if (stateRefs == null) { - val criteriaUpdate = criteriaBuilder.createCriteriaUpdate(VaultSchemaV1.VaultStates::class.java) - val vaultStates = criteriaUpdate.from(VaultSchemaV1.VaultStates::class.java) - val stateStatusPredication = criteriaBuilder.equal(vaultStates.get(VaultSchemaV1.VaultStates::stateStatus.name), Vault.StateStatus.UNCONSUMED) - val lockIdPredicate = criteriaBuilder.equal(vaultStates.get(VaultSchemaV1.VaultStates::lockId.name), lockId.toString()) - criteriaUpdate.set(vaultStates.get(VaultSchemaV1.VaultStates::lockId.name), criteriaBuilder.nullLiteral(String::class.java)) - criteriaUpdate.set(vaultStates.get(VaultSchemaV1.VaultStates::lockUpdateTime.name), softLockTimestamp) - criteriaUpdate.where(stateStatusPredication, lockIdPredicate) - val update = session.createQuery(criteriaUpdate).executeUpdate() + val update = execute { update, commonPredicates -> + update.where(*commonPredicates) + } if (update > 0) { log.trace("Releasing $update soft locked states for $lockId") } } else { try { - val criteriaUpdate = criteriaBuilder.createCriteriaUpdate(VaultSchemaV1.VaultStates::class.java) - val vaultStates = criteriaUpdate.from(VaultSchemaV1.VaultStates::class.java) - val stateStatusPredication = criteriaBuilder.equal(vaultStates.get(VaultSchemaV1.VaultStates::stateStatus.name), Vault.StateStatus.UNCONSUMED) - val lockIdPredicate = criteriaBuilder.equal(vaultStates.get(VaultSchemaV1.VaultStates::lockId.name), lockId.toString()) - val persistentStateRefs = stateRefs.map { PersistentStateRef(it.txhash.bytes.toHexString(), it.index) } - val compositeKey = vaultStates.get(VaultSchemaV1.VaultStates::stateRef.name) - val stateRefsPredicate = criteriaBuilder.and(compositeKey.`in`(persistentStateRefs)) - criteriaUpdate.set(vaultStates.get(VaultSchemaV1.VaultStates::lockId.name), criteriaBuilder.nullLiteral(String::class.java)) - criteriaUpdate.set(vaultStates.get(VaultSchemaV1.VaultStates::lockUpdateTime.name), softLockTimestamp) - criteriaUpdate.where(stateStatusPredication, lockIdPredicate, stateRefsPredicate) - val updatedRows = session.createQuery(criteriaUpdate).executeUpdate() + val updatedRows = execute { update, commonPredicates -> + val persistentStateRefs = stateRefs.map { PersistentStateRef(it.txhash.bytes.toHexString(), it.index) } + val compositeKey = get(VaultSchemaV1.VaultStates::stateRef.name) + val stateRefsPredicate = criteriaBuilder.and(compositeKey.`in`(persistentStateRefs)) + update.where(*commonPredicates, stateRefsPredicate) + } if (updatedRows > 0) { log.trace("Releasing $updatedRows soft locked states for $lockId and stateRefs $stateRefs") } diff --git a/node/src/main/kotlin/net/corda/node/utilities/AppendOnlyPersistentMap.kt b/node/src/main/kotlin/net/corda/node/utilities/AppendOnlyPersistentMap.kt index 62508f0e2b..a90f77d171 100644 --- a/node/src/main/kotlin/net/corda/node/utilities/AppendOnlyPersistentMap.kt +++ b/node/src/main/kotlin/net/corda/node/utilities/AppendOnlyPersistentMap.kt @@ -40,10 +40,11 @@ class AppendOnlyPersistentMap( * Returns all key/value pairs from the underlying storage. */ fun allPersisted(): Sequence> { - val criteriaQuery = DatabaseTransactionManager.current().session.criteriaBuilder.createQuery(persistentEntityClass) + val session = currentDBSession() + val criteriaQuery = session.criteriaBuilder.createQuery(persistentEntityClass) val root = criteriaQuery.from(persistentEntityClass) criteriaQuery.select(root) - val query = DatabaseTransactionManager.current().session.createQuery(criteriaQuery) + val query = session.createQuery(criteriaQuery) val result = query.resultList return result.map { x -> fromPersistentEntity(x) }.asSequence() } @@ -87,7 +88,7 @@ class AppendOnlyPersistentMap( */ operator fun set(key: K, value: V) = set(key, value, logWarning = false) { k, v -> - DatabaseTransactionManager.current().session.save(toPersistentEntity(k, v)) + currentDBSession().save(toPersistentEntity(k, v)) null } @@ -98,9 +99,10 @@ class AppendOnlyPersistentMap( */ fun addWithDuplicatesAllowed(key: K, value: V, logWarning: Boolean = true): Boolean = set(key, value, logWarning) { k, v -> - val existingEntry = DatabaseTransactionManager.current().session.find(persistentEntityClass, toPersistentEntityKey(k)) + val session = currentDBSession() + val existingEntry = session.find(persistentEntityClass, toPersistentEntityKey(k)) if (existingEntry == null) { - DatabaseTransactionManager.current().session.save(toPersistentEntity(k, v)) + session.save(toPersistentEntity(k, v)) null } else { fromPersistentEntity(existingEntry).second @@ -114,7 +116,7 @@ class AppendOnlyPersistentMap( } private fun loadValue(key: K): V? { - val result = DatabaseTransactionManager.current().session.find(persistentEntityClass, toPersistentEntityKey(key)) + val result = currentDBSession().find(persistentEntityClass, toPersistentEntityKey(key)) return result?.let(fromPersistentEntity)?.second } @@ -125,7 +127,7 @@ class AppendOnlyPersistentMap( * WARNING!! The method is not thread safe. */ fun clear() { - val session = DatabaseTransactionManager.current().session + val session = currentDBSession() val deleteQuery = session.criteriaBuilder.createCriteriaDelete(persistentEntityClass) deleteQuery.from(persistentEntityClass) session.createQuery(deleteQuery).executeUpdate() diff --git a/node/src/main/kotlin/net/corda/node/utilities/DatabaseTransactionManager.kt b/node/src/main/kotlin/net/corda/node/utilities/DatabaseTransactionManager.kt index 9016112f7b..6c810a7005 100644 --- a/node/src/main/kotlin/net/corda/node/utilities/DatabaseTransactionManager.kt +++ b/node/src/main/kotlin/net/corda/node/utilities/DatabaseTransactionManager.kt @@ -62,6 +62,7 @@ class DatabaseTransaction(isolation: Int, val threadLocal: ThreadLocal() diff --git a/node/src/main/kotlin/net/corda/node/utilities/PersistentMap.kt b/node/src/main/kotlin/net/corda/node/utilities/PersistentMap.kt index c24ce3c229..11ace024ef 100644 --- a/node/src/main/kotlin/net/corda/node/utilities/PersistentMap.kt +++ b/node/src/main/kotlin/net/corda/node/utilities/PersistentMap.kt @@ -28,7 +28,7 @@ class PersistentMap( removalListener = ExplicitRemoval(toPersistentEntityKey, persistentEntityClass) ).apply { //preload to allow all() to take data only from the cache (cache is unbound) - val session = DatabaseTransactionManager.current().session + val session = currentDBSession() val criteriaQuery = session.criteriaBuilder.createQuery(persistentEntityClass) criteriaQuery.select(criteriaQuery.from(persistentEntityClass)) getAll(session.createQuery(criteriaQuery).resultList.map { e -> fromPersistentEntity(e as E).first }.asIterable()) @@ -38,7 +38,7 @@ class PersistentMap( override fun onRemoval(notification: RemovalNotification?) { when (notification?.cause) { RemovalCause.EXPLICIT -> { - val session = DatabaseTransactionManager.current().session + val session = currentDBSession() val elem = session.find(persistentEntityClass, toPersistentEntityKey(notification.key)) if (elem != null) { session.remove(elem) @@ -101,7 +101,7 @@ class PersistentMap( set(key, value, logWarning = false, store = { k: K, v: V -> - DatabaseTransactionManager.current().session.save(toPersistentEntity(k, v)) + currentDBSession().save(toPersistentEntity(k, v)) null }, replace = { _: K, _: V -> Unit } @@ -115,9 +115,10 @@ class PersistentMap( fun addWithDuplicatesAllowed(key: K, value: V) = set(key, value, store = { k, v -> - val existingEntry = DatabaseTransactionManager.current().session.find(persistentEntityClass, toPersistentEntityKey(k)) + val session = currentDBSession() + val existingEntry = session.find(persistentEntityClass, toPersistentEntityKey(k)) if (existingEntry == null) { - DatabaseTransactionManager.current().session.save(toPersistentEntity(k, v)) + session.save(toPersistentEntity(k, v)) null } else { fromPersistentEntity(existingEntry).second @@ -145,18 +146,19 @@ class PersistentMap( } private fun merge(key: K, value: V): V? { - val existingEntry = DatabaseTransactionManager.current().session.find(persistentEntityClass, toPersistentEntityKey(key)) + val session = currentDBSession() + val existingEntry = session.find(persistentEntityClass, toPersistentEntityKey(key)) return if (existingEntry != null) { - DatabaseTransactionManager.current().session.merge(toPersistentEntity(key, value)) + session.merge(toPersistentEntity(key, value)) fromPersistentEntity(existingEntry).second } else { - DatabaseTransactionManager.current().session.save(toPersistentEntity(key, value)) + session.save(toPersistentEntity(key, value)) null } } private fun loadValue(key: K): V? { - val result = DatabaseTransactionManager.current().session.find(persistentEntityClass, toPersistentEntityKey(key)) + val result = currentDBSession().find(persistentEntityClass, toPersistentEntityKey(key)) return result?.let(fromPersistentEntity)?.second } @@ -256,7 +258,7 @@ class PersistentMap( } fun load() { - val session = DatabaseTransactionManager.current().session + val session = currentDBSession() val criteriaQuery = session.criteriaBuilder.createQuery(persistentEntityClass) criteriaQuery.select(criteriaQuery.from(persistentEntityClass)) cache.getAll(session.createQuery(criteriaQuery).resultList.map { e -> fromPersistentEntity(e as E).first }.asIterable()) diff --git a/node/src/test/kotlin/net/corda/node/services/persistence/NodeAttachmentStorageTest.kt b/node/src/test/kotlin/net/corda/node/services/persistence/NodeAttachmentStorageTest.kt index 46e954ecde..d3ee944069 100644 --- a/node/src/test/kotlin/net/corda/node/services/persistence/NodeAttachmentStorageTest.kt +++ b/node/src/test/kotlin/net/corda/node/services/persistence/NodeAttachmentStorageTest.kt @@ -11,7 +11,6 @@ import net.corda.core.internal.write import net.corda.core.internal.writeLines import net.corda.node.services.transactions.PersistentUniquenessProvider import net.corda.node.utilities.CordaPersistence -import net.corda.node.utilities.DatabaseTransactionManager import net.corda.node.utilities.configureDatabase import net.corda.testing.LogHelper import net.corda.testing.node.MockServices.Companion.makeTestDataSourceProperties @@ -98,19 +97,18 @@ class NodeAttachmentStorageTest { @Test fun `corrupt entry throws exception`() { val testJar = makeTestJar() - val id = - database.transaction { - val storage = NodeAttachmentService(MetricRegistry()) - val id = testJar.read { storage.importAttachment(it) } + val id = database.transaction { + val storage = NodeAttachmentService(MetricRegistry()) + val id = testJar.read { storage.importAttachment(it) } - // Corrupt the file in the store. - val bytes = testJar.readAll() - val corruptBytes = "arggghhhh".toByteArray() - System.arraycopy(corruptBytes, 0, bytes, 0, corruptBytes.size) - val corruptAttachment = NodeAttachmentService.DBAttachment(attId = id.toString(), content = bytes) - DatabaseTransactionManager.current().session.merge(corruptAttachment) - id - } + // Corrupt the file in the store. + val bytes = testJar.readAll() + val corruptBytes = "arggghhhh".toByteArray() + System.arraycopy(corruptBytes, 0, bytes, 0, corruptBytes.size) + val corruptAttachment = NodeAttachmentService.DBAttachment(attId = id.toString(), content = bytes) + session.merge(corruptAttachment) + id + } database.transaction { val storage = NodeAttachmentService(MetricRegistry()) val e = assertFailsWith {