ENT-10110 Back-port changes from ENT + additional clean-up (#7532)

This commit is contained in:
Jose Coll
2023-10-17 07:03:49 +01:00
committed by GitHub
parent 1981334921
commit 6a2bad8077
6 changed files with 223 additions and 157 deletions

View File

@ -2580,6 +2580,16 @@ public static final class net.corda.core.flows.DistributionList$SenderDistributi
public int hashCode() public int hashCode()
@NotNull @NotNull
public String toString() public String toString()
@CordaSerializable
public abstract class net.corda.core.flows.DistributionRecord extends java.lang.Object implements net.corda.core.contracts.NamedByHash
public <init>()
@NotNull
public abstract net.corda.core.crypto.SecureHash getPeerPartyId()
@NotNull
public abstract java.time.Instant getTimestamp()
public abstract int getTimestampDiscriminator()
@NotNull
public abstract net.corda.core.crypto.SecureHash getTxId()
## ##
@InitiatingFlow @InitiatingFlow
public final class net.corda.core.flows.FinalityFlow extends net.corda.core.flows.FlowLogic public final class net.corda.core.flows.FinalityFlow extends net.corda.core.flows.FlowLogic

View File

@ -21,7 +21,9 @@ import net.corda.core.flows.NotaryException
import net.corda.core.flows.NotarySigCheck import net.corda.core.flows.NotarySigCheck
import net.corda.core.flows.ReceiveFinalityFlow import net.corda.core.flows.ReceiveFinalityFlow
import net.corda.core.flows.ReceiveTransactionFlow import net.corda.core.flows.ReceiveTransactionFlow
import net.corda.core.flows.ReceiverDistributionRecord
import net.corda.core.flows.SendTransactionFlow import net.corda.core.flows.SendTransactionFlow
import net.corda.core.flows.SenderDistributionRecord
import net.corda.core.flows.StartableByRPC import net.corda.core.flows.StartableByRPC
import net.corda.core.flows.TransactionStatus import net.corda.core.flows.TransactionStatus
import net.corda.core.flows.UnexpectedFlowEndException import net.corda.core.flows.UnexpectedFlowEndException
@ -53,8 +55,6 @@ import net.corda.node.services.persistence.DBTransactionStorage
import net.corda.node.services.persistence.DBTransactionStorageLedgerRecovery.DBReceiverDistributionRecord import net.corda.node.services.persistence.DBTransactionStorageLedgerRecovery.DBReceiverDistributionRecord
import net.corda.node.services.persistence.DBTransactionStorageLedgerRecovery.DBSenderDistributionRecord import net.corda.node.services.persistence.DBTransactionStorageLedgerRecovery.DBSenderDistributionRecord
import net.corda.node.services.persistence.HashedDistributionList import net.corda.node.services.persistence.HashedDistributionList
import net.corda.node.services.persistence.ReceiverDistributionRecord
import net.corda.node.services.persistence.SenderDistributionRecord
import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.testing.contracts.DummyContract import net.corda.testing.contracts.DummyContract
import net.corda.testing.core.ALICE_NAME import net.corda.testing.core.ALICE_NAME
@ -361,7 +361,7 @@ class FinalityFlowTests : WithFinality {
assertNotNull(this) assertNotNull(this)
val hashedDL = HashedDistributionList.decrypt(this!!.encryptedDistributionList.bytes, aliceNode.internals.encryptionService) val hashedDL = HashedDistributionList.decrypt(this!!.encryptedDistributionList.bytes, aliceNode.internals.encryptionService)
assertEquals(StatesToRecord.ONLY_RELEVANT, hashedDL.senderStatesToRecord) assertEquals(StatesToRecord.ONLY_RELEVANT, hashedDL.senderStatesToRecord)
assertEquals(SecureHash.sha256(aliceNode.info.singleIdentity().name.toString()), this.initiatorPartyId) assertEquals(SecureHash.sha256(aliceNode.info.singleIdentity().name.toString()), this.peerPartyId)
assertEquals(mapOf<SecureHash, StatesToRecord>(SecureHash.sha256(BOB_NAME.toString()) to StatesToRecord.ALL_VISIBLE), hashedDL.peerHashToStatesToRecord) assertEquals(mapOf<SecureHash, StatesToRecord>(SecureHash.sha256(BOB_NAME.toString()) to StatesToRecord.ALL_VISIBLE), hashedDL.peerHashToStatesToRecord)
} }
validateSenderAndReceiverTimestamps(sdrs, rdr!!) validateSenderAndReceiverTimestamps(sdrs, rdr!!)
@ -396,7 +396,7 @@ class FinalityFlowTests : WithFinality {
assertNotNull(this) assertNotNull(this)
val hashedDL = HashedDistributionList.decrypt(this!!.encryptedDistributionList.bytes, aliceNode.internals.encryptionService) val hashedDL = HashedDistributionList.decrypt(this!!.encryptedDistributionList.bytes, aliceNode.internals.encryptionService)
assertEquals(StatesToRecord.ONLY_RELEVANT, hashedDL.senderStatesToRecord) assertEquals(StatesToRecord.ONLY_RELEVANT, hashedDL.senderStatesToRecord)
assertEquals(SecureHash.sha256(aliceNode.info.singleIdentity().name.toString()), this.initiatorPartyId) assertEquals(SecureHash.sha256(aliceNode.info.singleIdentity().name.toString()), this.peerPartyId)
// note: Charlie assertion here is using the hinted StatesToRecord value passed to it from Alice // note: Charlie assertion here is using the hinted StatesToRecord value passed to it from Alice
assertEquals(mapOf<SecureHash, StatesToRecord>( assertEquals(mapOf<SecureHash, StatesToRecord>(
SecureHash.sha256(BOB_NAME.toString()) to StatesToRecord.ONLY_RELEVANT, SecureHash.sha256(BOB_NAME.toString()) to StatesToRecord.ONLY_RELEVANT,
@ -458,7 +458,7 @@ class FinalityFlowTests : WithFinality {
assertNotNull(this) assertNotNull(this)
val hashedDL = HashedDistributionList.decrypt(this!!.encryptedDistributionList.bytes, aliceNode.internals.encryptionService) val hashedDL = HashedDistributionList.decrypt(this!!.encryptedDistributionList.bytes, aliceNode.internals.encryptionService)
assertEquals(StatesToRecord.ONLY_RELEVANT, hashedDL.senderStatesToRecord) assertEquals(StatesToRecord.ONLY_RELEVANT, hashedDL.senderStatesToRecord)
assertEquals(SecureHash.sha256(aliceNode.info.singleIdentity().name.toString()), this.initiatorPartyId) assertEquals(SecureHash.sha256(aliceNode.info.singleIdentity().name.toString()), this.peerPartyId)
assertEquals(mapOf<SecureHash, StatesToRecord>(SecureHash.sha256(BOB_NAME.toString()) to StatesToRecord.ONLY_RELEVANT), hashedDL.peerHashToStatesToRecord) assertEquals(mapOf<SecureHash, StatesToRecord>(SecureHash.sha256(BOB_NAME.toString()) to StatesToRecord.ONLY_RELEVANT), hashedDL.peerHashToStatesToRecord)
} }
validateSenderAndReceiverTimestamps(sdr, rdr!!) validateSenderAndReceiverTimestamps(sdr, rdr!!)

View File

@ -1,78 +0,0 @@
package net.corda.core.flows
import net.corda.core.identity.CordaX500Name
import net.corda.core.node.StatesToRecord
import net.corda.core.serialization.CordaSerializable
import java.time.Instant
/**
* Flow data object representing key information required for recovery.
*/
@CordaSerializable
data class FlowTransactionInfo(
val stateMachineRunId: StateMachineRunId,
val txId: String,
val status: TransactionStatus,
val timestamp: Instant,
val metadata: TransactionMetadata?
) {
fun isInitiator(myCordaX500Name: CordaX500Name) =
this.metadata?.initiator == myCordaX500Name
}
@CordaSerializable
data class TransactionMetadata(
val initiator: CordaX500Name,
val distributionList: DistributionList
)
@CordaSerializable
sealed class DistributionList {
@CordaSerializable
data class SenderDistributionList(
val senderStatesToRecord: StatesToRecord,
val peersToStatesToRecord: Map<CordaX500Name, StatesToRecord>
) : DistributionList()
@CordaSerializable
data class ReceiverDistributionList(
val opaqueData: ByteArray, // decipherable only by sender
val receiverStatesToRecord: StatesToRecord // inferred or actual
) : DistributionList()
}
@CordaSerializable
enum class TransactionStatus {
UNVERIFIED,
VERIFIED,
IN_FLIGHT;
}
@CordaSerializable
data class RecoveryTimeWindow(val fromTime: Instant, val untilTime: Instant = Instant.now()) {
init {
if (untilTime < fromTime) {
throw IllegalArgumentException("$fromTime must be before $untilTime")
}
}
companion object {
@JvmStatic
fun between(fromTime: Instant, untilTime: Instant): RecoveryTimeWindow {
return RecoveryTimeWindow(fromTime, untilTime)
}
@JvmStatic
fun fromOnly(fromTime: Instant): RecoveryTimeWindow {
return RecoveryTimeWindow(fromTime = fromTime)
}
@JvmStatic
fun untilOnly(untilTime: Instant): RecoveryTimeWindow {
return RecoveryTimeWindow(fromTime = Instant.EPOCH, untilTime = untilTime)
}
}
}

View File

@ -0,0 +1,150 @@
package net.corda.core.flows
import net.corda.core.contracts.NamedByHash
import net.corda.core.crypto.SecureHash
import net.corda.core.identity.CordaX500Name
import net.corda.core.node.StatesToRecord
import net.corda.core.serialization.CordaSerializable
import net.corda.core.utilities.OpaqueBytes
import java.time.Instant
import java.time.temporal.ChronoUnit
/**
* Transaction recovery type information.
*/
@CordaSerializable
data class FlowTransactionInfo(
val stateMachineRunId: StateMachineRunId,
val txId: String,
val status: TransactionStatus,
val timestamp: Instant,
val metadata: TransactionMetadata?
) {
fun isInitiator(myCordaX500Name: CordaX500Name) =
this.metadata?.initiator == myCordaX500Name
}
@CordaSerializable
data class TransactionMetadata(
val initiator: CordaX500Name,
val distributionList: DistributionList
)
@CordaSerializable
sealed class DistributionList {
@CordaSerializable
data class SenderDistributionList(
val senderStatesToRecord: StatesToRecord,
val peersToStatesToRecord: Map<CordaX500Name, StatesToRecord>
) : DistributionList()
@CordaSerializable
data class ReceiverDistributionList(
val opaqueData: ByteArray, // decipherable only by sender
val receiverStatesToRecord: StatesToRecord // inferred or actual
) : DistributionList()
}
@CordaSerializable
enum class TransactionStatus {
UNVERIFIED,
VERIFIED,
IN_FLIGHT;
}
@CordaSerializable
class DistributionRecords(
val senderRecords: List<SenderDistributionRecord> = emptyList(),
val receiverRecords: List<ReceiverDistributionRecord> = emptyList()
) {
val size = senderRecords.size + receiverRecords.size
}
@CordaSerializable
abstract class DistributionRecord : NamedByHash {
abstract val txId: SecureHash
abstract val peerPartyId: SecureHash
abstract val timestamp: Instant
abstract val timestampDiscriminator: Int
}
@CordaSerializable
data class SenderDistributionRecord(
override val txId: SecureHash,
override val peerPartyId: SecureHash,
override val timestamp: Instant,
override val timestampDiscriminator: Int,
val senderStatesToRecord: StatesToRecord,
val receiverStatesToRecord: StatesToRecord
) : DistributionRecord() {
override val id: SecureHash
get() = this.txId
}
@CordaSerializable
data class ReceiverDistributionRecord(
override val txId: SecureHash,
override val peerPartyId: SecureHash,
override val timestamp: Instant,
override val timestampDiscriminator: Int,
val encryptedDistributionList: OpaqueBytes,
val receiverStatesToRecord: StatesToRecord
) : DistributionRecord() {
override val id: SecureHash
get() = this.txId
}
@CordaSerializable
enum class DistributionRecordType {
SENDER, RECEIVER, ALL
}
@CordaSerializable
data class DistributionRecordKey(
val txnId: SecureHash,
val timestamp: Instant,
val timestampDiscriminator: Int
)
@CordaSerializable
data class RecoveryTimeWindow(val fromTime: Instant, val untilTime: Instant = Instant.now()) {
init {
if (untilTime < fromTime) {
throw IllegalArgumentException("$fromTime must be before $untilTime")
}
}
companion object {
@JvmStatic
fun between(fromTime: Instant, untilTime: Instant): RecoveryTimeWindow {
return RecoveryTimeWindow(fromTime, untilTime)
}
@JvmStatic
fun fromOnly(fromTime: Instant): RecoveryTimeWindow {
return RecoveryTimeWindow(fromTime = fromTime)
}
@JvmStatic
fun untilOnly(untilTime: Instant): RecoveryTimeWindow {
return RecoveryTimeWindow(fromTime = Instant.EPOCH, untilTime = untilTime)
}
}
}
@CordaSerializable
data class ComparableRecoveryTimeWindow(
val fromTime: Instant,
val fromTimestampDiscriminator: Int,
val untilTime: Instant,
val untilTimestampDiscriminator: Int
) {
companion object {
fun from(timeWindow: RecoveryTimeWindow) =
ComparableRecoveryTimeWindow(
timeWindow.fromTime.truncatedTo(ChronoUnit.SECONDS), 0,
timeWindow.untilTime.truncatedTo(ChronoUnit.SECONDS), Int.MAX_VALUE)
}
}

View File

@ -3,14 +3,17 @@ package net.corda.node.services.persistence
import net.corda.core.crypto.SecureHash import net.corda.core.crypto.SecureHash
import net.corda.core.flows.DistributionList.ReceiverDistributionList import net.corda.core.flows.DistributionList.ReceiverDistributionList
import net.corda.core.flows.DistributionList.SenderDistributionList import net.corda.core.flows.DistributionList.SenderDistributionList
import net.corda.core.flows.DistributionRecordKey
import net.corda.core.flows.DistributionRecordType
import net.corda.core.flows.DistributionRecords
import net.corda.core.flows.ReceiverDistributionRecord
import net.corda.core.flows.RecoveryTimeWindow import net.corda.core.flows.RecoveryTimeWindow
import net.corda.core.flows.SenderDistributionRecord
import net.corda.core.flows.TransactionMetadata import net.corda.core.flows.TransactionMetadata
import net.corda.core.identity.CordaX500Name import net.corda.core.identity.CordaX500Name
import net.corda.core.internal.NamedCacheFactory import net.corda.core.internal.NamedCacheFactory
import net.corda.core.internal.VisibleForTesting
import net.corda.core.node.StatesToRecord import net.corda.core.node.StatesToRecord
import net.corda.core.node.services.vault.Sort import net.corda.core.node.services.vault.Sort
import net.corda.core.serialization.CordaSerializable
import net.corda.core.utilities.OpaqueBytes import net.corda.core.utilities.OpaqueBytes
import net.corda.node.CordaClock import net.corda.node.CordaClock
import net.corda.node.services.EncryptionService import net.corda.node.services.EncryptionService
@ -20,6 +23,7 @@ import net.corda.nodeapi.internal.persistence.NODE_DATABASE_PREFIX
import org.hibernate.annotations.Immutable import org.hibernate.annotations.Immutable
import java.io.Serializable import java.io.Serializable
import java.time.Instant import java.time.Instant
import java.time.temporal.ChronoUnit
import java.util.concurrent.atomic.AtomicInteger import java.util.concurrent.atomic.AtomicInteger
import javax.persistence.Column import javax.persistence.Column
import javax.persistence.Embeddable import javax.persistence.Embeddable
@ -54,7 +58,6 @@ class DBTransactionStorageLedgerRecovery(private val database: CordaPersistence,
constructor(key: Key) : this(key.txId.toString(), key.partyId.toString(), key.timestamp, key.timestampDiscriminator) constructor(key: Key) : this(key.txId.toString(), key.partyId.toString(), key.timestamp, key.timestampDiscriminator)
} }
@CordaSerializable
@Entity @Entity
@Table(name = "${NODE_DATABASE_PREFIX}sender_distr_recs") @Table(name = "${NODE_DATABASE_PREFIX}sender_distr_recs")
data class DBSenderDistributionRecord( data class DBSenderDistributionRecord(
@ -69,17 +72,22 @@ class DBTransactionStorageLedgerRecovery(private val database: CordaPersistence,
@Column(name = "receiver_states_to_record", nullable = false) @Column(name = "receiver_states_to_record", nullable = false)
var receiverStatesToRecord: StatesToRecord var receiverStatesToRecord: StatesToRecord
) { ) {
fun key() = DistributionRecordKey(
SecureHash.parse(this.compositeKey.txId),
this.compositeKey.timestamp,
this.compositeKey.timestampDiscriminator)
fun toSenderDistributionRecord() = fun toSenderDistributionRecord() =
SenderDistributionRecord( SenderDistributionRecord(
SecureHash.parse(this.compositeKey.txId), SecureHash.parse(this.compositeKey.txId),
SecureHash.parse(this.compositeKey.peerPartyId), SecureHash.parse(this.compositeKey.peerPartyId),
this.compositeKey.timestamp,
this.compositeKey.timestampDiscriminator,
this.senderStatesToRecord, this.senderStatesToRecord,
this.receiverStatesToRecord, this.receiverStatesToRecord
this.compositeKey.timestamp
) )
} }
@CordaSerializable
@Entity @Entity
@Table(name = "${NODE_DATABASE_PREFIX}receiver_distr_recs") @Table(name = "${NODE_DATABASE_PREFIX}receiver_distr_recs")
data class DBReceiverDistributionRecord( data class DBReceiverDistributionRecord(
@ -100,13 +108,20 @@ class DBTransactionStorageLedgerRecovery(private val database: CordaPersistence,
distributionList = encryptedDistributionList, distributionList = encryptedDistributionList,
receiverStatesToRecord = receiverStatesToRecord receiverStatesToRecord = receiverStatesToRecord
) )
@VisibleForTesting
fun key() = DistributionRecordKey(
SecureHash.parse(this.compositeKey.txId),
this.compositeKey.timestamp,
this.compositeKey.timestampDiscriminator)
fun toReceiverDistributionRecord(): ReceiverDistributionRecord { fun toReceiverDistributionRecord(): ReceiverDistributionRecord {
return ReceiverDistributionRecord( return ReceiverDistributionRecord(
SecureHash.parse(this.compositeKey.txId), SecureHash.parse(this.compositeKey.txId),
SecureHash.parse(this.compositeKey.peerPartyId), SecureHash.parse(this.compositeKey.peerPartyId),
this.compositeKey.timestamp,
this.compositeKey.timestampDiscriminator,
OpaqueBytes(this.distributionList), OpaqueBytes(this.distributionList),
this.compositeKey.timestamp this.receiverStatesToRecord
) )
} }
} }
@ -137,7 +152,7 @@ class DBTransactionStorageLedgerRecovery(private val database: CordaPersistence,
override fun addSenderTransactionRecoveryMetadata(txId: SecureHash, metadata: TransactionMetadata): ByteArray { override fun addSenderTransactionRecoveryMetadata(txId: SecureHash, metadata: TransactionMetadata): ByteArray {
return database.transaction { return database.transaction {
val senderRecordingTimestamp = clock.instant() val senderRecordingTimestamp = clock.instant().truncatedTo(ChronoUnit.SECONDS)
val timeDiscriminator = Key.nextDiscriminatorNumber.andIncrement val timeDiscriminator = Key.nextDiscriminatorNumber.andIncrement
val distributionList = metadata.distributionList as? SenderDistributionList ?: throw IllegalStateException("Expecting SenderDistributionList") val distributionList = metadata.distributionList as? SenderDistributionList ?: throw IllegalStateException("Expecting SenderDistributionList")
distributionList.peersToStatesToRecord.map { (peerCordaX500Name, peerStatesToRecord) -> distributionList.peersToStatesToRecord.map { (peerCordaX500Name, peerStatesToRecord) ->
@ -174,7 +189,7 @@ class DBTransactionStorageLedgerRecovery(private val database: CordaPersistence,
distributionList.opaqueData, distributionList.opaqueData,
distributionList.receiverStatesToRecord distributionList.receiverStatesToRecord
) )
session.save(receiverDistributionRecord) session.saveOrUpdate(receiverDistributionRecord)
} }
} }
else -> throw IllegalStateException("Expecting ReceiverDistributionList") else -> throw IllegalStateException("Expecting ReceiverDistributionList")
@ -224,7 +239,7 @@ class DBTransactionStorageLedgerRecovery(private val database: CordaPersistence,
peers: Set<CordaX500Name> = emptySet(), peers: Set<CordaX500Name> = emptySet(),
excludingTxnIds: Set<SecureHash> = emptySet(), excludingTxnIds: Set<SecureHash> = emptySet(),
orderByTimestamp: Sort.Direction? = null orderByTimestamp: Sort.Direction? = null
): List<DBSenderDistributionRecord> { ): List<SenderDistributionRecord> {
return database.transaction { return database.transaction {
val criteriaBuilder = session.criteriaBuilder val criteriaBuilder = session.criteriaBuilder
val criteriaQuery = criteriaBuilder.createQuery(DBSenderDistributionRecord::class.java) val criteriaQuery = criteriaBuilder.createQuery(DBSenderDistributionRecord::class.java)
@ -253,7 +268,7 @@ class DBTransactionStorageLedgerRecovery(private val database: CordaPersistence,
criteriaQuery.orderBy(orderCriteria) criteriaQuery.orderBy(orderCriteria)
} }
session.createQuery(criteriaQuery).resultList session.createQuery(criteriaQuery).resultList
} }.map { it.toSenderDistributionRecord() }
} }
@Suppress("SpreadOperator") @Suppress("SpreadOperator")
@ -261,7 +276,7 @@ class DBTransactionStorageLedgerRecovery(private val database: CordaPersistence,
initiators: Set<CordaX500Name> = emptySet(), initiators: Set<CordaX500Name> = emptySet(),
excludingTxnIds: Set<SecureHash> = emptySet(), excludingTxnIds: Set<SecureHash> = emptySet(),
orderByTimestamp: Sort.Direction? = null orderByTimestamp: Sort.Direction? = null
): List<DBReceiverDistributionRecord> { ): List<ReceiverDistributionRecord> {
return database.transaction { return database.transaction {
val criteriaBuilder = session.criteriaBuilder val criteriaBuilder = session.criteriaBuilder
val criteriaQuery = criteriaBuilder.createQuery(DBReceiverDistributionRecord::class.java) val criteriaQuery = criteriaBuilder.createQuery(DBReceiverDistributionRecord::class.java)
@ -277,7 +292,7 @@ class DBTransactionStorageLedgerRecovery(private val database: CordaPersistence,
} }
if (initiators.isNotEmpty()) { if (initiators.isNotEmpty()) {
val initiatorPartyIds = initiators.map { partyInfoCache.getPartyIdByCordaX500Name(it).toString() } val initiatorPartyIds = initiators.map { partyInfoCache.getPartyIdByCordaX500Name(it).toString() }
predicates.add(criteriaBuilder.and(compositeKey.get<Long>(PersistentKey::peerPartyId.name).`in`(initiatorPartyIds))) predicates.add(criteriaBuilder.and(compositeKey.get<String>(PersistentKey::peerPartyId.name).`in`(initiatorPartyIds)))
} }
criteriaQuery.where(*predicates.toTypedArray()) criteriaQuery.where(*predicates.toTypedArray())
// optionally order by timestamp // optionally order by timestamp
@ -290,7 +305,7 @@ class DBTransactionStorageLedgerRecovery(private val database: CordaPersistence,
criteriaQuery.orderBy(orderCriteria) criteriaQuery.orderBy(orderCriteria)
} }
session.createQuery(criteriaQuery).resultList session.createQuery(criteriaQuery).resultList
} }.map { it.toReceiverDistributionRecord() }
} }
fun decryptHashedDistributionList(encryptedBytes: ByteArray): HashedDistributionList { fun decryptHashedDistributionList(encryptedBytes: ByteArray): HashedDistributionList {
@ -298,43 +313,3 @@ class DBTransactionStorageLedgerRecovery(private val database: CordaPersistence,
} }
} }
@CordaSerializable
class DistributionRecords(
val senderRecords: List<DBTransactionStorageLedgerRecovery.DBSenderDistributionRecord> = emptyList(),
val receiverRecords: List<DBTransactionStorageLedgerRecovery.DBReceiverDistributionRecord> = emptyList()
) {
init {
require(senderRecords.isNotEmpty() || receiverRecords.isNotEmpty()) { "Must set senderRecords or receiverRecords or both." }
}
val size = senderRecords.size + receiverRecords.size
}
@CordaSerializable
abstract class DistributionRecord {
abstract val txId: SecureHash
abstract val timestamp: Instant
}
@CordaSerializable
data class SenderDistributionRecord(
override val txId: SecureHash,
val peerPartyId: SecureHash, // CordaX500Name hashCode()
val senderStatesToRecord: StatesToRecord,
val receiverStatesToRecord: StatesToRecord,
override val timestamp: Instant
) : DistributionRecord()
@CordaSerializable
data class ReceiverDistributionRecord(
override val txId: SecureHash,
val initiatorPartyId: SecureHash, // CordaX500Name hashCode()
val encryptedDistributionList: OpaqueBytes,
override val timestamp: Instant
) : DistributionRecord()
@CordaSerializable
enum class DistributionRecordType {
SENDER, RECEIVER, ALL
}

View File

@ -8,7 +8,10 @@ import net.corda.core.crypto.SignatureMetadata
import net.corda.core.crypto.sign import net.corda.core.crypto.sign
import net.corda.core.flows.DistributionList.ReceiverDistributionList import net.corda.core.flows.DistributionList.ReceiverDistributionList
import net.corda.core.flows.DistributionList.SenderDistributionList import net.corda.core.flows.DistributionList.SenderDistributionList
import net.corda.core.flows.DistributionRecordType
import net.corda.core.flows.ReceiverDistributionRecord
import net.corda.core.flows.RecoveryTimeWindow import net.corda.core.flows.RecoveryTimeWindow
import net.corda.core.flows.SenderDistributionRecord
import net.corda.core.flows.TransactionMetadata import net.corda.core.flows.TransactionMetadata
import net.corda.core.node.NodeInfo import net.corda.core.node.NodeInfo
import net.corda.core.node.StatesToRecord.ALL_VISIBLE import net.corda.core.node.StatesToRecord.ALL_VISIBLE
@ -18,7 +21,6 @@ import net.corda.core.transactions.SignedTransaction
import net.corda.core.transactions.WireTransaction import net.corda.core.transactions.WireTransaction
import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.NetworkHostAndPort
import net.corda.node.CordaClock import net.corda.node.CordaClock
import net.corda.node.SimpleClock
import net.corda.node.services.identity.InMemoryIdentityService import net.corda.node.services.identity.InMemoryIdentityService
import net.corda.node.services.network.PersistentNetworkMapCache import net.corda.node.services.network.PersistentNetworkMapCache
import net.corda.node.services.network.PersistentPartyInfoCache import net.corda.node.services.network.PersistentPartyInfoCache
@ -40,6 +42,7 @@ import net.corda.testing.internal.TestingNamedCacheFactory
import net.corda.testing.internal.configureDatabase import net.corda.testing.internal.configureDatabase
import net.corda.testing.internal.createWireTransaction import net.corda.testing.internal.createWireTransaction
import net.corda.testing.node.MockServices.Companion.makeTestDataSourceProperties import net.corda.testing.node.MockServices.Companion.makeTestDataSourceProperties
import net.corda.testing.node.TestClock
import net.corda.testing.node.internal.MockEncryptionService import net.corda.testing.node.internal.MockEncryptionService
import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThat
import org.junit.After import org.junit.After
@ -48,7 +51,8 @@ import org.junit.Rule
import org.junit.Test import org.junit.Test
import java.security.KeyPair import java.security.KeyPair
import java.time.Clock import java.time.Clock
import java.time.Instant.now import java.time.Duration
import java.time.Instant
import java.time.temporal.ChronoUnit import java.time.temporal.ChronoUnit
import kotlin.test.assertEquals import kotlin.test.assertEquals
import kotlin.test.assertFailsWith import kotlin.test.assertFailsWith
@ -84,9 +88,13 @@ class DBTransactionStorageLedgerRecoveryTests {
database.close() database.close()
} }
fun now(): Instant {
return transactionRecovery.clock.instant()
}
@Test(timeout = 300_000) @Test(timeout = 300_000)
fun `query local ledger for transactions with recovery peers within time window`() { fun `query local ledger for transactions with recovery peers within time window`() {
val beforeFirstTxn = now() val beforeFirstTxn = now().truncatedTo(ChronoUnit.SECONDS)
val txn = newTransaction() val txn = newTransaction()
transactionRecovery.addUnnotarisedTransaction(txn) transactionRecovery.addUnnotarisedTransaction(txn)
transactionRecovery.addSenderTransactionRecoveryMetadata(txn.id, TransactionMetadata(ALICE_NAME, SenderDistributionList(ALL_VISIBLE, mapOf(BOB_NAME to ONLY_RELEVANT)))) transactionRecovery.addSenderTransactionRecoveryMetadata(txn.id, TransactionMetadata(ALICE_NAME, SenderDistributionList(ALL_VISIBLE, mapOf(BOB_NAME to ONLY_RELEVANT))))
@ -94,13 +102,14 @@ class DBTransactionStorageLedgerRecoveryTests {
untilTime = beforeFirstTxn.plus(1, ChronoUnit.MINUTES)) untilTime = beforeFirstTxn.plus(1, ChronoUnit.MINUTES))
val results = transactionRecovery.querySenderDistributionRecords(timeWindow) val results = transactionRecovery.querySenderDistributionRecords(timeWindow)
assertEquals(1, results.size) assertEquals(1, results.size)
(transactionRecovery.clock as TestClock).advanceBy(Duration.ofSeconds(1))
val afterFirstTxn = now() val afterFirstTxn = now().truncatedTo(ChronoUnit.SECONDS)
val txn2 = newTransaction() val txn2 = newTransaction()
transactionRecovery.addUnnotarisedTransaction(txn2) transactionRecovery.addUnnotarisedTransaction(txn2)
transactionRecovery.addSenderTransactionRecoveryMetadata(txn2.id, TransactionMetadata(ALICE_NAME, SenderDistributionList(ONLY_RELEVANT, mapOf(CHARLIE_NAME to ONLY_RELEVANT)))) transactionRecovery.addSenderTransactionRecoveryMetadata(txn2.id, TransactionMetadata(ALICE_NAME, SenderDistributionList(ONLY_RELEVANT, mapOf(CHARLIE_NAME to ONLY_RELEVANT))))
assertEquals(2, transactionRecovery.querySenderDistributionRecords(timeWindow).size) assertEquals(2, transactionRecovery.querySenderDistributionRecords(timeWindow).size)
assertEquals(1, transactionRecovery.querySenderDistributionRecords(RecoveryTimeWindow(fromTime = afterFirstTxn)).size) assertEquals(1, transactionRecovery.querySenderDistributionRecords(RecoveryTimeWindow(fromTime = afterFirstTxn,
untilTime = afterFirstTxn.plus(1, ChronoUnit.MINUTES))).size)
} }
@Test(timeout = 300_000) @Test(timeout = 300_000)
@ -114,7 +123,7 @@ class DBTransactionStorageLedgerRecoveryTests {
val timeWindow = RecoveryTimeWindow(fromTime = now().minus(1, ChronoUnit.DAYS)) val timeWindow = RecoveryTimeWindow(fromTime = now().minus(1, ChronoUnit.DAYS))
val results = transactionRecovery.querySenderDistributionRecords(timeWindow, excludingTxnIds = setOf(transaction1.id)) val results = transactionRecovery.querySenderDistributionRecords(timeWindow, excludingTxnIds = setOf(transaction1.id))
assertEquals(1, results.size) assertEquals(1, results.size)
assertEquals(transaction2.id.toString(), results[0].compositeKey.txId) assertEquals(transaction2.id, results[0].txId)
} }
@Test(timeout = 300_000) @Test(timeout = 300_000)
@ -128,7 +137,7 @@ class DBTransactionStorageLedgerRecoveryTests {
val timeWindow = RecoveryTimeWindow(fromTime = now().minus(1, ChronoUnit.DAYS)) val timeWindow = RecoveryTimeWindow(fromTime = now().minus(1, ChronoUnit.DAYS))
val results = transactionRecovery.querySenderDistributionRecords(timeWindow, peers = setOf(CHARLIE_NAME)) val results = transactionRecovery.querySenderDistributionRecords(timeWindow, peers = setOf(CHARLIE_NAME))
assertEquals(1, results.size) assertEquals(1, results.size)
assertEquals(transaction2.id.toString(), results[0].compositeKey.txId) assertEquals(transaction2.id, results[0].txId)
} }
@Test(timeout = 300_000) @Test(timeout = 300_000)
@ -147,13 +156,13 @@ class DBTransactionStorageLedgerRecoveryTests {
val timeWindow = RecoveryTimeWindow(fromTime = now().minus(1, ChronoUnit.DAYS)) val timeWindow = RecoveryTimeWindow(fromTime = now().minus(1, ChronoUnit.DAYS))
transactionRecovery.queryDistributionRecords(timeWindow, recordType = DistributionRecordType.SENDER).let { transactionRecovery.queryDistributionRecords(timeWindow, recordType = DistributionRecordType.SENDER).let {
assertEquals(2, it.size) assertEquals(2, it.size)
assertEquals(SecureHash.sha256(BOB_NAME.toString()).toString(), it.senderRecords[0].compositeKey.peerPartyId) assertEquals(SecureHash.sha256(BOB_NAME.toString()), it.senderRecords[0].peerPartyId)
assertEquals(ALL_VISIBLE, it.senderRecords[0].senderStatesToRecord) assertEquals(ALL_VISIBLE, it.senderRecords[0].senderStatesToRecord)
} }
transactionRecovery.queryDistributionRecords(timeWindow, recordType = DistributionRecordType.RECEIVER).let { transactionRecovery.queryDistributionRecords(timeWindow, recordType = DistributionRecordType.RECEIVER).let {
assertEquals(1, it.size) assertEquals(1, it.size)
assertEquals(SecureHash.sha256(BOB_NAME.toString()).toString(), it.receiverRecords[0].compositeKey.peerPartyId) assertEquals(SecureHash.sha256(BOB_NAME.toString()), it.receiverRecords[0].peerPartyId)
assertEquals(ALL_VISIBLE, (HashedDistributionList.decrypt(it.receiverRecords[0].distributionList, encryptionService)).peerHashToStatesToRecord.map { it.value }[0]) assertEquals(ALL_VISIBLE, (HashedDistributionList.decrypt(it.receiverRecords[0].encryptedDistributionList.bytes, encryptionService)).peerHashToStatesToRecord.map { it.value }[0])
} }
val resultsAll = transactionRecovery.queryDistributionRecords(timeWindow, recordType = DistributionRecordType.ALL) val resultsAll = transactionRecovery.queryDistributionRecords(timeWindow, recordType = DistributionRecordType.ALL)
assertEquals(3, resultsAll.size) assertEquals(3, resultsAll.size)
@ -224,9 +233,9 @@ class DBTransactionStorageLedgerRecoveryTests {
val timeWindow = RecoveryTimeWindow(fromTime = now().minus(1, ChronoUnit.DAYS)) val timeWindow = RecoveryTimeWindow(fromTime = now().minus(1, ChronoUnit.DAYS))
transactionRecovery.queryReceiverDistributionRecords(timeWindow, initiators = setOf(ALICE_NAME)).let { transactionRecovery.queryReceiverDistributionRecords(timeWindow, initiators = setOf(ALICE_NAME)).let {
assertEquals(3, it.size) assertEquals(3, it.size)
assertEquals(HashedDistributionList.decrypt(it[0].distributionList, encryptionService).peerHashToStatesToRecord.map { it.value }[0], ALL_VISIBLE) assertEquals(HashedDistributionList.decrypt(it[0].encryptedDistributionList.bytes, encryptionService).peerHashToStatesToRecord.map { it.value }[0], ALL_VISIBLE)
assertEquals(HashedDistributionList.decrypt(it[1].distributionList, encryptionService).peerHashToStatesToRecord.map { it.value }[0], ONLY_RELEVANT) assertEquals(HashedDistributionList.decrypt(it[1].encryptedDistributionList.bytes, encryptionService).peerHashToStatesToRecord.map { it.value }[0], ONLY_RELEVANT)
assertEquals(HashedDistributionList.decrypt(it[2].distributionList, encryptionService).peerHashToStatesToRecord.map { it.value }[0], NONE) assertEquals(HashedDistributionList.decrypt(it[2].encryptedDistributionList.bytes, encryptionService).peerHashToStatesToRecord.map { it.value }[0], NONE)
} }
assertEquals(1, transactionRecovery.queryReceiverDistributionRecords(timeWindow, initiators = setOf(BOB_NAME)).size) assertEquals(1, transactionRecovery.queryReceiverDistributionRecords(timeWindow, initiators = setOf(BOB_NAME)).size)
assertEquals(1, transactionRecovery.queryReceiverDistributionRecords(timeWindow, initiators = setOf(CHARLIE_NAME)).size) assertEquals(1, transactionRecovery.queryReceiverDistributionRecords(timeWindow, initiators = setOf(CHARLIE_NAME)).size)
@ -266,7 +275,7 @@ class DBTransactionStorageLedgerRecoveryTests {
val distList = transactionRecovery.decryptHashedDistributionList(record.encryptedDistributionList.bytes) val distList = transactionRecovery.decryptHashedDistributionList(record.encryptedDistributionList.bytes)
assertEquals(ONLY_RELEVANT, distList.senderStatesToRecord) assertEquals(ONLY_RELEVANT, distList.senderStatesToRecord)
assertEquals(ALL_VISIBLE, distList.peerHashToStatesToRecord.values.first()) assertEquals(ALL_VISIBLE, distList.peerHashToStatesToRecord.values.first())
assertEquals(ALICE_NAME, partyInfoCache.getCordaX500NameByPartyId(record.initiatorPartyId)) assertEquals(ALICE_NAME, partyInfoCache.getCordaX500NameByPartyId(record.peerPartyId))
assertEquals(setOf(BOB_NAME), distList.peerHashToStatesToRecord.map { (peer) -> partyInfoCache.getCordaX500NameByPartyId(peer) }.toSet() ) assertEquals(setOf(BOB_NAME), distList.peerHashToStatesToRecord.map { (peer) -> partyInfoCache.getCordaX500NameByPartyId(peer) }.toSet() )
} }
} }
@ -364,7 +373,7 @@ class DBTransactionStorageLedgerRecoveryTests {
return fromDb[0].toReceiverDistributionRecord() return fromDb[0].toReceiverDistributionRecord()
} }
private fun newTransactionRecovery(cacheSizeBytesOverride: Long? = null, clock: CordaClock = SimpleClock(Clock.systemUTC())) { private fun newTransactionRecovery(cacheSizeBytesOverride: Long? = null, clock: CordaClock = TestClock(Clock.systemUTC())) {
val networkMapCache = PersistentNetworkMapCache(TestingNamedCacheFactory(), database, InMemoryIdentityService(trustRoot = DEV_ROOT_CA.certificate)) val networkMapCache = PersistentNetworkMapCache(TestingNamedCacheFactory(), database, InMemoryIdentityService(trustRoot = DEV_ROOT_CA.certificate))
val alice = createNodeInfo(listOf(ALICE)) val alice = createNodeInfo(listOf(ALICE))
val bob = createNodeInfo(listOf(BOB)) val bob = createNodeInfo(listOf(BOB))