CORDA-530 Unduplicate code (#1791)

This commit is contained in:
Andrzej Cichocki 2017-10-13 12:15:52 +01:00 committed by GitHub
parent 7b10e92819
commit ce5b7de718
10 changed files with 152 additions and 167 deletions

View File

@ -12,7 +12,6 @@ import net.corda.core.utilities.getOrThrow
import net.corda.node.internal.StartedNode
import net.corda.node.services.config.NodeConfiguration
import net.corda.node.services.persistence.NodeAttachmentService
import net.corda.node.utilities.DatabaseTransactionManager
import net.corda.nodeapi.internal.ServiceInfo
import net.corda.testing.ALICE
import net.corda.testing.ALICE_NAME
@ -147,7 +146,7 @@ class AttachmentTests {
val corruptAttachment = NodeAttachmentService.DBAttachment(attId = id.toString(), content = attachment)
aliceNode.database.transaction {
DatabaseTransactionManager.current().session.update(corruptAttachment)
session.update(corruptAttachment)
}
// Get n1 to fetch the attachment. Should receive corrupted bytes.

View File

@ -16,7 +16,7 @@ import net.corda.node.internal.InitiatedFlowFactory
import net.corda.node.internal.StartedNode
import net.corda.node.services.config.NodeConfiguration
import net.corda.node.services.persistence.NodeAttachmentService
import net.corda.node.utilities.DatabaseTransactionManager
import net.corda.node.utilities.currentDBSession
import net.corda.nodeapi.internal.ServiceInfo
import net.corda.testing.chooseIdentity
import net.corda.testing.node.MockNetwork
@ -54,7 +54,7 @@ private fun StartedNode<*>.hackAttachment(attachmentId: SecureHash, content: Str
* @see NodeAttachmentService.importAttachment
*/
private fun updateAttachment(attachmentId: SecureHash, data: ByteArray) {
val session = DatabaseTransactionManager.current().session
val session = currentDBSession()
val attachment = session.get<NodeAttachmentService.DBAttachment>(NodeAttachmentService.DBAttachment::class.java, attachmentId.toString())
attachment?.let {
attachment.content = data

View File

@ -31,7 +31,6 @@ import net.corda.node.services.messaging.sendRequest
import net.corda.node.services.network.NetworkMapService.FetchMapResponse
import net.corda.node.services.network.NetworkMapService.SubscribeResponse
import net.corda.node.utilities.AddOrRemove
import net.corda.node.utilities.DatabaseTransactionManager
import net.corda.node.utilities.bufferUntilDatabaseCommit
import net.corda.node.utilities.wrapWithDatabaseTransaction
import org.hibernate.Session
@ -39,7 +38,6 @@ import rx.Observable
import rx.subjects.PublishSubject
import java.security.PublicKey
import java.security.SignatureException
import java.time.Duration
import java.util.*
import javax.annotation.concurrent.ThreadSafe
import kotlin.collections.HashMap
@ -93,7 +91,7 @@ open class PersistentNetworkMapCache(private val serviceHub: ServiceHubInternal)
init {
loadFromFiles()
serviceHub.database.transaction { loadFromDB() }
serviceHub.database.transaction { loadFromDB(session) }
}
private fun loadFromFiles() {
@ -102,7 +100,7 @@ open class PersistentNetworkMapCache(private val serviceHub: ServiceHubInternal)
}
override fun getPartyInfo(party: Party): PartyInfo? {
val nodes = serviceHub.database.transaction { queryByIdentityKey(party.owningKey) }
val nodes = serviceHub.database.transaction { queryByIdentityKey(session, party.owningKey) }
if (nodes.size == 1 && nodes[0].isLegalIdentity(party)) {
return PartyInfo.SingleNode(party, nodes[0].addresses)
}
@ -117,9 +115,9 @@ open class PersistentNetworkMapCache(private val serviceHub: ServiceHubInternal)
}
override fun getNodeByLegalName(name: CordaX500Name): NodeInfo? = getNodesByLegalName(name).firstOrNull()
override fun getNodesByLegalName(name: CordaX500Name): List<NodeInfo> = serviceHub.database.transaction { queryByLegalName(name) }
override fun getNodesByLegalName(name: CordaX500Name): List<NodeInfo> = serviceHub.database.transaction { queryByLegalName(session, name) }
override fun getNodesByLegalIdentityKey(identityKey: PublicKey): List<NodeInfo> =
serviceHub.database.transaction { queryByIdentityKey(identityKey) }
serviceHub.database.transaction { queryByIdentityKey(session, identityKey) }
override fun getNodeByLegalIdentity(party: AbstractParty): NodeInfo? {
val wellKnownParty = serviceHub.identityService.wellKnownPartyFromAnonymous(party)
@ -128,9 +126,9 @@ open class PersistentNetworkMapCache(private val serviceHub: ServiceHubInternal)
}
}
override fun getNodeByAddress(address: NetworkHostAndPort): NodeInfo? = serviceHub.database.transaction { queryByAddress(address) }
override fun getNodeByAddress(address: NetworkHostAndPort): NodeInfo? = serviceHub.database.transaction { queryByAddress(session, address) }
override fun getPeerCertificateByLegalName(name: CordaX500Name): PartyAndCertificate? = serviceHub.database.transaction { queryIdentityByLegalName(name) }
override fun getPeerCertificateByLegalName(name: CordaX500Name): PartyAndCertificate? = serviceHub.database.transaction { queryIdentityByLegalName(session, name) }
override fun track(): DataFeed<List<NodeInfo>, MapChange> {
synchronized(_changed) {
@ -204,7 +202,7 @@ open class PersistentNetworkMapCache(private val serviceHub: ServiceHubInternal)
synchronized(_changed) {
registeredNodes.remove(node.legalIdentities.first().owningKey)
serviceHub.database.transaction {
removeInfoDB(node)
removeInfoDB(session, node)
changePublisher.onNext(MapChange.Removed(node))
}
}
@ -238,10 +236,8 @@ open class PersistentNetworkMapCache(private val serviceHub: ServiceHubInternal)
}
override val allNodes: List<NodeInfo>
get () = serviceHub.database.transaction {
createSession {
getAllInfos(it).map { it.toNodeInfo() }
}
get() = serviceHub.database.transaction {
getAllInfos(session).map { it.toNodeInfo() }
}
private fun processRegistration(reg: NodeRegistration) {
@ -259,10 +255,6 @@ open class PersistentNetworkMapCache(private val serviceHub: ServiceHubInternal)
// Changes related to NetworkMap redesign
// TODO It will be properly merged into network map cache after services removal.
private inline fun <T> createSession(block: (Session) -> T): T {
return DatabaseTransactionManager.current().session.let { block(it) }
}
private fun getAllInfos(session: Session): List<NodeInfoSchemaV1.PersistentNodeInfo> {
val criteria = session.criteriaBuilder.createQuery(NodeInfoSchemaV1.PersistentNodeInfo::class.java)
criteria.select(criteria.from(NodeInfoSchemaV1.PersistentNodeInfo::class.java))
@ -272,24 +264,22 @@ open class PersistentNetworkMapCache(private val serviceHub: ServiceHubInternal)
/**
* Load NetworkMap data from the database if present. Node can start without having NetworkMapService configured.
*/
private fun loadFromDB() {
private fun loadFromDB(session: Session) {
logger.info("Loading network map from database...")
createSession {
val result = getAllInfos(it)
for (nodeInfo in result) {
try {
logger.info("Loaded node info: $nodeInfo")
val node = nodeInfo.toNodeInfo()
addNode(node)
_loadDBSuccess = true // This is used in AbstractNode to indicate that node is ready.
} catch (e: Exception) {
logger.warn("Exception parsing network map from the database.", e)
}
}
if (loadDBSuccess) {
_registrationFuture.set(null) // Useful only if we don't have NetworkMapService configured so StateMachineManager can start.
val result = getAllInfos(session)
for (nodeInfo in result) {
try {
logger.info("Loaded node info: $nodeInfo")
val node = nodeInfo.toNodeInfo()
addNode(node)
_loadDBSuccess = true // This is used in AbstractNode to indicate that node is ready.
} catch (e: Exception) {
logger.warn("Exception parsing network map from the database.", e)
}
}
if (loadDBSuccess) {
_registrationFuture.set(null) // Useful only if we don't have NetworkMapService configured so StateMachineManager can start.
}
}
private fun updateInfoDB(nodeInfo: NodeInfo) {
@ -313,11 +303,9 @@ open class PersistentNetworkMapCache(private val serviceHub: ServiceHubInternal)
}
}
private fun removeInfoDB(nodeInfo: NodeInfo) {
createSession {
val info = findByIdentityKey(it, nodeInfo.legalIdentitiesAndCerts.first().owningKey).single()
it.remove(info)
}
private fun removeInfoDB(session: Session, nodeInfo: NodeInfo) {
val info = findByIdentityKey(session, nodeInfo.legalIdentitiesAndCerts.first().owningKey).single()
session.remove(info)
}
private fun findByIdentityKey(session: Session, identityKey: PublicKey): List<NodeInfoSchemaV1.PersistentNodeInfo> {
@ -328,48 +316,40 @@ open class PersistentNetworkMapCache(private val serviceHub: ServiceHubInternal)
return query.resultList
}
private fun queryByIdentityKey(identityKey: PublicKey): List<NodeInfo> {
createSession {
val result = findByIdentityKey(it, identityKey)
return result.map { it.toNodeInfo() }
}
private fun queryByIdentityKey(session: Session, identityKey: PublicKey): List<NodeInfo> {
val result = findByIdentityKey(session, identityKey)
return result.map { it.toNodeInfo() }
}
private fun queryIdentityByLegalName(name: CordaX500Name): PartyAndCertificate? {
createSession {
val query = it.createQuery(
// We do the JOIN here to restrict results to those present in the network map
"SELECT DISTINCT l FROM ${NodeInfoSchemaV1.PersistentNodeInfo::class.java.name} n JOIN n.legalIdentitiesAndCerts l WHERE l.name = :name",
NodeInfoSchemaV1.DBPartyAndCertificate::class.java)
query.setParameter("name", name.toString())
val candidates = query.resultList.map { it.toLegalIdentityAndCert() }
// The map is restricted to holding a single identity for any X.500 name, so firstOrNull() is correct here.
return candidates.firstOrNull()
}
private fun queryIdentityByLegalName(session: Session, name: CordaX500Name): PartyAndCertificate? {
val query = session.createQuery(
// We do the JOIN here to restrict results to those present in the network map
"SELECT DISTINCT l FROM ${NodeInfoSchemaV1.PersistentNodeInfo::class.java.name} n JOIN n.legalIdentitiesAndCerts l WHERE l.name = :name",
NodeInfoSchemaV1.DBPartyAndCertificate::class.java)
query.setParameter("name", name.toString())
val candidates = query.resultList.map { it.toLegalIdentityAndCert() }
// The map is restricted to holding a single identity for any X.500 name, so firstOrNull() is correct here.
return candidates.firstOrNull()
}
private fun queryByLegalName(name: CordaX500Name): List<NodeInfo> {
createSession {
val query = it.createQuery(
"SELECT n FROM ${NodeInfoSchemaV1.PersistentNodeInfo::class.java.name} n JOIN n.legalIdentitiesAndCerts l WHERE l.name = :name",
NodeInfoSchemaV1.PersistentNodeInfo::class.java)
query.setParameter("name", name.toString())
val result = query.resultList
return result.map { it.toNodeInfo() }
}
private fun queryByLegalName(session: Session, name: CordaX500Name): List<NodeInfo> {
val query = session.createQuery(
"SELECT n FROM ${NodeInfoSchemaV1.PersistentNodeInfo::class.java.name} n JOIN n.legalIdentitiesAndCerts l WHERE l.name = :name",
NodeInfoSchemaV1.PersistentNodeInfo::class.java)
query.setParameter("name", name.toString())
val result = query.resultList
return result.map { it.toNodeInfo() }
}
private fun queryByAddress(hostAndPort: NetworkHostAndPort): NodeInfo? {
createSession {
val query = it.createQuery(
"SELECT n FROM ${NodeInfoSchemaV1.PersistentNodeInfo::class.java.name} n JOIN n.addresses a WHERE a.pk.host = :host AND a.pk.port = :port",
NodeInfoSchemaV1.PersistentNodeInfo::class.java)
query.setParameter("host", hostAndPort.host)
query.setParameter("port", hostAndPort.port)
val result = query.resultList
return if (result.isEmpty()) null
else result.map { it.toNodeInfo() }.singleOrNull() ?: throw IllegalStateException("More than one node with the same host and port")
}
private fun queryByAddress(session: Session, hostAndPort: NetworkHostAndPort): NodeInfo? {
val query = session.createQuery(
"SELECT n FROM ${NodeInfoSchemaV1.PersistentNodeInfo::class.java.name} n JOIN n.addresses a WHERE a.pk.host = :host AND a.pk.port = :port",
NodeInfoSchemaV1.PersistentNodeInfo::class.java)
query.setParameter("host", hostAndPort.host)
query.setParameter("port", hostAndPort.port)
val result = query.resultList
return if (result.isEmpty()) null
else result.map { it.toNodeInfo() }.singleOrNull() ?: throw IllegalStateException("More than one node with the same host and port")
}
/** Object Relational Mapping support. */
@ -388,10 +368,8 @@ open class PersistentNetworkMapCache(private val serviceHub: ServiceHubInternal)
override fun clearNetworkMapCache() {
serviceHub.database.transaction {
createSession {
val result = getAllInfos(it)
for (nodeInfo in result) it.remove(nodeInfo)
}
val result = getAllInfos(session)
for (nodeInfo in result) session.remove(nodeInfo)
}
}
}

View File

@ -3,8 +3,8 @@ package net.corda.node.services.persistence
import net.corda.core.serialization.SerializedBytes
import net.corda.node.services.api.Checkpoint
import net.corda.node.services.api.CheckpointStorage
import net.corda.node.utilities.DatabaseTransactionManager
import net.corda.node.utilities.NODE_DATABASE_PREFIX
import net.corda.node.utilities.currentDBSession
import javax.persistence.Column
import javax.persistence.Entity
import javax.persistence.Id
@ -28,15 +28,14 @@ class DBCheckpointStorage : CheckpointStorage {
)
override fun addCheckpoint(checkpoint: Checkpoint) {
val session = DatabaseTransactionManager.current().session
session.save(DBCheckpoint().apply {
currentDBSession().save(DBCheckpoint().apply {
checkpointId = checkpoint.id.toString()
this.checkpoint = checkpoint.serializedFiber.bytes
})
}
override fun removeCheckpoint(checkpoint: Checkpoint) {
val session = DatabaseTransactionManager.current().session
val session = currentDBSession()
val criteriaBuilder = session.criteriaBuilder
val delete = criteriaBuilder.createCriteriaDelete(DBCheckpoint::class.java)
val root = delete.from(DBCheckpoint::class.java)
@ -45,7 +44,7 @@ class DBCheckpointStorage : CheckpointStorage {
}
override fun forEach(block: (Checkpoint) -> Boolean) {
val session = DatabaseTransactionManager.current().session
val session = currentDBSession()
val criteriaQuery = session.criteriaBuilder.createQuery(DBCheckpoint::class.java)
val root = criteriaQuery.from(DBCheckpoint::class.java)
criteriaQuery.select(root)

View File

@ -13,8 +13,8 @@ import net.corda.core.crypto.SecureHash
import net.corda.core.node.services.AttachmentStorage
import net.corda.core.serialization.*
import net.corda.core.utilities.loggerFor
import net.corda.node.utilities.DatabaseTransactionManager
import net.corda.node.utilities.NODE_DATABASE_PREFIX
import net.corda.node.utilities.currentDBSession
import java.io.*
import java.nio.file.Paths
import java.util.jar.JarInputStream
@ -50,7 +50,7 @@ class NodeAttachmentService(metrics: MetricRegistry) : AttachmentStorage, Single
private val attachmentCount = metrics.counter("Attachments")
init {
val session = DatabaseTransactionManager.current().session
val session = currentDBSession()
val criteriaBuilder = session.criteriaBuilder
val criteriaQuery = criteriaBuilder.createQuery(Long::class.java)
criteriaQuery.select(criteriaBuilder.count(criteriaQuery.from(NodeAttachmentService.DBAttachment::class.java)))
@ -140,7 +140,7 @@ class NodeAttachmentService(metrics: MetricRegistry) : AttachmentStorage, Single
}
override fun openAttachment(id: SecureHash): Attachment? {
val attachment = DatabaseTransactionManager.current().session.get(NodeAttachmentService.DBAttachment::class.java, id.toString())
val attachment = currentDBSession().get(NodeAttachmentService.DBAttachment::class.java, id.toString())
attachment?.let {
return AttachmentImpl(id, { attachment.content }, checkAttachmentsOnLoad)
}
@ -161,7 +161,7 @@ class NodeAttachmentService(metrics: MetricRegistry) : AttachmentStorage, Single
checkIsAValidJAR(ByteArrayInputStream(bytes))
val id = SecureHash.SHA256(hs.hash().asBytes())
val session = DatabaseTransactionManager.current().session
val session = currentDBSession()
val criteriaBuilder = session.criteriaBuilder
val criteriaQuery = criteriaBuilder.createQuery(Long::class.java)
val attachments = criteriaQuery.from(NodeAttachmentService.DBAttachment::class.java)

View File

@ -29,6 +29,7 @@ import net.corda.node.services.persistence.HibernateConfiguration
import net.corda.node.services.statemachine.FlowStateMachineImpl
import net.corda.node.utilities.DatabaseTransactionManager
import net.corda.node.utilities.bufferUntilDatabaseCommit
import net.corda.node.utilities.currentDBSession
import net.corda.node.utilities.wrapWithDatabaseTransaction
import org.hibernate.Session
import rx.Observable
@ -38,6 +39,15 @@ import java.time.Clock
import java.time.Instant
import java.util.*
import javax.persistence.Tuple
import javax.persistence.criteria.CriteriaBuilder
import javax.persistence.criteria.CriteriaUpdate
import javax.persistence.criteria.Predicate
import javax.persistence.criteria.Root
private fun CriteriaBuilder.executeUpdate(session: Session, configure: Root<*>.(CriteriaUpdate<*>) -> Any?) = createCriteriaUpdate(VaultSchemaV1.VaultStates::class.java).let { update ->
update.from(VaultSchemaV1.VaultStates::class.java).run { configure(update) }
session.createQuery(update).executeUpdate()
}
/**
* Currently, the node vault service is a very simple RDBMS backed implementation. It will change significantly when
@ -73,7 +83,7 @@ class NodeVaultService(private val clock: Clock, private val keyManagementServic
val consumedStateRefs = update.consumed.map { it.ref }
log.trace { "Removing $consumedStateRefs consumed contract states and adding $producedStateRefs produced contract states to the database." }
val session = DatabaseTransactionManager.current().session
val session = currentDBSession()
producedStateRefsMap.forEach { stateAndRef ->
val state = VaultSchemaV1.VaultStates(
notary = stateAndRef.value.state.notary,
@ -189,7 +199,7 @@ class NodeVaultService(private val clock: Clock, private val keyManagementServic
private fun loadStates(refs: Collection<StateRef>): HashSet<StateAndRef<ContractState>> {
val states = HashSet<StateAndRef<ContractState>>()
if (refs.isNotEmpty()) {
val session = DatabaseTransactionManager.current().session
val session = currentDBSession()
val criteriaBuilder = session.criteriaBuilder
val criteriaQuery = criteriaBuilder.createQuery(VaultSchemaV1.VaultStates::class.java)
val vaultStates = criteriaQuery.from(VaultSchemaV1.VaultStates::class.java)
@ -223,11 +233,11 @@ class NodeVaultService(private val clock: Clock, private val keyManagementServic
override fun addNoteToTransaction(txnId: SecureHash, noteText: String) {
val txnNoteEntity = VaultSchemaV1.VaultTxnNote(txnId.toString(), noteText)
DatabaseTransactionManager.current().session.save(txnNoteEntity)
currentDBSession().save(txnNoteEntity)
}
override fun getTransactionNotes(txnId: SecureHash): Iterable<String> {
val session = DatabaseTransactionManager.current().session
val session = currentDBSession()
val criteriaBuilder = session.criteriaBuilder
val criteriaQuery = criteriaBuilder.createQuery(VaultSchemaV1.VaultTxnNote::class.java)
val vaultStates = criteriaQuery.from(VaultSchemaV1.VaultTxnNote::class.java)
@ -241,35 +251,34 @@ class NodeVaultService(private val clock: Clock, private val keyManagementServic
override fun softLockReserve(lockId: UUID, stateRefs: NonEmptySet<StateRef>) {
val softLockTimestamp = clock.instant()
try {
val session = DatabaseTransactionManager.current().session
val session = currentDBSession()
val criteriaBuilder = session.criteriaBuilder
val criteriaUpdate = criteriaBuilder.createCriteriaUpdate(VaultSchemaV1.VaultStates::class.java)
val vaultStates = criteriaUpdate.from(VaultSchemaV1.VaultStates::class.java)
val stateStatusPredication = criteriaBuilder.equal(vaultStates.get<Vault.StateStatus>(VaultSchemaV1.VaultStates::stateStatus.name), Vault.StateStatus.UNCONSUMED)
val lockIdPredicate = criteriaBuilder.or(vaultStates.get<String>(VaultSchemaV1.VaultStates::lockId.name).isNull,
criteriaBuilder.equal(vaultStates.get<String>(VaultSchemaV1.VaultStates::lockId.name), lockId.toString()))
val persistentStateRefs = stateRefs.map { PersistentStateRef(it.txhash.bytes.toHexString(), it.index) }
val compositeKey = vaultStates.get<PersistentStateRef>(VaultSchemaV1.VaultStates::stateRef.name)
val stateRefsPredicate = criteriaBuilder.and(compositeKey.`in`(persistentStateRefs))
criteriaUpdate.set(vaultStates.get<String>(VaultSchemaV1.VaultStates::lockId.name), lockId.toString())
criteriaUpdate.set(vaultStates.get<Instant>(VaultSchemaV1.VaultStates::lockUpdateTime.name), softLockTimestamp)
criteriaUpdate.where(stateStatusPredication, lockIdPredicate, stateRefsPredicate)
val updatedRows = session.createQuery(criteriaUpdate).executeUpdate()
fun execute(configure: Root<*>.(CriteriaUpdate<*>, Array<Predicate>) -> Any?) = criteriaBuilder.executeUpdate(session) { update ->
val persistentStateRefs = stateRefs.map { PersistentStateRef(it.txhash.bytes.toHexString(), it.index) }
val compositeKey = get<PersistentStateRef>(VaultSchemaV1.VaultStates::stateRef.name)
val stateRefsPredicate = criteriaBuilder.and(compositeKey.`in`(persistentStateRefs))
configure(update, arrayOf(stateRefsPredicate))
}
val updatedRows = execute { update, commonPredicates ->
val stateStatusPredication = criteriaBuilder.equal(get<Vault.StateStatus>(VaultSchemaV1.VaultStates::stateStatus.name), Vault.StateStatus.UNCONSUMED)
val lockIdPredicate = criteriaBuilder.or(get<String>(VaultSchemaV1.VaultStates::lockId.name).isNull,
criteriaBuilder.equal(get<String>(VaultSchemaV1.VaultStates::lockId.name), lockId.toString()))
update.set(get<String>(VaultSchemaV1.VaultStates::lockId.name), lockId.toString())
update.set(get<Instant>(VaultSchemaV1.VaultStates::lockUpdateTime.name), softLockTimestamp)
update.where(stateStatusPredication, lockIdPredicate, *commonPredicates)
}
if (updatedRows > 0 && updatedRows == stateRefs.size) {
log.trace("Reserving soft lock states for $lockId: $stateRefs")
FlowStateMachineImpl.currentStateMachine()?.hasSoftLockedStates = true
} else {
// revert partial soft locks
val criteriaRevertUpdate = criteriaBuilder.createCriteriaUpdate(VaultSchemaV1.VaultStates::class.java)
val vaultStatesRevert = criteriaRevertUpdate.from(VaultSchemaV1.VaultStates::class.java)
val lockIdPredicateRevert = criteriaBuilder.equal(vaultStatesRevert.get<String>(VaultSchemaV1.VaultStates::lockId.name), lockId.toString())
val lockUpdateTime = criteriaBuilder.equal(vaultStatesRevert.get<Instant>(VaultSchemaV1.VaultStates::lockUpdateTime.name), softLockTimestamp)
val persistentStateRefsRevert = stateRefs.map { PersistentStateRef(it.txhash.bytes.toHexString(), it.index) }
val compositeKeyRevert = vaultStatesRevert.get<PersistentStateRef>(VaultSchemaV1.VaultStates::stateRef.name)
val stateRefsPredicateRevert = criteriaBuilder.and(compositeKeyRevert.`in`(persistentStateRefsRevert))
criteriaRevertUpdate.set(vaultStatesRevert.get<String>(VaultSchemaV1.VaultStates::lockId.name), criteriaBuilder.nullLiteral(String::class.java))
criteriaRevertUpdate.where(lockUpdateTime, lockIdPredicateRevert, stateRefsPredicateRevert)
val revertUpdatedRows = session.createQuery(criteriaRevertUpdate).executeUpdate()
val revertUpdatedRows = execute { update, commonPredicates ->
val lockIdPredicate = criteriaBuilder.equal(get<String>(VaultSchemaV1.VaultStates::lockId.name), lockId.toString())
val lockUpdateTime = criteriaBuilder.equal(get<Instant>(VaultSchemaV1.VaultStates::lockUpdateTime.name), softLockTimestamp)
update.set(get<String>(VaultSchemaV1.VaultStates::lockId.name), criteriaBuilder.nullLiteral(String::class.java))
update.where(lockUpdateTime, lockIdPredicate, *commonPredicates)
}
if (revertUpdatedRows > 0) {
log.trace("Reverting $revertUpdatedRows partially soft locked states for $lockId")
}
@ -286,33 +295,30 @@ class NodeVaultService(private val clock: Clock, private val keyManagementServic
override fun softLockRelease(lockId: UUID, stateRefs: NonEmptySet<StateRef>?) {
val softLockTimestamp = clock.instant()
val session = DatabaseTransactionManager.current().session
val session = currentDBSession()
val criteriaBuilder = session.criteriaBuilder
fun execute(configure: Root<*>.(CriteriaUpdate<*>, Array<Predicate>) -> Any?) = criteriaBuilder.executeUpdate(session) { update ->
val stateStatusPredication = criteriaBuilder.equal(get<Vault.StateStatus>(VaultSchemaV1.VaultStates::stateStatus.name), Vault.StateStatus.UNCONSUMED)
val lockIdPredicate = criteriaBuilder.equal(get<String>(VaultSchemaV1.VaultStates::lockId.name), lockId.toString())
update.set<String>(get<String>(VaultSchemaV1.VaultStates::lockId.name), criteriaBuilder.nullLiteral(String::class.java))
update.set(get<Instant>(VaultSchemaV1.VaultStates::lockUpdateTime.name), softLockTimestamp)
configure(update, arrayOf(stateStatusPredication, lockIdPredicate))
}
if (stateRefs == null) {
val criteriaUpdate = criteriaBuilder.createCriteriaUpdate(VaultSchemaV1.VaultStates::class.java)
val vaultStates = criteriaUpdate.from(VaultSchemaV1.VaultStates::class.java)
val stateStatusPredication = criteriaBuilder.equal(vaultStates.get<Vault.StateStatus>(VaultSchemaV1.VaultStates::stateStatus.name), Vault.StateStatus.UNCONSUMED)
val lockIdPredicate = criteriaBuilder.equal(vaultStates.get<String>(VaultSchemaV1.VaultStates::lockId.name), lockId.toString())
criteriaUpdate.set<String>(vaultStates.get<String>(VaultSchemaV1.VaultStates::lockId.name), criteriaBuilder.nullLiteral(String::class.java))
criteriaUpdate.set(vaultStates.get<Instant>(VaultSchemaV1.VaultStates::lockUpdateTime.name), softLockTimestamp)
criteriaUpdate.where(stateStatusPredication, lockIdPredicate)
val update = session.createQuery(criteriaUpdate).executeUpdate()
val update = execute { update, commonPredicates ->
update.where(*commonPredicates)
}
if (update > 0) {
log.trace("Releasing $update soft locked states for $lockId")
}
} else {
try {
val criteriaUpdate = criteriaBuilder.createCriteriaUpdate(VaultSchemaV1.VaultStates::class.java)
val vaultStates = criteriaUpdate.from(VaultSchemaV1.VaultStates::class.java)
val stateStatusPredication = criteriaBuilder.equal(vaultStates.get<Vault.StateStatus>(VaultSchemaV1.VaultStates::stateStatus.name), Vault.StateStatus.UNCONSUMED)
val lockIdPredicate = criteriaBuilder.equal(vaultStates.get<String>(VaultSchemaV1.VaultStates::lockId.name), lockId.toString())
val persistentStateRefs = stateRefs.map { PersistentStateRef(it.txhash.bytes.toHexString(), it.index) }
val compositeKey = vaultStates.get<PersistentStateRef>(VaultSchemaV1.VaultStates::stateRef.name)
val stateRefsPredicate = criteriaBuilder.and(compositeKey.`in`(persistentStateRefs))
criteriaUpdate.set<String>(vaultStates.get<String>(VaultSchemaV1.VaultStates::lockId.name), criteriaBuilder.nullLiteral(String::class.java))
criteriaUpdate.set(vaultStates.get<Instant>(VaultSchemaV1.VaultStates::lockUpdateTime.name), softLockTimestamp)
criteriaUpdate.where(stateStatusPredication, lockIdPredicate, stateRefsPredicate)
val updatedRows = session.createQuery(criteriaUpdate).executeUpdate()
val updatedRows = execute { update, commonPredicates ->
val persistentStateRefs = stateRefs.map { PersistentStateRef(it.txhash.bytes.toHexString(), it.index) }
val compositeKey = get<PersistentStateRef>(VaultSchemaV1.VaultStates::stateRef.name)
val stateRefsPredicate = criteriaBuilder.and(compositeKey.`in`(persistentStateRefs))
update.where(*commonPredicates, stateRefsPredicate)
}
if (updatedRows > 0) {
log.trace("Releasing $updatedRows soft locked states for $lockId and stateRefs $stateRefs")
}

View File

@ -40,10 +40,11 @@ class AppendOnlyPersistentMap<K, V, E, out EK>(
* Returns all key/value pairs from the underlying storage.
*/
fun allPersisted(): Sequence<Pair<K, V>> {
val criteriaQuery = DatabaseTransactionManager.current().session.criteriaBuilder.createQuery(persistentEntityClass)
val session = currentDBSession()
val criteriaQuery = session.criteriaBuilder.createQuery(persistentEntityClass)
val root = criteriaQuery.from(persistentEntityClass)
criteriaQuery.select(root)
val query = DatabaseTransactionManager.current().session.createQuery(criteriaQuery)
val query = session.createQuery(criteriaQuery)
val result = query.resultList
return result.map { x -> fromPersistentEntity(x) }.asSequence()
}
@ -87,7 +88,7 @@ class AppendOnlyPersistentMap<K, V, E, out EK>(
*/
operator fun set(key: K, value: V) =
set(key, value, logWarning = false) { k, v ->
DatabaseTransactionManager.current().session.save(toPersistentEntity(k, v))
currentDBSession().save(toPersistentEntity(k, v))
null
}
@ -98,9 +99,10 @@ class AppendOnlyPersistentMap<K, V, E, out EK>(
*/
fun addWithDuplicatesAllowed(key: K, value: V, logWarning: Boolean = true): Boolean =
set(key, value, logWarning) { k, v ->
val existingEntry = DatabaseTransactionManager.current().session.find(persistentEntityClass, toPersistentEntityKey(k))
val session = currentDBSession()
val existingEntry = session.find(persistentEntityClass, toPersistentEntityKey(k))
if (existingEntry == null) {
DatabaseTransactionManager.current().session.save(toPersistentEntity(k, v))
session.save(toPersistentEntity(k, v))
null
} else {
fromPersistentEntity(existingEntry).second
@ -114,7 +116,7 @@ class AppendOnlyPersistentMap<K, V, E, out EK>(
}
private fun loadValue(key: K): V? {
val result = DatabaseTransactionManager.current().session.find(persistentEntityClass, toPersistentEntityKey(key))
val result = currentDBSession().find(persistentEntityClass, toPersistentEntityKey(key))
return result?.let(fromPersistentEntity)?.second
}
@ -125,7 +127,7 @@ class AppendOnlyPersistentMap<K, V, E, out EK>(
* WARNING!! The method is not thread safe.
*/
fun clear() {
val session = DatabaseTransactionManager.current().session
val session = currentDBSession()
val deleteQuery = session.criteriaBuilder.createCriteriaDelete(persistentEntityClass)
deleteQuery.from(persistentEntityClass)
session.createQuery(deleteQuery).executeUpdate()

View File

@ -62,6 +62,7 @@ class DatabaseTransaction(isolation: Int, val threadLocal: ThreadLocal<DatabaseT
}
}
fun currentDBSession() = DatabaseTransactionManager.current().session
class DatabaseTransactionManager(initDataSource: CordaPersistence) {
companion object {
private val threadLocalDb = ThreadLocal<CordaPersistence>()

View File

@ -28,7 +28,7 @@ class PersistentMap<K, V, E, out EK>(
removalListener = ExplicitRemoval(toPersistentEntityKey, persistentEntityClass)
).apply {
//preload to allow all() to take data only from the cache (cache is unbound)
val session = DatabaseTransactionManager.current().session
val session = currentDBSession()
val criteriaQuery = session.criteriaBuilder.createQuery(persistentEntityClass)
criteriaQuery.select(criteriaQuery.from(persistentEntityClass))
getAll(session.createQuery(criteriaQuery).resultList.map { e -> fromPersistentEntity(e as E).first }.asIterable())
@ -38,7 +38,7 @@ class PersistentMap<K, V, E, out EK>(
override fun onRemoval(notification: RemovalNotification<K, V>?) {
when (notification?.cause) {
RemovalCause.EXPLICIT -> {
val session = DatabaseTransactionManager.current().session
val session = currentDBSession()
val elem = session.find(persistentEntityClass, toPersistentEntityKey(notification.key))
if (elem != null) {
session.remove(elem)
@ -101,7 +101,7 @@ class PersistentMap<K, V, E, out EK>(
set(key, value,
logWarning = false,
store = { k: K, v: V ->
DatabaseTransactionManager.current().session.save(toPersistentEntity(k, v))
currentDBSession().save(toPersistentEntity(k, v))
null
},
replace = { _: K, _: V -> Unit }
@ -115,9 +115,10 @@ class PersistentMap<K, V, E, out EK>(
fun addWithDuplicatesAllowed(key: K, value: V) =
set(key, value,
store = { k, v ->
val existingEntry = DatabaseTransactionManager.current().session.find(persistentEntityClass, toPersistentEntityKey(k))
val session = currentDBSession()
val existingEntry = session.find(persistentEntityClass, toPersistentEntityKey(k))
if (existingEntry == null) {
DatabaseTransactionManager.current().session.save(toPersistentEntity(k, v))
session.save(toPersistentEntity(k, v))
null
} else {
fromPersistentEntity(existingEntry).second
@ -145,18 +146,19 @@ class PersistentMap<K, V, E, out EK>(
}
private fun merge(key: K, value: V): V? {
val existingEntry = DatabaseTransactionManager.current().session.find(persistentEntityClass, toPersistentEntityKey(key))
val session = currentDBSession()
val existingEntry = session.find(persistentEntityClass, toPersistentEntityKey(key))
return if (existingEntry != null) {
DatabaseTransactionManager.current().session.merge(toPersistentEntity(key, value))
session.merge(toPersistentEntity(key, value))
fromPersistentEntity(existingEntry).second
} else {
DatabaseTransactionManager.current().session.save(toPersistentEntity(key, value))
session.save(toPersistentEntity(key, value))
null
}
}
private fun loadValue(key: K): V? {
val result = DatabaseTransactionManager.current().session.find(persistentEntityClass, toPersistentEntityKey(key))
val result = currentDBSession().find(persistentEntityClass, toPersistentEntityKey(key))
return result?.let(fromPersistentEntity)?.second
}
@ -256,7 +258,7 @@ class PersistentMap<K, V, E, out EK>(
}
fun load() {
val session = DatabaseTransactionManager.current().session
val session = currentDBSession()
val criteriaQuery = session.criteriaBuilder.createQuery(persistentEntityClass)
criteriaQuery.select(criteriaQuery.from(persistentEntityClass))
cache.getAll(session.createQuery(criteriaQuery).resultList.map { e -> fromPersistentEntity(e as E).first }.asIterable())

View File

@ -11,7 +11,6 @@ import net.corda.core.internal.write
import net.corda.core.internal.writeLines
import net.corda.node.services.transactions.PersistentUniquenessProvider
import net.corda.node.utilities.CordaPersistence
import net.corda.node.utilities.DatabaseTransactionManager
import net.corda.node.utilities.configureDatabase
import net.corda.testing.LogHelper
import net.corda.testing.node.MockServices.Companion.makeTestDataSourceProperties
@ -98,19 +97,18 @@ class NodeAttachmentStorageTest {
@Test
fun `corrupt entry throws exception`() {
val testJar = makeTestJar()
val id =
database.transaction {
val storage = NodeAttachmentService(MetricRegistry())
val id = testJar.read { storage.importAttachment(it) }
val id = database.transaction {
val storage = NodeAttachmentService(MetricRegistry())
val id = testJar.read { storage.importAttachment(it) }
// Corrupt the file in the store.
val bytes = testJar.readAll()
val corruptBytes = "arggghhhh".toByteArray()
System.arraycopy(corruptBytes, 0, bytes, 0, corruptBytes.size)
val corruptAttachment = NodeAttachmentService.DBAttachment(attId = id.toString(), content = bytes)
DatabaseTransactionManager.current().session.merge(corruptAttachment)
id
}
// Corrupt the file in the store.
val bytes = testJar.readAll()
val corruptBytes = "arggghhhh".toByteArray()
System.arraycopy(corruptBytes, 0, bytes, 0, corruptBytes.size)
val corruptAttachment = NodeAttachmentService.DBAttachment(attId = id.toString(), content = bytes)
session.merge(corruptAttachment)
id
}
database.transaction {
val storage = NodeAttachmentService(MetricRegistry())
val e = assertFailsWith<NodeAttachmentService.HashMismatchException> {