diff --git a/core-tests/src/test/kotlin/net/corda/coretests/flows/FinalityFlowTests.kt b/core-tests/src/test/kotlin/net/corda/coretests/flows/FinalityFlowTests.kt index fefbac09c1..e34b0a4c2b 100644 --- a/core-tests/src/test/kotlin/net/corda/coretests/flows/FinalityFlowTests.kt +++ b/core-tests/src/test/kotlin/net/corda/coretests/flows/FinalityFlowTests.kt @@ -354,12 +354,11 @@ class FinalityFlowTests : WithFinality { getSenderRecoveryData(stx.id, aliceNode.database).apply { assertEquals(1, this.size) - assertEquals(StatesToRecord.ONLY_RELEVANT, this[0].statesToRecord) + assertEquals(StatesToRecord.ALL_VISIBLE, this[0].statesToRecord) assertEquals(BOB_NAME.hashCode().toLong(), this[0].peerPartyId) } getReceiverRecoveryData(stx.id, bobNode.database).apply { - assertEquals(StatesToRecord.ALL_VISIBLE, this?.statesToRecord) - assertEquals(StatesToRecord.ONLY_RELEVANT, this?.senderStatesToRecord) + assertEquals(StatesToRecord.ONLY_RELEVANT, this?.statesToRecord) assertEquals(aliceNode.info.singleIdentity().name.hashCode().toLong(), this?.initiatorPartyId) assertEquals(mapOf(BOB_NAME.hashCode().toLong() to StatesToRecord.ALL_VISIBLE), this?.peersToStatesToRecord) } @@ -387,12 +386,10 @@ class FinalityFlowTests : WithFinality { assertEquals(2, this.size) assertEquals(StatesToRecord.ONLY_RELEVANT, this[0].statesToRecord) assertEquals(BOB_NAME.hashCode().toLong(), this[0].peerPartyId) - assertEquals(StatesToRecord.ONLY_RELEVANT, this[1].statesToRecord) + assertEquals(StatesToRecord.ALL_VISIBLE, this[1].statesToRecord) assertEquals(CHARLIE_NAME.hashCode().toLong(), this[1].peerPartyId) } getReceiverRecoveryData(stx.id, bobNode.database).apply { - assertEquals(StatesToRecord.ONLY_RELEVANT, this?.statesToRecord) - assertEquals(StatesToRecord.ONLY_RELEVANT, this?.senderStatesToRecord) assertEquals(aliceNode.info.singleIdentity().name.hashCode().toLong(), this?.initiatorPartyId) // note: Charlie assertion here is using the hinted StatesToRecord value passed to it from Alice assertEquals(mapOf(BOB_NAME.hashCode().toLong() to StatesToRecord.ONLY_RELEVANT, @@ -434,8 +431,6 @@ class FinalityFlowTests : WithFinality { assertEquals(BOB_NAME.hashCode().toLong(), this[0].peerPartyId) } getReceiverRecoveryData(stx.id, bobNode.database).apply { - assertEquals(StatesToRecord.ONLY_RELEVANT, this?.statesToRecord) - assertEquals(StatesToRecord.ONLY_RELEVANT, this?.senderStatesToRecord) assertEquals(aliceNode.info.singleIdentity().name.hashCode().toLong(), this?.initiatorPartyId) assertEquals(mapOf(BOB_NAME.hashCode().toLong() to StatesToRecord.ONLY_RELEVANT), this?.peersToStatesToRecord) } diff --git a/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorageLedgerRecovery.kt b/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorageLedgerRecovery.kt index 8548dc0d80..6a064f6bc8 100644 --- a/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorageLedgerRecovery.kt +++ b/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorageLedgerRecovery.kt @@ -5,6 +5,7 @@ import net.corda.core.flows.RecoveryTimeWindow import net.corda.core.flows.TransactionMetadata import net.corda.core.identity.CordaX500Name import net.corda.core.internal.NamedCacheFactory +import net.corda.core.internal.VisibleForTesting import net.corda.core.node.StatesToRecord import net.corda.core.node.services.vault.Sort import net.corda.core.serialization.CordaSerializable @@ -20,7 +21,7 @@ import java.io.DataInputStream import java.io.DataOutputStream import java.io.Serializable import java.time.Instant -import java.util.concurrent.atomic.AtomicLong +import java.util.concurrent.atomic.AtomicInteger import javax.persistence.Column import javax.persistence.Embeddable import javax.persistence.EmbeddedId @@ -33,20 +34,26 @@ import kotlin.streams.toList class DBTransactionStorageLedgerRecovery(private val database: CordaPersistence, cacheFactory: NamedCacheFactory, val clock: CordaClock, - private val cryptoService: CryptoService, + val cryptoService: CryptoService, private val partyInfoCache: PersistentPartyInfoCache) : DBTransactionStorage(database, cacheFactory, clock) { @Embeddable @Immutable data class PersistentKey( - @Column(name = "sequence_number", nullable = false) - var sequenceNumber: Long, + /** PartyId of flow peer **/ + @Column(name = "peer_party_id", nullable = false) + var peerPartyId: Long, @Column(name = "timestamp", nullable = false) - var timestamp: Instant + var timestamp: Instant, + + @Column(name = "timestamp_discriminator", nullable = false) + var timestampDiscriminator: Int + ) : Serializable { - constructor(key: Key) : this(key.sequenceNumber, key.timestamp) + constructor(key: Key) : this(key.partyId, key.timestamp, key.timestampDiscriminator) } + @CordaSerializable @Entity @Table(name = "${NODE_DATABASE_PREFIX}sender_distribution_records") data class DBSenderDistributionRecord( @@ -56,10 +63,6 @@ class DBTransactionStorageLedgerRecovery(private val database: CordaPersistence, @Column(name = "transaction_id", length = 144, nullable = false) var txId: String, - /** PartyId of flow peer **/ - @Column(name = "receiver_party_id", nullable = false) - val receiverPartyId: Long, - /** states to record: NONE, ALL_VISIBLE, ONLY_RELEVANT */ @Column(name = "states_to_record", nullable = false) var statesToRecord: StatesToRecord @@ -68,12 +71,13 @@ class DBTransactionStorageLedgerRecovery(private val database: CordaPersistence, fun toSenderDistributionRecord() = SenderDistributionRecord( SecureHash.parse(this.txId), - this.receiverPartyId, + this.compositeKey.peerPartyId, this.statesToRecord, this.compositeKey.timestamp ) } + @CordaSerializable @Entity @Table(name = "${NODE_DATABASE_PREFIX}receiver_distribution_records") data class DBReceiverDistributionRecord( @@ -83,34 +87,23 @@ class DBTransactionStorageLedgerRecovery(private val database: CordaPersistence, @Column(name = "transaction_id", length = 144, nullable = false) var txId: String, - /** PartyId of flow initiator **/ - @Column(name = "sender_party_id", nullable = true) - val senderPartyId: Long, - /** Encrypted recovery information for sole use by Sender **/ @Lob @Column(name = "distribution_list", nullable = false) - val distributionList: ByteArray, - - /** states to record: NONE, ALL_VISIBLE, ONLY_RELEVANT */ - @Column(name = "receiver_states_to_record", nullable = false) - val receiverStatesToRecord: StatesToRecord + val distributionList: ByteArray ) { - constructor(key: Key, txId: SecureHash, initiatorPartyId: Long, encryptedDistributionList: ByteArray, receiverStatesToRecord: StatesToRecord) : + constructor(key: Key, txId: SecureHash, encryptedDistributionList: ByteArray) : this(PersistentKey(key), txId = txId.toString(), - senderPartyId = initiatorPartyId, - distributionList = encryptedDistributionList, - receiverStatesToRecord = receiverStatesToRecord + distributionList = encryptedDistributionList ) fun toReceiverDistributionRecord(cryptoService: CryptoService): ReceiverDistributionRecord { val hashedDL = HashedDistributionList.deserialize(cryptoService.decrypt(this.distributionList)) return ReceiverDistributionRecord( SecureHash.parse(this.txId), - this.senderPartyId, + this.compositeKey.peerPartyId, hashedDL.peerHashToStatesToRecord, - this.receiverStatesToRecord, hashedDL.senderStatesToRecord, this.compositeKey.timestamp ) @@ -130,23 +123,29 @@ class DBTransactionStorageLedgerRecovery(private val database: CordaPersistence, val partyName: String ) + data class TimestampKey(val timestamp: Instant, val timestampDiscriminator: Int) + class Key( + val partyId: Long, val timestamp: Instant, - val sequenceNumber: Long = nextSequenceNumber.andIncrement + val timestampDiscriminator: Int = nextDiscriminatorNumber.andIncrement ) { + constructor(key: TimestampKey, partyId: Long): this(partyId = partyId, timestamp = key.timestamp, timestampDiscriminator = key.timestampDiscriminator) companion object { - private val nextSequenceNumber = AtomicLong() + val nextDiscriminatorNumber = AtomicInteger() } } override fun addSenderTransactionRecoveryMetadata(id: SecureHash, metadata: TransactionMetadata): ByteArray { + val senderRecordingTimestamp = clock.instant() return database.transaction { - val senderRecordingTimestamp = clock.instant() - metadata.distributionList.peersToStatesToRecord.forEach { (peer, _) -> - val senderDistributionRecord = DBSenderDistributionRecord(PersistentKey(Key(senderRecordingTimestamp)), + // sender distribution records must be unique per txnId and timestamp + val timeDiscriminator = Key.nextDiscriminatorNumber.andIncrement + metadata.distributionList.peersToStatesToRecord.map { (peerCordaX500Name, peerStatesToRecord) -> + val senderDistributionRecord = DBSenderDistributionRecord( + PersistentKey(Key(TimestampKey(senderRecordingTimestamp, timeDiscriminator), partyInfoCache.getPartyIdByCordaX500Name(peerCordaX500Name))), id.toString(), - partyInfoCache.getPartyIdByCordaX500Name(peer), - metadata.distributionList.senderStatesToRecord) + peerStatesToRecord) session.save(senderDistributionRecord) } val hashedPeersToStatesToRecord = metadata.distributionList.peersToStatesToRecord.map { (peer, statesToRecord) -> @@ -156,19 +155,48 @@ class DBTransactionStorageLedgerRecovery(private val database: CordaPersistence, } } + fun createReceiverTransactionRecoverMetadata(txId: SecureHash, + senderPartyId: Long, + senderStatesToRecord: StatesToRecord, + senderRecords: List): List { + val senderRecordsByTimestampKey = senderRecords.groupBy { TimestampKey(it.compositeKey.timestamp, it.compositeKey.timestampDiscriminator) } + return senderRecordsByTimestampKey.map { + val hashedDistributionList = HashedDistributionList( + senderStatesToRecord = senderStatesToRecord, + peerHashToStatesToRecord = senderRecords.map { it.compositeKey.peerPartyId to it.statesToRecord }.toMap(), + senderRecordedTimestamp = it.key.timestamp + ) + DBReceiverDistributionRecord( + compositeKey = PersistentKey(Key(TimestampKey(it.key.timestamp, it.key.timestampDiscriminator), senderPartyId)), + txId = txId.toString(), + distributionList = cryptoService.encrypt(hashedDistributionList.serialize()) + ) + } + } + + fun addSenderTransactionRecoveryMetadata(record: DBSenderDistributionRecord) { + return database.transaction { + session.save(record) + } + } + override fun addReceiverTransactionRecoveryMetadata(id: SecureHash, sender: CordaX500Name, receiver: CordaX500Name, receiverStatesToRecord: StatesToRecord, encryptedDistributionList: ByteArray) { val senderRecordedTimestamp = HashedDistributionList.deserialize(cryptoService.decrypt(encryptedDistributionList)).senderRecordedTimestamp database.transaction { val receiverDistributionRecord = - DBReceiverDistributionRecord(Key(senderRecordedTimestamp), + DBReceiverDistributionRecord(Key(partyInfoCache.getPartyIdByCordaX500Name(sender), senderRecordedTimestamp), id, - partyInfoCache.getPartyIdByCordaX500Name(sender), - encryptedDistributionList, - receiverStatesToRecord) + encryptedDistributionList) session.save(receiverDistributionRecord) } } + fun addReceiverTransactionRecoveryMetadata(record: DBReceiverDistributionRecord) { + return database.transaction { + session.save(record) + } + } + override fun removeUnnotarisedTransaction(id: SecureHash): Boolean { return database.transaction { super.removeUnnotarisedTransaction(id) @@ -187,27 +215,30 @@ class DBTransactionStorageLedgerRecovery(private val database: CordaPersistence, fun queryDistributionRecords(timeWindow: RecoveryTimeWindow, recordType: DistributionRecordType = DistributionRecordType.ALL, - excludingTxnIds: Set? = null, + excludingTxnIds: Set = emptySet(), orderByTimestamp: Sort.Direction? = null - ): List { + ): DistributionRecords { return when(recordType) { DistributionRecordType.SENDER -> - querySenderDistributionRecords(timeWindow, excludingTxnIds = excludingTxnIds, orderByTimestamp = orderByTimestamp) + DistributionRecords(senderRecords = + querySenderDistributionRecords(timeWindow, excludingTxnIds = excludingTxnIds, orderByTimestamp = orderByTimestamp)) DistributionRecordType.RECEIVER -> - queryReceiverDistributionRecords(timeWindow, excludingTxnIds = excludingTxnIds, orderByTimestamp = orderByTimestamp) + DistributionRecords(receiverRecords = + queryReceiverDistributionRecords(timeWindow, excludingTxnIds = excludingTxnIds, orderByTimestamp = orderByTimestamp)) DistributionRecordType.ALL -> - querySenderDistributionRecords(timeWindow, excludingTxnIds = excludingTxnIds, orderByTimestamp = orderByTimestamp).plus( - queryReceiverDistributionRecords(timeWindow, excludingTxnIds = excludingTxnIds, orderByTimestamp = orderByTimestamp) - ) + DistributionRecords(senderRecords = + querySenderDistributionRecords(timeWindow, excludingTxnIds = excludingTxnIds, orderByTimestamp = orderByTimestamp), + receiverRecords = + queryReceiverDistributionRecords(timeWindow, excludingTxnIds = excludingTxnIds, orderByTimestamp = orderByTimestamp)) } } @Suppress("SpreadOperator") fun querySenderDistributionRecords(timeWindow: RecoveryTimeWindow, peers: Set = emptySet(), - excludingTxnIds: Set? = null, + excludingTxnIds: Set = emptySet(), orderByTimestamp: Sort.Direction? = null - ): List { + ): List { return database.transaction { val criteriaBuilder = session.criteriaBuilder val criteriaQuery = criteriaBuilder.createQuery(DBSenderDistributionRecord::class.java) @@ -216,13 +247,13 @@ class DBTransactionStorageLedgerRecovery(private val database: CordaPersistence, val compositeKey = txnMetadata.get("compositeKey") predicates.add(criteriaBuilder.greaterThanOrEqualTo(compositeKey.get(PersistentKey::timestamp.name), timeWindow.fromTime)) predicates.add(criteriaBuilder.and(criteriaBuilder.lessThanOrEqualTo(compositeKey.get(PersistentKey::timestamp.name), timeWindow.untilTime))) - excludingTxnIds?.let { excludingTxnIds -> - predicates.add(criteriaBuilder.and(criteriaBuilder.notEqual(txnMetadata.get(DBSenderDistributionRecord::txId.name), - excludingTxnIds.map { it.toString() }))) + if (excludingTxnIds.isNotEmpty()) { + predicates.add(criteriaBuilder.and(criteriaBuilder.not(txnMetadata.get(DBSenderDistributionRecord::txId.name).`in`( + excludingTxnIds.map { it.toString() })))) } if (peers.isNotEmpty()) { val peerPartyIds = peers.map { partyInfoCache.getPartyIdByCordaX500Name(it) } - predicates.add(criteriaBuilder.and(txnMetadata.get(DBSenderDistributionRecord::receiverPartyId.name).`in`(peerPartyIds))) + predicates.add(criteriaBuilder.and(compositeKey.get(PersistentKey::peerPartyId.name).`in`(peerPartyIds))) } criteriaQuery.where(*predicates.toTypedArray()) // optionally order by timestamp @@ -236,16 +267,27 @@ class DBTransactionStorageLedgerRecovery(private val database: CordaPersistence, criteriaQuery.orderBy(orderCriteria) } val results = session.createQuery(criteriaQuery).stream() - results.map { it.toSenderDistributionRecord() }.toList() + results.toList() + } + } + + fun querySenderDistributionRecordsByTxId(txId: SecureHash): List { + return database.transaction { + val criteriaBuilder = session.criteriaBuilder + val criteriaQuery = criteriaBuilder.createQuery(DBSenderDistributionRecord::class.java) + val txnMetadata = criteriaQuery.from(DBSenderDistributionRecord::class.java) + criteriaQuery.where(criteriaBuilder.equal(txnMetadata.get(DBSenderDistributionRecord::txId.name), txId.toString())) + val results = session.createQuery(criteriaQuery).stream() + results.toList() } } @Suppress("SpreadOperator") fun queryReceiverDistributionRecords(timeWindow: RecoveryTimeWindow, initiators: Set = emptySet(), - excludingTxnIds: Set? = null, + excludingTxnIds: Set = emptySet(), orderByTimestamp: Sort.Direction? = null - ): List { + ): List { return database.transaction { val criteriaBuilder = session.criteriaBuilder val criteriaQuery = criteriaBuilder.createQuery(DBReceiverDistributionRecord::class.java) @@ -254,13 +296,13 @@ class DBTransactionStorageLedgerRecovery(private val database: CordaPersistence, val compositeKey = txnMetadata.get("compositeKey") predicates.add(criteriaBuilder.greaterThanOrEqualTo(compositeKey.get(PersistentKey::timestamp.name), timeWindow.fromTime)) predicates.add(criteriaBuilder.and(criteriaBuilder.lessThanOrEqualTo(compositeKey.get(PersistentKey::timestamp.name), timeWindow.untilTime))) - excludingTxnIds?.let { excludingTxnIds -> - predicates.add(criteriaBuilder.and(criteriaBuilder.notEqual(txnMetadata.get(DBReceiverDistributionRecord::txId.name), - excludingTxnIds.map { it.toString() }))) + if (excludingTxnIds.isNotEmpty()) { + predicates.add(criteriaBuilder.and(criteriaBuilder.not(txnMetadata.get(DBSenderDistributionRecord::txId.name).`in`( + excludingTxnIds.map { it.toString() })))) } if (initiators.isNotEmpty()) { val initiatorPartyIds = initiators.map { partyInfoCache.getPartyIdByCordaX500Name(it) } - predicates.add(criteriaBuilder.and(txnMetadata.get(DBReceiverDistributionRecord::senderPartyId.name).`in`(initiatorPartyIds))) + predicates.add(criteriaBuilder.and(compositeKey.get(PersistentKey::peerPartyId.name).`in`(initiatorPartyIds))) } criteriaQuery.where(*predicates.toTypedArray()) // optionally order by timestamp @@ -274,13 +316,14 @@ class DBTransactionStorageLedgerRecovery(private val database: CordaPersistence, criteriaQuery.orderBy(orderCriteria) } val results = session.createQuery(criteriaQuery).stream() - results.map { it.toReceiverDistributionRecord(cryptoService) }.toList() + results.toList() } } } // TO DO: https://r3-cev.atlassian.net/browse/ENT-9876 -private fun CryptoService.decrypt(bytes: ByteArray): ByteArray { +@VisibleForTesting +fun CryptoService.decrypt(bytes: ByteArray): ByteArray { return bytes } @@ -289,6 +332,18 @@ fun CryptoService.encrypt(bytes: ByteArray): ByteArray { return bytes } +@CordaSerializable +class DistributionRecords( + val senderRecords: List = emptyList(), + val receiverRecords: List = emptyList() +) { + init { + assert(senderRecords.isNotEmpty() || receiverRecords.isNotEmpty()) { "Must set senderRecords or receiverRecords or both." } + } + + val size = senderRecords.size + receiverRecords.size +} + @CordaSerializable open class DistributionRecord( open val txId: SecureHash, @@ -310,7 +365,6 @@ data class ReceiverDistributionRecord( val initiatorPartyId: Long, // CordaX500Name hashCode() val peersToStatesToRecord: Map, // CordaX500Name hashCode() -> StatesToRecord override val statesToRecord: StatesToRecord, - val senderStatesToRecord: StatesToRecord, override val timestamp: Instant ) : DistributionRecord(txId, statesToRecord, timestamp) diff --git a/node/src/main/resources/migration/node-core.changelog-v25.xml b/node/src/main/resources/migration/node-core.changelog-v25.xml index 0a81edd870..a199a65df8 100644 --- a/node/src/main/resources/migration/node-core.changelog-v25.xml +++ b/node/src/main/resources/migration/node-core.changelog-v25.xml @@ -12,16 +12,16 @@ - + - + - + @@ -31,54 +31,35 @@ - - - - - - - - - - + - + - + - - - - - - - - - - - - @@ -94,15 +75,15 @@ - - + - - + diff --git a/node/src/test/kotlin/net/corda/node/services/persistence/DBTransactionStorageLedgerRecoveryTests.kt b/node/src/test/kotlin/net/corda/node/services/persistence/DBTransactionStorageLedgerRecoveryTests.kt index 2bc406601c..39836dfdd1 100644 --- a/node/src/test/kotlin/net/corda/node/services/persistence/DBTransactionStorageLedgerRecoveryTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/persistence/DBTransactionStorageLedgerRecoveryTests.kt @@ -109,6 +109,21 @@ class DBTransactionStorageLedgerRecoveryTests { val timeWindow = RecoveryTimeWindow(fromTime = now().minus(1, ChronoUnit.DAYS)) val results = transactionRecovery.querySenderDistributionRecords(timeWindow, excludingTxnIds = setOf(transaction1.id)) assertEquals(1, results.size) + assertEquals(transaction2.id.toString(), results[0].txId) + } + + @Test(timeout = 300_000) + fun `query local ledger for transactions within timeWindow and for given peers`() { + val transaction1 = newTransaction() + transactionRecovery.addUnnotarisedTransaction(transaction1) + transactionRecovery.addSenderTransactionRecoveryMetadata(transaction1.id, TransactionMetadata(ALICE_NAME, DistributionList(ALL_VISIBLE, mapOf(BOB_NAME to ONLY_RELEVANT)))) + val transaction2 = newTransaction() + transactionRecovery.addUnnotarisedTransaction(transaction2) + transactionRecovery.addSenderTransactionRecoveryMetadata(transaction2.id, TransactionMetadata(ALICE_NAME, DistributionList(ALL_VISIBLE, mapOf(CHARLIE_NAME to ONLY_RELEVANT)))) + val timeWindow = RecoveryTimeWindow(fromTime = now().minus(1, ChronoUnit.DAYS)) + val results = transactionRecovery.querySenderDistributionRecords(timeWindow, peers = setOf(CHARLIE_NAME)) + assertEquals(1, results.size) + assertEquals(transaction2.id.toString(), results[0].txId) } @Test(timeout = 300_000) @@ -125,13 +140,13 @@ class DBTransactionStorageLedgerRecoveryTests { val timeWindow = RecoveryTimeWindow(fromTime = now().minus(1, ChronoUnit.DAYS)) transactionRecovery.queryDistributionRecords(timeWindow, recordType = DistributionRecordType.SENDER).let { assertEquals(1, it.size) - assertEquals(BOB_NAME.hashCode().toLong(), (it[0] as SenderDistributionRecord).peerPartyId) - assertEquals(ALL_VISIBLE, (it[0] as SenderDistributionRecord).statesToRecord) + assertEquals(BOB_NAME.hashCode().toLong(), it.senderRecords[0].compositeKey.peerPartyId) + assertEquals(ALL_VISIBLE, it.senderRecords[0].statesToRecord) } transactionRecovery.queryDistributionRecords(timeWindow, recordType = DistributionRecordType.RECEIVER).let { assertEquals(1, it.size) - assertEquals(BOB_NAME.hashCode().toLong(), (it[0] as ReceiverDistributionRecord).initiatorPartyId) - assertEquals(ALL_VISIBLE, (it[0] as ReceiverDistributionRecord).statesToRecord) + assertEquals(BOB_NAME.hashCode().toLong(), it.receiverRecords[0].compositeKey.peerPartyId) + assertEquals(ALL_VISIBLE, (transactionRecovery.decrypt(it.receiverRecords[0].distributionList).peerHashToStatesToRecord.map { it.value }[0])) } val resultsAll = transactionRecovery.queryDistributionRecords(timeWindow, recordType = DistributionRecordType.ALL) assertEquals(2, resultsAll.size) @@ -192,9 +207,9 @@ class DBTransactionStorageLedgerRecoveryTests { val timeWindow = RecoveryTimeWindow(fromTime = now().minus(1, ChronoUnit.DAYS)) transactionRecovery.queryReceiverDistributionRecords(timeWindow, initiators = setOf(ALICE_NAME)).let { assertEquals(3, it.size) - assertEquals(it[0].statesToRecord, ALL_VISIBLE) - assertEquals(it[1].statesToRecord, ONLY_RELEVANT) - assertEquals(it[2].statesToRecord, NONE) + assertEquals(transactionRecovery.decrypt(it[0].distributionList).peerHashToStatesToRecord.map { it.value }[0], ALL_VISIBLE) + assertEquals(transactionRecovery.decrypt(it[1].distributionList).peerHashToStatesToRecord.map { it.value }[0], ONLY_RELEVANT) + assertEquals(transactionRecovery.decrypt(it[2].distributionList).peerHashToStatesToRecord.map { it.value }[0], NONE) } assertEquals(1, transactionRecovery.queryReceiverDistributionRecords(timeWindow, initiators = setOf(BOB_NAME)).size) assertEquals(1, transactionRecovery.queryReceiverDistributionRecords(timeWindow, initiators = setOf(CHARLIE_NAME)).size) @@ -229,8 +244,8 @@ class DBTransactionStorageLedgerRecoveryTests { DistributionList(ONLY_RELEVANT, mapOf(BOB_NAME to ALL_VISIBLE)).toWire()) assertEquals(IN_FLIGHT, readTransactionFromDB(receiverTransaction.id).status) readReceiverDistributionRecordFromDB(receiverTransaction.id).let { - assertEquals(ALL_VISIBLE, it.statesToRecord) - assertEquals(ONLY_RELEVANT, it.senderStatesToRecord) + assertEquals(ONLY_RELEVANT, it.statesToRecord) + assertEquals(ALL_VISIBLE, it.peersToStatesToRecord.map { it.value }[0]) assertEquals(ALICE_NAME, partyInfoCache.getCordaX500NameByPartyId(it.initiatorPartyId)) assertEquals(setOf(BOB_NAME), it.peersToStatesToRecord.map { (peer, _) -> partyInfoCache.getCordaX500NameByPartyId(peer) }.toSet() ) } @@ -245,7 +260,7 @@ class DBTransactionStorageLedgerRecoveryTests { assertEquals(VERIFIED, readTransactionFromDB(transaction.id).status) readSenderDistributionRecordFromDB(transaction.id).apply { assertEquals(1, this.size) - assertEquals(ONLY_RELEVANT, this[0].statesToRecord) + assertEquals(ALL_VISIBLE, this[0].statesToRecord) } } @@ -377,3 +392,7 @@ class DBTransactionStorageLedgerRecoveryTests { return cryptoService.encrypt(hashedDistributionList.serialize()) } } + +internal fun DBTransactionStorageLedgerRecovery.decrypt(distributionList: ByteArray): HashedDistributionList { + return HashedDistributionList.deserialize(this.cryptoService.decrypt(distributionList)) +}