From de67ab7377451e16b40136a5aa76f17454af94f7 Mon Sep 17 00:00:00 2001 From: Shams Asari Date: Tue, 25 Jul 2023 11:58:32 +0100 Subject: [PATCH 01/10] ENT-9876: Encrypting the ledger recovery participant distribution list --- .../coretests/flows/FinalityFlowTests.kt | 37 +++-- .../core/internal/ServiceHubCoreInternal.kt | 6 +- .../nodeapi/internal/crypto/AesEncryption.kt | 65 ++++++++ .../internal/crypto/AesEncryptionTest.kt | 73 +++++++++ .../net/corda/node/internal/AbstractNode.kt | 5 +- .../corda/node/services/EncryptionService.kt | 42 +++++ .../node/services/api/ServiceHubInternal.kt | 12 +- .../persistence/AesDbEncryptionService.kt | 152 ++++++++++++++++++ .../persistence/DBTransactionStorage.kt | 8 +- .../DBTransactionStorageLedgerRecovery.kt | 142 ++++++---------- .../persistence/HashedDistributionList.kt | 104 ++++++++++++ .../node/services/schema/NodeSchemaService.kt | 6 +- .../migration/node-core.changelog-master.xml | 1 + .../migration/node-core.changelog-v26.xml | 28 ++++ .../node/messaging/TwoPartyTradeFlowTests.kt | 12 +- .../persistence/AesDbEncryptionServiceTest.kt | 134 +++++++++++++++ ...DBTransactionStorageLedgerRecoveryTests.kt | 54 ++++--- .../node/internal/MockEncryptionService.kt | 39 +++++ .../node/internal/MockTransactionStorage.kt | 8 +- 19 files changed, 785 insertions(+), 143 deletions(-) create mode 100644 node-api/src/main/kotlin/net/corda/nodeapi/internal/crypto/AesEncryption.kt create mode 100644 node-api/src/test/kotlin/net/corda/nodeapi/internal/crypto/AesEncryptionTest.kt create mode 100644 node/src/main/kotlin/net/corda/node/services/EncryptionService.kt create mode 100644 node/src/main/kotlin/net/corda/node/services/persistence/AesDbEncryptionService.kt create mode 100644 node/src/main/kotlin/net/corda/node/services/persistence/HashedDistributionList.kt create mode 100644 node/src/main/resources/migration/node-core.changelog-v26.xml create mode 100644 node/src/test/kotlin/net/corda/node/services/persistence/AesDbEncryptionServiceTest.kt create mode 100644 testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/MockEncryptionService.kt diff --git a/core-tests/src/test/kotlin/net/corda/coretests/flows/FinalityFlowTests.kt b/core-tests/src/test/kotlin/net/corda/coretests/flows/FinalityFlowTests.kt index c120a1b620..1fa3b1516d 100644 --- a/core-tests/src/test/kotlin/net/corda/coretests/flows/FinalityFlowTests.kt +++ b/core-tests/src/test/kotlin/net/corda/coretests/flows/FinalityFlowTests.kt @@ -66,7 +66,6 @@ import net.corda.testing.node.internal.FINANCE_WORKFLOWS_CORDAPP import net.corda.testing.node.internal.InternalMockNetwork import net.corda.testing.node.internal.InternalMockNodeParameters import net.corda.testing.node.internal.MOCK_VERSION_INFO -import net.corda.testing.node.internal.MockCryptoService import net.corda.testing.node.internal.TestCordappInternal import net.corda.testing.node.internal.TestStartedNode import net.corda.testing.node.internal.cordappWithPackages @@ -75,6 +74,7 @@ import net.corda.testing.node.internal.findCordapp import org.assertj.core.api.Assertions.assertThat import org.junit.After import org.junit.Test +import org.junit.jupiter.api.assertThrows import java.sql.SQLException import java.util.Random import kotlin.test.assertEquals @@ -239,9 +239,9 @@ class FinalityFlowTests : WithFinality { private fun assertTxnRemovedFromDatabase(node: TestStartedNode, stxId: SecureHash) { val fromDb = node.database.transaction { session.createQuery( - "from ${DBTransactionStorage.DBTransaction::class.java.name} where tx_id = :transactionId", + "from ${DBTransactionStorage.DBTransaction::class.java.name} where txId = :transactionId", DBTransactionStorage.DBTransaction::class.java - ).setParameter("transactionId", stxId.toString()).resultList.map { it } + ).setParameter("transactionId", stxId.toString()).resultList } assertEquals(0, fromDb.size) } @@ -357,7 +357,7 @@ class FinalityFlowTests : WithFinality { assertEquals(StatesToRecord.ONLY_RELEVANT, this[0].statesToRecord) assertEquals(BOB_NAME.hashCode().toLong(), this[0].peerPartyId) } - getReceiverRecoveryData(stx.id, bobNode.database).apply { + getReceiverRecoveryData(stx.id, bobNode, aliceNode).apply { assertEquals(StatesToRecord.ALL_VISIBLE, this?.statesToRecord) assertEquals(StatesToRecord.ONLY_RELEVANT, this?.senderStatesToRecord) assertEquals(aliceNode.info.singleIdentity().name.hashCode().toLong(), this?.initiatorPartyId) @@ -390,7 +390,7 @@ class FinalityFlowTests : WithFinality { assertEquals(StatesToRecord.ONLY_RELEVANT, this[1].statesToRecord) assertEquals(CHARLIE_NAME.hashCode().toLong(), this[1].peerPartyId) } - getReceiverRecoveryData(stx.id, bobNode.database).apply { + getReceiverRecoveryData(stx.id, bobNode, aliceNode).apply { assertEquals(StatesToRecord.ONLY_RELEVANT, this?.statesToRecord) assertEquals(StatesToRecord.ONLY_RELEVANT, this?.senderStatesToRecord) assertEquals(aliceNode.info.singleIdentity().name.hashCode().toLong(), this?.initiatorPartyId) @@ -411,8 +411,8 @@ class FinalityFlowTests : WithFinality { assertThat(charlieNode.services.validatedTransactions.getTransaction(stx3.id)).isNotNull assertEquals(2, getSenderRecoveryData(stx3.id, aliceNode.database).size) - assertThat(getReceiverRecoveryData(stx3.id, bobNode.database)).isNotNull - assertThat(getReceiverRecoveryData(stx3.id, charlieNode.database)).isNotNull + assertThat(getReceiverRecoveryData(stx3.id, bobNode, aliceNode)).isNotNull + assertThat(getReceiverRecoveryData(stx3.id, charlieNode, aliceNode)).isNotNull } @Test(timeout=300_000) @@ -433,7 +433,7 @@ class FinalityFlowTests : WithFinality { assertEquals(StatesToRecord.ONLY_RELEVANT, this[0].statesToRecord) assertEquals(BOB_NAME.hashCode().toLong(), this[0].peerPartyId) } - getReceiverRecoveryData(stx.id, bobNode.database).apply { + getReceiverRecoveryData(stx.id, bobNode, aliceNode).apply { assertEquals(StatesToRecord.ONLY_RELEVANT, this?.statesToRecord) assertEquals(StatesToRecord.ONLY_RELEVANT, this?.senderStatesToRecord) assertEquals(aliceNode.info.singleIdentity().name.hashCode().toLong(), this?.initiatorPartyId) @@ -444,21 +444,28 @@ class FinalityFlowTests : WithFinality { private fun getSenderRecoveryData(id: SecureHash, database: CordaPersistence): List { val fromDb = database.transaction { session.createQuery( - "from ${DBTransactionStorageLedgerRecovery.DBSenderDistributionRecord::class.java.name} where tx_id = :transactionId", + "from ${DBTransactionStorageLedgerRecovery.DBSenderDistributionRecord::class.java.name} where txId = :transactionId", DBTransactionStorageLedgerRecovery.DBSenderDistributionRecord::class.java - ).setParameter("transactionId", id.toString()).resultList.map { it } + ).setParameter("transactionId", id.toString()).resultList } return fromDb.map { it.toSenderDistributionRecord() }.also { println("SenderDistributionRecord\n$it") } } - private fun getReceiverRecoveryData(id: SecureHash, database: CordaPersistence): ReceiverDistributionRecord? { - val fromDb = database.transaction { + private fun getReceiverRecoveryData(txId: SecureHash, receiver: TestStartedNode, sender: TestStartedNode): ReceiverDistributionRecord? { + val fromDb = receiver.database.transaction { session.createQuery( - "from ${DBTransactionStorageLedgerRecovery.DBReceiverDistributionRecord::class.java.name} where tx_id = :transactionId", + "from ${DBTransactionStorageLedgerRecovery.DBReceiverDistributionRecord::class.java.name} where txId = :transactionId", DBTransactionStorageLedgerRecovery.DBReceiverDistributionRecord::class.java - ).setParameter("transactionId", id.toString()).resultList.map { it } + ).setParameter("transactionId", txId.toString()).resultList + }.singleOrNull() + + // The receiver should not be able to decrypt the distribution list + assertThrows { + fromDb?.toReceiverDistributionRecord(receiver.internals.encryptionService) } - return fromDb.singleOrNull()?.toReceiverDistributionRecord(MockCryptoService(emptyMap())).also { println("ReceiverDistributionRecord\n$it") } + + // Only the sender can + return fromDb?.toReceiverDistributionRecord(sender.internals.encryptionService) } @StartableByRPC diff --git a/core/src/main/kotlin/net/corda/core/internal/ServiceHubCoreInternal.kt b/core/src/main/kotlin/net/corda/core/internal/ServiceHubCoreInternal.kt index f5febbad97..a880ecb152 100644 --- a/core/src/main/kotlin/net/corda/core/internal/ServiceHubCoreInternal.kt +++ b/core/src/main/kotlin/net/corda/core/internal/ServiceHubCoreInternal.kt @@ -82,7 +82,11 @@ interface ServiceHubCoreInternal : ServiceHub { * @param receiverStatesToRecord The StatesToRecord value of the receiver. * @param encryptedDistributionList encrypted distribution list (hashed peers -> StatesToRecord values) */ - fun recordReceiverTransactionRecoveryMetadata(txnId: SecureHash, sender: CordaX500Name, receiver: CordaX500Name, receiverStatesToRecord: StatesToRecord, encryptedDistributionList: ByteArray) + fun recordReceiverTransactionRecoveryMetadata(txnId: SecureHash, + sender: CordaX500Name, + receiver: CordaX500Name, + receiverStatesToRecord: StatesToRecord, + encryptedDistributionList: ByteArray) } interface TransactionsResolver { diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/crypto/AesEncryption.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/crypto/AesEncryption.kt new file mode 100644 index 0000000000..f9b36ffd07 --- /dev/null +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/crypto/AesEncryption.kt @@ -0,0 +1,65 @@ +package net.corda.nodeapi.internal.crypto + +import net.corda.core.crypto.secureRandomBytes +import java.nio.ByteBuffer +import javax.crypto.Cipher +import javax.crypto.SecretKey +import javax.crypto.spec.GCMParameterSpec +import javax.crypto.spec.SecretKeySpec + +object AesEncryption { + const val KEY_SIZE_BYTES = 16 + internal const val IV_SIZE_BYTES = 12 + private const val TAG_SIZE_BYTES = 16 + private const val TAG_SIZE_BITS = TAG_SIZE_BYTES * 8 + + /** + * Generates a random 128-bit AES key. + */ + fun randomKey(): SecretKey { + return SecretKeySpec(secureRandomBytes(KEY_SIZE_BYTES), "AES") + } + + /** + * Encrypt the given [plaintext] with AES using the given [aesKey]. + * + * An optional public [additionalData] bytes can also be provided which will be authenticated alongside the ciphertext but not encrypted. + * This may be metadata for example. The same authenticated data bytes must be provided to [decrypt] to be able to decrypt the + * ciphertext. Typically these bytes are serialised alongside the ciphertext. Since it's authenticated in the ciphertext, it cannot be + * modified undetected. + */ + fun encrypt(aesKey: SecretKey, plaintext: ByteArray, additionalData: ByteArray? = null): ByteArray { + val cipher = Cipher.getInstance("AES/GCM/NoPadding") + val iv = secureRandomBytes(IV_SIZE_BYTES) // Never use the same IV with the same key! + cipher.init(Cipher.ENCRYPT_MODE, aesKey, GCMParameterSpec(TAG_SIZE_BITS, iv)) + val buffer = ByteBuffer.allocate(IV_SIZE_BYTES + plaintext.size + TAG_SIZE_BYTES) + buffer.put(iv) + if (additionalData != null) { + cipher.updateAAD(additionalData) + } + cipher.doFinal(ByteBuffer.wrap(plaintext), buffer) + return buffer.array() + } + + fun encrypt(aesKey: ByteArray, plaintext: ByteArray, additionalData: ByteArray? = null): ByteArray { + return encrypt(SecretKeySpec(aesKey, "AES"), plaintext, additionalData) + } + + /** + * Decrypt ciphertext that was encrypted with the same key using [encrypt]. + * + * If additional data was used for the encryption then it must also be provided. If doesn't match then the decryption will fail. + */ + fun decrypt(aesKey: SecretKey, ciphertext: ByteArray, additionalData: ByteArray? = null): ByteArray { + val cipher = Cipher.getInstance("AES/GCM/NoPadding") + cipher.init(Cipher.DECRYPT_MODE, aesKey, GCMParameterSpec(TAG_SIZE_BITS, ciphertext, 0, IV_SIZE_BYTES)) + if (additionalData != null) { + cipher.updateAAD(additionalData) + } + return cipher.doFinal(ciphertext, IV_SIZE_BYTES, ciphertext.size - IV_SIZE_BYTES) + } + + fun decrypt(aesKey: ByteArray, ciphertext: ByteArray, additionalData: ByteArray? = null): ByteArray { + return decrypt(SecretKeySpec(aesKey, "AES"), ciphertext, additionalData) + } +} diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/crypto/AesEncryptionTest.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/crypto/AesEncryptionTest.kt new file mode 100644 index 0000000000..d3b1ded638 --- /dev/null +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/crypto/AesEncryptionTest.kt @@ -0,0 +1,73 @@ +package net.corda.nodeapi.internal.crypto + +import net.corda.core.crypto.secureRandomBytes +import net.corda.nodeapi.internal.crypto.AesEncryption.IV_SIZE_BYTES +import net.corda.nodeapi.internal.crypto.AesEncryption.KEY_SIZE_BYTES +import org.assertj.core.api.Assertions.assertThat +import org.assertj.core.api.Assertions.assertThatExceptionOfType +import org.junit.Test +import java.security.GeneralSecurityException + +class AesEncryptionTest { + private val aesKey = secureRandomBytes(KEY_SIZE_BYTES) + private val plaintext = secureRandomBytes(257) // Intentionally not a power of 2 + + @Test(timeout = 300_000) + fun `ciphertext can be decrypted using the same key`() { + val ciphertext = AesEncryption.encrypt(aesKey, plaintext) + assertThat(String(ciphertext)).doesNotContain(String(plaintext)) + val decrypted = AesEncryption.decrypt(aesKey, ciphertext) + assertThat(decrypted).isEqualTo(plaintext) + } + + @Test(timeout = 300_000) + fun `ciphertext with authenticated data can be decrypted using the same key`() { + val ciphertext = AesEncryption.encrypt(aesKey, plaintext, "Extra public data".toByteArray()) + assertThat(String(ciphertext)).doesNotContain(String(plaintext)) + val decrypted = AesEncryption.decrypt(aesKey, ciphertext, "Extra public data".toByteArray()) + assertThat(decrypted).isEqualTo(plaintext) + } + + @Test(timeout = 300_000) + fun `ciphertext cannot be decrypted with different authenticated data`() { + val ciphertext = AesEncryption.encrypt(aesKey, plaintext, "Extra public data".toByteArray()) + assertThat(String(ciphertext)).doesNotContain(String(plaintext)) + assertThatExceptionOfType(GeneralSecurityException::class.java).isThrownBy { + AesEncryption.decrypt(aesKey, ciphertext, "Different public data".toByteArray()) + } + } + + @Test(timeout = 300_000) + fun `ciphertext cannot be decrypted with different key`() { + val ciphertext = AesEncryption.encrypt(aesKey, plaintext) + for (index in aesKey.indices) { + aesKey[index]-- + assertThatExceptionOfType(GeneralSecurityException::class.java).isThrownBy { + AesEncryption.decrypt(aesKey, ciphertext) + } + aesKey[index]++ + } + } + + @Test(timeout = 300_000) + fun `corrupted ciphertext cannot be decrypted`() { + val ciphertext = AesEncryption.encrypt(aesKey, plaintext) + for (index in ciphertext.indices) { + ciphertext[index]-- + assertThatExceptionOfType(GeneralSecurityException::class.java).isThrownBy { + AesEncryption.decrypt(aesKey, ciphertext) + } + ciphertext[index]++ + } + } + + @Test(timeout = 300_000) + fun `encrypting same plainttext twice with same key does not produce same ciphertext`() { + val first = AesEncryption.encrypt(aesKey, plaintext) + val second = AesEncryption.encrypt(aesKey, plaintext) + // The IV should be different + assertThat(first.take(IV_SIZE_BYTES)).isNotEqualTo(second.take(IV_SIZE_BYTES)) + // Which should cause the encrypted bytes to be different as well + assertThat(first.drop(IV_SIZE_BYTES)).isNotEqualTo(second.drop(IV_SIZE_BYTES)) + } +} diff --git a/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt b/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt index 198b158d24..cbae60071e 100644 --- a/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt +++ b/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt @@ -128,6 +128,7 @@ import net.corda.node.services.persistence.AbstractPartyToX500NameAsStringConver import net.corda.node.services.persistence.AttachmentStorageInternal import net.corda.node.services.persistence.DBCheckpointPerformanceRecorder import net.corda.node.services.persistence.DBCheckpointStorage +import net.corda.node.services.persistence.AesDbEncryptionService import net.corda.node.services.persistence.DBTransactionMappingStorage import net.corda.node.services.persistence.DBTransactionStorageLedgerRecovery import net.corda.node.services.persistence.NodeAttachmentService @@ -286,6 +287,7 @@ abstract class AbstractNode(val configuration: NodeConfiguration, val networkMapCache = PersistentNetworkMapCache(cacheFactory, database, identityService).tokenize() val partyInfoCache = PersistentPartyInfoCache(networkMapCache, cacheFactory, database) + val encryptionService = AesDbEncryptionService(database) @Suppress("LeakingThis") val cryptoService = makeCryptoService() @Suppress("LeakingThis") @@ -658,6 +660,7 @@ abstract class AbstractNode(val configuration: NodeConfiguration, verifyCheckpointsCompatible(frozenTokenizableServices) partyInfoCache.start() + encryptionService.start(nodeInfo.legalIdentities[0]) /* Note the .get() at the end of the distributeEvent call, below. This will block until all Corda Services have returned from processing the event, allowing a service to prevent the @@ -1080,7 +1083,7 @@ abstract class AbstractNode(val configuration: NodeConfiguration, } protected open fun makeTransactionStorage(transactionCacheSizeBytes: Long): WritableTransactionStorage { - return DBTransactionStorageLedgerRecovery(database, cacheFactory, platformClock, cryptoService, partyInfoCache) + return DBTransactionStorageLedgerRecovery(database, cacheFactory, platformClock, encryptionService, partyInfoCache) } protected open fun makeNetworkParametersStorage(): NetworkParametersStorage { diff --git a/node/src/main/kotlin/net/corda/node/services/EncryptionService.kt b/node/src/main/kotlin/net/corda/node/services/EncryptionService.kt new file mode 100644 index 0000000000..85dea166e0 --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/EncryptionService.kt @@ -0,0 +1,42 @@ +package net.corda.node.services + +/** + * A service for encrypting data. This abstraction does not mandate any security properties except the same service instance will be + * able to decrypt ciphertext encrypted by it. Further security properties are defined by the implementations. This includes the encryption + * protocol used. + */ +interface EncryptionService { + /** + * Encrypt the given [plaintext]. The encryption key used is dependent on the implementation. The returned ciphertext can be decrypted + * using [decrypt]. + * + * An optional public [additionalData] bytes can also be provided which will be authenticated (thus tamperproof) alongside the + * ciphertext but not encrypted. It will be incorporated into the returned bytes in an implementation dependent fashion. + */ + fun encrypt(plaintext: ByteArray, additionalData: ByteArray? = null): ByteArray + + /** + * Decrypt ciphertext that was encrypted using [encrypt] and return the original plaintext plus the additional data authenticated (if + * present). The service will select the correct encryption key to use. + */ + fun decrypt(ciphertext: ByteArray): PlaintextAndAAD + + /** + * Extracts the (unauthenticated) additional data, if present, from the given [ciphertext]. This is the public data that would have been + * given at encryption time. + * + * Note, this method does not verify if the data was tampered with, and hence is unauthenticated. To have it authenticated requires + * calling [decrypt]. This is still useful however, as it doesn't require the encryption key, and so a third-party can view the + * additional data without needing access to the key. + */ + fun extractUnauthenticatedAdditionalData(ciphertext: ByteArray): ByteArray? + + + /** + * Represents the decrypted plaintext and the optional authenticated additional data bytes. + */ + class PlaintextAndAAD(val plaintext: ByteArray, val authenticatedAdditionalData: ByteArray?) { + operator fun component1() = plaintext + operator fun component2() = authenticatedAdditionalData + } +} diff --git a/node/src/main/kotlin/net/corda/node/services/api/ServiceHubInternal.kt b/node/src/main/kotlin/net/corda/node/services/api/ServiceHubInternal.kt index a5ef5f054a..962f7a0664 100644 --- a/node/src/main/kotlin/net/corda/node/services/api/ServiceHubInternal.kt +++ b/node/src/main/kotlin/net/corda/node/services/api/ServiceHubInternal.kt @@ -372,22 +372,26 @@ interface WritableTransactionStorage : TransactionStorage { /** * Records Sender [TransactionMetadata] for a given txnId. * - * @param id The SecureHash of a transaction. + * @param txId The SecureHash of a transaction. * @param metadata The recovery metadata associated with a transaction. * @return encrypted distribution list (hashed peers -> StatesToRecord values). */ - fun addSenderTransactionRecoveryMetadata(id: SecureHash, metadata: TransactionMetadata): ByteArray? + fun addSenderTransactionRecoveryMetadata(txId: SecureHash, metadata: TransactionMetadata): ByteArray? /** * Records Received [TransactionMetadata] for a given txnId. * - * @param id The SecureHash of a transaction. + * @param txId The SecureHash of a transaction. * @param sender The sender of the transaction. * @param receiver The receiver of the transaction. * @param receiverStatesToRecord The StatesToRecord value of the receiver. * @param encryptedDistributionList encrypted distribution list (hashed peers -> StatesToRecord values) */ - fun addReceiverTransactionRecoveryMetadata(id: SecureHash, sender: CordaX500Name, receiver: CordaX500Name, receiverStatesToRecord: StatesToRecord, encryptedDistributionList: ByteArray) + fun addReceiverTransactionRecoveryMetadata(txId: SecureHash, + sender: CordaX500Name, + receiver: CordaX500Name, + receiverStatesToRecord: StatesToRecord, + encryptedDistributionList: ByteArray) /** * Removes an un-notarised transaction (with a status of *MISSING_TRANSACTION_SIG*) from the data store. diff --git a/node/src/main/kotlin/net/corda/node/services/persistence/AesDbEncryptionService.kt b/node/src/main/kotlin/net/corda/node/services/persistence/AesDbEncryptionService.kt new file mode 100644 index 0000000000..924ef48c7f --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/persistence/AesDbEncryptionService.kt @@ -0,0 +1,152 @@ +package net.corda.node.services.persistence + +import net.corda.core.crypto.newSecureRandom +import net.corda.core.identity.Party +import net.corda.core.internal.copyBytes +import net.corda.core.serialization.SingletonSerializeAsToken +import net.corda.node.services.EncryptionService +import net.corda.nodeapi.internal.crypto.AesEncryption +import net.corda.nodeapi.internal.persistence.CordaPersistence +import net.corda.nodeapi.internal.persistence.NODE_DATABASE_PREFIX +import org.hibernate.annotations.Type +import java.nio.ByteBuffer +import java.security.Key +import java.security.MessageDigest +import java.util.UUID +import javax.crypto.Cipher +import javax.crypto.SecretKey +import javax.crypto.spec.SecretKeySpec +import javax.persistence.Column +import javax.persistence.Entity +import javax.persistence.Id +import javax.persistence.Table + +/** + * [EncryptionService] which uses AES keys stored in the node database. A random key is chosen for encryption, and the resultant ciphertext + * encodes the key used so that it can be decrypted without needing further information. + * + * **Storing encryption keys in a database is not secure, and so only use this service if the data being encrypted is also stored + * unencrypted in the same database.** + * + * To obfuscate the keys, they are stored wrapped using another AES key (called the wrapping key or key-encryption-key) derived from the + * node's legal identity. This is not a security measure; it's only meant to reduce the impact of accidental leakage. + */ +// TODO Add support for key expiry +class AesDbEncryptionService(private val database: CordaPersistence) : EncryptionService, SingletonSerializeAsToken() { + companion object { + private const val INITIAL_KEY_COUNT = 10 + private const val UUID_BYTES = 16 + } + + private val aesKeys = ArrayList>() + + fun start(ourIdentity: Party) { + database.transaction { + val criteria = session.criteriaBuilder.createQuery(EncryptionKeyRecord::class.java) + criteria.select(criteria.from(EncryptionKeyRecord::class.java)) + val dbKeyRecords = session.createQuery(criteria).resultList + val keyWrapper = Cipher.getInstance("AESWrap") + if (dbKeyRecords.isEmpty()) { + repeat(INITIAL_KEY_COUNT) { + val keyId = UUID.randomUUID() + val aesKey = AesEncryption.randomKey() + aesKeys += Pair(keyId, aesKey) + val wrappedKey = with(keyWrapper) { + init(Cipher.WRAP_MODE, createKEK(ourIdentity, keyId)) + wrap(aesKey) + } + session.save(EncryptionKeyRecord(keyId = keyId, keyMaterial = wrappedKey)) + } + } else { + for (dbKeyRecord in dbKeyRecords) { + val aesKey = with(keyWrapper) { + init(Cipher.UNWRAP_MODE, createKEK(ourIdentity, dbKeyRecord.keyId)) + unwrap(dbKeyRecord.keyMaterial, "AES", Cipher.SECRET_KEY) as SecretKey + } + aesKeys += Pair(dbKeyRecord.keyId, aesKey) + } + } + } + } + + override fun encrypt(plaintext: ByteArray, additionalData: ByteArray?): ByteArray { + val (keyId, aesKey) = aesKeys[newSecureRandom().nextInt(aesKeys.size)] + val ciphertext = AesEncryption.encrypt(aesKey, plaintext, additionalData) + val buffer = ByteBuffer.allocate(1 + UUID_BYTES + Integer.BYTES + (additionalData?.size ?: 0) + ciphertext.size) + buffer.put(1) // Version tag + // Prepend the key ID to the returned ciphertext. It's OK that this is not included in the authenticated additional data because + // changing this value will lead to either an non-existent key or an another key which will not be able decrypt the ciphertext. + buffer.putUUID(keyId) + if (additionalData != null) { + buffer.putInt(additionalData.size) + buffer.put(additionalData) + } else { + buffer.putInt(0) + } + buffer.put(ciphertext) + return buffer.array() + } + + override fun decrypt(ciphertext: ByteArray): EncryptionService.PlaintextAndAAD { + val buffer = ByteBuffer.wrap(ciphertext) + val version = buffer.get().toInt() + require(version == 1) + val keyId = buffer.getUUID() + val aesKey = requireNotNull(aesKeys.find { it.first == keyId }?.second) { "Unable to decrypt" } + val additionalData = buffer.getAdditionaData() + val plaintext = AesEncryption.decrypt(aesKey, buffer.copyBytes(), additionalData) + // Only now is the additional data authenticated + return EncryptionService.PlaintextAndAAD(plaintext, additionalData) + } + + override fun extractUnauthenticatedAdditionalData(ciphertext: ByteArray): ByteArray? { + val buffer = ByteBuffer.wrap(ciphertext) + buffer.position(1 + UUID_BYTES) + return buffer.getAdditionaData() + } + + private fun ByteBuffer.getAdditionaData(): ByteArray? { + val additionalDataSize = getInt() + return if (additionalDataSize > 0) ByteArray(additionalDataSize).also { get(it) } else null + } + + private fun UUID.toByteArray(): ByteArray { + val buffer = ByteBuffer.allocate(UUID_BYTES) + buffer.putUUID(this) + return buffer.array() + } + + /** + * Derive the key-encryption-key (KEK) from the the node's identity and the persisted key's ID. + */ + private fun createKEK(ourIdentity: Party, keyId: UUID): Key { + val digest = MessageDigest.getInstance("SHA-256") + digest.update(ourIdentity.name.x500Principal.encoded) + digest.update(keyId.toByteArray()) + return SecretKeySpec(digest.digest(), 0, AesEncryption.KEY_SIZE_BYTES, "AES") + } + + + @Entity + @Table(name = "${NODE_DATABASE_PREFIX}aes_encryption_keys") + class EncryptionKeyRecord( + @Id + @Type(type = "uuid-char") + @Column(name = "key_id", nullable = false) + val keyId: UUID, + + @Column(name = "key_material", nullable = false) + val keyMaterial: ByteArray + ) +} + +internal fun ByteBuffer.putUUID(uuid: UUID) { + putLong(uuid.mostSignificantBits) + putLong(uuid.leastSignificantBits) +} + +internal fun ByteBuffer.getUUID(): UUID { + val mostSigBits = getLong() + val leastSigBits = getLong() + return UUID(mostSigBits, leastSigBits) +} diff --git a/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorage.kt b/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorage.kt index 56acdfd61b..1973f9e7c1 100644 --- a/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorage.kt +++ b/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorage.kt @@ -215,9 +215,13 @@ open class DBTransactionStorage(private val database: CordaPersistence, cacheFac false } - override fun addSenderTransactionRecoveryMetadata(id: SecureHash, metadata: TransactionMetadata): ByteArray? { return null } + override fun addSenderTransactionRecoveryMetadata(txId: SecureHash, metadata: TransactionMetadata): ByteArray? { return null } - override fun addReceiverTransactionRecoveryMetadata(id: SecureHash, sender: CordaX500Name, receiver: CordaX500Name, receiverStatesToRecord: StatesToRecord, encryptedDistributionList: ByteArray) { } + override fun addReceiverTransactionRecoveryMetadata(txId: SecureHash, + sender: CordaX500Name, + receiver: CordaX500Name, + receiverStatesToRecord: StatesToRecord, + encryptedDistributionList: ByteArray) { } override fun finalizeTransaction(transaction: SignedTransaction) = addTransaction(transaction) { diff --git a/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorageLedgerRecovery.kt b/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorageLedgerRecovery.kt index 0d00344742..6c101a8401 100644 --- a/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorageLedgerRecovery.kt +++ b/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorageLedgerRecovery.kt @@ -9,15 +9,11 @@ import net.corda.core.node.StatesToRecord import net.corda.core.node.services.vault.Sort import net.corda.core.serialization.CordaSerializable import net.corda.node.CordaClock +import net.corda.node.services.EncryptionService import net.corda.node.services.network.PersistentPartyInfoCache -import net.corda.nodeapi.internal.cryptoservice.CryptoService import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.nodeapi.internal.persistence.NODE_DATABASE_PREFIX import org.hibernate.annotations.Immutable -import java.io.ByteArrayInputStream -import java.io.ByteArrayOutputStream -import java.io.DataInputStream -import java.io.DataOutputStream import java.io.Serializable import java.time.Instant import java.util.concurrent.atomic.AtomicLong @@ -31,9 +27,10 @@ import javax.persistence.Table import javax.persistence.criteria.Predicate import kotlin.streams.toList -class DBTransactionStorageLedgerRecovery(private val database: CordaPersistence, cacheFactory: NamedCacheFactory, +class DBTransactionStorageLedgerRecovery(private val database: CordaPersistence, + cacheFactory: NamedCacheFactory, val clock: CordaClock, - private val cryptoService: CryptoService, + private val encryptionService: EncryptionService, private val partyInfoCache: PersistentPartyInfoCache) : DBTransactionStorage(database, cacheFactory, clock) { @Embeddable @Immutable @@ -63,7 +60,6 @@ class DBTransactionStorageLedgerRecovery(private val database: CordaPersistence, /** states to record: NONE, ALL_VISIBLE, ONLY_RELEVANT */ @Column(name = "states_to_record", nullable = false) var statesToRecord: StatesToRecord - ) { fun toSenderDistributionRecord() = SenderDistributionRecord( @@ -76,7 +72,7 @@ class DBTransactionStorageLedgerRecovery(private val database: CordaPersistence, @Entity @Table(name = "${NODE_DATABASE_PREFIX}receiver_distribution_records") - data class DBReceiverDistributionRecord( + class DBReceiverDistributionRecord( @EmbeddedId var compositeKey: PersistentKey, @@ -95,17 +91,18 @@ class DBTransactionStorageLedgerRecovery(private val database: CordaPersistence, /** states to record: NONE, ALL_VISIBLE, ONLY_RELEVANT */ @Column(name = "receiver_states_to_record", nullable = false) val receiverStatesToRecord: StatesToRecord -) { + ) { constructor(key: Key, txId: SecureHash, initiatorPartyId: Long, encryptedDistributionList: ByteArray, receiverStatesToRecord: StatesToRecord) : - this(PersistentKey(key), - txId = txId.toString(), - senderPartyId = initiatorPartyId, - distributionList = encryptedDistributionList, - receiverStatesToRecord = receiverStatesToRecord - ) + this( + PersistentKey(key), + txId = txId.toString(), + senderPartyId = initiatorPartyId, + distributionList = encryptedDistributionList, + receiverStatesToRecord = receiverStatesToRecord + ) - fun toReceiverDistributionRecord(cryptoService: CryptoService): ReceiverDistributionRecord { - val hashedDL = HashedDistributionList.deserialize(cryptoService.decrypt(this.distributionList)) + fun toReceiverDistributionRecord(encryptionService: EncryptionService): ReceiverDistributionRecord { + val hashedDL = HashedDistributionList.decrypt(this.distributionList, encryptionService) return ReceiverDistributionRecord( SecureHash.parse(this.txId), this.senderPartyId, @@ -139,32 +136,45 @@ class DBTransactionStorageLedgerRecovery(private val database: CordaPersistence, } } - override fun addSenderTransactionRecoveryMetadata(id: SecureHash, metadata: TransactionMetadata): ByteArray { + override fun addSenderTransactionRecoveryMetadata(txId: SecureHash, metadata: TransactionMetadata): ByteArray { return database.transaction { val senderRecordingTimestamp = clock.instant() - metadata.distributionList.peersToStatesToRecord.forEach { (peer, _) -> - val senderDistributionRecord = DBSenderDistributionRecord(PersistentKey(Key(senderRecordingTimestamp)), - id.toString(), + for (peer in metadata.distributionList.peersToStatesToRecord.keys) { + val senderDistributionRecord = DBSenderDistributionRecord( + PersistentKey(Key(senderRecordingTimestamp)), + txId.toString(), partyInfoCache.getPartyIdByCordaX500Name(peer), - metadata.distributionList.senderStatesToRecord) + metadata.distributionList.senderStatesToRecord + ) session.save(senderDistributionRecord) } - val hashedPeersToStatesToRecord = metadata.distributionList.peersToStatesToRecord.map { (peer, statesToRecord) -> - partyInfoCache.getPartyIdByCordaX500Name(peer) to statesToRecord }.toMap() - val hashedDistributionList = HashedDistributionList(metadata.distributionList.senderStatesToRecord, hashedPeersToStatesToRecord, senderRecordingTimestamp) - cryptoService.encrypt(hashedDistributionList.serialize()) + + val hashedPeersToStatesToRecord = metadata.distributionList.peersToStatesToRecord.mapKeys { (peer) -> + partyInfoCache.getPartyIdByCordaX500Name(peer) + } + val hashedDistributionList = HashedDistributionList( + metadata.distributionList.senderStatesToRecord, + hashedPeersToStatesToRecord, + HashedDistributionList.PublicHeader(senderRecordingTimestamp) + ) + hashedDistributionList.encrypt(encryptionService) } } - override fun addReceiverTransactionRecoveryMetadata(id: SecureHash, sender: CordaX500Name, receiver: CordaX500Name, receiverStatesToRecord: StatesToRecord, encryptedDistributionList: ByteArray) { - val senderRecordedTimestamp = HashedDistributionList.deserialize(cryptoService.decrypt(encryptedDistributionList)).senderRecordedTimestamp + override fun addReceiverTransactionRecoveryMetadata(txId: SecureHash, + sender: CordaX500Name, + receiver: CordaX500Name, + receiverStatesToRecord: StatesToRecord, + encryptedDistributionList: ByteArray) { + val publicHeader = HashedDistributionList.PublicHeader.unauthenticatedDeserialise(encryptedDistributionList, encryptionService) database.transaction { - val receiverDistributionRecord = - DBReceiverDistributionRecord(Key(senderRecordedTimestamp), - id, - partyInfoCache.getPartyIdByCordaX500Name(sender), - encryptedDistributionList, - receiverStatesToRecord) + val receiverDistributionRecord = DBReceiverDistributionRecord( + Key(publicHeader.senderRecordedTimestamp), + txId, + partyInfoCache.getPartyIdByCordaX500Name(sender), + encryptedDistributionList, + receiverStatesToRecord + ) session.save(receiverDistributionRecord) } } @@ -235,8 +245,9 @@ class DBTransactionStorageLedgerRecovery(private val database: CordaPersistence, } criteriaQuery.orderBy(orderCriteria) } - val results = session.createQuery(criteriaQuery).stream() - results.map { it.toSenderDistributionRecord() }.toList() + session.createQuery(criteriaQuery).stream().use { results -> + results.map { it.toSenderDistributionRecord() }.toList() + } } } @@ -273,21 +284,13 @@ class DBTransactionStorageLedgerRecovery(private val database: CordaPersistence, } criteriaQuery.orderBy(orderCriteria) } - val results = session.createQuery(criteriaQuery).stream() - results.map { it.toReceiverDistributionRecord(cryptoService) }.toList() + session.createQuery(criteriaQuery).stream().use { results -> + results.map { it.toReceiverDistributionRecord(encryptionService) }.toList() + } } } } -// TO DO: https://r3-cev.atlassian.net/browse/ENT-9876 -private fun CryptoService.decrypt(bytes: ByteArray): ByteArray { - return bytes -} - -// TO DO: https://r3-cev.atlassian.net/browse/ENT-9876 -fun CryptoService.encrypt(bytes: ByteArray): ByteArray { - return bytes -} @CordaSerializable open class DistributionRecord( @@ -318,46 +321,3 @@ data class ReceiverDistributionRecord( enum class DistributionRecordType { SENDER, RECEIVER, ALL } - -@CordaSerializable -data class HashedDistributionList( - val senderStatesToRecord: StatesToRecord, - val peerHashToStatesToRecord: Map, - val senderRecordedTimestamp: Instant -) { - fun serialize(): ByteArray { - val baos = ByteArrayOutputStream() - val out = DataOutputStream(baos) - out.use { - out.writeByte(SERIALIZER_VERSION_ID) - out.writeByte(senderStatesToRecord.ordinal) - out.writeInt(peerHashToStatesToRecord.size) - for(entry in peerHashToStatesToRecord) { - out.writeLong(entry.key) - out.writeByte(entry.value.ordinal) - } - out.writeLong(senderRecordedTimestamp.toEpochMilli()) - out.flush() - return baos.toByteArray() - } - } - companion object { - const val SERIALIZER_VERSION_ID = 1 - fun deserialize(bytes: ByteArray): HashedDistributionList { - val input = DataInputStream(ByteArrayInputStream(bytes)) - input.use { - assert(input.readByte().toInt() == SERIALIZER_VERSION_ID) { "Serialization version conflict." } - val senderStatesToRecord = StatesToRecord.values()[input.readByte().toInt()] - val numPeerHashToStatesToRecords = input.readInt() - val peerHashToStatesToRecord = mutableMapOf() - repeat (numPeerHashToStatesToRecords) { - peerHashToStatesToRecord[input.readLong()] = StatesToRecord.values()[input.readByte().toInt()] - } - val senderRecordedTimestamp = Instant.ofEpochMilli(input.readLong()) - return HashedDistributionList(senderStatesToRecord, peerHashToStatesToRecord, senderRecordedTimestamp) - } - } - } -} - - diff --git a/node/src/main/kotlin/net/corda/node/services/persistence/HashedDistributionList.kt b/node/src/main/kotlin/net/corda/node/services/persistence/HashedDistributionList.kt new file mode 100644 index 0000000000..910a00ce74 --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/persistence/HashedDistributionList.kt @@ -0,0 +1,104 @@ +package net.corda.node.services.persistence + +import net.corda.core.node.StatesToRecord +import net.corda.core.serialization.CordaSerializable +import net.corda.node.services.EncryptionService +import java.io.ByteArrayOutputStream +import java.io.DataInputStream +import java.io.DataOutputStream +import java.nio.ByteBuffer +import java.time.Instant + +@Suppress("TooGenericExceptionCaught") +@CordaSerializable +data class HashedDistributionList( + val senderStatesToRecord: StatesToRecord, + val peerHashToStatesToRecord: Map, + val publicHeader: PublicHeader +) { + /** + * Encrypt this hashed distribution list using the given [EncryptionService]. The [publicHeader] is not encrypted but is instead + * authenticated so that it is tamperproof. + * + * The same [EncryptionService] instance needs to be used with [decrypt] for decryption. + */ + fun encrypt(encryptionService: EncryptionService): ByteArray { + val baos = ByteArrayOutputStream() + val out = DataOutputStream(baos) + out.writeByte(senderStatesToRecord.ordinal) + out.writeInt(peerHashToStatesToRecord.size) + for (entry in peerHashToStatesToRecord) { + out.writeLong(entry.key) + out.writeByte(entry.value.ordinal) + } + return encryptionService.encrypt(baos.toByteArray(), publicHeader.serialise()) + } + + + @CordaSerializable + data class PublicHeader( + val senderRecordedTimestamp: Instant + ) { + fun serialise(): ByteArray { + val buffer = ByteBuffer.allocate(1 + java.lang.Long.BYTES) + buffer.put(VERSION_TAG.toByte()) + buffer.putLong(senderRecordedTimestamp.toEpochMilli()) + return buffer.array() + } + + companion object { + /** + * Deserialise a [PublicHeader] from the given [encryptedBytes]. The bytes is expected is to be a valid encrypted blob that can + * be decrypted by [HashedDistributionList.decrypt] using the same [EncryptionService]. + * + * Because this method does not actually decrypt the bytes, the header returned is not authenticated and any modifications to it + * will not be detected. That can only be done by the encrypting party with [HashedDistributionList.decrypt]. + */ + fun unauthenticatedDeserialise(encryptedBytes: ByteArray, encryptionService: EncryptionService): PublicHeader { + val additionalData = encryptionService.extractUnauthenticatedAdditionalData(encryptedBytes) + requireNotNull(additionalData) { "Missing additional data field" } + return deserialise(additionalData!!) + } + + fun deserialise(bytes: ByteArray): PublicHeader { + val buffer = ByteBuffer.wrap(bytes) + try { + val version = buffer.get().toInt() + require(version == VERSION_TAG) { "Unknown distribution list format $version" } + val senderRecordedTimestamp = Instant.ofEpochMilli(buffer.getLong()) + return PublicHeader(senderRecordedTimestamp) + } catch (e: Exception) { + throw IllegalArgumentException("Corrupt or not a distribution list header", e) + } + } + } + } + + companion object { + // The version tag is serialised in the header, even though it is separate from the encrypted main body of the distribution list. + // This is because the header and the dist list are cryptographically coupled and we want to avoid declaring the version field twice. + private const val VERSION_TAG = 1 + private val statesToRecordValues = StatesToRecord.values() // Cache the enum values since .values() returns a new array each time. + + /** + * Decrypt a [HashedDistributionList] from the given [encryptedBytes] using the same [EncryptionService] that was used in [encrypt]. + */ + fun decrypt(encryptedBytes: ByteArray, encryptionService: EncryptionService): HashedDistributionList { + val (plaintext, authenticatedAdditionalData) = encryptionService.decrypt(encryptedBytes) + requireNotNull(authenticatedAdditionalData) { "Missing authenticated header" } + val publicHeader = PublicHeader.deserialise(authenticatedAdditionalData!!) + val input = DataInputStream(plaintext.inputStream()) + try { + val senderStatesToRecord = statesToRecordValues[input.readByte().toInt()] + val numPeerHashToStatesToRecords = input.readInt() + val peerHashToStatesToRecord = mutableMapOf() + repeat(numPeerHashToStatesToRecords) { + peerHashToStatesToRecord[input.readLong()] = statesToRecordValues[input.readByte().toInt()] + } + return HashedDistributionList(senderStatesToRecord, peerHashToStatesToRecord, publicHeader) + } catch (e: Exception) { + throw IllegalArgumentException("Corrupt or not a distribution list", e) + } + } + } +} diff --git a/node/src/main/kotlin/net/corda/node/services/schema/NodeSchemaService.kt b/node/src/main/kotlin/net/corda/node/services/schema/NodeSchemaService.kt index 760544758d..68dc445e29 100644 --- a/node/src/main/kotlin/net/corda/node/services/schema/NodeSchemaService.kt +++ b/node/src/main/kotlin/net/corda/node/services/schema/NodeSchemaService.kt @@ -16,6 +16,7 @@ import net.corda.node.services.keys.BasicHSMKeyManagementService import net.corda.node.services.messaging.P2PMessageDeduplicator import net.corda.node.services.network.PersistentNetworkMapCache import net.corda.node.services.persistence.DBCheckpointStorage +import net.corda.node.services.persistence.AesDbEncryptionService import net.corda.node.services.persistence.DBTransactionStorageLedgerRecovery import net.corda.node.services.persistence.DBTransactionStorage import net.corda.node.services.persistence.NodeAttachmentService @@ -30,7 +31,7 @@ import net.corda.node.services.vault.VaultSchemaV1 * TODO: support plugins for schema version upgrading or custom mapping not supported by original [QueryableState]. * TODO: create whitelisted tables when a CorDapp is first installed */ -class NodeSchemaService(private val extraSchemas: Set = emptySet()) : SchemaService, SingletonSerializeAsToken() { +class NodeSchemaService(extraSchemas: Set = emptySet()) : SchemaService, SingletonSerializeAsToken() { // Core Entities used by a Node object NodeCore @@ -55,7 +56,8 @@ class NodeSchemaService(private val extraSchemas: Set = emptySet() PersistentNetworkMapCache.PersistentPartyToPublicKeyHash::class.java, DBTransactionStorageLedgerRecovery.DBSenderDistributionRecord::class.java, DBTransactionStorageLedgerRecovery.DBReceiverDistributionRecord::class.java, - DBTransactionStorageLedgerRecovery.DBRecoveryPartyInfo::class.java + DBTransactionStorageLedgerRecovery.DBRecoveryPartyInfo::class.java, + AesDbEncryptionService.EncryptionKeyRecord::class.java )) { override val migrationResource = "node-core.changelog-master" } diff --git a/node/src/main/resources/migration/node-core.changelog-master.xml b/node/src/main/resources/migration/node-core.changelog-master.xml index 0ebf26bdc1..ef9116aade 100644 --- a/node/src/main/resources/migration/node-core.changelog-master.xml +++ b/node/src/main/resources/migration/node-core.changelog-master.xml @@ -31,6 +31,7 @@ + diff --git a/node/src/main/resources/migration/node-core.changelog-v26.xml b/node/src/main/resources/migration/node-core.changelog-v26.xml new file mode 100644 index 0000000000..b0d4925c7a --- /dev/null +++ b/node/src/main/resources/migration/node-core.changelog-v26.xml @@ -0,0 +1,28 @@ + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/node/src/test/kotlin/net/corda/node/messaging/TwoPartyTradeFlowTests.kt b/node/src/test/kotlin/net/corda/node/messaging/TwoPartyTradeFlowTests.kt index fa516073fb..d0f0de3d60 100644 --- a/node/src/test/kotlin/net/corda/node/messaging/TwoPartyTradeFlowTests.kt +++ b/node/src/test/kotlin/net/corda/node/messaging/TwoPartyTradeFlowTests.kt @@ -810,15 +810,19 @@ class TwoPartyTradeFlowTests(private val anonymous: Boolean) { return true } - override fun addSenderTransactionRecoveryMetadata(id: SecureHash, metadata: TransactionMetadata): ByteArray? { + override fun addSenderTransactionRecoveryMetadata(txId: SecureHash, metadata: TransactionMetadata): ByteArray? { return database.transaction { - delegate.addSenderTransactionRecoveryMetadata(id, metadata) + delegate.addSenderTransactionRecoveryMetadata(txId, metadata) } } - override fun addReceiverTransactionRecoveryMetadata(id: SecureHash, sender: CordaX500Name, receiver: CordaX500Name, receiverStatesToRecord: StatesToRecord, encryptedDistributionList: ByteArray) { + override fun addReceiverTransactionRecoveryMetadata(txId: SecureHash, + sender: CordaX500Name, + receiver: CordaX500Name, + receiverStatesToRecord: StatesToRecord, + encryptedDistributionList: ByteArray) { database.transaction { - delegate.addReceiverTransactionRecoveryMetadata(id, sender, receiver, receiverStatesToRecord, encryptedDistributionList) + delegate.addReceiverTransactionRecoveryMetadata(txId, sender, receiver, receiverStatesToRecord, encryptedDistributionList) } } diff --git a/node/src/test/kotlin/net/corda/node/services/persistence/AesDbEncryptionServiceTest.kt b/node/src/test/kotlin/net/corda/node/services/persistence/AesDbEncryptionServiceTest.kt new file mode 100644 index 0000000000..806357627a --- /dev/null +++ b/node/src/test/kotlin/net/corda/node/services/persistence/AesDbEncryptionServiceTest.kt @@ -0,0 +1,134 @@ +package net.corda.node.services.persistence + +import net.corda.node.services.persistence.AesDbEncryptionService.EncryptionKeyRecord +import net.corda.nodeapi.internal.persistence.CordaPersistence +import net.corda.nodeapi.internal.persistence.DatabaseConfig +import net.corda.testing.core.TestIdentity +import net.corda.testing.internal.configureDatabase +import net.corda.testing.node.MockServices +import org.assertj.core.api.Assertions.assertThat +import org.assertj.core.api.Assertions.assertThatExceptionOfType +import org.assertj.core.api.Assertions.assertThatIllegalArgumentException +import org.junit.After +import org.junit.Before +import org.junit.Test +import java.nio.ByteBuffer +import java.security.GeneralSecurityException +import java.util.UUID + +class AesDbEncryptionServiceTest { + private val identity = TestIdentity.fresh("me").party + private lateinit var database: CordaPersistence + private lateinit var encryptionService: AesDbEncryptionService + + @Before + fun setUp() { + val dataSourceProps = MockServices.makeTestDataSourceProperties() + database = configureDatabase(dataSourceProps, DatabaseConfig(), { null }, { null }) + encryptionService = AesDbEncryptionService(database) + encryptionService.start(identity) + } + + @After + fun cleanUp() { + database.close() + } + + @Test(timeout = 300_000) + fun `same instance can decrypt ciphertext`() { + val ciphertext = encryptionService.encrypt("Hello World".toByteArray()) + val (plaintext, authenticatedData) = encryptionService.decrypt(ciphertext) + assertThat(String(plaintext)).isEqualTo("Hello World") + assertThat(authenticatedData).isNull() + } + + @Test(timeout = 300_000) + fun `encypting twice produces different ciphertext`() { + val plaintext = "Hello".toByteArray() + assertThat(encryptionService.encrypt(plaintext)).isNotEqualTo(encryptionService.encrypt(plaintext)) + } + + @Test(timeout = 300_000) + fun `ciphertext can be decrypted after restart`() { + val ciphertext = encryptionService.encrypt("Hello World".toByteArray()) + encryptionService = AesDbEncryptionService(database) + encryptionService.start(identity) + val plaintext = encryptionService.decrypt(ciphertext).plaintext + assertThat(String(plaintext)).isEqualTo("Hello World") + } + + @Test(timeout = 300_000) + fun `encrypting with authenticated data`() { + val ciphertext = encryptionService.encrypt("Hello World".toByteArray(), "Additional data".toByteArray()) + val (plaintext, authenticatedData) = encryptionService.decrypt(ciphertext) + assertThat(String(plaintext)).isEqualTo("Hello World") + assertThat(authenticatedData?.let { String(it) }).isEqualTo("Additional data") + } + + @Test(timeout = 300_000) + fun extractUnauthenticatedAdditionalData() { + val ciphertext = encryptionService.encrypt("Hello World".toByteArray(), "Additional data".toByteArray()) + val additionalData = encryptionService.extractUnauthenticatedAdditionalData(ciphertext) + assertThat(additionalData?.let { String(it) }).isEqualTo("Additional data") + } + + @Test(timeout = 300_000) + fun `ciphertext cannot be decrypted if the authenticated data is modified`() { + val ciphertext = ByteBuffer.wrap(encryptionService.encrypt("Hello World".toByteArray(), "1234".toByteArray())) + + ciphertext.position(21) + ciphertext.put("4321".toByteArray()) // Use same length for the modified AAD + + assertThatExceptionOfType(GeneralSecurityException::class.java).isThrownBy { + encryptionService.decrypt(ciphertext.array()) + } + } + + @Test(timeout = 300_000) + fun `ciphertext cannot be decrypted if the key used is deleted`() { + val ciphertext = encryptionService.encrypt("Hello World".toByteArray()) + val keyId = ByteBuffer.wrap(ciphertext).getKeyId() + val deletedCount = database.transaction { + session.createQuery("DELETE FROM ${EncryptionKeyRecord::class.java.name} k WHERE k.keyId = :keyId") + .setParameter("keyId", keyId) + .executeUpdate() + } + assertThat(deletedCount).isEqualTo(1) + + encryptionService = AesDbEncryptionService(database) + encryptionService.start(identity) + assertThatIllegalArgumentException().isThrownBy { + encryptionService.decrypt(ciphertext) + } + } + + @Test(timeout = 300_000) + fun `ciphertext cannot be decrypted if forced to use a different key`() { + val ciphertext = ByteBuffer.wrap(encryptionService.encrypt("Hello World".toByteArray())) + val keyId = ciphertext.getKeyId() + val anotherKeyId = database.transaction { + session.createQuery("SELECT keyId FROM ${EncryptionKeyRecord::class.java.name} k WHERE k.keyId != :keyId", UUID::class.java) + .setParameter("keyId", keyId) + .setMaxResults(1) + .singleResult + } + + ciphertext.putKeyId(anotherKeyId) + + encryptionService = AesDbEncryptionService(database) + encryptionService.start(identity) + assertThatExceptionOfType(GeneralSecurityException::class.java).isThrownBy { + encryptionService.decrypt(ciphertext.array()) + } + } + + private fun ByteBuffer.getKeyId(): UUID { + position(1) + return getUUID() + } + + private fun ByteBuffer.putKeyId(keyId: UUID) { + position(1) + putUUID(keyId) + } +} diff --git a/node/src/test/kotlin/net/corda/node/services/persistence/DBTransactionStorageLedgerRecoveryTests.kt b/node/src/test/kotlin/net/corda/node/services/persistence/DBTransactionStorageLedgerRecoveryTests.kt index 5f52c0849f..e38079536d 100644 --- a/node/src/test/kotlin/net/corda/node/services/persistence/DBTransactionStorageLedgerRecoveryTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/persistence/DBTransactionStorageLedgerRecoveryTests.kt @@ -24,7 +24,6 @@ import net.corda.node.services.network.PersistentPartyInfoCache import net.corda.node.services.persistence.DBTransactionStorage.TransactionStatus.IN_FLIGHT import net.corda.node.services.persistence.DBTransactionStorage.TransactionStatus.VERIFIED import net.corda.nodeapi.internal.DEV_ROOT_CA -import net.corda.nodeapi.internal.cryptoservice.CryptoService import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.nodeapi.internal.persistence.DatabaseConfig import net.corda.testing.core.ALICE_NAME @@ -38,7 +37,8 @@ import net.corda.testing.internal.TestingNamedCacheFactory import net.corda.testing.internal.configureDatabase import net.corda.testing.internal.createWireTransaction import net.corda.testing.node.MockServices.Companion.makeTestDataSourceProperties -import net.corda.testing.node.internal.MockCryptoService +import net.corda.testing.node.internal.MockEncryptionService +import org.assertj.core.api.Assertions.assertThat import org.junit.After import org.junit.Before import org.junit.Rule @@ -67,6 +67,8 @@ class DBTransactionStorageLedgerRecoveryTests { private lateinit var transactionRecovery: DBTransactionStorageLedgerRecovery private lateinit var partyInfoCache: PersistentPartyInfoCache + private val encryptionService = MockEncryptionService() + @Before fun setUp() { val dataSourceProps = makeTestDataSourceProperties() @@ -278,17 +280,21 @@ class DBTransactionStorageLedgerRecoveryTests { @Test(timeout = 300_000) fun `test lightweight serialization and deserialization of hashed distribution list payload`() { - val dl = HashedDistributionList(ALL_VISIBLE, - mapOf(BOB.name.hashCode().toLong() to NONE, CHARLIE_NAME.hashCode().toLong() to ONLY_RELEVANT), now()) - assertEquals(dl, dl.serialize().let { HashedDistributionList.deserialize(it) }) + val hashedDistList = HashedDistributionList( + ALL_VISIBLE, + mapOf(BOB.name.hashCode().toLong() to NONE, CHARLIE_NAME.hashCode().toLong() to ONLY_RELEVANT), + HashedDistributionList.PublicHeader(now()) + ) + val roundtrip = HashedDistributionList.decrypt(hashedDistList.encrypt(encryptionService), encryptionService) + assertThat(roundtrip).isEqualTo(hashedDistList) } private fun readTransactionFromDB(id: SecureHash): DBTransactionStorage.DBTransaction { val fromDb = database.transaction { session.createQuery( - "from ${DBTransactionStorage.DBTransaction::class.java.name} where tx_id = :transactionId", + "from ${DBTransactionStorage.DBTransaction::class.java.name} where txId = :transactionId", DBTransactionStorage.DBTransaction::class.java - ).setParameter("transactionId", id.toString()).resultList.map { it } + ).setParameter("transactionId", id.toString()).resultList } assertEquals(1, fromDb.size) return fromDb[0] @@ -298,7 +304,7 @@ class DBTransactionStorageLedgerRecoveryTests { return database.transaction { if (id != null) session.createQuery( - "from ${DBTransactionStorageLedgerRecovery.DBSenderDistributionRecord::class.java.name} where tx_id = :transactionId", + "from ${DBTransactionStorageLedgerRecovery.DBSenderDistributionRecord::class.java.name} where txId = :transactionId", DBTransactionStorageLedgerRecovery.DBSenderDistributionRecord::class.java ).setParameter("transactionId", id.toString()).resultList.map { it.toSenderDistributionRecord() } else @@ -312,17 +318,15 @@ class DBTransactionStorageLedgerRecoveryTests { private fun readReceiverDistributionRecordFromDB(id: SecureHash): ReceiverDistributionRecord { val fromDb = database.transaction { session.createQuery( - "from ${DBTransactionStorageLedgerRecovery.DBReceiverDistributionRecord::class.java.name} where tx_id = :transactionId", + "from ${DBTransactionStorageLedgerRecovery.DBReceiverDistributionRecord::class.java.name} where txId = :transactionId", DBTransactionStorageLedgerRecovery.DBReceiverDistributionRecord::class.java - ).setParameter("transactionId", id.toString()).resultList.map { it } + ).setParameter("transactionId", id.toString()).resultList } assertEquals(1, fromDb.size) - return fromDb[0].toReceiverDistributionRecord(MockCryptoService(emptyMap())) + return fromDb[0].toReceiverDistributionRecord(encryptionService) } - private fun newTransactionRecovery(cacheSizeBytesOverride: Long? = null, clock: CordaClock = SimpleClock(Clock.systemUTC()), - cryptoService: CryptoService = MockCryptoService(emptyMap())) { - + private fun newTransactionRecovery(cacheSizeBytesOverride: Long? = null, clock: CordaClock = SimpleClock(Clock.systemUTC())) { val networkMapCache = PersistentNetworkMapCache(TestingNamedCacheFactory(), database, InMemoryIdentityService(trustRoot = DEV_ROOT_CA.certificate)) val alice = createNodeInfo(listOf(ALICE)) val bob = createNodeInfo(listOf(BOB)) @@ -330,8 +334,13 @@ class DBTransactionStorageLedgerRecoveryTests { networkMapCache.addOrUpdateNodes(listOf(alice, bob, charlie)) partyInfoCache = PersistentPartyInfoCache(networkMapCache, TestingNamedCacheFactory(), database) partyInfoCache.start() - transactionRecovery = DBTransactionStorageLedgerRecovery(database, TestingNamedCacheFactory(cacheSizeBytesOverride - ?: 1024), clock, cryptoService, partyInfoCache) + transactionRecovery = DBTransactionStorageLedgerRecovery( + database, + TestingNamedCacheFactory(cacheSizeBytesOverride ?: 1024), + clock, + encryptionService, + partyInfoCache + ) } private var portCounter = 1000 @@ -370,10 +379,13 @@ class DBTransactionStorageLedgerRecoveryTests { private fun notarySig(txId: SecureHash) = DUMMY_NOTARY.keyPair.sign(SignableData(txId, SignatureMetadata(1, Crypto.findSignatureScheme(DUMMY_NOTARY.publicKey).schemeNumberID))) - private fun DistributionList.toWire(cryptoService: CryptoService = MockCryptoService(emptyMap())): ByteArray { - val hashedPeersToStatesToRecord = this.peersToStatesToRecord.map { (peer, statesToRecord) -> - partyInfoCache.getPartyIdByCordaX500Name(peer) to statesToRecord }.toMap() - val hashedDistributionList = HashedDistributionList(this.senderStatesToRecord, hashedPeersToStatesToRecord, now()) - return cryptoService.encrypt(hashedDistributionList.serialize()) + private fun DistributionList.toWire(): ByteArray { + val hashedPeersToStatesToRecord = this.peersToStatesToRecord.mapKeys { (peer) -> partyInfoCache.getPartyIdByCordaX500Name(peer) } + val hashedDistributionList = HashedDistributionList( + this.senderStatesToRecord, + hashedPeersToStatesToRecord, + HashedDistributionList.PublicHeader(now()) + ) + return hashedDistributionList.encrypt(encryptionService) } } diff --git a/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/MockEncryptionService.kt b/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/MockEncryptionService.kt new file mode 100644 index 0000000000..1c3875191c --- /dev/null +++ b/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/MockEncryptionService.kt @@ -0,0 +1,39 @@ +package net.corda.testing.node.internal + +import net.corda.core.internal.copyBytes +import net.corda.node.services.EncryptionService +import net.corda.nodeapi.internal.crypto.AesEncryption +import java.nio.ByteBuffer +import javax.crypto.SecretKey + +class MockEncryptionService(private val aesKey: SecretKey = AesEncryption.randomKey()) : EncryptionService { + override fun encrypt(plaintext: ByteArray, additionalData: ByteArray?): ByteArray { + val ciphertext = AesEncryption.encrypt(aesKey, plaintext, additionalData) + val buffer = ByteBuffer.allocate(Integer.BYTES + (additionalData?.size ?: 0) + ciphertext.size) + if (additionalData != null) { + buffer.putInt(additionalData.size) + buffer.put(additionalData) + } else { + buffer.putInt(0) + } + buffer.put(ciphertext) + return buffer.array() + } + + override fun decrypt(ciphertext: ByteArray): EncryptionService.PlaintextAndAAD { + val buffer = ByteBuffer.wrap(ciphertext) + val additionalData = buffer.getAdditionaData() + val plaintext = AesEncryption.decrypt(aesKey, buffer.copyBytes(), additionalData) + // Only now is the additional data authenticated + return EncryptionService.PlaintextAndAAD(plaintext, additionalData) + } + + override fun extractUnauthenticatedAdditionalData(ciphertext: ByteArray): ByteArray? { + return ByteBuffer.wrap(ciphertext).getAdditionaData() + } + + private fun ByteBuffer.getAdditionaData(): ByteArray? { + val additionalDataSize = getInt() + return if (additionalDataSize > 0) ByteArray(additionalDataSize).also { get(it) } else null + } +} diff --git a/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/MockTransactionStorage.kt b/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/MockTransactionStorage.kt index f850aab58b..9f23bf6beb 100644 --- a/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/MockTransactionStorage.kt +++ b/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/MockTransactionStorage.kt @@ -61,9 +61,13 @@ open class MockTransactionStorage : WritableTransactionStorage, SingletonSeriali return txns.putIfAbsent(transaction.id, TxHolder(transaction, status = TransactionStatus.IN_FLIGHT)) == null } - override fun addSenderTransactionRecoveryMetadata(id: SecureHash, metadata: TransactionMetadata): ByteArray? { return null } + override fun addSenderTransactionRecoveryMetadata(txId: SecureHash, metadata: TransactionMetadata): ByteArray? { return null } - override fun addReceiverTransactionRecoveryMetadata(id: SecureHash, sender: CordaX500Name, receiver: CordaX500Name, receiverStatesToRecord: StatesToRecord, encryptedDistributionList: ByteArray) { } + override fun addReceiverTransactionRecoveryMetadata(txId: SecureHash, + sender: CordaX500Name, + receiver: CordaX500Name, + receiverStatesToRecord: StatesToRecord, + encryptedDistributionList: ByteArray) { } override fun removeUnnotarisedTransaction(id: SecureHash): Boolean { return txns.remove(id) != null From 492373d1802e6b02a19b7b3171670d4db8a85565 Mon Sep 17 00:00:00 2001 From: Jose Coll Date: Wed, 16 Aug 2023 17:02:58 +0100 Subject: [PATCH 02/10] Introduction of Sender and Receiver Distribution Lists to support receiver self-recovery mode. --- .../net/corda/core/flows/FlowTransaction.kt | 22 ++++++--- .../corda/core/flows/SendTransactionFlow.kt | 7 +-- .../DBTransactionStorageLedgerRecovery.kt | 10 ++-- ...DBTransactionStorageLedgerRecoveryTests.kt | 48 +++++++++---------- 4 files changed, 50 insertions(+), 37 deletions(-) diff --git a/core/src/main/kotlin/net/corda/core/flows/FlowTransaction.kt b/core/src/main/kotlin/net/corda/core/flows/FlowTransaction.kt index 05539f7480..b213c6dbd0 100644 --- a/core/src/main/kotlin/net/corda/core/flows/FlowTransaction.kt +++ b/core/src/main/kotlin/net/corda/core/flows/FlowTransaction.kt @@ -23,15 +23,25 @@ data class FlowTransactionInfo( @CordaSerializable data class TransactionMetadata( - val initiator: CordaX500Name, - val distributionList: DistributionList + val initiator: CordaX500Name, + val distributionList: DistributionList ) @CordaSerializable -data class DistributionList( - val senderStatesToRecord: StatesToRecord, - val peersToStatesToRecord: Map -) +sealed class DistributionList { + + @CordaSerializable + data class SenderDistributionList( + val senderStatesToRecord: StatesToRecord, + val peersToStatesToRecord: Map + ) : DistributionList() + + @CordaSerializable + data class ReceiverDistributionList( + val opaqueData: ByteArray, // decipherable only by sender + val receiverStatesToRecord: StatesToRecord // inferred or actual + ) : DistributionList() +} @CordaSerializable enum class TransactionStatus { diff --git a/core/src/main/kotlin/net/corda/core/flows/SendTransactionFlow.kt b/core/src/main/kotlin/net/corda/core/flows/SendTransactionFlow.kt index 93ed7d0c97..07247b3b46 100644 --- a/core/src/main/kotlin/net/corda/core/flows/SendTransactionFlow.kt +++ b/core/src/main/kotlin/net/corda/core/flows/SendTransactionFlow.kt @@ -23,6 +23,7 @@ import kotlin.collections.map import kotlin.collections.mutableSetOf import kotlin.collections.plus import kotlin.collections.toSet +import net.corda.core.flows.DistributionList.SenderDistributionList /** * In the words of Matt working code is more important then pretty code. This class that contains code that may @@ -98,9 +99,9 @@ open class SendTransactionFlow(val stx: SignedTransaction, fun makeMetaData(stx: SignedTransaction, recordMetaDataEvenIfNotFullySigned: Boolean, senderStatesToRecord: StatesToRecord, participantSessions: Set, observerSessions: Set): TransactionMetadata? { return if (recordMetaDataEvenIfNotFullySigned || isFullySigned(stx)) TransactionMetadata(DUMMY_PARTICIPANT_NAME, - DistributionList(senderStatesToRecord, - (participantSessions.map { it.counterparty.name to StatesToRecord.ONLY_RELEVANT}).toMap() + - (observerSessions.map { it.counterparty.name to StatesToRecord.ALL_VISIBLE}).toMap())) + SenderDistributionList(senderStatesToRecord, + (participantSessions.map { it.counterparty.name to StatesToRecord.ONLY_RELEVANT }).toMap() + + (observerSessions.map { it.counterparty.name to StatesToRecord.ALL_VISIBLE }).toMap())) else null } diff --git a/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorageLedgerRecovery.kt b/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorageLedgerRecovery.kt index 6c101a8401..4492c27724 100644 --- a/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorageLedgerRecovery.kt +++ b/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorageLedgerRecovery.kt @@ -1,6 +1,7 @@ package net.corda.node.services.persistence import net.corda.core.crypto.SecureHash +import net.corda.core.flows.DistributionList.SenderDistributionList import net.corda.core.flows.RecoveryTimeWindow import net.corda.core.flows.TransactionMetadata import net.corda.core.identity.CordaX500Name @@ -139,21 +140,22 @@ class DBTransactionStorageLedgerRecovery(private val database: CordaPersistence, override fun addSenderTransactionRecoveryMetadata(txId: SecureHash, metadata: TransactionMetadata): ByteArray { return database.transaction { val senderRecordingTimestamp = clock.instant() - for (peer in metadata.distributionList.peersToStatesToRecord.keys) { + val distributionList = metadata.distributionList as? SenderDistributionList ?: throw IllegalStateException("Expecting SenderDistributionList") + for (peer in distributionList.peersToStatesToRecord.keys) { val senderDistributionRecord = DBSenderDistributionRecord( PersistentKey(Key(senderRecordingTimestamp)), txId.toString(), partyInfoCache.getPartyIdByCordaX500Name(peer), - metadata.distributionList.senderStatesToRecord + distributionList.senderStatesToRecord ) session.save(senderDistributionRecord) } - val hashedPeersToStatesToRecord = metadata.distributionList.peersToStatesToRecord.mapKeys { (peer) -> + val hashedPeersToStatesToRecord = distributionList.peersToStatesToRecord.mapKeys { (peer) -> partyInfoCache.getPartyIdByCordaX500Name(peer) } val hashedDistributionList = HashedDistributionList( - metadata.distributionList.senderStatesToRecord, + distributionList.senderStatesToRecord, hashedPeersToStatesToRecord, HashedDistributionList.PublicHeader(senderRecordingTimestamp) ) diff --git a/node/src/test/kotlin/net/corda/node/services/persistence/DBTransactionStorageLedgerRecoveryTests.kt b/node/src/test/kotlin/net/corda/node/services/persistence/DBTransactionStorageLedgerRecoveryTests.kt index e38079536d..9a0a8b194d 100644 --- a/node/src/test/kotlin/net/corda/node/services/persistence/DBTransactionStorageLedgerRecoveryTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/persistence/DBTransactionStorageLedgerRecoveryTests.kt @@ -6,7 +6,7 @@ import net.corda.core.crypto.SecureHash import net.corda.core.crypto.SignableData import net.corda.core.crypto.SignatureMetadata import net.corda.core.crypto.sign -import net.corda.core.flows.DistributionList +import net.corda.core.flows.DistributionList.SenderDistributionList import net.corda.core.flows.RecoveryTimeWindow import net.corda.core.flows.TransactionMetadata import net.corda.core.node.NodeInfo @@ -86,7 +86,7 @@ class DBTransactionStorageLedgerRecoveryTests { val beforeFirstTxn = now() val txn = newTransaction() transactionRecovery.addUnnotarisedTransaction(txn) - transactionRecovery.addSenderTransactionRecoveryMetadata(txn.id, TransactionMetadata(ALICE_NAME, DistributionList(ALL_VISIBLE, mapOf(BOB_NAME to ONLY_RELEVANT)))) + transactionRecovery.addSenderTransactionRecoveryMetadata(txn.id, TransactionMetadata(ALICE_NAME, SenderDistributionList(ALL_VISIBLE, mapOf(BOB_NAME to ONLY_RELEVANT)))) val timeWindow = RecoveryTimeWindow(fromTime = beforeFirstTxn, untilTime = beforeFirstTxn.plus(1, ChronoUnit.MINUTES)) val results = transactionRecovery.querySenderDistributionRecords(timeWindow) @@ -95,7 +95,7 @@ class DBTransactionStorageLedgerRecoveryTests { val afterFirstTxn = now() val txn2 = newTransaction() transactionRecovery.addUnnotarisedTransaction(txn2) - transactionRecovery.addSenderTransactionRecoveryMetadata(txn2.id, TransactionMetadata(ALICE_NAME, DistributionList(ONLY_RELEVANT, mapOf(CHARLIE_NAME to ONLY_RELEVANT)))) + transactionRecovery.addSenderTransactionRecoveryMetadata(txn2.id, TransactionMetadata(ALICE_NAME, SenderDistributionList(ONLY_RELEVANT, mapOf(CHARLIE_NAME to ONLY_RELEVANT)))) assertEquals(2, transactionRecovery.querySenderDistributionRecords(timeWindow).size) assertEquals(1, transactionRecovery.querySenderDistributionRecords(RecoveryTimeWindow(fromTime = afterFirstTxn)).size) } @@ -104,10 +104,10 @@ class DBTransactionStorageLedgerRecoveryTests { fun `query local ledger for transactions within timeWindow and excluding remoteTransactionIds`() { val transaction1 = newTransaction() transactionRecovery.addUnnotarisedTransaction(transaction1) - transactionRecovery.addSenderTransactionRecoveryMetadata(transaction1.id, TransactionMetadata(ALICE_NAME, DistributionList(ALL_VISIBLE, mapOf(BOB_NAME to ONLY_RELEVANT)))) + transactionRecovery.addSenderTransactionRecoveryMetadata(transaction1.id, TransactionMetadata(ALICE_NAME, SenderDistributionList(ALL_VISIBLE, mapOf(BOB_NAME to ONLY_RELEVANT)))) val transaction2 = newTransaction() transactionRecovery.addUnnotarisedTransaction(transaction2) - transactionRecovery.addSenderTransactionRecoveryMetadata(transaction2.id, TransactionMetadata(ALICE_NAME, DistributionList(ALL_VISIBLE, mapOf(BOB_NAME to ONLY_RELEVANT)))) + transactionRecovery.addSenderTransactionRecoveryMetadata(transaction2.id, TransactionMetadata(ALICE_NAME, SenderDistributionList(ALL_VISIBLE, mapOf(BOB_NAME to ONLY_RELEVANT)))) val timeWindow = RecoveryTimeWindow(fromTime = now().minus(1, ChronoUnit.DAYS)) val results = transactionRecovery.querySenderDistributionRecords(timeWindow, excludingTxnIds = setOf(transaction1.id)) assertEquals(1, results.size) @@ -118,12 +118,12 @@ class DBTransactionStorageLedgerRecoveryTests { val transaction1 = newTransaction() // sender txn transactionRecovery.addUnnotarisedTransaction(transaction1) - transactionRecovery.addSenderTransactionRecoveryMetadata(transaction1.id, TransactionMetadata(ALICE_NAME, DistributionList(ALL_VISIBLE, mapOf(BOB_NAME to ALL_VISIBLE)))) + transactionRecovery.addSenderTransactionRecoveryMetadata(transaction1.id, TransactionMetadata(ALICE_NAME, SenderDistributionList(ALL_VISIBLE, mapOf(BOB_NAME to ALL_VISIBLE)))) val transaction2 = newTransaction() // receiver txn transactionRecovery.addUnnotarisedTransaction(transaction2) transactionRecovery.addReceiverTransactionRecoveryMetadata(transaction2.id, BOB_NAME, ALICE_NAME, ALL_VISIBLE, - DistributionList(ONLY_RELEVANT, mapOf(ALICE_NAME to ALL_VISIBLE)).toWire()) + SenderDistributionList(ONLY_RELEVANT, mapOf(ALICE_NAME to ALL_VISIBLE)).toWire()) val timeWindow = RecoveryTimeWindow(fromTime = now().minus(1, ChronoUnit.DAYS)) transactionRecovery.queryDistributionRecords(timeWindow, recordType = DistributionRecordType.SENDER).let { assertEquals(1, it.size) @@ -143,19 +143,19 @@ class DBTransactionStorageLedgerRecoveryTests { fun `query for sender distribution records by peers`() { val txn1 = newTransaction() transactionRecovery.addUnnotarisedTransaction(txn1) - transactionRecovery.addSenderTransactionRecoveryMetadata(txn1.id, TransactionMetadata(ALICE_NAME, DistributionList(ALL_VISIBLE, mapOf(BOB_NAME to ALL_VISIBLE)))) + transactionRecovery.addSenderTransactionRecoveryMetadata(txn1.id, TransactionMetadata(ALICE_NAME, SenderDistributionList(ALL_VISIBLE, mapOf(BOB_NAME to ALL_VISIBLE)))) val txn2 = newTransaction() transactionRecovery.addUnnotarisedTransaction(txn2) - transactionRecovery.addSenderTransactionRecoveryMetadata(txn2.id, TransactionMetadata(ALICE_NAME, DistributionList(ONLY_RELEVANT, mapOf(CHARLIE_NAME to ONLY_RELEVANT)))) + transactionRecovery.addSenderTransactionRecoveryMetadata(txn2.id, TransactionMetadata(ALICE_NAME, SenderDistributionList(ONLY_RELEVANT, mapOf(CHARLIE_NAME to ONLY_RELEVANT)))) val txn3 = newTransaction() transactionRecovery.addUnnotarisedTransaction(txn3) - transactionRecovery.addSenderTransactionRecoveryMetadata(txn3.id, TransactionMetadata(ALICE_NAME, DistributionList(ONLY_RELEVANT, mapOf(BOB_NAME to ONLY_RELEVANT, CHARLIE_NAME to ALL_VISIBLE)))) + transactionRecovery.addSenderTransactionRecoveryMetadata(txn3.id, TransactionMetadata(ALICE_NAME, SenderDistributionList(ONLY_RELEVANT, mapOf(BOB_NAME to ONLY_RELEVANT, CHARLIE_NAME to ALL_VISIBLE)))) val txn4 = newTransaction() transactionRecovery.addUnnotarisedTransaction(txn4) - transactionRecovery.addSenderTransactionRecoveryMetadata(txn4.id, TransactionMetadata(BOB_NAME, DistributionList(ONLY_RELEVANT, mapOf(ALICE_NAME to ONLY_RELEVANT)))) + transactionRecovery.addSenderTransactionRecoveryMetadata(txn4.id, TransactionMetadata(BOB_NAME, SenderDistributionList(ONLY_RELEVANT, mapOf(ALICE_NAME to ONLY_RELEVANT)))) val txn5 = newTransaction() transactionRecovery.addUnnotarisedTransaction(txn5) - transactionRecovery.addSenderTransactionRecoveryMetadata(txn5.id, TransactionMetadata(CHARLIE_NAME, DistributionList(ONLY_RELEVANT, emptyMap()))) + transactionRecovery.addSenderTransactionRecoveryMetadata(txn5.id, TransactionMetadata(CHARLIE_NAME, SenderDistributionList(ONLY_RELEVANT, emptyMap()))) assertEquals(5, readSenderDistributionRecordFromDB().size) val timeWindow = RecoveryTimeWindow(fromTime = now().minus(1, ChronoUnit.DAYS)) @@ -173,23 +173,23 @@ class DBTransactionStorageLedgerRecoveryTests { val txn1 = newTransaction() transactionRecovery.addUnnotarisedTransaction(txn1) transactionRecovery.addReceiverTransactionRecoveryMetadata(txn1.id, ALICE_NAME, BOB_NAME, ALL_VISIBLE, - DistributionList(ONLY_RELEVANT, mapOf(BOB_NAME to ALL_VISIBLE, CHARLIE_NAME to ALL_VISIBLE)).toWire()) + SenderDistributionList(ONLY_RELEVANT, mapOf(BOB_NAME to ALL_VISIBLE, CHARLIE_NAME to ALL_VISIBLE)).toWire()) val txn2 = newTransaction() transactionRecovery.addUnnotarisedTransaction(txn2) transactionRecovery.addReceiverTransactionRecoveryMetadata(txn2.id, ALICE_NAME, BOB_NAME, ONLY_RELEVANT, - DistributionList(ONLY_RELEVANT, mapOf(BOB_NAME to ONLY_RELEVANT)).toWire()) + SenderDistributionList(ONLY_RELEVANT, mapOf(BOB_NAME to ONLY_RELEVANT)).toWire()) val txn3 = newTransaction() transactionRecovery.addUnnotarisedTransaction(txn3) transactionRecovery.addReceiverTransactionRecoveryMetadata(txn3.id, ALICE_NAME, CHARLIE_NAME, NONE, - DistributionList(ONLY_RELEVANT, mapOf(CHARLIE_NAME to NONE)).toWire()) + SenderDistributionList(ONLY_RELEVANT, mapOf(CHARLIE_NAME to NONE)).toWire()) val txn4 = newTransaction() transactionRecovery.addUnnotarisedTransaction(txn4) transactionRecovery.addReceiverTransactionRecoveryMetadata(txn4.id, BOB_NAME, ALICE_NAME, ONLY_RELEVANT, - DistributionList(ONLY_RELEVANT, mapOf(ALICE_NAME to ALL_VISIBLE)).toWire()) + SenderDistributionList(ONLY_RELEVANT, mapOf(ALICE_NAME to ALL_VISIBLE)).toWire()) val txn5 = newTransaction() transactionRecovery.addUnnotarisedTransaction(txn5) transactionRecovery.addReceiverTransactionRecoveryMetadata(txn5.id, CHARLIE_NAME, BOB_NAME, ONLY_RELEVANT, - DistributionList(ONLY_RELEVANT, mapOf(BOB_NAME to ONLY_RELEVANT)).toWire()) + SenderDistributionList(ONLY_RELEVANT, mapOf(BOB_NAME to ONLY_RELEVANT)).toWire()) val timeWindow = RecoveryTimeWindow(fromTime = now().minus(1, ChronoUnit.DAYS)) transactionRecovery.queryReceiverDistributionRecords(timeWindow, initiators = setOf(ALICE_NAME)).let { @@ -207,7 +207,7 @@ class DBTransactionStorageLedgerRecoveryTests { fun `transaction without peers does not store recovery metadata in database`() { val senderTransaction = newTransaction() transactionRecovery.addUnnotarisedTransaction(senderTransaction) - transactionRecovery.addSenderTransactionRecoveryMetadata(senderTransaction.id, TransactionMetadata(ALICE_NAME, DistributionList(ONLY_RELEVANT, emptyMap()))) + transactionRecovery.addSenderTransactionRecoveryMetadata(senderTransaction.id, TransactionMetadata(ALICE_NAME, SenderDistributionList(ONLY_RELEVANT, emptyMap()))) assertEquals(IN_FLIGHT, readTransactionFromDB(senderTransaction.id).status) assertEquals(0, readSenderDistributionRecordFromDB(senderTransaction.id).size) } @@ -217,7 +217,7 @@ class DBTransactionStorageLedgerRecoveryTests { val senderTransaction = newTransaction() transactionRecovery.addUnnotarisedTransaction(senderTransaction) transactionRecovery.addSenderTransactionRecoveryMetadata(senderTransaction.id, - TransactionMetadata(ALICE_NAME, DistributionList(ALL_VISIBLE, mapOf(BOB_NAME to ALL_VISIBLE)))) + TransactionMetadata(ALICE_NAME, SenderDistributionList(ALL_VISIBLE, mapOf(BOB_NAME to ALL_VISIBLE)))) assertEquals(IN_FLIGHT, readTransactionFromDB(senderTransaction.id).status) readSenderDistributionRecordFromDB(senderTransaction.id).let { assertEquals(1, it.size) @@ -228,7 +228,7 @@ class DBTransactionStorageLedgerRecoveryTests { val receiverTransaction = newTransaction() transactionRecovery.addUnnotarisedTransaction(receiverTransaction) transactionRecovery.addReceiverTransactionRecoveryMetadata(receiverTransaction.id, ALICE_NAME, BOB_NAME, ALL_VISIBLE, - DistributionList(ONLY_RELEVANT, mapOf(BOB_NAME to ALL_VISIBLE)).toWire()) + SenderDistributionList(ONLY_RELEVANT, mapOf(BOB_NAME to ALL_VISIBLE)).toWire()) assertEquals(IN_FLIGHT, readTransactionFromDB(receiverTransaction.id).status) readReceiverDistributionRecordFromDB(receiverTransaction.id).let { assertEquals(ALL_VISIBLE, it.statesToRecord) @@ -243,7 +243,7 @@ class DBTransactionStorageLedgerRecoveryTests { val transaction = newTransaction(notarySig = false) transactionRecovery.finalizeTransaction(transaction) transactionRecovery.addSenderTransactionRecoveryMetadata(transaction.id, - TransactionMetadata(ALICE_NAME, DistributionList(ONLY_RELEVANT, mapOf(CHARLIE_NAME to ALL_VISIBLE)))) + TransactionMetadata(ALICE_NAME, SenderDistributionList(ONLY_RELEVANT, mapOf(CHARLIE_NAME to ALL_VISIBLE)))) assertEquals(VERIFIED, readTransactionFromDB(transaction.id).status) readSenderDistributionRecordFromDB(transaction.id).apply { assertEquals(1, this.size) @@ -256,7 +256,7 @@ class DBTransactionStorageLedgerRecoveryTests { val senderTransaction = newTransaction(notarySig = false) transactionRecovery.addUnnotarisedTransaction(senderTransaction) transactionRecovery.addReceiverTransactionRecoveryMetadata(senderTransaction.id, ALICE.name, BOB.name, ONLY_RELEVANT, - DistributionList(ONLY_RELEVANT, mapOf(BOB.name to ONLY_RELEVANT, CHARLIE_NAME to ONLY_RELEVANT)).toWire()) + SenderDistributionList(ONLY_RELEVANT, mapOf(BOB.name to ONLY_RELEVANT, CHARLIE_NAME to ONLY_RELEVANT)).toWire()) assertNull(transactionRecovery.getTransaction(senderTransaction.id)) assertEquals(IN_FLIGHT, readTransactionFromDB(senderTransaction.id).status) @@ -268,7 +268,7 @@ class DBTransactionStorageLedgerRecoveryTests { val receiverTransaction = newTransaction(notarySig = false) transactionRecovery.addUnnotarisedTransaction(receiverTransaction) transactionRecovery.addReceiverTransactionRecoveryMetadata(receiverTransaction.id, ALICE.name, BOB.name, ONLY_RELEVANT, - DistributionList(ONLY_RELEVANT, mapOf(BOB.name to ONLY_RELEVANT)).toWire()) + SenderDistributionList(ONLY_RELEVANT, mapOf(BOB.name to ONLY_RELEVANT)).toWire()) assertNull(transactionRecovery.getTransaction(receiverTransaction.id)) assertEquals(IN_FLIGHT, readTransactionFromDB(receiverTransaction.id).status) @@ -379,7 +379,7 @@ class DBTransactionStorageLedgerRecoveryTests { private fun notarySig(txId: SecureHash) = DUMMY_NOTARY.keyPair.sign(SignableData(txId, SignatureMetadata(1, Crypto.findSignatureScheme(DUMMY_NOTARY.publicKey).schemeNumberID))) - private fun DistributionList.toWire(): ByteArray { + private fun SenderDistributionList.toWire(): ByteArray { val hashedPeersToStatesToRecord = this.peersToStatesToRecord.mapKeys { (peer) -> partyInfoCache.getPartyIdByCordaX500Name(peer) } val hashedDistributionList = HashedDistributionList( this.senderStatesToRecord, From 9b7affa6b303a71376ff1a416df9e9f80d8d8fa2 Mon Sep 17 00:00:00 2001 From: Jose Coll Date: Wed, 16 Aug 2023 17:40:33 +0100 Subject: [PATCH 03/10] Fix compilation errors following merge. --- .../coretests/flows/FinalityFlowTests.kt | 3 -- .../DBTransactionStorageLedgerRecovery.kt | 51 +------------------ 2 files changed, 2 insertions(+), 52 deletions(-) diff --git a/core-tests/src/test/kotlin/net/corda/coretests/flows/FinalityFlowTests.kt b/core-tests/src/test/kotlin/net/corda/coretests/flows/FinalityFlowTests.kt index 277061deae..9d82164ccc 100644 --- a/core-tests/src/test/kotlin/net/corda/coretests/flows/FinalityFlowTests.kt +++ b/core-tests/src/test/kotlin/net/corda/coretests/flows/FinalityFlowTests.kt @@ -359,7 +359,6 @@ class FinalityFlowTests : WithFinality { } getReceiverRecoveryData(stx.id, bobNode, aliceNode).apply { assertEquals(StatesToRecord.ALL_VISIBLE, this?.statesToRecord) - assertEquals(StatesToRecord.ONLY_RELEVANT, this?.senderStatesToRecord) assertEquals(aliceNode.info.singleIdentity().name.hashCode().toLong(), this?.initiatorPartyId) assertEquals(mapOf(BOB_NAME.hashCode().toLong() to StatesToRecord.ALL_VISIBLE), this?.peersToStatesToRecord) } @@ -392,7 +391,6 @@ class FinalityFlowTests : WithFinality { } getReceiverRecoveryData(stx.id, bobNode, aliceNode).apply { assertEquals(StatesToRecord.ONLY_RELEVANT, this?.statesToRecord) - assertEquals(StatesToRecord.ONLY_RELEVANT, this?.senderStatesToRecord) assertEquals(aliceNode.info.singleIdentity().name.hashCode().toLong(), this?.initiatorPartyId) // note: Charlie assertion here is using the hinted StatesToRecord value passed to it from Alice assertEquals(mapOf(BOB_NAME.hashCode().toLong() to StatesToRecord.ONLY_RELEVANT, @@ -435,7 +433,6 @@ class FinalityFlowTests : WithFinality { } getReceiverRecoveryData(stx.id, bobNode, aliceNode).apply { assertEquals(StatesToRecord.ONLY_RELEVANT, this?.statesToRecord) - assertEquals(StatesToRecord.ONLY_RELEVANT, this?.senderStatesToRecord) assertEquals(aliceNode.info.singleIdentity().name.hashCode().toLong(), this?.initiatorPartyId) assertEquals(mapOf(BOB_NAME.hashCode().toLong() to StatesToRecord.ONLY_RELEVANT), this?.peersToStatesToRecord) } diff --git a/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorageLedgerRecovery.kt b/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorageLedgerRecovery.kt index 86cfdb66f5..ad234a6500 100644 --- a/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorageLedgerRecovery.kt +++ b/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorageLedgerRecovery.kt @@ -6,7 +6,6 @@ import net.corda.core.flows.RecoveryTimeWindow import net.corda.core.flows.TransactionMetadata import net.corda.core.identity.CordaX500Name import net.corda.core.internal.NamedCacheFactory -import net.corda.core.internal.VisibleForTesting import net.corda.core.node.StatesToRecord import net.corda.core.node.services.vault.Sort import net.corda.core.serialization.CordaSerializable @@ -164,31 +163,6 @@ class DBTransactionStorageLedgerRecovery(private val database: CordaPersistence, } } - fun createReceiverTransactionRecoverMetadata(txId: SecureHash, - senderPartyId: Long, - senderStatesToRecord: StatesToRecord, - senderRecords: List): List { - val senderRecordsByTimestampKey = senderRecords.groupBy { TimestampKey(it.compositeKey.timestamp, it.compositeKey.timestampDiscriminator) } - return senderRecordsByTimestampKey.map { - val hashedDistributionList = HashedDistributionList( - senderStatesToRecord = senderStatesToRecord, - peerHashToStatesToRecord = senderRecords.map { it.compositeKey.peerPartyId to it.statesToRecord }.toMap(), - senderRecordedTimestamp = it.key.timestamp - ) - DBReceiverDistributionRecord( - compositeKey = PersistentKey(Key(TimestampKey(it.key.timestamp, it.key.timestampDiscriminator), senderPartyId)), - txId = txId.toString(), - distributionList = cryptoService.encrypt(hashedDistributionList.serialize()) - ) - } - } - - fun addSenderTransactionRecoveryMetadata(record: DBSenderDistributionRecord) { - return database.transaction { - session.save(record) - } - } - override fun addReceiverTransactionRecoveryMetadata(txId: SecureHash, sender: CordaX500Name, receiver: CordaX500Name, @@ -206,12 +180,6 @@ class DBTransactionStorageLedgerRecovery(private val database: CordaPersistence, } } - fun addReceiverTransactionRecoveryMetadata(record: DBReceiverDistributionRecord) { - return database.transaction { - session.save(record) - } - } - override fun removeUnnotarisedTransaction(id: SecureHash): Boolean { return database.transaction { super.removeUnnotarisedTransaction(id) @@ -281,20 +249,7 @@ class DBTransactionStorageLedgerRecovery(private val database: CordaPersistence, } criteriaQuery.orderBy(orderCriteria) } - session.createQuery(criteriaQuery).stream().use { results -> - results.map { it.toSenderDistributionRecord() }.toList() - } - } - } - - fun querySenderDistributionRecordsByTxId(txId: SecureHash): List { - return database.transaction { - val criteriaBuilder = session.criteriaBuilder - val criteriaQuery = criteriaBuilder.createQuery(DBSenderDistributionRecord::class.java) - val txnMetadata = criteriaQuery.from(DBSenderDistributionRecord::class.java) - criteriaQuery.where(criteriaBuilder.equal(txnMetadata.get(DBSenderDistributionRecord::txId.name), txId.toString())) - val results = session.createQuery(criteriaQuery).stream() - results.toList() + session.createQuery(criteriaQuery).stream().toList() } } @@ -331,9 +286,7 @@ class DBTransactionStorageLedgerRecovery(private val database: CordaPersistence, } criteriaQuery.orderBy(orderCriteria) } - session.createQuery(criteriaQuery).stream().use { results -> - results.map { it.toReceiverDistributionRecord(encryptionService) }.toList() - } + session.createQuery(criteriaQuery).stream().toList() } } } From f565232f36b2f37755890ee4557470c048c23b93 Mon Sep 17 00:00:00 2001 From: Jose Coll Date: Wed, 16 Aug 2023 18:05:18 +0100 Subject: [PATCH 04/10] Fix compilation errors following merge. --- .../DBTransactionStorageLedgerRecoveryTests.kt | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/node/src/test/kotlin/net/corda/node/services/persistence/DBTransactionStorageLedgerRecoveryTests.kt b/node/src/test/kotlin/net/corda/node/services/persistence/DBTransactionStorageLedgerRecoveryTests.kt index e0f628acc7..8f759838cb 100644 --- a/node/src/test/kotlin/net/corda/node/services/persistence/DBTransactionStorageLedgerRecoveryTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/persistence/DBTransactionStorageLedgerRecoveryTests.kt @@ -118,10 +118,10 @@ class DBTransactionStorageLedgerRecoveryTests { fun `query local ledger for transactions within timeWindow and for given peers`() { val transaction1 = newTransaction() transactionRecovery.addUnnotarisedTransaction(transaction1) - transactionRecovery.addSenderTransactionRecoveryMetadata(transaction1.id, TransactionMetadata(ALICE_NAME, DistributionList(ALL_VISIBLE, mapOf(BOB_NAME to ONLY_RELEVANT)))) + transactionRecovery.addSenderTransactionRecoveryMetadata(transaction1.id, TransactionMetadata(ALICE_NAME, SenderDistributionList(ALL_VISIBLE, mapOf(BOB_NAME to ONLY_RELEVANT)))) val transaction2 = newTransaction() transactionRecovery.addUnnotarisedTransaction(transaction2) - transactionRecovery.addSenderTransactionRecoveryMetadata(transaction2.id, TransactionMetadata(ALICE_NAME, DistributionList(ALL_VISIBLE, mapOf(CHARLIE_NAME to ONLY_RELEVANT)))) + transactionRecovery.addSenderTransactionRecoveryMetadata(transaction2.id, TransactionMetadata(ALICE_NAME, SenderDistributionList(ALL_VISIBLE, mapOf(CHARLIE_NAME to ONLY_RELEVANT)))) val timeWindow = RecoveryTimeWindow(fromTime = now().minus(1, ChronoUnit.DAYS)) val results = transactionRecovery.querySenderDistributionRecords(timeWindow, peers = setOf(CHARLIE_NAME)) assertEquals(1, results.size) @@ -148,7 +148,7 @@ class DBTransactionStorageLedgerRecoveryTests { transactionRecovery.queryDistributionRecords(timeWindow, recordType = DistributionRecordType.RECEIVER).let { assertEquals(1, it.size) assertEquals(BOB_NAME.hashCode().toLong(), it.receiverRecords[0].compositeKey.peerPartyId) - assertEquals(ALL_VISIBLE, (transactionRecovery.decrypt(it.receiverRecords[0].distributionList).peerHashToStatesToRecord.map { it.value }[0])) + assertEquals(ALL_VISIBLE, (HashedDistributionList.decrypt(it.receiverRecords[0].distributionList, encryptionService)).peerHashToStatesToRecord.map { it.value }[0]) } val resultsAll = transactionRecovery.queryDistributionRecords(timeWindow, recordType = DistributionRecordType.ALL) assertEquals(2, resultsAll.size) @@ -209,9 +209,9 @@ class DBTransactionStorageLedgerRecoveryTests { val timeWindow = RecoveryTimeWindow(fromTime = now().minus(1, ChronoUnit.DAYS)) transactionRecovery.queryReceiverDistributionRecords(timeWindow, initiators = setOf(ALICE_NAME)).let { assertEquals(3, it.size) - assertEquals(transactionRecovery.decrypt(it[0].distributionList).peerHashToStatesToRecord.map { it.value }[0], ALL_VISIBLE) - assertEquals(transactionRecovery.decrypt(it[1].distributionList).peerHashToStatesToRecord.map { it.value }[0], ONLY_RELEVANT) - assertEquals(transactionRecovery.decrypt(it[2].distributionList).peerHashToStatesToRecord.map { it.value }[0], NONE) + assertEquals(HashedDistributionList.decrypt(it[0].distributionList, encryptionService).peerHashToStatesToRecord.map { it.value }[0], ALL_VISIBLE) + assertEquals(HashedDistributionList.decrypt(it[1].distributionList, encryptionService).peerHashToStatesToRecord.map { it.value }[0], ONLY_RELEVANT) + assertEquals(HashedDistributionList.decrypt(it[2].distributionList, encryptionService).peerHashToStatesToRecord.map { it.value }[0], NONE) } assertEquals(1, transactionRecovery.queryReceiverDistributionRecords(timeWindow, initiators = setOf(BOB_NAME)).size) assertEquals(1, transactionRecovery.queryReceiverDistributionRecords(timeWindow, initiators = setOf(CHARLIE_NAME)).size) @@ -405,6 +405,3 @@ class DBTransactionStorageLedgerRecoveryTests { } } -internal fun DBTransactionStorageLedgerRecovery.decrypt(distributionList: ByteArray): HashedDistributionList { - return HashedDistributionList.deserialize(this.cryptoService.decrypt(distributionList)) -} From 06e43eb9e2cb2e3cbed26bc3e21d3b5c8d2618bc Mon Sep 17 00:00:00 2001 From: Jose Coll Date: Thu, 17 Aug 2023 08:47:58 +0100 Subject: [PATCH 05/10] Fixes following merge. --- .ci/api-current.txt | 26 +++++++++++++++++-- .../DBTransactionStorageLedgerRecovery.kt | 8 +++--- .../migration/node-core.changelog-v25.xml | 3 +++ 3 files changed, 30 insertions(+), 7 deletions(-) diff --git a/.ci/api-current.txt b/.ci/api-current.txt index 200b056099..1233ed1555 100644 --- a/.ci/api-current.txt +++ b/.ci/api-current.txt @@ -2542,14 +2542,36 @@ public class net.corda.core.flows.DataVendingFlow extends net.corda.core.flows.F public interface net.corda.core.flows.Destination ## @CordaSerializable -public final class net.corda.core.flows.DistributionList extends java.lang.Object +public abstract class net.corda.core.flows.DistributionList extends java.lang.Object + public (kotlin.jvm.internal.DefaultConstructorMarker) +## +@CordaSerializable +public static final class net.corda.core.flows.DistributionList$ReceiverDistributionList extends net.corda.core.flows.DistributionList + public (byte[], net.corda.core.node.StatesToRecord) + @NotNull + public final byte[] component1() + @NotNull + public final net.corda.core.node.StatesToRecord component2() + @NotNull + public final net.corda.core.flows.DistributionList$ReceiverDistributionList copy(byte[], net.corda.core.node.StatesToRecord) + public boolean equals(Object) + @NotNull + public final byte[] getOpaqueData() + @NotNull + public final net.corda.core.node.StatesToRecord getReceiverStatesToRecord() + public int hashCode() + @NotNull + public String toString() +## +@CordaSerializable +public static final class net.corda.core.flows.DistributionList$SenderDistributionList extends net.corda.core.flows.DistributionList public (net.corda.core.node.StatesToRecord, java.util.Map) @NotNull public final net.corda.core.node.StatesToRecord component1() @NotNull public final java.util.Map component2() @NotNull - public final net.corda.core.flows.DistributionList copy(net.corda.core.node.StatesToRecord, java.util.Map) + public final net.corda.core.flows.DistributionList$SenderDistributionList copy(net.corda.core.node.StatesToRecord, java.util.Map) public boolean equals(Object) @NotNull public final java.util.Map getPeersToStatesToRecord() diff --git a/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorageLedgerRecovery.kt b/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorageLedgerRecovery.kt index ad234a6500..8f92995e56 100644 --- a/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorageLedgerRecovery.kt +++ b/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorageLedgerRecovery.kt @@ -142,15 +142,13 @@ class DBTransactionStorageLedgerRecovery(private val database: CordaPersistence, val senderRecordingTimestamp = clock.instant() val timeDiscriminator = Key.nextDiscriminatorNumber.andIncrement val distributionList = metadata.distributionList as? SenderDistributionList ?: throw IllegalStateException("Expecting SenderDistributionList") - for (peer in distributionList.peersToStatesToRecord.keys) { + distributionList.peersToStatesToRecord.map { (peerCordaX500Name, peerStatesToRecord) -> val senderDistributionRecord = DBSenderDistributionRecord( - PersistentKey(Key(TimestampKey(senderRecordingTimestamp, timeDiscriminator), partyInfoCache.getPartyIdByCordaX500Name(peer))), + PersistentKey(Key(TimestampKey(senderRecordingTimestamp, timeDiscriminator), partyInfoCache.getPartyIdByCordaX500Name(peerCordaX500Name))), txId.toString(), - distributionList.senderStatesToRecord - ) + peerStatesToRecord) session.save(senderDistributionRecord) } - val hashedPeersToStatesToRecord = distributionList.peersToStatesToRecord.mapKeys { (peer) -> partyInfoCache.getPartyIdByCordaX500Name(peer) } diff --git a/node/src/main/resources/migration/node-core.changelog-v25.xml b/node/src/main/resources/migration/node-core.changelog-v25.xml index a199a65df8..9ea40bada9 100644 --- a/node/src/main/resources/migration/node-core.changelog-v25.xml +++ b/node/src/main/resources/migration/node-core.changelog-v25.xml @@ -52,6 +52,9 @@ + + + From 4a6e99556bbe08061330df897321391ddde07235 Mon Sep 17 00:00:00 2001 From: Jose Coll Date: Fri, 18 Aug 2023 17:22:42 +0100 Subject: [PATCH 06/10] Incorporating PR review feedback. --- .../coretests/flows/FinalityFlowTests.kt | 2 +- .../core/flows/ReceiveTransactionFlow.kt | 3 +- .../core/internal/ServiceHubCoreInternal.kt | 8 +-- .../node/services/api/ServiceHubInternal.kt | 12 ++-- .../persistence/DBTransactionStorage.kt | 5 +- .../DBTransactionStorageLedgerRecovery.kt | 37 +++++++----- .../node/messaging/TwoPartyTradeFlowTests.kt | 6 +- ...DBTransactionStorageLedgerRecoveryTests.kt | 59 ++++++++++++------- .../test/flows/CashIssueWithObserversFlow.kt | 4 +- .../node/internal/MockTransactionStorage.kt | 5 +- .../kotlin/net/corda/testing/dsl/TestDSL.kt | 2 +- 11 files changed, 77 insertions(+), 66 deletions(-) diff --git a/core-tests/src/test/kotlin/net/corda/coretests/flows/FinalityFlowTests.kt b/core-tests/src/test/kotlin/net/corda/coretests/flows/FinalityFlowTests.kt index 9d82164ccc..f0e8ad93b9 100644 --- a/core-tests/src/test/kotlin/net/corda/coretests/flows/FinalityFlowTests.kt +++ b/core-tests/src/test/kotlin/net/corda/coretests/flows/FinalityFlowTests.kt @@ -358,7 +358,7 @@ class FinalityFlowTests : WithFinality { assertEquals(BOB_NAME.hashCode().toLong(), this[0].peerPartyId) } getReceiverRecoveryData(stx.id, bobNode, aliceNode).apply { - assertEquals(StatesToRecord.ALL_VISIBLE, this?.statesToRecord) + assertEquals(StatesToRecord.ONLY_RELEVANT, this?.statesToRecord) assertEquals(aliceNode.info.singleIdentity().name.hashCode().toLong(), this?.initiatorPartyId) assertEquals(mapOf(BOB_NAME.hashCode().toLong() to StatesToRecord.ALL_VISIBLE), this?.peersToStatesToRecord) } diff --git a/core/src/main/kotlin/net/corda/core/flows/ReceiveTransactionFlow.kt b/core/src/main/kotlin/net/corda/core/flows/ReceiveTransactionFlow.kt index a4f2defa3a..6c1431748d 100644 --- a/core/src/main/kotlin/net/corda/core/flows/ReceiveTransactionFlow.kt +++ b/core/src/main/kotlin/net/corda/core/flows/ReceiveTransactionFlow.kt @@ -86,7 +86,8 @@ open class ReceiveTransactionFlow constructor(private val otherSideSession: Flow open fun resolvePayload(payload: Any): SignedTransaction { return if (payload is SignedTransactionWithDistributionList) { if (checkSufficientSignatures || deferredAck) { - (serviceHub as ServiceHubCoreInternal).recordReceiverTransactionRecoveryMetadata(payload.stx.id, otherSideSession.counterparty.name, ourIdentity.name, statesToRecord, payload.distributionList) + (serviceHub as ServiceHubCoreInternal).recordReceiverTransactionRecoveryMetadata(payload.stx.id, otherSideSession.counterparty.name, + TransactionMetadata(otherSideSession.counterparty.name, DistributionList.ReceiverDistributionList(payload.distributionList, statesToRecord))) payload.stx } else payload.stx } else payload as SignedTransaction diff --git a/core/src/main/kotlin/net/corda/core/internal/ServiceHubCoreInternal.kt b/core/src/main/kotlin/net/corda/core/internal/ServiceHubCoreInternal.kt index 27f05c9f2f..d752eb3b15 100644 --- a/core/src/main/kotlin/net/corda/core/internal/ServiceHubCoreInternal.kt +++ b/core/src/main/kotlin/net/corda/core/internal/ServiceHubCoreInternal.kt @@ -76,15 +76,11 @@ interface ServiceHubCoreInternal : ServiceHub { * * @param txnId The SecureHash of a transaction. * @param sender The sender of the transaction. - * @param receiver The receiver of the transaction. - * @param receiverStatesToRecord The StatesToRecord value of the receiver. - * @param encryptedDistributionList encrypted distribution list (hashed peers -> StatesToRecord values) + * @param txnMetadata The recovery metadata associated with a transaction. */ fun recordReceiverTransactionRecoveryMetadata(txnId: SecureHash, sender: CordaX500Name, - receiver: CordaX500Name, - receiverStatesToRecord: StatesToRecord, - encryptedDistributionList: ByteArray) + txnMetadata: TransactionMetadata) } interface TransactionsResolver { diff --git a/node/src/main/kotlin/net/corda/node/services/api/ServiceHubInternal.kt b/node/src/main/kotlin/net/corda/node/services/api/ServiceHubInternal.kt index 962f7a0664..9ed76b15c8 100644 --- a/node/src/main/kotlin/net/corda/node/services/api/ServiceHubInternal.kt +++ b/node/src/main/kotlin/net/corda/node/services/api/ServiceHubInternal.kt @@ -198,8 +198,8 @@ interface ServiceHubInternal : ServiceHubCoreInternal { override fun recordSenderTransactionRecoveryMetadata(txnId: SecureHash, txnMetadata: TransactionMetadata) = validatedTransactions.addSenderTransactionRecoveryMetadata(txnId, txnMetadata) - override fun recordReceiverTransactionRecoveryMetadata(txnId: SecureHash, sender: CordaX500Name, receiver: CordaX500Name, receiverStatesToRecord: StatesToRecord, encryptedDistributionList: ByteArray) = - validatedTransactions.addReceiverTransactionRecoveryMetadata(txnId, sender, receiver, receiverStatesToRecord, encryptedDistributionList) + override fun recordReceiverTransactionRecoveryMetadata(txnId: SecureHash, sender: CordaX500Name, txnMetadata: TransactionMetadata) = + validatedTransactions.addReceiverTransactionRecoveryMetadata(txnId, sender, txnMetadata) @Suppress("NestedBlockDepth") @VisibleForTesting @@ -383,15 +383,11 @@ interface WritableTransactionStorage : TransactionStorage { * * @param txId The SecureHash of a transaction. * @param sender The sender of the transaction. - * @param receiver The receiver of the transaction. - * @param receiverStatesToRecord The StatesToRecord value of the receiver. - * @param encryptedDistributionList encrypted distribution list (hashed peers -> StatesToRecord values) + * @param metadata The recovery metadata associated with a transaction. */ fun addReceiverTransactionRecoveryMetadata(txId: SecureHash, sender: CordaX500Name, - receiver: CordaX500Name, - receiverStatesToRecord: StatesToRecord, - encryptedDistributionList: ByteArray) + metadata: TransactionMetadata) /** * Removes an un-notarised transaction (with a status of *MISSING_TRANSACTION_SIG*) from the data store. diff --git a/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorage.kt b/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorage.kt index 1973f9e7c1..c43834993c 100644 --- a/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorage.kt +++ b/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorage.kt @@ -219,9 +219,8 @@ open class DBTransactionStorage(private val database: CordaPersistence, cacheFac override fun addReceiverTransactionRecoveryMetadata(txId: SecureHash, sender: CordaX500Name, - receiver: CordaX500Name, - receiverStatesToRecord: StatesToRecord, - encryptedDistributionList: ByteArray) { } + metadata: TransactionMetadata + ) { } override fun finalizeTransaction(transaction: SignedTransaction) = addTransaction(transaction) { diff --git a/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorageLedgerRecovery.kt b/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorageLedgerRecovery.kt index 8f92995e56..527c0e4e9a 100644 --- a/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorageLedgerRecovery.kt +++ b/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorageLedgerRecovery.kt @@ -1,11 +1,13 @@ package net.corda.node.services.persistence import net.corda.core.crypto.SecureHash +import net.corda.core.flows.DistributionList.ReceiverDistributionList import net.corda.core.flows.DistributionList.SenderDistributionList import net.corda.core.flows.RecoveryTimeWindow import net.corda.core.flows.TransactionMetadata import net.corda.core.identity.CordaX500Name import net.corda.core.internal.NamedCacheFactory +import net.corda.core.internal.VisibleForTesting import net.corda.core.node.StatesToRecord import net.corda.core.node.services.vault.Sort import net.corda.core.serialization.CordaSerializable @@ -26,7 +28,6 @@ import javax.persistence.Id import javax.persistence.Lob import javax.persistence.Table import javax.persistence.criteria.Predicate -import kotlin.streams.toList class DBTransactionStorageLedgerRecovery(private val database: CordaPersistence, cacheFactory: NamedCacheFactory, @@ -98,7 +99,7 @@ class DBTransactionStorageLedgerRecovery(private val database: CordaPersistence, distributionList = encryptedDistributionList, receiverStatesToRecord = receiverStatesToRecord ) - + @VisibleForTesting fun toReceiverDistributionRecord(encryptionService: EncryptionService): ReceiverDistributionRecord { val hashedDL = HashedDistributionList.decrypt(this.distributionList, encryptionService) return ReceiverDistributionRecord( @@ -163,18 +164,22 @@ class DBTransactionStorageLedgerRecovery(private val database: CordaPersistence, override fun addReceiverTransactionRecoveryMetadata(txId: SecureHash, sender: CordaX500Name, - receiver: CordaX500Name, - receiverStatesToRecord: StatesToRecord, - encryptedDistributionList: ByteArray) { - val publicHeader = HashedDistributionList.PublicHeader.unauthenticatedDeserialise(encryptedDistributionList, encryptionService) - database.transaction { - val receiverDistributionRecord = DBReceiverDistributionRecord( - Key(partyInfoCache.getPartyIdByCordaX500Name(sender), publicHeader.senderRecordedTimestamp), - txId, - encryptedDistributionList, - receiverStatesToRecord - ) - session.save(receiverDistributionRecord) + metadata: TransactionMetadata) { + when (metadata.distributionList) { + is ReceiverDistributionList -> { + val distributionList = metadata.distributionList as ReceiverDistributionList + val publicHeader = HashedDistributionList.PublicHeader.unauthenticatedDeserialise(distributionList.opaqueData, encryptionService) + database.transaction { + val receiverDistributionRecord = DBReceiverDistributionRecord( + Key(partyInfoCache.getPartyIdByCordaX500Name(sender), publicHeader.senderRecordedTimestamp), + txId, + distributionList.opaqueData, + distributionList.receiverStatesToRecord + ) + session.save(receiverDistributionRecord) + } + } + else -> throw IllegalStateException("Expecting ReceiverDistributionList") } } @@ -247,7 +252,7 @@ class DBTransactionStorageLedgerRecovery(private val database: CordaPersistence, } criteriaQuery.orderBy(orderCriteria) } - session.createQuery(criteriaQuery).stream().toList() + session.createQuery(criteriaQuery).resultList } } @@ -284,7 +289,7 @@ class DBTransactionStorageLedgerRecovery(private val database: CordaPersistence, } criteriaQuery.orderBy(orderCriteria) } - session.createQuery(criteriaQuery).stream().toList() + session.createQuery(criteriaQuery).resultList } } } diff --git a/node/src/test/kotlin/net/corda/node/messaging/TwoPartyTradeFlowTests.kt b/node/src/test/kotlin/net/corda/node/messaging/TwoPartyTradeFlowTests.kt index d0f0de3d60..68cfd544d5 100644 --- a/node/src/test/kotlin/net/corda/node/messaging/TwoPartyTradeFlowTests.kt +++ b/node/src/test/kotlin/net/corda/node/messaging/TwoPartyTradeFlowTests.kt @@ -818,11 +818,9 @@ class TwoPartyTradeFlowTests(private val anonymous: Boolean) { override fun addReceiverTransactionRecoveryMetadata(txId: SecureHash, sender: CordaX500Name, - receiver: CordaX500Name, - receiverStatesToRecord: StatesToRecord, - encryptedDistributionList: ByteArray) { + metadata: TransactionMetadata) { database.transaction { - delegate.addReceiverTransactionRecoveryMetadata(txId, sender, receiver, receiverStatesToRecord, encryptedDistributionList) + delegate.addReceiverTransactionRecoveryMetadata(txId, sender, metadata) } } diff --git a/node/src/test/kotlin/net/corda/node/services/persistence/DBTransactionStorageLedgerRecoveryTests.kt b/node/src/test/kotlin/net/corda/node/services/persistence/DBTransactionStorageLedgerRecoveryTests.kt index 8f759838cb..3e32fef074 100644 --- a/node/src/test/kotlin/net/corda/node/services/persistence/DBTransactionStorageLedgerRecoveryTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/persistence/DBTransactionStorageLedgerRecoveryTests.kt @@ -6,6 +6,7 @@ import net.corda.core.crypto.SecureHash import net.corda.core.crypto.SignableData import net.corda.core.crypto.SignatureMetadata import net.corda.core.crypto.sign +import net.corda.core.flows.DistributionList.ReceiverDistributionList import net.corda.core.flows.DistributionList.SenderDistributionList import net.corda.core.flows.RecoveryTimeWindow import net.corda.core.flows.TransactionMetadata @@ -137,11 +138,13 @@ class DBTransactionStorageLedgerRecoveryTests { val transaction2 = newTransaction() // receiver txn transactionRecovery.addUnnotarisedTransaction(transaction2) - transactionRecovery.addReceiverTransactionRecoveryMetadata(transaction2.id, BOB_NAME, ALICE_NAME, ALL_VISIBLE, - SenderDistributionList(ONLY_RELEVANT, mapOf(ALICE_NAME to ALL_VISIBLE)).toWire()) + val encryptedDL = transactionRecovery.addSenderTransactionRecoveryMetadata(transaction2.id, + TransactionMetadata(BOB_NAME, SenderDistributionList(ONLY_RELEVANT, mapOf(ALICE_NAME to ALL_VISIBLE)))) + transactionRecovery.addReceiverTransactionRecoveryMetadata(transaction2.id, BOB_NAME, + TransactionMetadata(BOB_NAME, ReceiverDistributionList(encryptedDL, ALL_VISIBLE))) val timeWindow = RecoveryTimeWindow(fromTime = now().minus(1, ChronoUnit.DAYS)) transactionRecovery.queryDistributionRecords(timeWindow, recordType = DistributionRecordType.SENDER).let { - assertEquals(1, it.size) + assertEquals(2, it.size) assertEquals(BOB_NAME.hashCode().toLong(), it.senderRecords[0].compositeKey.peerPartyId) assertEquals(ALL_VISIBLE, it.senderRecords[0].statesToRecord) } @@ -151,7 +154,7 @@ class DBTransactionStorageLedgerRecoveryTests { assertEquals(ALL_VISIBLE, (HashedDistributionList.decrypt(it.receiverRecords[0].distributionList, encryptionService)).peerHashToStatesToRecord.map { it.value }[0]) } val resultsAll = transactionRecovery.queryDistributionRecords(timeWindow, recordType = DistributionRecordType.ALL) - assertEquals(2, resultsAll.size) + assertEquals(3, resultsAll.size) } @Test(timeout = 300_000) @@ -187,24 +190,34 @@ class DBTransactionStorageLedgerRecoveryTests { fun `query for receiver distribution records by initiator`() { val txn1 = newTransaction() transactionRecovery.addUnnotarisedTransaction(txn1) - transactionRecovery.addReceiverTransactionRecoveryMetadata(txn1.id, ALICE_NAME, BOB_NAME, ALL_VISIBLE, - SenderDistributionList(ONLY_RELEVANT, mapOf(BOB_NAME to ALL_VISIBLE, CHARLIE_NAME to ALL_VISIBLE)).toWire()) + val encryptedDL1 = transactionRecovery.addSenderTransactionRecoveryMetadata(txn1.id, + TransactionMetadata(ALICE_NAME, SenderDistributionList(ONLY_RELEVANT, mapOf(BOB_NAME to ALL_VISIBLE, CHARLIE_NAME to ALL_VISIBLE)))) + transactionRecovery.addReceiverTransactionRecoveryMetadata(txn1.id, ALICE_NAME, + TransactionMetadata(ALICE_NAME, ReceiverDistributionList(encryptedDL1, ALL_VISIBLE))) val txn2 = newTransaction() transactionRecovery.addUnnotarisedTransaction(txn2) - transactionRecovery.addReceiverTransactionRecoveryMetadata(txn2.id, ALICE_NAME, BOB_NAME, ONLY_RELEVANT, - SenderDistributionList(ONLY_RELEVANT, mapOf(BOB_NAME to ONLY_RELEVANT)).toWire()) + val encryptedDL2 = transactionRecovery.addSenderTransactionRecoveryMetadata(txn2.id, + TransactionMetadata(ALICE_NAME, SenderDistributionList(ONLY_RELEVANT, mapOf(BOB_NAME to ONLY_RELEVANT)))) + transactionRecovery.addReceiverTransactionRecoveryMetadata(txn2.id, ALICE_NAME, + TransactionMetadata(ALICE_NAME, ReceiverDistributionList(encryptedDL2, ONLY_RELEVANT))) val txn3 = newTransaction() transactionRecovery.addUnnotarisedTransaction(txn3) - transactionRecovery.addReceiverTransactionRecoveryMetadata(txn3.id, ALICE_NAME, CHARLIE_NAME, NONE, - SenderDistributionList(ONLY_RELEVANT, mapOf(CHARLIE_NAME to NONE)).toWire()) + val encryptedDL3 = transactionRecovery.addSenderTransactionRecoveryMetadata(txn3.id, + TransactionMetadata(ALICE_NAME, SenderDistributionList(ONLY_RELEVANT, mapOf(CHARLIE_NAME to NONE)))) + transactionRecovery.addReceiverTransactionRecoveryMetadata(txn3.id, ALICE_NAME, + TransactionMetadata(ALICE_NAME, ReceiverDistributionList(encryptedDL3, NONE))) val txn4 = newTransaction() transactionRecovery.addUnnotarisedTransaction(txn4) - transactionRecovery.addReceiverTransactionRecoveryMetadata(txn4.id, BOB_NAME, ALICE_NAME, ONLY_RELEVANT, - SenderDistributionList(ONLY_RELEVANT, mapOf(ALICE_NAME to ALL_VISIBLE)).toWire()) + val encryptedDL4 = transactionRecovery.addSenderTransactionRecoveryMetadata(txn4.id, + TransactionMetadata(BOB_NAME, SenderDistributionList(ONLY_RELEVANT, mapOf(ALICE_NAME to ALL_VISIBLE)))) + transactionRecovery.addReceiverTransactionRecoveryMetadata(txn4.id, BOB_NAME, + TransactionMetadata(BOB_NAME, ReceiverDistributionList(encryptedDL4, ALL_VISIBLE))) val txn5 = newTransaction() transactionRecovery.addUnnotarisedTransaction(txn5) - transactionRecovery.addReceiverTransactionRecoveryMetadata(txn5.id, CHARLIE_NAME, BOB_NAME, ONLY_RELEVANT, - SenderDistributionList(ONLY_RELEVANT, mapOf(BOB_NAME to ONLY_RELEVANT)).toWire()) + val encryptedDL5 = transactionRecovery.addSenderTransactionRecoveryMetadata(txn5.id, + TransactionMetadata(CHARLIE_NAME, SenderDistributionList(ONLY_RELEVANT, mapOf(BOB_NAME to ONLY_RELEVANT)))) + transactionRecovery.addReceiverTransactionRecoveryMetadata(txn5.id, CHARLIE_NAME, + TransactionMetadata(CHARLIE_NAME, ReceiverDistributionList(encryptedDL5, ONLY_RELEVANT))) val timeWindow = RecoveryTimeWindow(fromTime = now().minus(1, ChronoUnit.DAYS)) transactionRecovery.queryReceiverDistributionRecords(timeWindow, initiators = setOf(ALICE_NAME)).let { @@ -242,8 +255,10 @@ class DBTransactionStorageLedgerRecoveryTests { val receiverTransaction = newTransaction() transactionRecovery.addUnnotarisedTransaction(receiverTransaction) - transactionRecovery.addReceiverTransactionRecoveryMetadata(receiverTransaction.id, ALICE_NAME, BOB_NAME, ALL_VISIBLE, - SenderDistributionList(ONLY_RELEVANT, mapOf(BOB_NAME to ALL_VISIBLE)).toWire()) + val encryptedDL = transactionRecovery.addSenderTransactionRecoveryMetadata(receiverTransaction.id, + TransactionMetadata(ALICE_NAME, SenderDistributionList(ONLY_RELEVANT, mapOf(BOB_NAME to ALL_VISIBLE)))) + transactionRecovery.addReceiverTransactionRecoveryMetadata(receiverTransaction.id, ALICE_NAME, + TransactionMetadata(ALICE_NAME, ReceiverDistributionList(encryptedDL, ALL_VISIBLE))) assertEquals(IN_FLIGHT, readTransactionFromDB(receiverTransaction.id).status) readReceiverDistributionRecordFromDB(receiverTransaction.id).let { assertEquals(ONLY_RELEVANT, it.statesToRecord) @@ -270,8 +285,10 @@ class DBTransactionStorageLedgerRecoveryTests { fun `remove un-notarised transaction and associated recovery metadata`() { val senderTransaction = newTransaction(notarySig = false) transactionRecovery.addUnnotarisedTransaction(senderTransaction) - transactionRecovery.addReceiverTransactionRecoveryMetadata(senderTransaction.id, ALICE.name, BOB.name, ONLY_RELEVANT, - SenderDistributionList(ONLY_RELEVANT, mapOf(BOB.name to ONLY_RELEVANT, CHARLIE_NAME to ONLY_RELEVANT)).toWire()) + val encryptedDL1 = transactionRecovery.addSenderTransactionRecoveryMetadata(senderTransaction.id, + TransactionMetadata(ALICE.name, SenderDistributionList(ONLY_RELEVANT, mapOf(BOB.name to ONLY_RELEVANT, CHARLIE_NAME to ONLY_RELEVANT)))) + transactionRecovery.addReceiverTransactionRecoveryMetadata(senderTransaction.id, BOB.name, + TransactionMetadata(ALICE.name, ReceiverDistributionList(encryptedDL1, ONLY_RELEVANT))) assertNull(transactionRecovery.getTransaction(senderTransaction.id)) assertEquals(IN_FLIGHT, readTransactionFromDB(senderTransaction.id).status) @@ -282,8 +299,10 @@ class DBTransactionStorageLedgerRecoveryTests { val receiverTransaction = newTransaction(notarySig = false) transactionRecovery.addUnnotarisedTransaction(receiverTransaction) - transactionRecovery.addReceiverTransactionRecoveryMetadata(receiverTransaction.id, ALICE.name, BOB.name, ONLY_RELEVANT, - SenderDistributionList(ONLY_RELEVANT, mapOf(BOB.name to ONLY_RELEVANT)).toWire()) + val encryptedDL2 = transactionRecovery.addSenderTransactionRecoveryMetadata(receiverTransaction.id, + TransactionMetadata(ALICE.name, SenderDistributionList(ONLY_RELEVANT, mapOf(BOB.name to ONLY_RELEVANT)))) + transactionRecovery.addReceiverTransactionRecoveryMetadata(receiverTransaction.id, BOB.name, + TransactionMetadata(ALICE.name, ReceiverDistributionList(encryptedDL2, ONLY_RELEVANT))) assertNull(transactionRecovery.getTransaction(receiverTransaction.id)) assertEquals(IN_FLIGHT, readTransactionFromDB(receiverTransaction.id).status) diff --git a/testing/cordapps/cashobservers/src/main/kotlin/net/corda/finance/test/flows/CashIssueWithObserversFlow.kt b/testing/cordapps/cashobservers/src/main/kotlin/net/corda/finance/test/flows/CashIssueWithObserversFlow.kt index c860617078..282b49c3cb 100644 --- a/testing/cordapps/cashobservers/src/main/kotlin/net/corda/finance/test/flows/CashIssueWithObserversFlow.kt +++ b/testing/cordapps/cashobservers/src/main/kotlin/net/corda/finance/test/flows/CashIssueWithObserversFlow.kt @@ -42,9 +42,9 @@ class CashIssueWithObserversFlow(private val amount: Amount, } @Suspendable - private fun finalise(tx: SignedTransaction, sessions: Collection, message: String): SignedTransaction { + private fun finalise(tx: SignedTransaction, observerSessions: Collection, message: String): SignedTransaction { try { - return subFlow(FinalityFlow(tx, sessions)) + return subFlow(FinalityFlow(tx, sessions = emptySet(), observerSessions = observerSessions)) } catch (e: NotaryException) { throw CashException(message, e) } diff --git a/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/MockTransactionStorage.kt b/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/MockTransactionStorage.kt index 9f23bf6beb..0a52c09f56 100644 --- a/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/MockTransactionStorage.kt +++ b/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/MockTransactionStorage.kt @@ -12,7 +12,6 @@ import net.corda.node.services.api.WritableTransactionStorage import net.corda.core.flows.TransactionMetadata import net.corda.core.flows.TransactionStatus import net.corda.core.identity.CordaX500Name -import net.corda.core.node.StatesToRecord import net.corda.testing.node.MockServices import rx.Observable import rx.subjects.PublishSubject @@ -65,9 +64,7 @@ open class MockTransactionStorage : WritableTransactionStorage, SingletonSeriali override fun addReceiverTransactionRecoveryMetadata(txId: SecureHash, sender: CordaX500Name, - receiver: CordaX500Name, - receiverStatesToRecord: StatesToRecord, - encryptedDistributionList: ByteArray) { } + metadata: TransactionMetadata) { } override fun removeUnnotarisedTransaction(id: SecureHash): Boolean { return txns.remove(id) != null diff --git a/testing/test-utils/src/main/kotlin/net/corda/testing/dsl/TestDSL.kt b/testing/test-utils/src/main/kotlin/net/corda/testing/dsl/TestDSL.kt index fd55d3645d..b460d02f30 100644 --- a/testing/test-utils/src/main/kotlin/net/corda/testing/dsl/TestDSL.kt +++ b/testing/test-utils/src/main/kotlin/net/corda/testing/dsl/TestDSL.kt @@ -150,7 +150,7 @@ data class TestTransactionDSLInterpreter private constructor( override fun recordSenderTransactionRecoveryMetadata(txnId: SecureHash, txnMetadata: TransactionMetadata): ByteArray? { return null } - override fun recordReceiverTransactionRecoveryMetadata(txnId: SecureHash, sender: CordaX500Name, receiver: CordaX500Name, receiverStatesToRecord: StatesToRecord, encryptedDistributionList: ByteArray) {} + override fun recordReceiverTransactionRecoveryMetadata(txnId: SecureHash, sender: CordaX500Name, txnMetadata: TransactionMetadata) {} } private fun copy(): TestTransactionDSLInterpreter = From 6a7e9000a4fac043f54bc2e2a8195bb4bc16368f Mon Sep 17 00:00:00 2001 From: Jose Coll Date: Fri, 18 Aug 2023 17:26:22 +0100 Subject: [PATCH 07/10] Detekt --- .../net/corda/node/services/persistence/DBTransactionStorage.kt | 1 - .../kotlin/net/corda/node/messaging/TwoPartyTradeFlowTests.kt | 1 - 2 files changed, 2 deletions(-) diff --git a/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorage.kt b/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorage.kt index c43834993c..6905d6f7c1 100644 --- a/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorage.kt +++ b/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorage.kt @@ -11,7 +11,6 @@ import net.corda.core.internal.VisibleForTesting import net.corda.core.internal.bufferUntilSubscribed import net.corda.core.internal.concurrent.doneFuture import net.corda.core.messaging.DataFeed -import net.corda.core.node.StatesToRecord import net.corda.core.serialization.SerializationContext import net.corda.core.serialization.SerializationDefaults import net.corda.core.serialization.SerializedBytes diff --git a/node/src/test/kotlin/net/corda/node/messaging/TwoPartyTradeFlowTests.kt b/node/src/test/kotlin/net/corda/node/messaging/TwoPartyTradeFlowTests.kt index 68cfd544d5..b53d9daeee 100644 --- a/node/src/test/kotlin/net/corda/node/messaging/TwoPartyTradeFlowTests.kt +++ b/node/src/test/kotlin/net/corda/node/messaging/TwoPartyTradeFlowTests.kt @@ -31,7 +31,6 @@ import net.corda.core.internal.concurrent.map import net.corda.core.internal.rootCause import net.corda.core.messaging.DataFeed import net.corda.core.messaging.StateMachineTransactionMapping -import net.corda.core.node.StatesToRecord import net.corda.core.node.services.Vault import net.corda.core.serialization.CordaSerializable import net.corda.core.serialization.SingletonSerializeAsToken From 3f067d90020d6a49b355534afbac64be49a28cfd Mon Sep 17 00:00:00 2001 From: Jose Coll Date: Mon, 21 Aug 2023 12:28:23 +0100 Subject: [PATCH 08/10] Check sender and receiver timestamps are same. --- .../coretests/flows/FinalityFlowTests.kt | 36 ++++++++++++++----- 1 file changed, 27 insertions(+), 9 deletions(-) diff --git a/core-tests/src/test/kotlin/net/corda/coretests/flows/FinalityFlowTests.kt b/core-tests/src/test/kotlin/net/corda/coretests/flows/FinalityFlowTests.kt index f0e8ad93b9..013fa07d53 100644 --- a/core-tests/src/test/kotlin/net/corda/coretests/flows/FinalityFlowTests.kt +++ b/core-tests/src/test/kotlin/net/corda/coretests/flows/FinalityFlowTests.kt @@ -352,16 +352,17 @@ class FinalityFlowTests : WithFinality { assertThat(aliceNode.services.validatedTransactions.getTransaction(stx.id)).isNotNull assertThat(bobNode.services.validatedTransactions.getTransaction(stx.id)).isNotNull - getSenderRecoveryData(stx.id, aliceNode.database).apply { + val sdrs = getSenderRecoveryData(stx.id, aliceNode.database).apply { assertEquals(1, this.size) assertEquals(StatesToRecord.ALL_VISIBLE, this[0].statesToRecord) assertEquals(BOB_NAME.hashCode().toLong(), this[0].peerPartyId) } - getReceiverRecoveryData(stx.id, bobNode, aliceNode).apply { + val rdr = getReceiverRecoveryData(stx.id, bobNode, aliceNode).apply { assertEquals(StatesToRecord.ONLY_RELEVANT, this?.statesToRecord) assertEquals(aliceNode.info.singleIdentity().name.hashCode().toLong(), this?.initiatorPartyId) assertEquals(mapOf(BOB_NAME.hashCode().toLong() to StatesToRecord.ALL_VISIBLE), this?.peersToStatesToRecord) } + validateSenderAndReceiverTimestamps(sdrs, rdr!!) } @Test(timeout=300_000) @@ -382,20 +383,21 @@ class FinalityFlowTests : WithFinality { assertThat(bobNode.services.validatedTransactions.getTransaction(stx.id)).isNotNull assertThat(charlieNode.services.validatedTransactions.getTransaction(stx.id)).isNotNull - getSenderRecoveryData(stx.id, aliceNode.database).apply { + val sdrs = getSenderRecoveryData(stx.id, aliceNode.database).apply { assertEquals(2, this.size) assertEquals(StatesToRecord.ONLY_RELEVANT, this[0].statesToRecord) assertEquals(BOB_NAME.hashCode().toLong(), this[0].peerPartyId) assertEquals(StatesToRecord.ALL_VISIBLE, this[1].statesToRecord) assertEquals(CHARLIE_NAME.hashCode().toLong(), this[1].peerPartyId) } - getReceiverRecoveryData(stx.id, bobNode, aliceNode).apply { + val rdr = getReceiverRecoveryData(stx.id, bobNode, aliceNode).apply { assertEquals(StatesToRecord.ONLY_RELEVANT, this?.statesToRecord) assertEquals(aliceNode.info.singleIdentity().name.hashCode().toLong(), this?.initiatorPartyId) // note: Charlie assertion here is using the hinted StatesToRecord value passed to it from Alice assertEquals(mapOf(BOB_NAME.hashCode().toLong() to StatesToRecord.ONLY_RELEVANT, CHARLIE_NAME.hashCode().toLong() to StatesToRecord.ALL_VISIBLE), this?.peersToStatesToRecord) } + validateSenderAndReceiverTimestamps(sdrs, rdr!!) // exercise the new FinalityFlow observerSessions constructor parameter val stx3 = aliceNode.startFlowAndRunNetwork(CashPaymentWithObserversFlow( @@ -408,9 +410,24 @@ class FinalityFlowTests : WithFinality { assertThat(bobNode.services.validatedTransactions.getTransaction(stx3.id)).isNotNull assertThat(charlieNode.services.validatedTransactions.getTransaction(stx3.id)).isNotNull - assertEquals(2, getSenderRecoveryData(stx3.id, aliceNode.database).size) - assertThat(getReceiverRecoveryData(stx3.id, bobNode, aliceNode)).isNotNull - assertThat(getReceiverRecoveryData(stx3.id, charlieNode, aliceNode)).isNotNull + val senderDistributionRecords = getSenderRecoveryData(stx3.id, aliceNode.database).apply { + assertEquals(2, this.size) + assertEquals(this[0].timestamp, this[1].timestamp) + } + getReceiverRecoveryData(stx3.id, bobNode, aliceNode).apply { + assertThat(getReceiverRecoveryData(stx3.id, bobNode, aliceNode)).isNotNull + assertEquals(senderDistributionRecords[0].timestamp, this!!.timestamp) + } + getReceiverRecoveryData(stx3.id, charlieNode, aliceNode).apply { + assertThat(getReceiverRecoveryData(stx3.id, charlieNode, aliceNode)).isNotNull + assertEquals(senderDistributionRecords[0].timestamp, this!!.timestamp) + } + } + + private fun validateSenderAndReceiverTimestamps(sdrs: List, rdr: ReceiverDistributionRecord) { + sdrs.map { + assertEquals(it.timestamp, rdr.timestamp) + } } @Test(timeout=300_000) @@ -426,16 +443,17 @@ class FinalityFlowTests : WithFinality { assertThat(aliceNode.services.validatedTransactions.getTransaction(stx.id)).isNotNull assertThat(bobNode.services.validatedTransactions.getTransaction(stx.id)).isNotNull - getSenderRecoveryData(stx.id, aliceNode.database).apply { + val sdr = getSenderRecoveryData(stx.id, aliceNode.database).apply { assertEquals(1, this.size) assertEquals(StatesToRecord.ONLY_RELEVANT, this[0].statesToRecord) assertEquals(BOB_NAME.hashCode().toLong(), this[0].peerPartyId) } - getReceiverRecoveryData(stx.id, bobNode, aliceNode).apply { + val rdr = getReceiverRecoveryData(stx.id, bobNode, aliceNode).apply { assertEquals(StatesToRecord.ONLY_RELEVANT, this?.statesToRecord) assertEquals(aliceNode.info.singleIdentity().name.hashCode().toLong(), this?.initiatorPartyId) assertEquals(mapOf(BOB_NAME.hashCode().toLong() to StatesToRecord.ONLY_RELEVANT), this?.peersToStatesToRecord) } + validateSenderAndReceiverTimestamps(sdr, rdr!!) } private fun getSenderRecoveryData(id: SecureHash, database: CordaPersistence): List { From 4fef01a5b0f11770afa86efbf04e3e4a2e945fed Mon Sep 17 00:00:00 2001 From: Jose Coll Date: Tue, 22 Aug 2023 16:03:13 +0100 Subject: [PATCH 09/10] Clean-up. --- .../coretests/flows/FinalityFlowTests.kt | 58 +++++++++---------- .../DBTransactionStorageLedgerRecovery.kt | 3 +- ...DBTransactionStorageLedgerRecoveryTests.kt | 20 ++----- 3 files changed, 33 insertions(+), 48 deletions(-) diff --git a/core-tests/src/test/kotlin/net/corda/coretests/flows/FinalityFlowTests.kt b/core-tests/src/test/kotlin/net/corda/coretests/flows/FinalityFlowTests.kt index 644ea2013a..da7a128e6f 100644 --- a/core-tests/src/test/kotlin/net/corda/coretests/flows/FinalityFlowTests.kt +++ b/core-tests/src/test/kotlin/net/corda/coretests/flows/FinalityFlowTests.kt @@ -51,8 +51,6 @@ import net.corda.finance.test.flows.CashIssueWithObserversFlow import net.corda.finance.test.flows.CashPaymentWithObserversFlow import net.corda.node.services.persistence.DBTransactionStorage import net.corda.node.services.persistence.DBTransactionStorageLedgerRecovery -import net.corda.node.services.persistence.DBTransactionStorageLedgerRecovery.DBReceiverDistributionRecord -import net.corda.node.services.persistence.DBTransactionStorageLedgerRecovery.DBSenderDistributionRecord import net.corda.node.services.persistence.HashedDistributionList import net.corda.node.services.persistence.ReceiverDistributionRecord import net.corda.node.services.persistence.SenderDistributionRecord @@ -76,8 +74,8 @@ import net.corda.testing.node.internal.enclosedCordapp import net.corda.testing.node.internal.findCordapp import org.assertj.core.api.Assertions.assertThat import org.junit.After +import org.junit.Assert.assertNotNull import org.junit.Test -import org.junit.jupiter.api.assertThrows import java.sql.SQLException import java.util.Random import kotlin.test.assertEquals @@ -358,10 +356,12 @@ class FinalityFlowTests : WithFinality { assertEquals(StatesToRecord.ALL_VISIBLE, this[0].statesToRecord) assertEquals(BOB_NAME.hashCode().toLong(), this[0].peerPartyId) } - val rdr = getReceiverRecoveryData(stx.id, bobNode, aliceNode).apply { - assertEquals(StatesToRecord.ONLY_RELEVANT, this?.statesToRecord) - assertEquals(aliceNode.info.singleIdentity().name.hashCode().toLong(), this?.initiatorPartyId) - assertEquals(mapOf(BOB_NAME.hashCode().toLong() to StatesToRecord.ALL_VISIBLE), this?.peersToStatesToRecord) + val rdr = getReceiverRecoveryData(stx.id, bobNode).apply { + assertNotNull(this) + val hashedDL = HashedDistributionList.decrypt(this!!.encryptedDistributionList.bytes, aliceNode.internals.encryptionService) + assertEquals(StatesToRecord.ONLY_RELEVANT, hashedDL.senderStatesToRecord) + assertEquals(aliceNode.info.singleIdentity().name.hashCode().toLong(), this.initiatorPartyId) + assertEquals(mapOf(BOB_NAME.hashCode().toLong() to StatesToRecord.ALL_VISIBLE), hashedDL.peerHashToStatesToRecord) } validateSenderAndReceiverTimestamps(sdrs, rdr!!) } @@ -391,14 +391,16 @@ class FinalityFlowTests : WithFinality { assertEquals(StatesToRecord.ALL_VISIBLE, this[1].statesToRecord) assertEquals(CHARLIE_NAME.hashCode().toLong(), this[1].peerPartyId) } - val rdr = getReceiverRecoveryData(stx.id, bobNode, aliceNode).apply { - assertEquals(StatesToRecord.ONLY_RELEVANT, this?.statesToRecord) - assertEquals(aliceNode.info.singleIdentity().name.hashCode().toLong(), this?.initiatorPartyId) + val rdr = getReceiverRecoveryData(stx.id, bobNode).apply { + assertNotNull(this) + val hashedDL = HashedDistributionList.decrypt(this!!.encryptedDistributionList.bytes, aliceNode.internals.encryptionService) + assertEquals(StatesToRecord.ONLY_RELEVANT, hashedDL.senderStatesToRecord) + assertEquals(aliceNode.info.singleIdentity().name.hashCode().toLong(), this.initiatorPartyId) // note: Charlie assertion here is using the hinted StatesToRecord value passed to it from Alice assertEquals(mapOf( BOB_NAME.hashCode().toLong() to StatesToRecord.ONLY_RELEVANT, CHARLIE_NAME.hashCode().toLong() to StatesToRecord.ALL_VISIBLE - ), distList.peerHashToStatesToRecord) + ), hashedDL.peerHashToStatesToRecord) } validateSenderAndReceiverTimestamps(sdrs, rdr!!) @@ -417,12 +419,12 @@ class FinalityFlowTests : WithFinality { assertEquals(2, this.size) assertEquals(this[0].timestamp, this[1].timestamp) } - getReceiverRecoveryData(stx3.id, bobNode, aliceNode).apply { - assertThat(getReceiverRecoveryData(stx3.id, bobNode, aliceNode)).isNotNull + getReceiverRecoveryData(stx3.id, bobNode).apply { + assertThat(this).isNotNull assertEquals(senderDistributionRecords[0].timestamp, this!!.timestamp) } - getReceiverRecoveryData(stx3.id, charlieNode, aliceNode).apply { - assertThat(getReceiverRecoveryData(stx3.id, charlieNode, aliceNode)).isNotNull + getReceiverRecoveryData(stx3.id, charlieNode).apply { + assertThat(this).isNotNull assertEquals(senderDistributionRecords[0].timestamp, this!!.timestamp) } } @@ -451,10 +453,12 @@ class FinalityFlowTests : WithFinality { assertEquals(StatesToRecord.ONLY_RELEVANT, this[0].statesToRecord) assertEquals(BOB_NAME.hashCode().toLong(), this[0].peerPartyId) } - val rdr = getReceiverRecoveryData(stx.id, bobNode, aliceNode).apply { - assertEquals(StatesToRecord.ONLY_RELEVANT, this?.statesToRecord) - assertEquals(aliceNode.info.singleIdentity().name.hashCode().toLong(), this?.initiatorPartyId) - assertEquals(mapOf(BOB_NAME.hashCode().toLong() to StatesToRecord.ONLY_RELEVANT), this?.peersToStatesToRecord) + val rdr = getReceiverRecoveryData(stx.id, bobNode).apply { + assertNotNull(this) + val hashedDL = HashedDistributionList.decrypt(this!!.encryptedDistributionList.bytes, aliceNode.internals.encryptionService) + assertEquals(StatesToRecord.ONLY_RELEVANT, hashedDL.senderStatesToRecord) + assertEquals(aliceNode.info.singleIdentity().name.hashCode().toLong(), this.initiatorPartyId) + assertEquals(mapOf(BOB_NAME.hashCode().toLong() to StatesToRecord.ONLY_RELEVANT), hashedDL.peerHashToStatesToRecord) } validateSenderAndReceiverTimestamps(sdr, rdr!!) } @@ -466,24 +470,16 @@ class FinalityFlowTests : WithFinality { DBTransactionStorageLedgerRecovery.DBSenderDistributionRecord::class.java ).setParameter("transactionId", id.toString()).resultList } - return fromDb.map { it.toSenderDistributionRecord() }.also { println("SenderDistributionRecord\n$it") } + return fromDb.map { it.toSenderDistributionRecord() } } - private fun getReceiverRecoveryData(txId: SecureHash, receiver: TestStartedNode, sender: TestStartedNode): ReceiverDistributionRecord? { - val fromDb = receiver.database.transaction { + private fun getReceiverRecoveryData(txId: SecureHash, receiver: TestStartedNode): ReceiverDistributionRecord? { + return receiver.database.transaction { session.createQuery( "from ${DBTransactionStorageLedgerRecovery.DBReceiverDistributionRecord::class.java.name} where txId = :transactionId", DBTransactionStorageLedgerRecovery.DBReceiverDistributionRecord::class.java ).setParameter("transactionId", txId.toString()).resultList - }.singleOrNull() - - // The receiver should not be able to decrypt the distribution list - assertThrows { - fromDb?.toReceiverDistributionRecord(receiver.internals.encryptionService) - } - - // Only the sender can - return fromDb?.toReceiverDistributionRecord(sender.internals.encryptionService) + }.singleOrNull()?.toReceiverDistributionRecord() } @StartableByRPC diff --git a/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorageLedgerRecovery.kt b/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorageLedgerRecovery.kt index db1c176f0a..69df1fe9b3 100644 --- a/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorageLedgerRecovery.kt +++ b/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorageLedgerRecovery.kt @@ -101,8 +101,7 @@ class DBTransactionStorageLedgerRecovery(private val database: CordaPersistence, receiverStatesToRecord = receiverStatesToRecord ) @VisibleForTesting - fun toReceiverDistributionRecord(encryptionService: EncryptionService): ReceiverDistributionRecord { - val hashedDL = HashedDistributionList.decrypt(this.distributionList, encryptionService) + fun toReceiverDistributionRecord(): ReceiverDistributionRecord { return ReceiverDistributionRecord( SecureHash.parse(this.txId), this.compositeKey.peerPartyId, diff --git a/node/src/test/kotlin/net/corda/node/services/persistence/DBTransactionStorageLedgerRecoveryTests.kt b/node/src/test/kotlin/net/corda/node/services/persistence/DBTransactionStorageLedgerRecoveryTests.kt index 1ffd27e736..3c81be2ab8 100644 --- a/node/src/test/kotlin/net/corda/node/services/persistence/DBTransactionStorageLedgerRecoveryTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/persistence/DBTransactionStorageLedgerRecoveryTests.kt @@ -331,7 +331,7 @@ class DBTransactionStorageLedgerRecoveryTests { session.createQuery( "from ${DBTransactionStorage.DBTransaction::class.java.name} where txId = :transactionId", DBTransactionStorage.DBTransaction::class.java - ).setParameter("transactionId", id.toString()).resultList + ).setParameter("transactionId", txId.toString()).resultList } assertEquals(1, fromDb.size) return fromDb[0] @@ -355,12 +355,12 @@ class DBTransactionStorageLedgerRecoveryTests { private fun readReceiverDistributionRecordFromDB(txId: SecureHash): ReceiverDistributionRecord { val fromDb = database.transaction { session.createQuery( - "from ${DBTransactionStorageLedgerRecovery.DBReceiverDistributionRecord::class.java.name} where txId = :transactionId", - DBTransactionStorageLedgerRecovery.DBReceiverDistributionRecord::class.java - ).setParameter("transactionId", id.toString()).resultList + "from ${DBReceiverDistributionRecord::class.java.name} where txId = :transactionId", + DBReceiverDistributionRecord::class.java + ).setParameter("transactionId", txId.toString()).resultList } assertEquals(1, fromDb.size) - return fromDb[0].toReceiverDistributionRecord(encryptionService) + return fromDb[0].toReceiverDistributionRecord() } private fun newTransactionRecovery(cacheSizeBytesOverride: Long? = null, clock: CordaClock = SimpleClock(Clock.systemUTC())) { @@ -415,15 +415,5 @@ class DBTransactionStorageLedgerRecoveryTests { private fun notarySig(txId: SecureHash) = DUMMY_NOTARY.keyPair.sign(SignableData(txId, SignatureMetadata(1, Crypto.findSignatureScheme(DUMMY_NOTARY.publicKey).schemeNumberID))) - - private fun SenderDistributionList.toWire(): ByteArray { - val hashedPeersToStatesToRecord = this.peersToStatesToRecord.mapKeys { (peer) -> partyInfoCache.getPartyIdByCordaX500Name(peer) } - val hashedDistributionList = HashedDistributionList( - this.senderStatesToRecord, - hashedPeersToStatesToRecord, - HashedDistributionList.PublicHeader(now()) - ) - return hashedDistributionList.encrypt(encryptionService) - } } From 1aaff8e6ae363c716270ce877efe164e34541e1e Mon Sep 17 00:00:00 2001 From: Jose Coll Date: Tue, 22 Aug 2023 16:08:00 +0100 Subject: [PATCH 10/10] Clean-up. --- .../net/corda/coretests/flows/FinalityFlowTests.kt | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/core-tests/src/test/kotlin/net/corda/coretests/flows/FinalityFlowTests.kt b/core-tests/src/test/kotlin/net/corda/coretests/flows/FinalityFlowTests.kt index da7a128e6f..8e01c60505 100644 --- a/core-tests/src/test/kotlin/net/corda/coretests/flows/FinalityFlowTests.kt +++ b/core-tests/src/test/kotlin/net/corda/coretests/flows/FinalityFlowTests.kt @@ -50,7 +50,8 @@ import net.corda.finance.issuedBy import net.corda.finance.test.flows.CashIssueWithObserversFlow import net.corda.finance.test.flows.CashPaymentWithObserversFlow import net.corda.node.services.persistence.DBTransactionStorage -import net.corda.node.services.persistence.DBTransactionStorageLedgerRecovery +import net.corda.node.services.persistence.DBTransactionStorageLedgerRecovery.DBReceiverDistributionRecord +import net.corda.node.services.persistence.DBTransactionStorageLedgerRecovery.DBSenderDistributionRecord import net.corda.node.services.persistence.HashedDistributionList import net.corda.node.services.persistence.ReceiverDistributionRecord import net.corda.node.services.persistence.SenderDistributionRecord @@ -466,8 +467,8 @@ class FinalityFlowTests : WithFinality { private fun getSenderRecoveryData(id: SecureHash, database: CordaPersistence): List { val fromDb = database.transaction { session.createQuery( - "from ${DBTransactionStorageLedgerRecovery.DBSenderDistributionRecord::class.java.name} where txId = :transactionId", - DBTransactionStorageLedgerRecovery.DBSenderDistributionRecord::class.java + "from ${DBSenderDistributionRecord::class.java.name} where txId = :transactionId", + DBSenderDistributionRecord::class.java ).setParameter("transactionId", id.toString()).resultList } return fromDb.map { it.toSenderDistributionRecord() } @@ -476,8 +477,8 @@ class FinalityFlowTests : WithFinality { private fun getReceiverRecoveryData(txId: SecureHash, receiver: TestStartedNode): ReceiverDistributionRecord? { return receiver.database.transaction { session.createQuery( - "from ${DBTransactionStorageLedgerRecovery.DBReceiverDistributionRecord::class.java.name} where txId = :transactionId", - DBTransactionStorageLedgerRecovery.DBReceiverDistributionRecord::class.java + "from ${DBReceiverDistributionRecord::class.java.name} where txId = :transactionId", + DBReceiverDistributionRecord::class.java ).setParameter("transactionId", txId.toString()).resultList }.singleOrNull()?.toReceiverDistributionRecord() }