ENT-9943: Now use SecureHash to represent CordaX500Name in distributi… (#7501)

This commit is contained in:
Adel El-Beik 2023-09-29 09:03:00 +01:00 committed by GitHub
parent aec91c450e
commit 3a23f60199
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 66 additions and 60 deletions

View File

@ -355,14 +355,14 @@ class FinalityFlowTests : WithFinality {
val sdrs = getSenderRecoveryData(stx.id, aliceNode.database).apply { val sdrs = getSenderRecoveryData(stx.id, aliceNode.database).apply {
assertEquals(1, this.size) assertEquals(1, this.size)
assertEquals(StatesToRecord.ALL_VISIBLE, this[0].statesToRecord) assertEquals(StatesToRecord.ALL_VISIBLE, this[0].statesToRecord)
assertEquals(BOB_NAME.hashCode().toLong(), this[0].peerPartyId) assertEquals(SecureHash.sha256(BOB_NAME.toString()), this[0].peerPartyId)
} }
val rdr = getReceiverRecoveryData(stx.id, bobNode).apply { val rdr = getReceiverRecoveryData(stx.id, bobNode).apply {
assertNotNull(this) assertNotNull(this)
val hashedDL = HashedDistributionList.decrypt(this!!.encryptedDistributionList.bytes, aliceNode.internals.encryptionService) val hashedDL = HashedDistributionList.decrypt(this!!.encryptedDistributionList.bytes, aliceNode.internals.encryptionService)
assertEquals(StatesToRecord.ONLY_RELEVANT, hashedDL.senderStatesToRecord) assertEquals(StatesToRecord.ONLY_RELEVANT, hashedDL.senderStatesToRecord)
assertEquals(aliceNode.info.singleIdentity().name.hashCode().toLong(), this.initiatorPartyId) assertEquals(SecureHash.sha256(aliceNode.info.singleIdentity().name.toString()), this.initiatorPartyId)
assertEquals(mapOf(BOB_NAME.hashCode().toLong() to StatesToRecord.ALL_VISIBLE), hashedDL.peerHashToStatesToRecord) assertEquals(mapOf<SecureHash, StatesToRecord>(SecureHash.sha256(BOB_NAME.toString()) to StatesToRecord.ALL_VISIBLE), hashedDL.peerHashToStatesToRecord)
} }
validateSenderAndReceiverTimestamps(sdrs, rdr!!) validateSenderAndReceiverTimestamps(sdrs, rdr!!)
} }
@ -388,19 +388,19 @@ class FinalityFlowTests : WithFinality {
val sdrs = getSenderRecoveryData(stx.id, aliceNode.database).apply { val sdrs = getSenderRecoveryData(stx.id, aliceNode.database).apply {
assertEquals(2, this.size) assertEquals(2, this.size)
assertEquals(StatesToRecord.ONLY_RELEVANT, this[0].statesToRecord) assertEquals(StatesToRecord.ONLY_RELEVANT, this[0].statesToRecord)
assertEquals(BOB_NAME.hashCode().toLong(), this[0].peerPartyId) assertEquals(SecureHash.sha256(BOB_NAME.toString()), this[0].peerPartyId)
assertEquals(StatesToRecord.ALL_VISIBLE, this[1].statesToRecord) assertEquals(StatesToRecord.ALL_VISIBLE, this[1].statesToRecord)
assertEquals(CHARLIE_NAME.hashCode().toLong(), this[1].peerPartyId) assertEquals(SecureHash.sha256(CHARLIE_NAME.toString()), this[1].peerPartyId)
} }
val rdr = getReceiverRecoveryData(stx.id, bobNode).apply { val rdr = getReceiverRecoveryData(stx.id, bobNode).apply {
assertNotNull(this) assertNotNull(this)
val hashedDL = HashedDistributionList.decrypt(this!!.encryptedDistributionList.bytes, aliceNode.internals.encryptionService) val hashedDL = HashedDistributionList.decrypt(this!!.encryptedDistributionList.bytes, aliceNode.internals.encryptionService)
assertEquals(StatesToRecord.ONLY_RELEVANT, hashedDL.senderStatesToRecord) assertEquals(StatesToRecord.ONLY_RELEVANT, hashedDL.senderStatesToRecord)
assertEquals(aliceNode.info.singleIdentity().name.hashCode().toLong(), this.initiatorPartyId) assertEquals(SecureHash.sha256(aliceNode.info.singleIdentity().name.toString()), this.initiatorPartyId)
// note: Charlie assertion here is using the hinted StatesToRecord value passed to it from Alice // note: Charlie assertion here is using the hinted StatesToRecord value passed to it from Alice
assertEquals(mapOf( assertEquals(mapOf<SecureHash, StatesToRecord>(
BOB_NAME.hashCode().toLong() to StatesToRecord.ONLY_RELEVANT, SecureHash.sha256(BOB_NAME.toString()) to StatesToRecord.ONLY_RELEVANT,
CHARLIE_NAME.hashCode().toLong() to StatesToRecord.ALL_VISIBLE SecureHash.sha256(CHARLIE_NAME.toString()) to StatesToRecord.ALL_VISIBLE
), hashedDL.peerHashToStatesToRecord) ), hashedDL.peerHashToStatesToRecord)
} }
validateSenderAndReceiverTimestamps(sdrs, rdr!!) validateSenderAndReceiverTimestamps(sdrs, rdr!!)
@ -452,14 +452,14 @@ class FinalityFlowTests : WithFinality {
val sdr = getSenderRecoveryData(stx.id, aliceNode.database).apply { val sdr = getSenderRecoveryData(stx.id, aliceNode.database).apply {
assertEquals(1, this.size) assertEquals(1, this.size)
assertEquals(StatesToRecord.ONLY_RELEVANT, this[0].statesToRecord) assertEquals(StatesToRecord.ONLY_RELEVANT, this[0].statesToRecord)
assertEquals(BOB_NAME.hashCode().toLong(), this[0].peerPartyId) assertEquals(SecureHash.sha256(BOB_NAME.toString()), this[0].peerPartyId)
} }
val rdr = getReceiverRecoveryData(stx.id, bobNode).apply { val rdr = getReceiverRecoveryData(stx.id, bobNode).apply {
assertNotNull(this) assertNotNull(this)
val hashedDL = HashedDistributionList.decrypt(this!!.encryptedDistributionList.bytes, aliceNode.internals.encryptionService) val hashedDL = HashedDistributionList.decrypt(this!!.encryptedDistributionList.bytes, aliceNode.internals.encryptionService)
assertEquals(StatesToRecord.ONLY_RELEVANT, hashedDL.senderStatesToRecord) assertEquals(StatesToRecord.ONLY_RELEVANT, hashedDL.senderStatesToRecord)
assertEquals(aliceNode.info.singleIdentity().name.hashCode().toLong(), this.initiatorPartyId) assertEquals(SecureHash.sha256(aliceNode.info.singleIdentity().name.toString()), this.initiatorPartyId)
assertEquals(mapOf(BOB_NAME.hashCode().toLong() to StatesToRecord.ONLY_RELEVANT), hashedDL.peerHashToStatesToRecord) assertEquals(mapOf<SecureHash, StatesToRecord>(SecureHash.sha256(BOB_NAME.toString()) to StatesToRecord.ONLY_RELEVANT), hashedDL.peerHashToStatesToRecord)
} }
validateSenderAndReceiverTimestamps(sdr, rdr!!) validateSenderAndReceiverTimestamps(sdr, rdr!!)
} }

View File

@ -1,5 +1,6 @@
package net.corda.node.services.network package net.corda.node.services.network
import net.corda.core.crypto.SecureHash
import net.corda.core.node.NodeInfo import net.corda.core.node.NodeInfo
import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.NetworkHostAndPort
import net.corda.node.services.identity.InMemoryIdentityService import net.corda.node.services.identity.InMemoryIdentityService
@ -35,9 +36,9 @@ class PersistentPartyInfoCacheTest {
createNodeInfo(listOf(CHARLIE)))) createNodeInfo(listOf(CHARLIE))))
val partyInfoCache = PersistentPartyInfoCache(charlieNetMapCache, TestingNamedCacheFactory(), database) val partyInfoCache = PersistentPartyInfoCache(charlieNetMapCache, TestingNamedCacheFactory(), database)
partyInfoCache.start() partyInfoCache.start()
assertThat(partyInfoCache.getPartyIdByCordaX500Name(ALICE.name)).isEqualTo(ALICE.name.hashCode().toLong()) assertThat(partyInfoCache.getPartyIdByCordaX500Name(ALICE.name)).isEqualTo(SecureHash.sha256(ALICE.name.toString()))
assertThat(partyInfoCache.getPartyIdByCordaX500Name(BOB.name)).isEqualTo(BOB.name.hashCode().toLong()) assertThat(partyInfoCache.getPartyIdByCordaX500Name(BOB.name)).isEqualTo(SecureHash.sha256(BOB.name.toString()))
assertThat(partyInfoCache.getPartyIdByCordaX500Name(CHARLIE.name)).isEqualTo(CHARLIE.name.hashCode().toLong()) assertThat(partyInfoCache.getPartyIdByCordaX500Name(CHARLIE.name)).isEqualTo(SecureHash.sha256(CHARLIE.name.toString()))
} }
@Test(timeout=300_000) @Test(timeout=300_000)
@ -50,9 +51,9 @@ class PersistentPartyInfoCacheTest {
// clear network map cache & bootstrap another PersistentInfoCache // clear network map cache & bootstrap another PersistentInfoCache
charlieNetMapCache.clearNetworkMapCache() charlieNetMapCache.clearNetworkMapCache()
val partyInfoCache = PersistentPartyInfoCache(charlieNetMapCache, TestingNamedCacheFactory(), database) val partyInfoCache = PersistentPartyInfoCache(charlieNetMapCache, TestingNamedCacheFactory(), database)
assertThat(partyInfoCache.getPartyIdByCordaX500Name(ALICE.name)).isEqualTo(ALICE.name.hashCode().toLong()) assertThat(partyInfoCache.getPartyIdByCordaX500Name(ALICE.name)).isEqualTo(SecureHash.sha256(ALICE.name.toString()))
assertThat(partyInfoCache.getPartyIdByCordaX500Name(BOB.name)).isEqualTo(BOB.name.hashCode().toLong()) assertThat(partyInfoCache.getPartyIdByCordaX500Name(BOB.name)).isEqualTo(SecureHash.sha256(BOB.name.toString()))
assertThat(partyInfoCache.getPartyIdByCordaX500Name(CHARLIE.name)).isEqualTo(CHARLIE.name.hashCode().toLong()) assertThat(partyInfoCache.getPartyIdByCordaX500Name(CHARLIE.name)).isEqualTo(SecureHash.sha256(CHARLIE.name.toString()))
} }
@Test(timeout=300_000) @Test(timeout=300_000)
@ -63,9 +64,9 @@ class PersistentPartyInfoCacheTest {
createNodeInfo(listOf(CHARLIE)))) createNodeInfo(listOf(CHARLIE))))
val partyInfoCache = PersistentPartyInfoCache(charlieNetMapCache, TestingNamedCacheFactory(), database) val partyInfoCache = PersistentPartyInfoCache(charlieNetMapCache, TestingNamedCacheFactory(), database)
partyInfoCache.start() partyInfoCache.start()
assertThat(partyInfoCache.getCordaX500NameByPartyId(ALICE.name.hashCode().toLong())).isEqualTo(ALICE.name) assertThat(partyInfoCache.getCordaX500NameByPartyId(SecureHash.sha256(ALICE.name.toString()))).isEqualTo(ALICE.name)
assertThat(partyInfoCache.getCordaX500NameByPartyId(BOB.name.hashCode().toLong())).isEqualTo(BOB.name) assertThat(partyInfoCache.getCordaX500NameByPartyId(SecureHash.sha256(BOB.name.toString()))).isEqualTo(BOB.name)
assertThat(partyInfoCache.getCordaX500NameByPartyId(CHARLIE.name.hashCode().toLong())).isEqualTo(CHARLIE.name) assertThat(partyInfoCache.getCordaX500NameByPartyId(SecureHash.sha256(CHARLIE.name.toString()))).isEqualTo(CHARLIE.name)
} }
@Test(timeout=300_000) @Test(timeout=300_000)
@ -78,9 +79,9 @@ class PersistentPartyInfoCacheTest {
// clear network map cache & bootstrap another PersistentInfoCache // clear network map cache & bootstrap another PersistentInfoCache
charlieNetMapCache.clearNetworkMapCache() charlieNetMapCache.clearNetworkMapCache()
val partyInfoCache = PersistentPartyInfoCache(charlieNetMapCache, TestingNamedCacheFactory(), database) val partyInfoCache = PersistentPartyInfoCache(charlieNetMapCache, TestingNamedCacheFactory(), database)
assertThat(partyInfoCache.getCordaX500NameByPartyId(ALICE.name.hashCode().toLong())).isEqualTo(ALICE.name) assertThat(partyInfoCache.getCordaX500NameByPartyId(SecureHash.sha256(ALICE.name.toString()))).isEqualTo(ALICE.name)
assertThat(partyInfoCache.getCordaX500NameByPartyId(BOB.name.hashCode().toLong())).isEqualTo(BOB.name) assertThat(partyInfoCache.getCordaX500NameByPartyId(SecureHash.sha256(BOB.name.toString()))).isEqualTo(BOB.name)
assertThat(partyInfoCache.getCordaX500NameByPartyId(CHARLIE.name.hashCode().toLong())).isEqualTo(CHARLIE.name) assertThat(partyInfoCache.getCordaX500NameByPartyId(SecureHash.sha256(CHARLIE.name.toString()))).isEqualTo(CHARLIE.name)
} }
private fun createNodeInfo(identities: List<TestIdentity>, private fun createNodeInfo(identities: List<TestIdentity>,

View File

@ -1,5 +1,6 @@
package net.corda.node.services.network package net.corda.node.services.network
import net.corda.core.crypto.SecureHash
import net.corda.core.identity.CordaX500Name import net.corda.core.identity.CordaX500Name
import net.corda.core.internal.NamedCacheFactory import net.corda.core.internal.NamedCacheFactory
import net.corda.core.node.services.NetworkMapCache import net.corda.core.node.services.NetworkMapCache
@ -14,13 +15,13 @@ class PersistentPartyInfoCache(private val networkMapCache: PersistentNetworkMap
private val database: CordaPersistence) { private val database: CordaPersistence) {
// probably better off using a BiMap here: https://www.baeldung.com/guava-bimap // probably better off using a BiMap here: https://www.baeldung.com/guava-bimap
private val cordaX500NameToPartyIdCache = NonInvalidatingCache<CordaX500Name, Long?>( private val cordaX500NameToPartyIdCache = NonInvalidatingCache<CordaX500Name, SecureHash?>(
cacheFactory = cacheFactory, cacheFactory = cacheFactory,
name = "RecoveryPartyInfoCache_byCordaX500Name") { key -> name = "RecoveryPartyInfoCache_byCordaX500Name") { key ->
database.transaction { queryByCordaX500Name(session, key) } database.transaction { queryByCordaX500Name(session, key) }
} }
private val partyIdToCordaX500NameCache = NonInvalidatingCache<Long, CordaX500Name?>( private val partyIdToCordaX500NameCache = NonInvalidatingCache<SecureHash, CordaX500Name?>(
cacheFactory = cacheFactory, cacheFactory = cacheFactory,
name = "RecoveryPartyInfoCache_byPartyId") { key -> name = "RecoveryPartyInfoCache_byPartyId") { key ->
database.transaction { queryByPartyId(session, key) } database.transaction { queryByPartyId(session, key) }
@ -32,48 +33,48 @@ class PersistentPartyInfoCache(private val networkMapCache: PersistentNetworkMap
val (snapshot, updates) = networkMapCache.track() val (snapshot, updates) = networkMapCache.track()
snapshot.map { entry -> snapshot.map { entry ->
entry.legalIdentities.map { party -> entry.legalIdentities.map { party ->
add(party.name.hashCode().toLong(), party.name) add(SecureHash.sha256(party.name.toString()), party.name)
} }
} }
trackNetworkMapUpdates = updates trackNetworkMapUpdates = updates
trackNetworkMapUpdates.cache().forEach { nodeInfo -> trackNetworkMapUpdates.cache().forEach { nodeInfo ->
nodeInfo.node.legalIdentities.map { party -> nodeInfo.node.legalIdentities.map { party ->
add(party.name.hashCode().toLong(), party.name) add(SecureHash.sha256(party.name.toString()), party.name)
} }
} }
} }
fun getPartyIdByCordaX500Name(name: CordaX500Name): Long = cordaX500NameToPartyIdCache[name] ?: throw IllegalStateException("Missing cache entry for $name") fun getPartyIdByCordaX500Name(name: CordaX500Name): SecureHash = cordaX500NameToPartyIdCache[name] ?: throw IllegalStateException("Missing cache entry for $name")
fun getCordaX500NameByPartyId(partyId: Long): CordaX500Name = partyIdToCordaX500NameCache[partyId] ?: throw IllegalStateException("Missing cache entry for $partyId") fun getCordaX500NameByPartyId(partyId: SecureHash): CordaX500Name = partyIdToCordaX500NameCache[partyId] ?: throw IllegalStateException("Missing cache entry for $partyId")
private fun add(partyHashCode: Long, partyName: CordaX500Name) { private fun add(partyHashCode: SecureHash, partyName: CordaX500Name) {
partyIdToCordaX500NameCache.cache.put(partyHashCode, partyName) partyIdToCordaX500NameCache.cache.put(partyHashCode, partyName)
cordaX500NameToPartyIdCache.cache.put(partyName, partyHashCode) cordaX500NameToPartyIdCache.cache.put(partyName, partyHashCode)
updateInfoDB(partyHashCode, partyName) updateInfoDB(partyHashCode, partyName)
} }
private fun updateInfoDB(partyHashCode: Long, partyName: CordaX500Name) { private fun updateInfoDB(partyHashCode: SecureHash, partyName: CordaX500Name) {
database.transaction { database.transaction {
if (queryByPartyId(session, partyHashCode) == null) { if (queryByPartyId(session, partyHashCode) == null) {
session.save(DBTransactionStorageLedgerRecovery.DBRecoveryPartyInfo(partyHashCode, partyName.toString())) session.save(DBTransactionStorageLedgerRecovery.DBRecoveryPartyInfo(partyHashCode.toString(), partyName.toString()))
} }
} }
} }
private fun queryByCordaX500Name(session: Session, key: CordaX500Name): Long? { private fun queryByCordaX500Name(session: Session, key: CordaX500Name): SecureHash? {
val query = session.createQuery( val query = session.createQuery(
"FROM ${DBTransactionStorageLedgerRecovery.DBRecoveryPartyInfo::class.java.name} WHERE partyName = :partyName", "FROM ${DBTransactionStorageLedgerRecovery.DBRecoveryPartyInfo::class.java.name} WHERE partyName = :partyName",
DBTransactionStorageLedgerRecovery.DBRecoveryPartyInfo::class.java) DBTransactionStorageLedgerRecovery.DBRecoveryPartyInfo::class.java)
query.setParameter("partyName", key.toString()) query.setParameter("partyName", key.toString())
return query.resultList.singleOrNull()?.partyId return query.resultList.singleOrNull()?.let { SecureHash.parse(it.partyId) }
} }
private fun queryByPartyId(session: Session, key: Long): CordaX500Name? { private fun queryByPartyId(session: Session, key: SecureHash): CordaX500Name? {
val query = session.createQuery( val query = session.createQuery(
"FROM ${DBTransactionStorageLedgerRecovery.DBRecoveryPartyInfo::class.java.name} WHERE partyId = :partyId", "FROM ${DBTransactionStorageLedgerRecovery.DBRecoveryPartyInfo::class.java.name} WHERE partyId = :partyId",
DBTransactionStorageLedgerRecovery.DBRecoveryPartyInfo::class.java) DBTransactionStorageLedgerRecovery.DBRecoveryPartyInfo::class.java)
query.setParameter("partyId", key) query.setParameter("partyId", key.toString())
return query.resultList.singleOrNull()?.partyName?.let { CordaX500Name.parse(it) } return query.resultList.singleOrNull()?.partyName?.let { CordaX500Name.parse(it) }
} }
} }

View File

@ -39,8 +39,8 @@ class DBTransactionStorageLedgerRecovery(private val database: CordaPersistence,
@Immutable @Immutable
data class PersistentKey( data class PersistentKey(
/** PartyId of flow peer **/ /** PartyId of flow peer **/
@Column(name = "peer_party_id", nullable = false) @Column(name = "peer_party_id", length = 144, nullable = false)
var peerPartyId: Long, var peerPartyId: String,
@Column(name = "timestamp", nullable = false) @Column(name = "timestamp", nullable = false)
var timestamp: Instant, var timestamp: Instant,
@ -49,7 +49,7 @@ class DBTransactionStorageLedgerRecovery(private val database: CordaPersistence,
var timestampDiscriminator: Int var timestampDiscriminator: Int
) : Serializable { ) : Serializable {
constructor(key: Key) : this(key.partyId, key.timestamp, key.timestampDiscriminator) constructor(key: Key) : this(key.partyId.toString(), key.timestamp, key.timestampDiscriminator)
} }
@CordaSerializable @CordaSerializable
@ -69,7 +69,7 @@ class DBTransactionStorageLedgerRecovery(private val database: CordaPersistence,
fun toSenderDistributionRecord() = fun toSenderDistributionRecord() =
SenderDistributionRecord( SenderDistributionRecord(
SecureHash.parse(this.txId), SecureHash.parse(this.txId),
this.compositeKey.peerPartyId, SecureHash.parse(this.compositeKey.peerPartyId),
this.statesToRecord, this.statesToRecord,
this.compositeKey.timestamp this.compositeKey.timestamp
) )
@ -104,7 +104,7 @@ class DBTransactionStorageLedgerRecovery(private val database: CordaPersistence,
fun toReceiverDistributionRecord(): ReceiverDistributionRecord { fun toReceiverDistributionRecord(): ReceiverDistributionRecord {
return ReceiverDistributionRecord( return ReceiverDistributionRecord(
SecureHash.parse(this.txId), SecureHash.parse(this.txId),
this.compositeKey.peerPartyId, SecureHash.parse(this.compositeKey.peerPartyId),
OpaqueBytes(this.distributionList), OpaqueBytes(this.distributionList),
this.compositeKey.timestamp this.compositeKey.timestamp
) )
@ -116,8 +116,8 @@ class DBTransactionStorageLedgerRecovery(private val database: CordaPersistence,
data class DBRecoveryPartyInfo( data class DBRecoveryPartyInfo(
@Id @Id
/** CordaX500Name hashCode() **/ /** CordaX500Name hashCode() **/
@Column(name = "party_id", nullable = false) @Column(name = "party_id", length = 144, nullable = false)
var partyId: Long, var partyId: String,
/** CordaX500Name of party **/ /** CordaX500Name of party **/
@Column(name = "party_name", nullable = false) @Column(name = "party_name", nullable = false)
@ -127,11 +127,11 @@ class DBTransactionStorageLedgerRecovery(private val database: CordaPersistence,
data class TimestampKey(val timestamp: Instant, val timestampDiscriminator: Int) data class TimestampKey(val timestamp: Instant, val timestampDiscriminator: Int)
class Key( class Key(
val partyId: Long, val partyId: SecureHash,
val timestamp: Instant, val timestamp: Instant,
val timestampDiscriminator: Int = nextDiscriminatorNumber.andIncrement val timestampDiscriminator: Int = nextDiscriminatorNumber.andIncrement
) { ) {
constructor(key: TimestampKey, partyId: Long): this(partyId, key.timestamp, key.timestampDiscriminator) constructor(key: TimestampKey, partyId: SecureHash): this(partyId, key.timestamp, key.timestampDiscriminator)
companion object { companion object {
val nextDiscriminatorNumber = AtomicInteger() val nextDiscriminatorNumber = AtomicInteger()
} }
@ -237,7 +237,7 @@ class DBTransactionStorageLedgerRecovery(private val database: CordaPersistence,
excludingTxnIds.map { it.toString() })))) excludingTxnIds.map { it.toString() }))))
} }
if (peers.isNotEmpty()) { if (peers.isNotEmpty()) {
val peerPartyIds = peers.map { partyInfoCache.getPartyIdByCordaX500Name(it) } val peerPartyIds = peers.map { partyInfoCache.getPartyIdByCordaX500Name(it).toString() }
predicates.add(criteriaBuilder.and(compositeKey.get<Long>(PersistentKey::peerPartyId.name).`in`(peerPartyIds))) predicates.add(criteriaBuilder.and(compositeKey.get<Long>(PersistentKey::peerPartyId.name).`in`(peerPartyIds)))
} }
criteriaQuery.where(*predicates.toTypedArray()) criteriaQuery.where(*predicates.toTypedArray())
@ -275,7 +275,7 @@ class DBTransactionStorageLedgerRecovery(private val database: CordaPersistence,
predicates.add(criteriaBuilder.and(criteriaBuilder.not(txId.`in`(excludingTxnIds.map { it.toString() })))) predicates.add(criteriaBuilder.and(criteriaBuilder.not(txId.`in`(excludingTxnIds.map { it.toString() }))))
} }
if (initiators.isNotEmpty()) { if (initiators.isNotEmpty()) {
val initiatorPartyIds = initiators.map(partyInfoCache::getPartyIdByCordaX500Name) val initiatorPartyIds = initiators.map { partyInfoCache.getPartyIdByCordaX500Name(it).toString() }
predicates.add(criteriaBuilder.and(compositeKey.get<Long>(PersistentKey::peerPartyId.name).`in`(initiatorPartyIds))) predicates.add(criteriaBuilder.and(compositeKey.get<Long>(PersistentKey::peerPartyId.name).`in`(initiatorPartyIds)))
} }
criteriaQuery.where(*predicates.toTypedArray()) criteriaQuery.where(*predicates.toTypedArray())
@ -319,7 +319,7 @@ abstract class DistributionRecord {
@CordaSerializable @CordaSerializable
data class SenderDistributionRecord( data class SenderDistributionRecord(
override val txId: SecureHash, override val txId: SecureHash,
val peerPartyId: Long, // CordaX500Name hashCode() val peerPartyId: SecureHash, // CordaX500Name hashCode()
val statesToRecord: StatesToRecord, val statesToRecord: StatesToRecord,
override val timestamp: Instant override val timestamp: Instant
) : DistributionRecord() ) : DistributionRecord()
@ -327,7 +327,7 @@ data class SenderDistributionRecord(
@CordaSerializable @CordaSerializable
data class ReceiverDistributionRecord( data class ReceiverDistributionRecord(
override val txId: SecureHash, override val txId: SecureHash,
val initiatorPartyId: Long, // CordaX500Name hashCode() val initiatorPartyId: SecureHash, // CordaX500Name hashCode()
val encryptedDistributionList: OpaqueBytes, val encryptedDistributionList: OpaqueBytes,
override val timestamp: Instant override val timestamp: Instant
) : DistributionRecord() ) : DistributionRecord()

View File

@ -1,5 +1,6 @@
package net.corda.node.services.persistence package net.corda.node.services.persistence
import net.corda.core.crypto.SecureHash
import net.corda.core.node.StatesToRecord import net.corda.core.node.StatesToRecord
import net.corda.core.serialization.CordaSerializable import net.corda.core.serialization.CordaSerializable
import net.corda.node.services.EncryptionService import net.corda.node.services.EncryptionService
@ -13,7 +14,7 @@ import java.time.Instant
@CordaSerializable @CordaSerializable
data class HashedDistributionList( data class HashedDistributionList(
val senderStatesToRecord: StatesToRecord, val senderStatesToRecord: StatesToRecord,
val peerHashToStatesToRecord: Map<Long, StatesToRecord>, val peerHashToStatesToRecord: Map<SecureHash, StatesToRecord>,
val publicHeader: PublicHeader val publicHeader: PublicHeader
) { ) {
/** /**
@ -28,7 +29,7 @@ data class HashedDistributionList(
out.writeByte(senderStatesToRecord.ordinal) out.writeByte(senderStatesToRecord.ordinal)
out.writeInt(peerHashToStatesToRecord.size) out.writeInt(peerHashToStatesToRecord.size)
for (entry in peerHashToStatesToRecord) { for (entry in peerHashToStatesToRecord) {
out.writeLong(entry.key) entry.key.writeTo(out)
out.writeByte(entry.value.ordinal) out.writeByte(entry.value.ordinal)
} }
return encryptionService.encrypt(baos.toByteArray(), publicHeader.serialise()) return encryptionService.encrypt(baos.toByteArray(), publicHeader.serialise())
@ -78,6 +79,7 @@ data class HashedDistributionList(
// The version tag is serialised in the header, even though it is separate from the encrypted main body of the distribution list. // The version tag is serialised in the header, even though it is separate from the encrypted main body of the distribution list.
// This is because the header and the dist list are cryptographically coupled and we want to avoid declaring the version field twice. // This is because the header and the dist list are cryptographically coupled and we want to avoid declaring the version field twice.
private const val VERSION_TAG = 1 private const val VERSION_TAG = 1
private const val SECURE_HASH_LENGTH = 32
private val statesToRecordValues = StatesToRecord.values() // Cache the enum values since .values() returns a new array each time. private val statesToRecordValues = StatesToRecord.values() // Cache the enum values since .values() returns a new array each time.
/** /**
@ -91,9 +93,11 @@ data class HashedDistributionList(
try { try {
val senderStatesToRecord = statesToRecordValues[input.readByte().toInt()] val senderStatesToRecord = statesToRecordValues[input.readByte().toInt()]
val numPeerHashToStatesToRecords = input.readInt() val numPeerHashToStatesToRecords = input.readInt()
val peerHashToStatesToRecord = mutableMapOf<Long, StatesToRecord>() val peerHashToStatesToRecord = mutableMapOf<SecureHash, StatesToRecord>()
repeat(numPeerHashToStatesToRecords) { repeat(numPeerHashToStatesToRecords) {
peerHashToStatesToRecord[input.readLong()] = statesToRecordValues[input.readByte().toInt()] val secureHashBytes = ByteArray(SECURE_HASH_LENGTH)
input.readFully(secureHashBytes)
peerHashToStatesToRecord[SecureHash.createSHA256(secureHashBytes)] = statesToRecordValues[input.readByte().toInt()]
} }
return HashedDistributionList(senderStatesToRecord, peerHashToStatesToRecord, publicHeader) return HashedDistributionList(senderStatesToRecord, peerHashToStatesToRecord, publicHeader)
} catch (e: Exception) { } catch (e: Exception) {

View File

@ -21,7 +21,7 @@
<column name="transaction_id" type="NVARCHAR(144)"> <column name="transaction_id" type="NVARCHAR(144)">
<constraints nullable="false"/> <constraints nullable="false"/>
</column> </column>
<column name="peer_party_id" type="BIGINT"> <column name="peer_party_id" type="NVARCHAR(144)">
<constraints nullable="false"/> <constraints nullable="false"/>
</column> </column>
<column name="states_to_record" type="INT"> <column name="states_to_record" type="INT">
@ -46,7 +46,7 @@
<column name="transaction_id" type="NVARCHAR(144)"> <column name="transaction_id" type="NVARCHAR(144)">
<constraints nullable="false"/> <constraints nullable="false"/>
</column> </column>
<column name="peer_party_id" type="BIGINT"> <column name="peer_party_id" type="NVARCHAR(144)">
<constraints nullable="false"/> <constraints nullable="false"/>
</column> </column>
<column name="distribution_list" type="BLOB"> <column name="distribution_list" type="BLOB">
@ -65,7 +65,7 @@
<changeSet author="R3.Corda" id="create_recovery_party_info_table"> <changeSet author="R3.Corda" id="create_recovery_party_info_table">
<createTable tableName="node_recovery_party_info"> <createTable tableName="node_recovery_party_info">
<column name="party_id" type="BIGINT"> <column name="party_id" type="NVARCHAR(144)">
<constraints nullable="false"/> <constraints nullable="false"/>
</column> </column>
<column name="party_name" type="NVARCHAR(255)"> <column name="party_name" type="NVARCHAR(255)">

View File

@ -147,12 +147,12 @@ class DBTransactionStorageLedgerRecoveryTests {
val timeWindow = RecoveryTimeWindow(fromTime = now().minus(1, ChronoUnit.DAYS)) val timeWindow = RecoveryTimeWindow(fromTime = now().minus(1, ChronoUnit.DAYS))
transactionRecovery.queryDistributionRecords(timeWindow, recordType = DistributionRecordType.SENDER).let { transactionRecovery.queryDistributionRecords(timeWindow, recordType = DistributionRecordType.SENDER).let {
assertEquals(2, it.size) assertEquals(2, it.size)
assertEquals(BOB_NAME.hashCode().toLong(), it.senderRecords[0].compositeKey.peerPartyId) assertEquals(SecureHash.sha256(BOB_NAME.toString()).toString(), it.senderRecords[0].compositeKey.peerPartyId)
assertEquals(ALL_VISIBLE, it.senderRecords[0].statesToRecord) assertEquals(ALL_VISIBLE, it.senderRecords[0].statesToRecord)
} }
transactionRecovery.queryDistributionRecords(timeWindow, recordType = DistributionRecordType.RECEIVER).let { transactionRecovery.queryDistributionRecords(timeWindow, recordType = DistributionRecordType.RECEIVER).let {
assertEquals(1, it.size) assertEquals(1, it.size)
assertEquals(BOB_NAME.hashCode().toLong(), it.receiverRecords[0].compositeKey.peerPartyId) assertEquals(SecureHash.sha256(BOB_NAME.toString()).toString(), it.receiverRecords[0].compositeKey.peerPartyId)
assertEquals(ALL_VISIBLE, (HashedDistributionList.decrypt(it.receiverRecords[0].distributionList, encryptionService)).peerHashToStatesToRecord.map { it.value }[0]) assertEquals(ALL_VISIBLE, (HashedDistributionList.decrypt(it.receiverRecords[0].distributionList, encryptionService)).peerHashToStatesToRecord.map { it.value }[0])
} }
val resultsAll = transactionRecovery.queryDistributionRecords(timeWindow, recordType = DistributionRecordType.ALL) val resultsAll = transactionRecovery.queryDistributionRecords(timeWindow, recordType = DistributionRecordType.ALL)
@ -319,7 +319,7 @@ class DBTransactionStorageLedgerRecoveryTests {
fun `test lightweight serialization and deserialization of hashed distribution list payload`() { fun `test lightweight serialization and deserialization of hashed distribution list payload`() {
val hashedDistList = HashedDistributionList( val hashedDistList = HashedDistributionList(
ALL_VISIBLE, ALL_VISIBLE,
mapOf(BOB.name.hashCode().toLong() to NONE, CHARLIE_NAME.hashCode().toLong() to ONLY_RELEVANT), mapOf(SecureHash.sha256(BOB.name.toString()) to NONE, SecureHash.sha256(CHARLIE_NAME.toString()) to ONLY_RELEVANT),
HashedDistributionList.PublicHeader(now()) HashedDistributionList.PublicHeader(now())
) )
val roundtrip = HashedDistributionList.decrypt(hashedDistList.encrypt(encryptionService), encryptionService) val roundtrip = HashedDistributionList.decrypt(hashedDistList.encrypt(encryptionService), encryptionService)