diff --git a/.ci/api-current.txt b/.ci/api-current.txt index 10e6fcc297..1c01aec0a6 100644 --- a/.ci/api-current.txt +++ b/.ci/api-current.txt @@ -2542,14 +2542,36 @@ public class net.corda.core.flows.DataVendingFlow extends net.corda.core.flows.F public interface net.corda.core.flows.Destination ## @CordaSerializable -public final class net.corda.core.flows.DistributionList extends java.lang.Object +public abstract class net.corda.core.flows.DistributionList extends java.lang.Object + public (kotlin.jvm.internal.DefaultConstructorMarker) +## +@CordaSerializable +public static final class net.corda.core.flows.DistributionList$ReceiverDistributionList extends net.corda.core.flows.DistributionList + public (byte[], net.corda.core.node.StatesToRecord) + @NotNull + public final byte[] component1() + @NotNull + public final net.corda.core.node.StatesToRecord component2() + @NotNull + public final net.corda.core.flows.DistributionList$ReceiverDistributionList copy(byte[], net.corda.core.node.StatesToRecord) + public boolean equals(Object) + @NotNull + public final byte[] getOpaqueData() + @NotNull + public final net.corda.core.node.StatesToRecord getReceiverStatesToRecord() + public int hashCode() + @NotNull + public String toString() +## +@CordaSerializable +public static final class net.corda.core.flows.DistributionList$SenderDistributionList extends net.corda.core.flows.DistributionList public (net.corda.core.node.StatesToRecord, java.util.Map) @NotNull public final net.corda.core.node.StatesToRecord component1() @NotNull public final java.util.Map component2() @NotNull - public final net.corda.core.flows.DistributionList copy(net.corda.core.node.StatesToRecord, java.util.Map) + public final net.corda.core.flows.DistributionList$SenderDistributionList copy(net.corda.core.node.StatesToRecord, java.util.Map) public boolean equals(Object) @NotNull public final java.util.Map getPeersToStatesToRecord() diff --git a/core-tests/src/test/kotlin/net/corda/coretests/flows/FinalityFlowTests.kt b/core-tests/src/test/kotlin/net/corda/coretests/flows/FinalityFlowTests.kt index e65b2825c8..421e1581cf 100644 --- a/core-tests/src/test/kotlin/net/corda/coretests/flows/FinalityFlowTests.kt +++ b/core-tests/src/test/kotlin/net/corda/coretests/flows/FinalityFlowTests.kt @@ -50,7 +50,6 @@ import net.corda.finance.issuedBy import net.corda.finance.test.flows.CashIssueWithObserversFlow import net.corda.finance.test.flows.CashPaymentWithObserversFlow import net.corda.node.services.persistence.DBTransactionStorage -import net.corda.node.services.persistence.DBTransactionStorageLedgerRecovery import net.corda.node.services.persistence.DBTransactionStorageLedgerRecovery.DBReceiverDistributionRecord import net.corda.node.services.persistence.DBTransactionStorageLedgerRecovery.DBSenderDistributionRecord import net.corda.node.services.persistence.HashedDistributionList @@ -76,8 +75,8 @@ import net.corda.testing.node.internal.enclosedCordapp import net.corda.testing.node.internal.findCordapp import org.assertj.core.api.Assertions.assertThat import org.junit.After +import org.junit.Assert.assertNotNull import org.junit.Test -import org.junit.jupiter.api.assertThrows import java.sql.SQLException import java.util.Random import kotlin.test.assertEquals @@ -353,16 +352,19 @@ class FinalityFlowTests : WithFinality { assertThat(aliceNode.services.validatedTransactions.getTransaction(stx.id)).isNotNull assertThat(bobNode.services.validatedTransactions.getTransaction(stx.id)).isNotNull - getSenderRecoveryData(stx.id, aliceNode.database).apply { + val sdrs = getSenderRecoveryData(stx.id, aliceNode.database).apply { assertEquals(1, this.size) assertEquals(StatesToRecord.ALL_VISIBLE, this[0].statesToRecord) assertEquals(BOB_NAME.hashCode().toLong(), this[0].peerPartyId) } - getReceiverRecoveryData(stx.id, bobNode, aliceNode).let { (record, distList) -> - assertEquals(StatesToRecord.ONLY_RELEVANT, distList.senderStatesToRecord) - assertEquals(aliceNode.info.singleIdentity().name.hashCode().toLong(), record.initiatorPartyId) - assertEquals(mapOf(BOB_NAME.hashCode().toLong() to StatesToRecord.ALL_VISIBLE), distList.peerHashToStatesToRecord) + val rdr = getReceiverRecoveryData(stx.id, bobNode).apply { + assertNotNull(this) + val hashedDL = HashedDistributionList.decrypt(this!!.encryptedDistributionList.bytes, aliceNode.internals.encryptionService) + assertEquals(StatesToRecord.ONLY_RELEVANT, hashedDL.senderStatesToRecord) + assertEquals(aliceNode.info.singleIdentity().name.hashCode().toLong(), this.initiatorPartyId) + assertEquals(mapOf(BOB_NAME.hashCode().toLong() to StatesToRecord.ALL_VISIBLE), hashedDL.peerHashToStatesToRecord) } + validateSenderAndReceiverTimestamps(sdrs, rdr!!) } @Test(timeout=300_000) @@ -383,21 +385,25 @@ class FinalityFlowTests : WithFinality { assertThat(bobNode.services.validatedTransactions.getTransaction(stx.id)).isNotNull assertThat(charlieNode.services.validatedTransactions.getTransaction(stx.id)).isNotNull - getSenderRecoveryData(stx.id, aliceNode.database).apply { + val sdrs = getSenderRecoveryData(stx.id, aliceNode.database).apply { assertEquals(2, this.size) assertEquals(StatesToRecord.ONLY_RELEVANT, this[0].statesToRecord) assertEquals(BOB_NAME.hashCode().toLong(), this[0].peerPartyId) assertEquals(StatesToRecord.ALL_VISIBLE, this[1].statesToRecord) assertEquals(CHARLIE_NAME.hashCode().toLong(), this[1].peerPartyId) } - getReceiverRecoveryData(stx.id, bobNode, aliceNode).let { (record, distList) -> - assertEquals(aliceNode.info.singleIdentity().name.hashCode().toLong(), record.initiatorPartyId) + val rdr = getReceiverRecoveryData(stx.id, bobNode).apply { + assertNotNull(this) + val hashedDL = HashedDistributionList.decrypt(this!!.encryptedDistributionList.bytes, aliceNode.internals.encryptionService) + assertEquals(StatesToRecord.ONLY_RELEVANT, hashedDL.senderStatesToRecord) + assertEquals(aliceNode.info.singleIdentity().name.hashCode().toLong(), this.initiatorPartyId) // note: Charlie assertion here is using the hinted StatesToRecord value passed to it from Alice assertEquals(mapOf( BOB_NAME.hashCode().toLong() to StatesToRecord.ONLY_RELEVANT, CHARLIE_NAME.hashCode().toLong() to StatesToRecord.ALL_VISIBLE - ), distList.peerHashToStatesToRecord) + ), hashedDL.peerHashToStatesToRecord) } + validateSenderAndReceiverTimestamps(sdrs, rdr!!) // exercise the new FinalityFlow observerSessions constructor parameter val stx3 = aliceNode.startFlowAndRunNetwork(CashPaymentWithObserversFlow( @@ -410,9 +416,24 @@ class FinalityFlowTests : WithFinality { assertThat(bobNode.services.validatedTransactions.getTransaction(stx3.id)).isNotNull assertThat(charlieNode.services.validatedTransactions.getTransaction(stx3.id)).isNotNull - assertEquals(2, getSenderRecoveryData(stx3.id, aliceNode.database).size) - assertThat(getReceiverRecoveryData(stx3.id, bobNode, aliceNode)).isNotNull - assertThat(getReceiverRecoveryData(stx3.id, charlieNode, aliceNode)).isNotNull + val senderDistributionRecords = getSenderRecoveryData(stx3.id, aliceNode.database).apply { + assertEquals(2, this.size) + assertEquals(this[0].timestamp, this[1].timestamp) + } + getReceiverRecoveryData(stx3.id, bobNode).apply { + assertThat(this).isNotNull + assertEquals(senderDistributionRecords[0].timestamp, this!!.timestamp) + } + getReceiverRecoveryData(stx3.id, charlieNode).apply { + assertThat(this).isNotNull + assertEquals(senderDistributionRecords[0].timestamp, this!!.timestamp) + } + } + + private fun validateSenderAndReceiverTimestamps(sdrs: List, rdr: ReceiverDistributionRecord) { + sdrs.map { + assertEquals(it.timestamp, rdr.timestamp) + } } @Test(timeout=300_000) @@ -428,15 +449,19 @@ class FinalityFlowTests : WithFinality { assertThat(aliceNode.services.validatedTransactions.getTransaction(stx.id)).isNotNull assertThat(bobNode.services.validatedTransactions.getTransaction(stx.id)).isNotNull - getSenderRecoveryData(stx.id, aliceNode.database).apply { + val sdr = getSenderRecoveryData(stx.id, aliceNode.database).apply { assertEquals(1, this.size) assertEquals(StatesToRecord.ONLY_RELEVANT, this[0].statesToRecord) assertEquals(BOB_NAME.hashCode().toLong(), this[0].peerPartyId) } - getReceiverRecoveryData(stx.id, bobNode, aliceNode).let { (record, distList) -> - assertEquals(aliceNode.info.singleIdentity().name.hashCode().toLong(), record.initiatorPartyId) - assertEquals(mapOf(BOB_NAME.hashCode().toLong() to StatesToRecord.ONLY_RELEVANT), distList.peerHashToStatesToRecord) + val rdr = getReceiverRecoveryData(stx.id, bobNode).apply { + assertNotNull(this) + val hashedDL = HashedDistributionList.decrypt(this!!.encryptedDistributionList.bytes, aliceNode.internals.encryptionService) + assertEquals(StatesToRecord.ONLY_RELEVANT, hashedDL.senderStatesToRecord) + assertEquals(aliceNode.info.singleIdentity().name.hashCode().toLong(), this.initiatorPartyId) + assertEquals(mapOf(BOB_NAME.hashCode().toLong() to StatesToRecord.ONLY_RELEVANT), hashedDL.peerHashToStatesToRecord) } + validateSenderAndReceiverTimestamps(sdr, rdr!!) } private fun getSenderRecoveryData(id: SecureHash, database: CordaPersistence): List { @@ -446,31 +471,16 @@ class FinalityFlowTests : WithFinality { DBSenderDistributionRecord::class.java ).setParameter("transactionId", id.toString()).resultList } - return fromDb.map { it.toSenderDistributionRecord() }.also { println("SenderDistributionRecord\n$it") } + return fromDb.map { it.toSenderDistributionRecord() } } - private fun getReceiverRecoveryData(txId: SecureHash, - receiver: TestStartedNode, - sender: TestStartedNode): Pair { - val fromDb = receiver.database.transaction { + private fun getReceiverRecoveryData(txId: SecureHash, receiver: TestStartedNode): ReceiverDistributionRecord? { + return receiver.database.transaction { session.createQuery( "from ${DBReceiverDistributionRecord::class.java.name} where txId = :transactionId", DBReceiverDistributionRecord::class.java - ).setParameter("transactionId", txId.toString()).singleResult - } - - // The receiver should not be able to decrypt the distribution list - assertThrows { - receiver.decryptReceiverDistributionRecord(fromDb) - } - - // Only the sender can - return sender.decryptReceiverDistributionRecord(fromDb) - } - - private fun TestStartedNode.decryptReceiverDistributionRecord(dbRecord: DBReceiverDistributionRecord): Pair { - val hashedDistList = (internals.transactionStorage as DBTransactionStorageLedgerRecovery).decryptHashedDistributionList(dbRecord.distributionList) - return Pair(dbRecord.toReceiverDistributionRecord(), hashedDistList) + ).setParameter("transactionId", txId.toString()).resultList + }.singleOrNull()?.toReceiverDistributionRecord() } @StartableByRPC diff --git a/core/src/main/kotlin/net/corda/core/flows/FlowTransaction.kt b/core/src/main/kotlin/net/corda/core/flows/FlowTransaction.kt index 05539f7480..b213c6dbd0 100644 --- a/core/src/main/kotlin/net/corda/core/flows/FlowTransaction.kt +++ b/core/src/main/kotlin/net/corda/core/flows/FlowTransaction.kt @@ -23,15 +23,25 @@ data class FlowTransactionInfo( @CordaSerializable data class TransactionMetadata( - val initiator: CordaX500Name, - val distributionList: DistributionList + val initiator: CordaX500Name, + val distributionList: DistributionList ) @CordaSerializable -data class DistributionList( - val senderStatesToRecord: StatesToRecord, - val peersToStatesToRecord: Map -) +sealed class DistributionList { + + @CordaSerializable + data class SenderDistributionList( + val senderStatesToRecord: StatesToRecord, + val peersToStatesToRecord: Map + ) : DistributionList() + + @CordaSerializable + data class ReceiverDistributionList( + val opaqueData: ByteArray, // decipherable only by sender + val receiverStatesToRecord: StatesToRecord // inferred or actual + ) : DistributionList() +} @CordaSerializable enum class TransactionStatus { diff --git a/core/src/main/kotlin/net/corda/core/flows/ReceiveTransactionFlow.kt b/core/src/main/kotlin/net/corda/core/flows/ReceiveTransactionFlow.kt index c21485e041..7b45d972af 100644 --- a/core/src/main/kotlin/net/corda/core/flows/ReceiveTransactionFlow.kt +++ b/core/src/main/kotlin/net/corda/core/flows/ReceiveTransactionFlow.kt @@ -145,7 +145,8 @@ open class ReceiveTransactionFlow constructor(private val otherSideSession: Flow open fun resolvePayload(payload: Any): SignedTransaction { return if (payload is SignedTransactionWithDistributionList) { if (checkSufficientSignatures) { - (serviceHub as ServiceHubCoreInternal).recordReceiverTransactionRecoveryMetadata(payload.stx.id, otherSideSession.counterparty.name, ourIdentity.name, statesToRecord, payload.distributionList) + (serviceHub as ServiceHubCoreInternal).recordReceiverTransactionRecoveryMetadata(payload.stx.id, otherSideSession.counterparty.name, + TransactionMetadata(otherSideSession.counterparty.name, DistributionList.ReceiverDistributionList(payload.distributionList, statesToRecord))) payload.stx } else payload.stx } else payload as SignedTransaction diff --git a/core/src/main/kotlin/net/corda/core/flows/SendTransactionFlow.kt b/core/src/main/kotlin/net/corda/core/flows/SendTransactionFlow.kt index 1217a6a085..3323d2d743 100644 --- a/core/src/main/kotlin/net/corda/core/flows/SendTransactionFlow.kt +++ b/core/src/main/kotlin/net/corda/core/flows/SendTransactionFlow.kt @@ -16,6 +16,7 @@ import net.corda.core.transactions.SignedTransaction import net.corda.core.utilities.trace import net.corda.core.utilities.unwrap import kotlin.collections.toSet +import net.corda.core.flows.DistributionList.SenderDistributionList /** * In the words of Matt working code is more important then pretty code. This class that contains code that may @@ -91,9 +92,9 @@ open class SendTransactionFlow(val stx: SignedTransaction, fun makeMetaData(stx: SignedTransaction, recordMetaDataEvenIfNotFullySigned: Boolean, senderStatesToRecord: StatesToRecord, participantSessions: Set, observerSessions: Set): TransactionMetadata? { return if (recordMetaDataEvenIfNotFullySigned || isFullySigned(stx)) TransactionMetadata(DUMMY_PARTICIPANT_NAME, - DistributionList(senderStatesToRecord, - (participantSessions.map { it.counterparty.name to StatesToRecord.ONLY_RELEVANT}).toMap() + - (observerSessions.map { it.counterparty.name to StatesToRecord.ALL_VISIBLE}).toMap())) + SenderDistributionList(senderStatesToRecord, + (participantSessions.map { it.counterparty.name to StatesToRecord.ONLY_RELEVANT }).toMap() + + (observerSessions.map { it.counterparty.name to StatesToRecord.ALL_VISIBLE }).toMap())) else null } diff --git a/core/src/main/kotlin/net/corda/core/internal/ServiceHubCoreInternal.kt b/core/src/main/kotlin/net/corda/core/internal/ServiceHubCoreInternal.kt index 27f05c9f2f..d752eb3b15 100644 --- a/core/src/main/kotlin/net/corda/core/internal/ServiceHubCoreInternal.kt +++ b/core/src/main/kotlin/net/corda/core/internal/ServiceHubCoreInternal.kt @@ -76,15 +76,11 @@ interface ServiceHubCoreInternal : ServiceHub { * * @param txnId The SecureHash of a transaction. * @param sender The sender of the transaction. - * @param receiver The receiver of the transaction. - * @param receiverStatesToRecord The StatesToRecord value of the receiver. - * @param encryptedDistributionList encrypted distribution list (hashed peers -> StatesToRecord values) + * @param txnMetadata The recovery metadata associated with a transaction. */ fun recordReceiverTransactionRecoveryMetadata(txnId: SecureHash, sender: CordaX500Name, - receiver: CordaX500Name, - receiverStatesToRecord: StatesToRecord, - encryptedDistributionList: ByteArray) + txnMetadata: TransactionMetadata) } interface TransactionsResolver { diff --git a/node/src/main/kotlin/net/corda/node/services/api/ServiceHubInternal.kt b/node/src/main/kotlin/net/corda/node/services/api/ServiceHubInternal.kt index 962f7a0664..9ed76b15c8 100644 --- a/node/src/main/kotlin/net/corda/node/services/api/ServiceHubInternal.kt +++ b/node/src/main/kotlin/net/corda/node/services/api/ServiceHubInternal.kt @@ -198,8 +198,8 @@ interface ServiceHubInternal : ServiceHubCoreInternal { override fun recordSenderTransactionRecoveryMetadata(txnId: SecureHash, txnMetadata: TransactionMetadata) = validatedTransactions.addSenderTransactionRecoveryMetadata(txnId, txnMetadata) - override fun recordReceiverTransactionRecoveryMetadata(txnId: SecureHash, sender: CordaX500Name, receiver: CordaX500Name, receiverStatesToRecord: StatesToRecord, encryptedDistributionList: ByteArray) = - validatedTransactions.addReceiverTransactionRecoveryMetadata(txnId, sender, receiver, receiverStatesToRecord, encryptedDistributionList) + override fun recordReceiverTransactionRecoveryMetadata(txnId: SecureHash, sender: CordaX500Name, txnMetadata: TransactionMetadata) = + validatedTransactions.addReceiverTransactionRecoveryMetadata(txnId, sender, txnMetadata) @Suppress("NestedBlockDepth") @VisibleForTesting @@ -383,15 +383,11 @@ interface WritableTransactionStorage : TransactionStorage { * * @param txId The SecureHash of a transaction. * @param sender The sender of the transaction. - * @param receiver The receiver of the transaction. - * @param receiverStatesToRecord The StatesToRecord value of the receiver. - * @param encryptedDistributionList encrypted distribution list (hashed peers -> StatesToRecord values) + * @param metadata The recovery metadata associated with a transaction. */ fun addReceiverTransactionRecoveryMetadata(txId: SecureHash, sender: CordaX500Name, - receiver: CordaX500Name, - receiverStatesToRecord: StatesToRecord, - encryptedDistributionList: ByteArray) + metadata: TransactionMetadata) /** * Removes an un-notarised transaction (with a status of *MISSING_TRANSACTION_SIG*) from the data store. diff --git a/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorage.kt b/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorage.kt index 1973f9e7c1..6905d6f7c1 100644 --- a/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorage.kt +++ b/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorage.kt @@ -11,7 +11,6 @@ import net.corda.core.internal.VisibleForTesting import net.corda.core.internal.bufferUntilSubscribed import net.corda.core.internal.concurrent.doneFuture import net.corda.core.messaging.DataFeed -import net.corda.core.node.StatesToRecord import net.corda.core.serialization.SerializationContext import net.corda.core.serialization.SerializationDefaults import net.corda.core.serialization.SerializedBytes @@ -219,9 +218,8 @@ open class DBTransactionStorage(private val database: CordaPersistence, cacheFac override fun addReceiverTransactionRecoveryMetadata(txId: SecureHash, sender: CordaX500Name, - receiver: CordaX500Name, - receiverStatesToRecord: StatesToRecord, - encryptedDistributionList: ByteArray) { } + metadata: TransactionMetadata + ) { } override fun finalizeTransaction(transaction: SignedTransaction) = addTransaction(transaction) { diff --git a/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorageLedgerRecovery.kt b/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorageLedgerRecovery.kt index f828537ea9..69df1fe9b3 100644 --- a/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorageLedgerRecovery.kt +++ b/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorageLedgerRecovery.kt @@ -1,10 +1,13 @@ package net.corda.node.services.persistence import net.corda.core.crypto.SecureHash +import net.corda.core.flows.DistributionList.ReceiverDistributionList +import net.corda.core.flows.DistributionList.SenderDistributionList import net.corda.core.flows.RecoveryTimeWindow import net.corda.core.flows.TransactionMetadata import net.corda.core.identity.CordaX500Name import net.corda.core.internal.NamedCacheFactory +import net.corda.core.internal.VisibleForTesting import net.corda.core.node.StatesToRecord import net.corda.core.node.services.vault.Sort import net.corda.core.serialization.CordaSerializable @@ -85,14 +88,19 @@ class DBTransactionStorageLedgerRecovery(private val database: CordaPersistence, /** Encrypted recovery information for sole use by Sender **/ @Lob @Column(name = "distribution_list", nullable = false) - val distributionList: ByteArray - ) { - constructor(key: Key, txId: SecureHash, encryptedDistributionList: ByteArray) : this( - PersistentKey(key), - txId.toString(), - encryptedDistributionList - ) + val distributionList: ByteArray, + /** states to record: NONE, ALL_VISIBLE, ONLY_RELEVANT */ + @Column(name = "receiver_states_to_record", nullable = false) + val receiverStatesToRecord: StatesToRecord +) { + constructor(key: Key, txId: SecureHash, encryptedDistributionList: ByteArray, receiverStatesToRecord: StatesToRecord) : + this(PersistentKey(key), + txId = txId.toString(), + distributionList = encryptedDistributionList, + receiverStatesToRecord = receiverStatesToRecord + ) + @VisibleForTesting fun toReceiverDistributionRecord(): ReceiverDistributionRecord { return ReceiverDistributionRecord( SecureHash.parse(this.txId), @@ -130,27 +138,22 @@ class DBTransactionStorageLedgerRecovery(private val database: CordaPersistence, } override fun addSenderTransactionRecoveryMetadata(txId: SecureHash, metadata: TransactionMetadata): ByteArray { - val senderRecordingTimestamp = clock.instant() return database.transaction { - // sender distribution records must be unique per txnId and timestamp + val senderRecordingTimestamp = clock.instant() val timeDiscriminator = Key.nextDiscriminatorNumber.andIncrement - metadata.distributionList.peersToStatesToRecord.forEach { peerCordaX500Name, peerStatesToRecord -> + val distributionList = metadata.distributionList as? SenderDistributionList ?: throw IllegalStateException("Expecting SenderDistributionList") + distributionList.peersToStatesToRecord.map { (peerCordaX500Name, peerStatesToRecord) -> val senderDistributionRecord = DBSenderDistributionRecord( - PersistentKey(Key( - TimestampKey(senderRecordingTimestamp, timeDiscriminator), - partyInfoCache.getPartyIdByCordaX500Name(peerCordaX500Name) - )), + PersistentKey(Key(TimestampKey(senderRecordingTimestamp, timeDiscriminator), partyInfoCache.getPartyIdByCordaX500Name(peerCordaX500Name))), txId.toString(), - peerStatesToRecord - ) + peerStatesToRecord) session.save(senderDistributionRecord) } - - val hashedPeersToStatesToRecord = metadata.distributionList.peersToStatesToRecord.mapKeys { (peer) -> + val hashedPeersToStatesToRecord = distributionList.peersToStatesToRecord.mapKeys { (peer) -> partyInfoCache.getPartyIdByCordaX500Name(peer) } val hashedDistributionList = HashedDistributionList( - metadata.distributionList.senderStatesToRecord, + distributionList.senderStatesToRecord, hashedPeersToStatesToRecord, HashedDistributionList.PublicHeader(senderRecordingTimestamp) ) @@ -158,50 +161,24 @@ class DBTransactionStorageLedgerRecovery(private val database: CordaPersistence, } } - fun createReceiverTransactionRecoverMetadata(txId: SecureHash, - senderPartyId: Long, - senderStatesToRecord: StatesToRecord, - senderRecords: List): List { - val senderRecordsByTimestampKey = senderRecords.groupBy { TimestampKey(it.compositeKey.timestamp, it.compositeKey.timestampDiscriminator) } - return senderRecordsByTimestampKey.map { (key) -> - val hashedDistributionList = HashedDistributionList( - senderStatesToRecord, - senderRecords.associate { it.compositeKey.peerPartyId to it.statesToRecord }, - HashedDistributionList.PublicHeader(key.timestamp) - ) - DBReceiverDistributionRecord( - PersistentKey(Key(TimestampKey(key.timestamp, key.timestampDiscriminator), senderPartyId)), - txId.toString(), - hashedDistributionList.encrypt(encryptionService) - ) - } - } - - fun addSenderTransactionRecoveryMetadata(record: DBSenderDistributionRecord) { - return database.transaction { - session.save(record) - } - } - override fun addReceiverTransactionRecoveryMetadata(txId: SecureHash, sender: CordaX500Name, - receiver: CordaX500Name, - receiverStatesToRecord: StatesToRecord, - encryptedDistributionList: ByteArray) { - val publicHeader = HashedDistributionList.PublicHeader.unauthenticatedDeserialise(encryptedDistributionList, encryptionService) - database.transaction { - val receiverDistributionRecord = DBReceiverDistributionRecord( - Key(partyInfoCache.getPartyIdByCordaX500Name(sender), publicHeader.senderRecordedTimestamp), - txId, - encryptedDistributionList - ) - session.save(receiverDistributionRecord) - } - } - - fun addReceiverTransactionRecoveryMetadata(record: DBReceiverDistributionRecord) { - return database.transaction { - session.save(record) + metadata: TransactionMetadata) { + when (metadata.distributionList) { + is ReceiverDistributionList -> { + val distributionList = metadata.distributionList as ReceiverDistributionList + val publicHeader = HashedDistributionList.PublicHeader.unauthenticatedDeserialise(distributionList.opaqueData, encryptionService) + database.transaction { + val receiverDistributionRecord = DBReceiverDistributionRecord( + Key(partyInfoCache.getPartyIdByCordaX500Name(sender), publicHeader.senderRecordedTimestamp), + txId, + distributionList.opaqueData, + distributionList.receiverStatesToRecord + ) + session.save(receiverDistributionRecord) + } + } + else -> throw IllegalStateException("Expecting ReceiverDistributionList") } } @@ -278,16 +255,6 @@ class DBTransactionStorageLedgerRecovery(private val database: CordaPersistence, } } - fun querySenderDistributionRecordsByTxId(txId: SecureHash): List { - return database.transaction { - val criteriaBuilder = session.criteriaBuilder - val criteriaQuery = criteriaBuilder.createQuery(DBSenderDistributionRecord::class.java) - val txnMetadata = criteriaQuery.from(DBSenderDistributionRecord::class.java) - criteriaQuery.where(criteriaBuilder.equal(txnMetadata.get(DBSenderDistributionRecord::txId.name), txId.toString())) - session.createQuery(criteriaQuery).resultList - } - } - @Suppress("SpreadOperator") fun queryReceiverDistributionRecords(timeWindow: RecoveryTimeWindow, initiators: Set = emptySet(), diff --git a/node/src/main/resources/migration/node-core.changelog-v25.xml b/node/src/main/resources/migration/node-core.changelog-v25.xml index a199a65df8..9ea40bada9 100644 --- a/node/src/main/resources/migration/node-core.changelog-v25.xml +++ b/node/src/main/resources/migration/node-core.changelog-v25.xml @@ -52,6 +52,9 @@ + + + diff --git a/node/src/test/kotlin/net/corda/node/messaging/TwoPartyTradeFlowTests.kt b/node/src/test/kotlin/net/corda/node/messaging/TwoPartyTradeFlowTests.kt index 2cf36d35f9..451bc813a9 100644 --- a/node/src/test/kotlin/net/corda/node/messaging/TwoPartyTradeFlowTests.kt +++ b/node/src/test/kotlin/net/corda/node/messaging/TwoPartyTradeFlowTests.kt @@ -31,7 +31,6 @@ import net.corda.core.internal.concurrent.map import net.corda.core.internal.rootCause import net.corda.core.messaging.DataFeed import net.corda.core.messaging.StateMachineTransactionMapping -import net.corda.core.node.StatesToRecord import net.corda.core.node.services.Vault import net.corda.core.serialization.CordaSerializable import net.corda.core.serialization.SingletonSerializeAsToken @@ -819,11 +818,9 @@ class TwoPartyTradeFlowTests(private val anonymous: Boolean) { override fun addReceiverTransactionRecoveryMetadata(txId: SecureHash, sender: CordaX500Name, - receiver: CordaX500Name, - receiverStatesToRecord: StatesToRecord, - encryptedDistributionList: ByteArray) { + metadata: TransactionMetadata) { database.transaction { - delegate.addReceiverTransactionRecoveryMetadata(txId, sender, receiver, receiverStatesToRecord, encryptedDistributionList) + delegate.addReceiverTransactionRecoveryMetadata(txId, sender, metadata) } } diff --git a/node/src/test/kotlin/net/corda/node/services/persistence/DBTransactionStorageLedgerRecoveryTests.kt b/node/src/test/kotlin/net/corda/node/services/persistence/DBTransactionStorageLedgerRecoveryTests.kt index 352398e2ab..3c81be2ab8 100644 --- a/node/src/test/kotlin/net/corda/node/services/persistence/DBTransactionStorageLedgerRecoveryTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/persistence/DBTransactionStorageLedgerRecoveryTests.kt @@ -6,7 +6,8 @@ import net.corda.core.crypto.SecureHash import net.corda.core.crypto.SignableData import net.corda.core.crypto.SignatureMetadata import net.corda.core.crypto.sign -import net.corda.core.flows.DistributionList +import net.corda.core.flows.DistributionList.ReceiverDistributionList +import net.corda.core.flows.DistributionList.SenderDistributionList import net.corda.core.flows.RecoveryTimeWindow import net.corda.core.flows.TransactionMetadata import net.corda.core.node.NodeInfo @@ -88,7 +89,7 @@ class DBTransactionStorageLedgerRecoveryTests { val beforeFirstTxn = now() val txn = newTransaction() transactionRecovery.addUnnotarisedTransaction(txn) - transactionRecovery.addSenderTransactionRecoveryMetadata(txn.id, TransactionMetadata(ALICE_NAME, DistributionList(ALL_VISIBLE, mapOf(BOB_NAME to ONLY_RELEVANT)))) + transactionRecovery.addSenderTransactionRecoveryMetadata(txn.id, TransactionMetadata(ALICE_NAME, SenderDistributionList(ALL_VISIBLE, mapOf(BOB_NAME to ONLY_RELEVANT)))) val timeWindow = RecoveryTimeWindow(fromTime = beforeFirstTxn, untilTime = beforeFirstTxn.plus(1, ChronoUnit.MINUTES)) val results = transactionRecovery.querySenderDistributionRecords(timeWindow) @@ -97,7 +98,7 @@ class DBTransactionStorageLedgerRecoveryTests { val afterFirstTxn = now() val txn2 = newTransaction() transactionRecovery.addUnnotarisedTransaction(txn2) - transactionRecovery.addSenderTransactionRecoveryMetadata(txn2.id, TransactionMetadata(ALICE_NAME, DistributionList(ONLY_RELEVANT, mapOf(CHARLIE_NAME to ONLY_RELEVANT)))) + transactionRecovery.addSenderTransactionRecoveryMetadata(txn2.id, TransactionMetadata(ALICE_NAME, SenderDistributionList(ONLY_RELEVANT, mapOf(CHARLIE_NAME to ONLY_RELEVANT)))) assertEquals(2, transactionRecovery.querySenderDistributionRecords(timeWindow).size) assertEquals(1, transactionRecovery.querySenderDistributionRecords(RecoveryTimeWindow(fromTime = afterFirstTxn)).size) } @@ -106,10 +107,10 @@ class DBTransactionStorageLedgerRecoveryTests { fun `query local ledger for transactions within timeWindow and excluding remoteTransactionIds`() { val transaction1 = newTransaction() transactionRecovery.addUnnotarisedTransaction(transaction1) - transactionRecovery.addSenderTransactionRecoveryMetadata(transaction1.id, TransactionMetadata(ALICE_NAME, DistributionList(ALL_VISIBLE, mapOf(BOB_NAME to ONLY_RELEVANT)))) + transactionRecovery.addSenderTransactionRecoveryMetadata(transaction1.id, TransactionMetadata(ALICE_NAME, SenderDistributionList(ALL_VISIBLE, mapOf(BOB_NAME to ONLY_RELEVANT)))) val transaction2 = newTransaction() transactionRecovery.addUnnotarisedTransaction(transaction2) - transactionRecovery.addSenderTransactionRecoveryMetadata(transaction2.id, TransactionMetadata(ALICE_NAME, DistributionList(ALL_VISIBLE, mapOf(BOB_NAME to ONLY_RELEVANT)))) + transactionRecovery.addSenderTransactionRecoveryMetadata(transaction2.id, TransactionMetadata(ALICE_NAME, SenderDistributionList(ALL_VISIBLE, mapOf(BOB_NAME to ONLY_RELEVANT)))) val timeWindow = RecoveryTimeWindow(fromTime = now().minus(1, ChronoUnit.DAYS)) val results = transactionRecovery.querySenderDistributionRecords(timeWindow, excludingTxnIds = setOf(transaction1.id)) assertEquals(1, results.size) @@ -120,10 +121,10 @@ class DBTransactionStorageLedgerRecoveryTests { fun `query local ledger for transactions within timeWindow and for given peers`() { val transaction1 = newTransaction() transactionRecovery.addUnnotarisedTransaction(transaction1) - transactionRecovery.addSenderTransactionRecoveryMetadata(transaction1.id, TransactionMetadata(ALICE_NAME, DistributionList(ALL_VISIBLE, mapOf(BOB_NAME to ONLY_RELEVANT)))) + transactionRecovery.addSenderTransactionRecoveryMetadata(transaction1.id, TransactionMetadata(ALICE_NAME, SenderDistributionList(ALL_VISIBLE, mapOf(BOB_NAME to ONLY_RELEVANT)))) val transaction2 = newTransaction() transactionRecovery.addUnnotarisedTransaction(transaction2) - transactionRecovery.addSenderTransactionRecoveryMetadata(transaction2.id, TransactionMetadata(ALICE_NAME, DistributionList(ALL_VISIBLE, mapOf(CHARLIE_NAME to ONLY_RELEVANT)))) + transactionRecovery.addSenderTransactionRecoveryMetadata(transaction2.id, TransactionMetadata(ALICE_NAME, SenderDistributionList(ALL_VISIBLE, mapOf(CHARLIE_NAME to ONLY_RELEVANT)))) val timeWindow = RecoveryTimeWindow(fromTime = now().minus(1, ChronoUnit.DAYS)) val results = transactionRecovery.querySenderDistributionRecords(timeWindow, peers = setOf(CHARLIE_NAME)) assertEquals(1, results.size) @@ -135,44 +136,46 @@ class DBTransactionStorageLedgerRecoveryTests { val transaction1 = newTransaction() // sender txn transactionRecovery.addUnnotarisedTransaction(transaction1) - transactionRecovery.addSenderTransactionRecoveryMetadata(transaction1.id, TransactionMetadata(ALICE_NAME, DistributionList(ALL_VISIBLE, mapOf(BOB_NAME to ALL_VISIBLE)))) + transactionRecovery.addSenderTransactionRecoveryMetadata(transaction1.id, TransactionMetadata(ALICE_NAME, SenderDistributionList(ALL_VISIBLE, mapOf(BOB_NAME to ALL_VISIBLE)))) val transaction2 = newTransaction() // receiver txn transactionRecovery.addUnnotarisedTransaction(transaction2) - transactionRecovery.addReceiverTransactionRecoveryMetadata(transaction2.id, BOB_NAME, ALICE_NAME, ALL_VISIBLE, - DistributionList(ONLY_RELEVANT, mapOf(ALICE_NAME to ALL_VISIBLE)).encrypt()) + val encryptedDL = transactionRecovery.addSenderTransactionRecoveryMetadata(transaction2.id, + TransactionMetadata(BOB_NAME, SenderDistributionList(ONLY_RELEVANT, mapOf(ALICE_NAME to ALL_VISIBLE)))) + transactionRecovery.addReceiverTransactionRecoveryMetadata(transaction2.id, BOB_NAME, + TransactionMetadata(BOB_NAME, ReceiverDistributionList(encryptedDL, ALL_VISIBLE))) val timeWindow = RecoveryTimeWindow(fromTime = now().minus(1, ChronoUnit.DAYS)) transactionRecovery.queryDistributionRecords(timeWindow, recordType = DistributionRecordType.SENDER).let { - assertEquals(1, it.size) + assertEquals(2, it.size) assertEquals(BOB_NAME.hashCode().toLong(), it.senderRecords[0].compositeKey.peerPartyId) assertEquals(ALL_VISIBLE, it.senderRecords[0].statesToRecord) } transactionRecovery.queryDistributionRecords(timeWindow, recordType = DistributionRecordType.RECEIVER).let { assertEquals(1, it.size) assertEquals(BOB_NAME.hashCode().toLong(), it.receiverRecords[0].compositeKey.peerPartyId) - assertEquals(ALL_VISIBLE, transactionRecovery.decryptHashedDistributionList(it.receiverRecords[0].distributionList).peerHashToStatesToRecord.values.first()) + assertEquals(ALL_VISIBLE, (HashedDistributionList.decrypt(it.receiverRecords[0].distributionList, encryptionService)).peerHashToStatesToRecord.map { it.value }[0]) } val resultsAll = transactionRecovery.queryDistributionRecords(timeWindow, recordType = DistributionRecordType.ALL) - assertEquals(2, resultsAll.size) + assertEquals(3, resultsAll.size) } @Test(timeout = 300_000) fun `query for sender distribution records by peers`() { val txn1 = newTransaction() transactionRecovery.addUnnotarisedTransaction(txn1) - transactionRecovery.addSenderTransactionRecoveryMetadata(txn1.id, TransactionMetadata(ALICE_NAME, DistributionList(ALL_VISIBLE, mapOf(BOB_NAME to ALL_VISIBLE)))) + transactionRecovery.addSenderTransactionRecoveryMetadata(txn1.id, TransactionMetadata(ALICE_NAME, SenderDistributionList(ALL_VISIBLE, mapOf(BOB_NAME to ALL_VISIBLE)))) val txn2 = newTransaction() transactionRecovery.addUnnotarisedTransaction(txn2) - transactionRecovery.addSenderTransactionRecoveryMetadata(txn2.id, TransactionMetadata(ALICE_NAME, DistributionList(ONLY_RELEVANT, mapOf(CHARLIE_NAME to ONLY_RELEVANT)))) + transactionRecovery.addSenderTransactionRecoveryMetadata(txn2.id, TransactionMetadata(ALICE_NAME, SenderDistributionList(ONLY_RELEVANT, mapOf(CHARLIE_NAME to ONLY_RELEVANT)))) val txn3 = newTransaction() transactionRecovery.addUnnotarisedTransaction(txn3) - transactionRecovery.addSenderTransactionRecoveryMetadata(txn3.id, TransactionMetadata(ALICE_NAME, DistributionList(ONLY_RELEVANT, mapOf(BOB_NAME to ONLY_RELEVANT, CHARLIE_NAME to ALL_VISIBLE)))) + transactionRecovery.addSenderTransactionRecoveryMetadata(txn3.id, TransactionMetadata(ALICE_NAME, SenderDistributionList(ONLY_RELEVANT, mapOf(BOB_NAME to ONLY_RELEVANT, CHARLIE_NAME to ALL_VISIBLE)))) val txn4 = newTransaction() transactionRecovery.addUnnotarisedTransaction(txn4) - transactionRecovery.addSenderTransactionRecoveryMetadata(txn4.id, TransactionMetadata(BOB_NAME, DistributionList(ONLY_RELEVANT, mapOf(ALICE_NAME to ONLY_RELEVANT)))) + transactionRecovery.addSenderTransactionRecoveryMetadata(txn4.id, TransactionMetadata(BOB_NAME, SenderDistributionList(ONLY_RELEVANT, mapOf(ALICE_NAME to ONLY_RELEVANT)))) val txn5 = newTransaction() transactionRecovery.addUnnotarisedTransaction(txn5) - transactionRecovery.addSenderTransactionRecoveryMetadata(txn5.id, TransactionMetadata(CHARLIE_NAME, DistributionList(ONLY_RELEVANT, emptyMap()))) + transactionRecovery.addSenderTransactionRecoveryMetadata(txn5.id, TransactionMetadata(CHARLIE_NAME, SenderDistributionList(ONLY_RELEVANT, emptyMap()))) assertEquals(5, readSenderDistributionRecordFromDB().size) val timeWindow = RecoveryTimeWindow(fromTime = now().minus(1, ChronoUnit.DAYS)) @@ -189,31 +192,41 @@ class DBTransactionStorageLedgerRecoveryTests { fun `query for receiver distribution records by initiator`() { val txn1 = newTransaction() transactionRecovery.addUnnotarisedTransaction(txn1) - transactionRecovery.addReceiverTransactionRecoveryMetadata(txn1.id, ALICE_NAME, BOB_NAME, ALL_VISIBLE, - DistributionList(ONLY_RELEVANT, mapOf(BOB_NAME to ALL_VISIBLE, CHARLIE_NAME to ALL_VISIBLE)).encrypt()) + val encryptedDL1 = transactionRecovery.addSenderTransactionRecoveryMetadata(txn1.id, + TransactionMetadata(ALICE_NAME, SenderDistributionList(ONLY_RELEVANT, mapOf(BOB_NAME to ALL_VISIBLE, CHARLIE_NAME to ALL_VISIBLE)))) + transactionRecovery.addReceiverTransactionRecoveryMetadata(txn1.id, ALICE_NAME, + TransactionMetadata(ALICE_NAME, ReceiverDistributionList(encryptedDL1, ALL_VISIBLE))) val txn2 = newTransaction() transactionRecovery.addUnnotarisedTransaction(txn2) - transactionRecovery.addReceiverTransactionRecoveryMetadata(txn2.id, ALICE_NAME, BOB_NAME, ONLY_RELEVANT, - DistributionList(ONLY_RELEVANT, mapOf(BOB_NAME to ONLY_RELEVANT)).encrypt()) + val encryptedDL2 = transactionRecovery.addSenderTransactionRecoveryMetadata(txn2.id, + TransactionMetadata(ALICE_NAME, SenderDistributionList(ONLY_RELEVANT, mapOf(BOB_NAME to ONLY_RELEVANT)))) + transactionRecovery.addReceiverTransactionRecoveryMetadata(txn2.id, ALICE_NAME, + TransactionMetadata(ALICE_NAME, ReceiverDistributionList(encryptedDL2, ONLY_RELEVANT))) val txn3 = newTransaction() transactionRecovery.addUnnotarisedTransaction(txn3) - transactionRecovery.addReceiverTransactionRecoveryMetadata(txn3.id, ALICE_NAME, CHARLIE_NAME, NONE, - DistributionList(ONLY_RELEVANT, mapOf(CHARLIE_NAME to NONE)).encrypt()) + val encryptedDL3 = transactionRecovery.addSenderTransactionRecoveryMetadata(txn3.id, + TransactionMetadata(ALICE_NAME, SenderDistributionList(ONLY_RELEVANT, mapOf(CHARLIE_NAME to NONE)))) + transactionRecovery.addReceiverTransactionRecoveryMetadata(txn3.id, ALICE_NAME, + TransactionMetadata(ALICE_NAME, ReceiverDistributionList(encryptedDL3, NONE))) val txn4 = newTransaction() transactionRecovery.addUnnotarisedTransaction(txn4) - transactionRecovery.addReceiverTransactionRecoveryMetadata(txn4.id, BOB_NAME, ALICE_NAME, ONLY_RELEVANT, - DistributionList(ONLY_RELEVANT, mapOf(ALICE_NAME to ALL_VISIBLE)).encrypt()) + val encryptedDL4 = transactionRecovery.addSenderTransactionRecoveryMetadata(txn4.id, + TransactionMetadata(BOB_NAME, SenderDistributionList(ONLY_RELEVANT, mapOf(ALICE_NAME to ALL_VISIBLE)))) + transactionRecovery.addReceiverTransactionRecoveryMetadata(txn4.id, BOB_NAME, + TransactionMetadata(BOB_NAME, ReceiverDistributionList(encryptedDL4, ALL_VISIBLE))) val txn5 = newTransaction() transactionRecovery.addUnnotarisedTransaction(txn5) - transactionRecovery.addReceiverTransactionRecoveryMetadata(txn5.id, CHARLIE_NAME, BOB_NAME, ONLY_RELEVANT, - DistributionList(ONLY_RELEVANT, mapOf(BOB_NAME to ONLY_RELEVANT)).encrypt()) + val encryptedDL5 = transactionRecovery.addSenderTransactionRecoveryMetadata(txn5.id, + TransactionMetadata(CHARLIE_NAME, SenderDistributionList(ONLY_RELEVANT, mapOf(BOB_NAME to ONLY_RELEVANT)))) + transactionRecovery.addReceiverTransactionRecoveryMetadata(txn5.id, CHARLIE_NAME, + TransactionMetadata(CHARLIE_NAME, ReceiverDistributionList(encryptedDL5, ONLY_RELEVANT))) val timeWindow = RecoveryTimeWindow(fromTime = now().minus(1, ChronoUnit.DAYS)) transactionRecovery.queryReceiverDistributionRecords(timeWindow, initiators = setOf(ALICE_NAME)).let { assertEquals(3, it.size) - assertEquals(transactionRecovery.decryptHashedDistributionList(it[0].distributionList).peerHashToStatesToRecord.values.first(), ALL_VISIBLE) - assertEquals(transactionRecovery.decryptHashedDistributionList(it[1].distributionList).peerHashToStatesToRecord.values.first(), ONLY_RELEVANT) - assertEquals(transactionRecovery.decryptHashedDistributionList(it[2].distributionList).peerHashToStatesToRecord.values.first(), NONE) + assertEquals(HashedDistributionList.decrypt(it[0].distributionList, encryptionService).peerHashToStatesToRecord.map { it.value }[0], ALL_VISIBLE) + assertEquals(HashedDistributionList.decrypt(it[1].distributionList, encryptionService).peerHashToStatesToRecord.map { it.value }[0], ONLY_RELEVANT) + assertEquals(HashedDistributionList.decrypt(it[2].distributionList, encryptionService).peerHashToStatesToRecord.map { it.value }[0], NONE) } assertEquals(1, transactionRecovery.queryReceiverDistributionRecords(timeWindow, initiators = setOf(BOB_NAME)).size) assertEquals(1, transactionRecovery.queryReceiverDistributionRecords(timeWindow, initiators = setOf(CHARLIE_NAME)).size) @@ -224,7 +237,7 @@ class DBTransactionStorageLedgerRecoveryTests { fun `transaction without peers does not store recovery metadata in database`() { val senderTransaction = newTransaction() transactionRecovery.addUnnotarisedTransaction(senderTransaction) - transactionRecovery.addSenderTransactionRecoveryMetadata(senderTransaction.id, TransactionMetadata(ALICE_NAME, DistributionList(ONLY_RELEVANT, emptyMap()))) + transactionRecovery.addSenderTransactionRecoveryMetadata(senderTransaction.id, TransactionMetadata(ALICE_NAME, SenderDistributionList(ONLY_RELEVANT, emptyMap()))) assertEquals(IN_FLIGHT, readTransactionFromDB(senderTransaction.id).status) assertEquals(0, readSenderDistributionRecordFromDB(senderTransaction.id).size) } @@ -234,7 +247,7 @@ class DBTransactionStorageLedgerRecoveryTests { val senderTransaction = newTransaction() transactionRecovery.addUnnotarisedTransaction(senderTransaction) transactionRecovery.addSenderTransactionRecoveryMetadata(senderTransaction.id, - TransactionMetadata(ALICE_NAME, DistributionList(ALL_VISIBLE, mapOf(BOB_NAME to ALL_VISIBLE)))) + TransactionMetadata(ALICE_NAME, SenderDistributionList(ALL_VISIBLE, mapOf(BOB_NAME to ALL_VISIBLE)))) assertEquals(IN_FLIGHT, readTransactionFromDB(senderTransaction.id).status) readSenderDistributionRecordFromDB(senderTransaction.id).let { assertEquals(1, it.size) @@ -244,8 +257,10 @@ class DBTransactionStorageLedgerRecoveryTests { val receiverTransaction = newTransaction() transactionRecovery.addUnnotarisedTransaction(receiverTransaction) - transactionRecovery.addReceiverTransactionRecoveryMetadata(receiverTransaction.id, ALICE_NAME, BOB_NAME, ALL_VISIBLE, - DistributionList(ONLY_RELEVANT, mapOf(BOB_NAME to ALL_VISIBLE)).encrypt()) + val encryptedDL = transactionRecovery.addSenderTransactionRecoveryMetadata(receiverTransaction.id, + TransactionMetadata(ALICE_NAME, SenderDistributionList(ONLY_RELEVANT, mapOf(BOB_NAME to ALL_VISIBLE)))) + transactionRecovery.addReceiverTransactionRecoveryMetadata(receiverTransaction.id, ALICE_NAME, + TransactionMetadata(ALICE_NAME, ReceiverDistributionList(encryptedDL, ALL_VISIBLE))) assertEquals(IN_FLIGHT, readTransactionFromDB(receiverTransaction.id).status) readReceiverDistributionRecordFromDB(receiverTransaction.id).let { record -> val distList = transactionRecovery.decryptHashedDistributionList(record.encryptedDistributionList.bytes) @@ -261,7 +276,7 @@ class DBTransactionStorageLedgerRecoveryTests { val transaction = newTransaction(notarySig = false) transactionRecovery.finalizeTransaction(transaction) transactionRecovery.addSenderTransactionRecoveryMetadata(transaction.id, - TransactionMetadata(ALICE_NAME, DistributionList(ONLY_RELEVANT, mapOf(CHARLIE_NAME to ALL_VISIBLE)))) + TransactionMetadata(ALICE_NAME, SenderDistributionList(ONLY_RELEVANT, mapOf(CHARLIE_NAME to ALL_VISIBLE)))) assertEquals(VERIFIED, readTransactionFromDB(transaction.id).status) readSenderDistributionRecordFromDB(transaction.id).apply { assertEquals(1, this.size) @@ -273,8 +288,10 @@ class DBTransactionStorageLedgerRecoveryTests { fun `remove un-notarised transaction and associated recovery metadata`() { val senderTransaction = newTransaction(notarySig = false) transactionRecovery.addUnnotarisedTransaction(senderTransaction) - transactionRecovery.addReceiverTransactionRecoveryMetadata(senderTransaction.id, ALICE.name, BOB.name, ONLY_RELEVANT, - DistributionList(ONLY_RELEVANT, mapOf(BOB.name to ONLY_RELEVANT, CHARLIE_NAME to ONLY_RELEVANT)).encrypt()) + val encryptedDL1 = transactionRecovery.addSenderTransactionRecoveryMetadata(senderTransaction.id, + TransactionMetadata(ALICE.name, SenderDistributionList(ONLY_RELEVANT, mapOf(BOB.name to ONLY_RELEVANT, CHARLIE_NAME to ONLY_RELEVANT)))) + transactionRecovery.addReceiverTransactionRecoveryMetadata(senderTransaction.id, BOB.name, + TransactionMetadata(ALICE.name, ReceiverDistributionList(encryptedDL1, ONLY_RELEVANT))) assertNull(transactionRecovery.getTransaction(senderTransaction.id)) assertEquals(IN_FLIGHT, readTransactionFromDB(senderTransaction.id).status) @@ -285,8 +302,10 @@ class DBTransactionStorageLedgerRecoveryTests { val receiverTransaction = newTransaction(notarySig = false) transactionRecovery.addUnnotarisedTransaction(receiverTransaction) - transactionRecovery.addReceiverTransactionRecoveryMetadata(receiverTransaction.id, ALICE.name, BOB.name, ONLY_RELEVANT, - DistributionList(ONLY_RELEVANT, mapOf(BOB.name to ONLY_RELEVANT)).encrypt()) + val encryptedDL2 = transactionRecovery.addSenderTransactionRecoveryMetadata(receiverTransaction.id, + TransactionMetadata(ALICE.name, SenderDistributionList(ONLY_RELEVANT, mapOf(BOB.name to ONLY_RELEVANT)))) + transactionRecovery.addReceiverTransactionRecoveryMetadata(receiverTransaction.id, BOB.name, + TransactionMetadata(ALICE.name, ReceiverDistributionList(encryptedDL2, ONLY_RELEVANT))) assertNull(transactionRecovery.getTransaction(receiverTransaction.id)) assertEquals(IN_FLIGHT, readTransactionFromDB(receiverTransaction.id).status) @@ -396,14 +415,5 @@ class DBTransactionStorageLedgerRecoveryTests { private fun notarySig(txId: SecureHash) = DUMMY_NOTARY.keyPair.sign(SignableData(txId, SignatureMetadata(1, Crypto.findSignatureScheme(DUMMY_NOTARY.publicKey).schemeNumberID))) - - private fun DistributionList.encrypt(): ByteArray { - val hashedPeersToStatesToRecord = this.peersToStatesToRecord.mapKeys { (peer) -> partyInfoCache.getPartyIdByCordaX500Name(peer) } - val hashedDistributionList = HashedDistributionList( - this.senderStatesToRecord, - hashedPeersToStatesToRecord, - HashedDistributionList.PublicHeader(now()) - ) - return hashedDistributionList.encrypt(encryptionService) - } } + diff --git a/testing/cordapps/cashobservers/src/main/kotlin/net/corda/finance/test/flows/CashIssueWithObserversFlow.kt b/testing/cordapps/cashobservers/src/main/kotlin/net/corda/finance/test/flows/CashIssueWithObserversFlow.kt index c860617078..282b49c3cb 100644 --- a/testing/cordapps/cashobservers/src/main/kotlin/net/corda/finance/test/flows/CashIssueWithObserversFlow.kt +++ b/testing/cordapps/cashobservers/src/main/kotlin/net/corda/finance/test/flows/CashIssueWithObserversFlow.kt @@ -42,9 +42,9 @@ class CashIssueWithObserversFlow(private val amount: Amount, } @Suspendable - private fun finalise(tx: SignedTransaction, sessions: Collection, message: String): SignedTransaction { + private fun finalise(tx: SignedTransaction, observerSessions: Collection, message: String): SignedTransaction { try { - return subFlow(FinalityFlow(tx, sessions)) + return subFlow(FinalityFlow(tx, sessions = emptySet(), observerSessions = observerSessions)) } catch (e: NotaryException) { throw CashException(message, e) } diff --git a/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/MockTransactionStorage.kt b/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/MockTransactionStorage.kt index 9f23bf6beb..0a52c09f56 100644 --- a/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/MockTransactionStorage.kt +++ b/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/MockTransactionStorage.kt @@ -12,7 +12,6 @@ import net.corda.node.services.api.WritableTransactionStorage import net.corda.core.flows.TransactionMetadata import net.corda.core.flows.TransactionStatus import net.corda.core.identity.CordaX500Name -import net.corda.core.node.StatesToRecord import net.corda.testing.node.MockServices import rx.Observable import rx.subjects.PublishSubject @@ -65,9 +64,7 @@ open class MockTransactionStorage : WritableTransactionStorage, SingletonSeriali override fun addReceiverTransactionRecoveryMetadata(txId: SecureHash, sender: CordaX500Name, - receiver: CordaX500Name, - receiverStatesToRecord: StatesToRecord, - encryptedDistributionList: ByteArray) { } + metadata: TransactionMetadata) { } override fun removeUnnotarisedTransaction(id: SecureHash): Boolean { return txns.remove(id) != null diff --git a/testing/test-utils/src/main/kotlin/net/corda/testing/dsl/TestDSL.kt b/testing/test-utils/src/main/kotlin/net/corda/testing/dsl/TestDSL.kt index fd55d3645d..b460d02f30 100644 --- a/testing/test-utils/src/main/kotlin/net/corda/testing/dsl/TestDSL.kt +++ b/testing/test-utils/src/main/kotlin/net/corda/testing/dsl/TestDSL.kt @@ -150,7 +150,7 @@ data class TestTransactionDSLInterpreter private constructor( override fun recordSenderTransactionRecoveryMetadata(txnId: SecureHash, txnMetadata: TransactionMetadata): ByteArray? { return null } - override fun recordReceiverTransactionRecoveryMetadata(txnId: SecureHash, sender: CordaX500Name, receiver: CordaX500Name, receiverStatesToRecord: StatesToRecord, encryptedDistributionList: ByteArray) {} + override fun recordReceiverTransactionRecoveryMetadata(txnId: SecureHash, sender: CordaX500Name, txnMetadata: TransactionMetadata) {} } private fun copy(): TestTransactionDSLInterpreter =