ENT-9876: Encrypting the ledger recovery participant distribution list

This commit is contained in:
Shams Asari 2023-07-25 11:58:32 +01:00
parent 6ec8855c6e
commit de67ab7377
19 changed files with 785 additions and 143 deletions

View File

@ -66,7 +66,6 @@ import net.corda.testing.node.internal.FINANCE_WORKFLOWS_CORDAPP
import net.corda.testing.node.internal.InternalMockNetwork
import net.corda.testing.node.internal.InternalMockNodeParameters
import net.corda.testing.node.internal.MOCK_VERSION_INFO
import net.corda.testing.node.internal.MockCryptoService
import net.corda.testing.node.internal.TestCordappInternal
import net.corda.testing.node.internal.TestStartedNode
import net.corda.testing.node.internal.cordappWithPackages
@ -75,6 +74,7 @@ import net.corda.testing.node.internal.findCordapp
import org.assertj.core.api.Assertions.assertThat
import org.junit.After
import org.junit.Test
import org.junit.jupiter.api.assertThrows
import java.sql.SQLException
import java.util.Random
import kotlin.test.assertEquals
@ -239,9 +239,9 @@ class FinalityFlowTests : WithFinality {
private fun assertTxnRemovedFromDatabase(node: TestStartedNode, stxId: SecureHash) {
val fromDb = node.database.transaction {
session.createQuery(
"from ${DBTransactionStorage.DBTransaction::class.java.name} where tx_id = :transactionId",
"from ${DBTransactionStorage.DBTransaction::class.java.name} where txId = :transactionId",
DBTransactionStorage.DBTransaction::class.java
).setParameter("transactionId", stxId.toString()).resultList.map { it }
).setParameter("transactionId", stxId.toString()).resultList
}
assertEquals(0, fromDb.size)
}
@ -357,7 +357,7 @@ class FinalityFlowTests : WithFinality {
assertEquals(StatesToRecord.ONLY_RELEVANT, this[0].statesToRecord)
assertEquals(BOB_NAME.hashCode().toLong(), this[0].peerPartyId)
}
getReceiverRecoveryData(stx.id, bobNode.database).apply {
getReceiverRecoveryData(stx.id, bobNode, aliceNode).apply {
assertEquals(StatesToRecord.ALL_VISIBLE, this?.statesToRecord)
assertEquals(StatesToRecord.ONLY_RELEVANT, this?.senderStatesToRecord)
assertEquals(aliceNode.info.singleIdentity().name.hashCode().toLong(), this?.initiatorPartyId)
@ -390,7 +390,7 @@ class FinalityFlowTests : WithFinality {
assertEquals(StatesToRecord.ONLY_RELEVANT, this[1].statesToRecord)
assertEquals(CHARLIE_NAME.hashCode().toLong(), this[1].peerPartyId)
}
getReceiverRecoveryData(stx.id, bobNode.database).apply {
getReceiverRecoveryData(stx.id, bobNode, aliceNode).apply {
assertEquals(StatesToRecord.ONLY_RELEVANT, this?.statesToRecord)
assertEquals(StatesToRecord.ONLY_RELEVANT, this?.senderStatesToRecord)
assertEquals(aliceNode.info.singleIdentity().name.hashCode().toLong(), this?.initiatorPartyId)
@ -411,8 +411,8 @@ class FinalityFlowTests : WithFinality {
assertThat(charlieNode.services.validatedTransactions.getTransaction(stx3.id)).isNotNull
assertEquals(2, getSenderRecoveryData(stx3.id, aliceNode.database).size)
assertThat(getReceiverRecoveryData(stx3.id, bobNode.database)).isNotNull
assertThat(getReceiverRecoveryData(stx3.id, charlieNode.database)).isNotNull
assertThat(getReceiverRecoveryData(stx3.id, bobNode, aliceNode)).isNotNull
assertThat(getReceiverRecoveryData(stx3.id, charlieNode, aliceNode)).isNotNull
}
@Test(timeout=300_000)
@ -433,7 +433,7 @@ class FinalityFlowTests : WithFinality {
assertEquals(StatesToRecord.ONLY_RELEVANT, this[0].statesToRecord)
assertEquals(BOB_NAME.hashCode().toLong(), this[0].peerPartyId)
}
getReceiverRecoveryData(stx.id, bobNode.database).apply {
getReceiverRecoveryData(stx.id, bobNode, aliceNode).apply {
assertEquals(StatesToRecord.ONLY_RELEVANT, this?.statesToRecord)
assertEquals(StatesToRecord.ONLY_RELEVANT, this?.senderStatesToRecord)
assertEquals(aliceNode.info.singleIdentity().name.hashCode().toLong(), this?.initiatorPartyId)
@ -444,21 +444,28 @@ class FinalityFlowTests : WithFinality {
private fun getSenderRecoveryData(id: SecureHash, database: CordaPersistence): List<SenderDistributionRecord> {
val fromDb = database.transaction {
session.createQuery(
"from ${DBTransactionStorageLedgerRecovery.DBSenderDistributionRecord::class.java.name} where tx_id = :transactionId",
"from ${DBTransactionStorageLedgerRecovery.DBSenderDistributionRecord::class.java.name} where txId = :transactionId",
DBTransactionStorageLedgerRecovery.DBSenderDistributionRecord::class.java
).setParameter("transactionId", id.toString()).resultList.map { it }
).setParameter("transactionId", id.toString()).resultList
}
return fromDb.map { it.toSenderDistributionRecord() }.also { println("SenderDistributionRecord\n$it") }
}
private fun getReceiverRecoveryData(id: SecureHash, database: CordaPersistence): ReceiverDistributionRecord? {
val fromDb = database.transaction {
private fun getReceiverRecoveryData(txId: SecureHash, receiver: TestStartedNode, sender: TestStartedNode): ReceiverDistributionRecord? {
val fromDb = receiver.database.transaction {
session.createQuery(
"from ${DBTransactionStorageLedgerRecovery.DBReceiverDistributionRecord::class.java.name} where tx_id = :transactionId",
"from ${DBTransactionStorageLedgerRecovery.DBReceiverDistributionRecord::class.java.name} where txId = :transactionId",
DBTransactionStorageLedgerRecovery.DBReceiverDistributionRecord::class.java
).setParameter("transactionId", id.toString()).resultList.map { it }
).setParameter("transactionId", txId.toString()).resultList
}.singleOrNull()
// The receiver should not be able to decrypt the distribution list
assertThrows<Exception> {
fromDb?.toReceiverDistributionRecord(receiver.internals.encryptionService)
}
return fromDb.singleOrNull()?.toReceiverDistributionRecord(MockCryptoService(emptyMap())).also { println("ReceiverDistributionRecord\n$it") }
// Only the sender can
return fromDb?.toReceiverDistributionRecord(sender.internals.encryptionService)
}
@StartableByRPC

View File

@ -82,7 +82,11 @@ interface ServiceHubCoreInternal : ServiceHub {
* @param receiverStatesToRecord The StatesToRecord value of the receiver.
* @param encryptedDistributionList encrypted distribution list (hashed peers -> StatesToRecord values)
*/
fun recordReceiverTransactionRecoveryMetadata(txnId: SecureHash, sender: CordaX500Name, receiver: CordaX500Name, receiverStatesToRecord: StatesToRecord, encryptedDistributionList: ByteArray)
fun recordReceiverTransactionRecoveryMetadata(txnId: SecureHash,
sender: CordaX500Name,
receiver: CordaX500Name,
receiverStatesToRecord: StatesToRecord,
encryptedDistributionList: ByteArray)
}
interface TransactionsResolver {

View File

@ -0,0 +1,65 @@
package net.corda.nodeapi.internal.crypto
import net.corda.core.crypto.secureRandomBytes
import java.nio.ByteBuffer
import javax.crypto.Cipher
import javax.crypto.SecretKey
import javax.crypto.spec.GCMParameterSpec
import javax.crypto.spec.SecretKeySpec
object AesEncryption {
const val KEY_SIZE_BYTES = 16
internal const val IV_SIZE_BYTES = 12
private const val TAG_SIZE_BYTES = 16
private const val TAG_SIZE_BITS = TAG_SIZE_BYTES * 8
/**
* Generates a random 128-bit AES key.
*/
fun randomKey(): SecretKey {
return SecretKeySpec(secureRandomBytes(KEY_SIZE_BYTES), "AES")
}
/**
* Encrypt the given [plaintext] with AES using the given [aesKey].
*
* An optional public [additionalData] bytes can also be provided which will be authenticated alongside the ciphertext but not encrypted.
* This may be metadata for example. The same authenticated data bytes must be provided to [decrypt] to be able to decrypt the
* ciphertext. Typically these bytes are serialised alongside the ciphertext. Since it's authenticated in the ciphertext, it cannot be
* modified undetected.
*/
fun encrypt(aesKey: SecretKey, plaintext: ByteArray, additionalData: ByteArray? = null): ByteArray {
val cipher = Cipher.getInstance("AES/GCM/NoPadding")
val iv = secureRandomBytes(IV_SIZE_BYTES) // Never use the same IV with the same key!
cipher.init(Cipher.ENCRYPT_MODE, aesKey, GCMParameterSpec(TAG_SIZE_BITS, iv))
val buffer = ByteBuffer.allocate(IV_SIZE_BYTES + plaintext.size + TAG_SIZE_BYTES)
buffer.put(iv)
if (additionalData != null) {
cipher.updateAAD(additionalData)
}
cipher.doFinal(ByteBuffer.wrap(plaintext), buffer)
return buffer.array()
}
fun encrypt(aesKey: ByteArray, plaintext: ByteArray, additionalData: ByteArray? = null): ByteArray {
return encrypt(SecretKeySpec(aesKey, "AES"), plaintext, additionalData)
}
/**
* Decrypt ciphertext that was encrypted with the same key using [encrypt].
*
* If additional data was used for the encryption then it must also be provided. If doesn't match then the decryption will fail.
*/
fun decrypt(aesKey: SecretKey, ciphertext: ByteArray, additionalData: ByteArray? = null): ByteArray {
val cipher = Cipher.getInstance("AES/GCM/NoPadding")
cipher.init(Cipher.DECRYPT_MODE, aesKey, GCMParameterSpec(TAG_SIZE_BITS, ciphertext, 0, IV_SIZE_BYTES))
if (additionalData != null) {
cipher.updateAAD(additionalData)
}
return cipher.doFinal(ciphertext, IV_SIZE_BYTES, ciphertext.size - IV_SIZE_BYTES)
}
fun decrypt(aesKey: ByteArray, ciphertext: ByteArray, additionalData: ByteArray? = null): ByteArray {
return decrypt(SecretKeySpec(aesKey, "AES"), ciphertext, additionalData)
}
}

View File

@ -0,0 +1,73 @@
package net.corda.nodeapi.internal.crypto
import net.corda.core.crypto.secureRandomBytes
import net.corda.nodeapi.internal.crypto.AesEncryption.IV_SIZE_BYTES
import net.corda.nodeapi.internal.crypto.AesEncryption.KEY_SIZE_BYTES
import org.assertj.core.api.Assertions.assertThat
import org.assertj.core.api.Assertions.assertThatExceptionOfType
import org.junit.Test
import java.security.GeneralSecurityException
class AesEncryptionTest {
private val aesKey = secureRandomBytes(KEY_SIZE_BYTES)
private val plaintext = secureRandomBytes(257) // Intentionally not a power of 2
@Test(timeout = 300_000)
fun `ciphertext can be decrypted using the same key`() {
val ciphertext = AesEncryption.encrypt(aesKey, plaintext)
assertThat(String(ciphertext)).doesNotContain(String(plaintext))
val decrypted = AesEncryption.decrypt(aesKey, ciphertext)
assertThat(decrypted).isEqualTo(plaintext)
}
@Test(timeout = 300_000)
fun `ciphertext with authenticated data can be decrypted using the same key`() {
val ciphertext = AesEncryption.encrypt(aesKey, plaintext, "Extra public data".toByteArray())
assertThat(String(ciphertext)).doesNotContain(String(plaintext))
val decrypted = AesEncryption.decrypt(aesKey, ciphertext, "Extra public data".toByteArray())
assertThat(decrypted).isEqualTo(plaintext)
}
@Test(timeout = 300_000)
fun `ciphertext cannot be decrypted with different authenticated data`() {
val ciphertext = AesEncryption.encrypt(aesKey, plaintext, "Extra public data".toByteArray())
assertThat(String(ciphertext)).doesNotContain(String(plaintext))
assertThatExceptionOfType(GeneralSecurityException::class.java).isThrownBy {
AesEncryption.decrypt(aesKey, ciphertext, "Different public data".toByteArray())
}
}
@Test(timeout = 300_000)
fun `ciphertext cannot be decrypted with different key`() {
val ciphertext = AesEncryption.encrypt(aesKey, plaintext)
for (index in aesKey.indices) {
aesKey[index]--
assertThatExceptionOfType(GeneralSecurityException::class.java).isThrownBy {
AesEncryption.decrypt(aesKey, ciphertext)
}
aesKey[index]++
}
}
@Test(timeout = 300_000)
fun `corrupted ciphertext cannot be decrypted`() {
val ciphertext = AesEncryption.encrypt(aesKey, plaintext)
for (index in ciphertext.indices) {
ciphertext[index]--
assertThatExceptionOfType(GeneralSecurityException::class.java).isThrownBy {
AesEncryption.decrypt(aesKey, ciphertext)
}
ciphertext[index]++
}
}
@Test(timeout = 300_000)
fun `encrypting same plainttext twice with same key does not produce same ciphertext`() {
val first = AesEncryption.encrypt(aesKey, plaintext)
val second = AesEncryption.encrypt(aesKey, plaintext)
// The IV should be different
assertThat(first.take(IV_SIZE_BYTES)).isNotEqualTo(second.take(IV_SIZE_BYTES))
// Which should cause the encrypted bytes to be different as well
assertThat(first.drop(IV_SIZE_BYTES)).isNotEqualTo(second.drop(IV_SIZE_BYTES))
}
}

View File

@ -128,6 +128,7 @@ import net.corda.node.services.persistence.AbstractPartyToX500NameAsStringConver
import net.corda.node.services.persistence.AttachmentStorageInternal
import net.corda.node.services.persistence.DBCheckpointPerformanceRecorder
import net.corda.node.services.persistence.DBCheckpointStorage
import net.corda.node.services.persistence.AesDbEncryptionService
import net.corda.node.services.persistence.DBTransactionMappingStorage
import net.corda.node.services.persistence.DBTransactionStorageLedgerRecovery
import net.corda.node.services.persistence.NodeAttachmentService
@ -286,6 +287,7 @@ abstract class AbstractNode<S>(val configuration: NodeConfiguration,
val networkMapCache = PersistentNetworkMapCache(cacheFactory, database, identityService).tokenize()
val partyInfoCache = PersistentPartyInfoCache(networkMapCache, cacheFactory, database)
val encryptionService = AesDbEncryptionService(database)
@Suppress("LeakingThis")
val cryptoService = makeCryptoService()
@Suppress("LeakingThis")
@ -658,6 +660,7 @@ abstract class AbstractNode<S>(val configuration: NodeConfiguration,
verifyCheckpointsCompatible(frozenTokenizableServices)
partyInfoCache.start()
encryptionService.start(nodeInfo.legalIdentities[0])
/* Note the .get() at the end of the distributeEvent call, below.
This will block until all Corda Services have returned from processing the event, allowing a service to prevent the
@ -1080,7 +1083,7 @@ abstract class AbstractNode<S>(val configuration: NodeConfiguration,
}
protected open fun makeTransactionStorage(transactionCacheSizeBytes: Long): WritableTransactionStorage {
return DBTransactionStorageLedgerRecovery(database, cacheFactory, platformClock, cryptoService, partyInfoCache)
return DBTransactionStorageLedgerRecovery(database, cacheFactory, platformClock, encryptionService, partyInfoCache)
}
protected open fun makeNetworkParametersStorage(): NetworkParametersStorage {

View File

@ -0,0 +1,42 @@
package net.corda.node.services
/**
* A service for encrypting data. This abstraction does not mandate any security properties except the same service instance will be
* able to decrypt ciphertext encrypted by it. Further security properties are defined by the implementations. This includes the encryption
* protocol used.
*/
interface EncryptionService {
/**
* Encrypt the given [plaintext]. The encryption key used is dependent on the implementation. The returned ciphertext can be decrypted
* using [decrypt].
*
* An optional public [additionalData] bytes can also be provided which will be authenticated (thus tamperproof) alongside the
* ciphertext but not encrypted. It will be incorporated into the returned bytes in an implementation dependent fashion.
*/
fun encrypt(plaintext: ByteArray, additionalData: ByteArray? = null): ByteArray
/**
* Decrypt ciphertext that was encrypted using [encrypt] and return the original plaintext plus the additional data authenticated (if
* present). The service will select the correct encryption key to use.
*/
fun decrypt(ciphertext: ByteArray): PlaintextAndAAD
/**
* Extracts the (unauthenticated) additional data, if present, from the given [ciphertext]. This is the public data that would have been
* given at encryption time.
*
* Note, this method does not verify if the data was tampered with, and hence is unauthenticated. To have it authenticated requires
* calling [decrypt]. This is still useful however, as it doesn't require the encryption key, and so a third-party can view the
* additional data without needing access to the key.
*/
fun extractUnauthenticatedAdditionalData(ciphertext: ByteArray): ByteArray?
/**
* Represents the decrypted plaintext and the optional authenticated additional data bytes.
*/
class PlaintextAndAAD(val plaintext: ByteArray, val authenticatedAdditionalData: ByteArray?) {
operator fun component1() = plaintext
operator fun component2() = authenticatedAdditionalData
}
}

View File

@ -372,22 +372,26 @@ interface WritableTransactionStorage : TransactionStorage {
/**
* Records Sender [TransactionMetadata] for a given txnId.
*
* @param id The SecureHash of a transaction.
* @param txId The SecureHash of a transaction.
* @param metadata The recovery metadata associated with a transaction.
* @return encrypted distribution list (hashed peers -> StatesToRecord values).
*/
fun addSenderTransactionRecoveryMetadata(id: SecureHash, metadata: TransactionMetadata): ByteArray?
fun addSenderTransactionRecoveryMetadata(txId: SecureHash, metadata: TransactionMetadata): ByteArray?
/**
* Records Received [TransactionMetadata] for a given txnId.
*
* @param id The SecureHash of a transaction.
* @param txId The SecureHash of a transaction.
* @param sender The sender of the transaction.
* @param receiver The receiver of the transaction.
* @param receiverStatesToRecord The StatesToRecord value of the receiver.
* @param encryptedDistributionList encrypted distribution list (hashed peers -> StatesToRecord values)
*/
fun addReceiverTransactionRecoveryMetadata(id: SecureHash, sender: CordaX500Name, receiver: CordaX500Name, receiverStatesToRecord: StatesToRecord, encryptedDistributionList: ByteArray)
fun addReceiverTransactionRecoveryMetadata(txId: SecureHash,
sender: CordaX500Name,
receiver: CordaX500Name,
receiverStatesToRecord: StatesToRecord,
encryptedDistributionList: ByteArray)
/**
* Removes an un-notarised transaction (with a status of *MISSING_TRANSACTION_SIG*) from the data store.

View File

@ -0,0 +1,152 @@
package net.corda.node.services.persistence
import net.corda.core.crypto.newSecureRandom
import net.corda.core.identity.Party
import net.corda.core.internal.copyBytes
import net.corda.core.serialization.SingletonSerializeAsToken
import net.corda.node.services.EncryptionService
import net.corda.nodeapi.internal.crypto.AesEncryption
import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.nodeapi.internal.persistence.NODE_DATABASE_PREFIX
import org.hibernate.annotations.Type
import java.nio.ByteBuffer
import java.security.Key
import java.security.MessageDigest
import java.util.UUID
import javax.crypto.Cipher
import javax.crypto.SecretKey
import javax.crypto.spec.SecretKeySpec
import javax.persistence.Column
import javax.persistence.Entity
import javax.persistence.Id
import javax.persistence.Table
/**
* [EncryptionService] which uses AES keys stored in the node database. A random key is chosen for encryption, and the resultant ciphertext
* encodes the key used so that it can be decrypted without needing further information.
*
* **Storing encryption keys in a database is not secure, and so only use this service if the data being encrypted is also stored
* unencrypted in the same database.**
*
* To obfuscate the keys, they are stored wrapped using another AES key (called the wrapping key or key-encryption-key) derived from the
* node's legal identity. This is not a security measure; it's only meant to reduce the impact of accidental leakage.
*/
// TODO Add support for key expiry
class AesDbEncryptionService(private val database: CordaPersistence) : EncryptionService, SingletonSerializeAsToken() {
companion object {
private const val INITIAL_KEY_COUNT = 10
private const val UUID_BYTES = 16
}
private val aesKeys = ArrayList<Pair<UUID, SecretKey>>()
fun start(ourIdentity: Party) {
database.transaction {
val criteria = session.criteriaBuilder.createQuery(EncryptionKeyRecord::class.java)
criteria.select(criteria.from(EncryptionKeyRecord::class.java))
val dbKeyRecords = session.createQuery(criteria).resultList
val keyWrapper = Cipher.getInstance("AESWrap")
if (dbKeyRecords.isEmpty()) {
repeat(INITIAL_KEY_COUNT) {
val keyId = UUID.randomUUID()
val aesKey = AesEncryption.randomKey()
aesKeys += Pair(keyId, aesKey)
val wrappedKey = with(keyWrapper) {
init(Cipher.WRAP_MODE, createKEK(ourIdentity, keyId))
wrap(aesKey)
}
session.save(EncryptionKeyRecord(keyId = keyId, keyMaterial = wrappedKey))
}
} else {
for (dbKeyRecord in dbKeyRecords) {
val aesKey = with(keyWrapper) {
init(Cipher.UNWRAP_MODE, createKEK(ourIdentity, dbKeyRecord.keyId))
unwrap(dbKeyRecord.keyMaterial, "AES", Cipher.SECRET_KEY) as SecretKey
}
aesKeys += Pair(dbKeyRecord.keyId, aesKey)
}
}
}
}
override fun encrypt(plaintext: ByteArray, additionalData: ByteArray?): ByteArray {
val (keyId, aesKey) = aesKeys[newSecureRandom().nextInt(aesKeys.size)]
val ciphertext = AesEncryption.encrypt(aesKey, plaintext, additionalData)
val buffer = ByteBuffer.allocate(1 + UUID_BYTES + Integer.BYTES + (additionalData?.size ?: 0) + ciphertext.size)
buffer.put(1) // Version tag
// Prepend the key ID to the returned ciphertext. It's OK that this is not included in the authenticated additional data because
// changing this value will lead to either an non-existent key or an another key which will not be able decrypt the ciphertext.
buffer.putUUID(keyId)
if (additionalData != null) {
buffer.putInt(additionalData.size)
buffer.put(additionalData)
} else {
buffer.putInt(0)
}
buffer.put(ciphertext)
return buffer.array()
}
override fun decrypt(ciphertext: ByteArray): EncryptionService.PlaintextAndAAD {
val buffer = ByteBuffer.wrap(ciphertext)
val version = buffer.get().toInt()
require(version == 1)
val keyId = buffer.getUUID()
val aesKey = requireNotNull(aesKeys.find { it.first == keyId }?.second) { "Unable to decrypt" }
val additionalData = buffer.getAdditionaData()
val plaintext = AesEncryption.decrypt(aesKey, buffer.copyBytes(), additionalData)
// Only now is the additional data authenticated
return EncryptionService.PlaintextAndAAD(plaintext, additionalData)
}
override fun extractUnauthenticatedAdditionalData(ciphertext: ByteArray): ByteArray? {
val buffer = ByteBuffer.wrap(ciphertext)
buffer.position(1 + UUID_BYTES)
return buffer.getAdditionaData()
}
private fun ByteBuffer.getAdditionaData(): ByteArray? {
val additionalDataSize = getInt()
return if (additionalDataSize > 0) ByteArray(additionalDataSize).also { get(it) } else null
}
private fun UUID.toByteArray(): ByteArray {
val buffer = ByteBuffer.allocate(UUID_BYTES)
buffer.putUUID(this)
return buffer.array()
}
/**
* Derive the key-encryption-key (KEK) from the the node's identity and the persisted key's ID.
*/
private fun createKEK(ourIdentity: Party, keyId: UUID): Key {
val digest = MessageDigest.getInstance("SHA-256")
digest.update(ourIdentity.name.x500Principal.encoded)
digest.update(keyId.toByteArray())
return SecretKeySpec(digest.digest(), 0, AesEncryption.KEY_SIZE_BYTES, "AES")
}
@Entity
@Table(name = "${NODE_DATABASE_PREFIX}aes_encryption_keys")
class EncryptionKeyRecord(
@Id
@Type(type = "uuid-char")
@Column(name = "key_id", nullable = false)
val keyId: UUID,
@Column(name = "key_material", nullable = false)
val keyMaterial: ByteArray
)
}
internal fun ByteBuffer.putUUID(uuid: UUID) {
putLong(uuid.mostSignificantBits)
putLong(uuid.leastSignificantBits)
}
internal fun ByteBuffer.getUUID(): UUID {
val mostSigBits = getLong()
val leastSigBits = getLong()
return UUID(mostSigBits, leastSigBits)
}

View File

@ -215,9 +215,13 @@ open class DBTransactionStorage(private val database: CordaPersistence, cacheFac
false
}
override fun addSenderTransactionRecoveryMetadata(id: SecureHash, metadata: TransactionMetadata): ByteArray? { return null }
override fun addSenderTransactionRecoveryMetadata(txId: SecureHash, metadata: TransactionMetadata): ByteArray? { return null }
override fun addReceiverTransactionRecoveryMetadata(id: SecureHash, sender: CordaX500Name, receiver: CordaX500Name, receiverStatesToRecord: StatesToRecord, encryptedDistributionList: ByteArray) { }
override fun addReceiverTransactionRecoveryMetadata(txId: SecureHash,
sender: CordaX500Name,
receiver: CordaX500Name,
receiverStatesToRecord: StatesToRecord,
encryptedDistributionList: ByteArray) { }
override fun finalizeTransaction(transaction: SignedTransaction) =
addTransaction(transaction) {

View File

@ -9,15 +9,11 @@ import net.corda.core.node.StatesToRecord
import net.corda.core.node.services.vault.Sort
import net.corda.core.serialization.CordaSerializable
import net.corda.node.CordaClock
import net.corda.node.services.EncryptionService
import net.corda.node.services.network.PersistentPartyInfoCache
import net.corda.nodeapi.internal.cryptoservice.CryptoService
import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.nodeapi.internal.persistence.NODE_DATABASE_PREFIX
import org.hibernate.annotations.Immutable
import java.io.ByteArrayInputStream
import java.io.ByteArrayOutputStream
import java.io.DataInputStream
import java.io.DataOutputStream
import java.io.Serializable
import java.time.Instant
import java.util.concurrent.atomic.AtomicLong
@ -31,9 +27,10 @@ import javax.persistence.Table
import javax.persistence.criteria.Predicate
import kotlin.streams.toList
class DBTransactionStorageLedgerRecovery(private val database: CordaPersistence, cacheFactory: NamedCacheFactory,
class DBTransactionStorageLedgerRecovery(private val database: CordaPersistence,
cacheFactory: NamedCacheFactory,
val clock: CordaClock,
private val cryptoService: CryptoService,
private val encryptionService: EncryptionService,
private val partyInfoCache: PersistentPartyInfoCache) : DBTransactionStorage(database, cacheFactory, clock) {
@Embeddable
@Immutable
@ -63,7 +60,6 @@ class DBTransactionStorageLedgerRecovery(private val database: CordaPersistence,
/** states to record: NONE, ALL_VISIBLE, ONLY_RELEVANT */
@Column(name = "states_to_record", nullable = false)
var statesToRecord: StatesToRecord
) {
fun toSenderDistributionRecord() =
SenderDistributionRecord(
@ -76,7 +72,7 @@ class DBTransactionStorageLedgerRecovery(private val database: CordaPersistence,
@Entity
@Table(name = "${NODE_DATABASE_PREFIX}receiver_distribution_records")
data class DBReceiverDistributionRecord(
class DBReceiverDistributionRecord(
@EmbeddedId
var compositeKey: PersistentKey,
@ -95,17 +91,18 @@ class DBTransactionStorageLedgerRecovery(private val database: CordaPersistence,
/** states to record: NONE, ALL_VISIBLE, ONLY_RELEVANT */
@Column(name = "receiver_states_to_record", nullable = false)
val receiverStatesToRecord: StatesToRecord
) {
) {
constructor(key: Key, txId: SecureHash, initiatorPartyId: Long, encryptedDistributionList: ByteArray, receiverStatesToRecord: StatesToRecord) :
this(PersistentKey(key),
txId = txId.toString(),
senderPartyId = initiatorPartyId,
distributionList = encryptedDistributionList,
receiverStatesToRecord = receiverStatesToRecord
)
this(
PersistentKey(key),
txId = txId.toString(),
senderPartyId = initiatorPartyId,
distributionList = encryptedDistributionList,
receiverStatesToRecord = receiverStatesToRecord
)
fun toReceiverDistributionRecord(cryptoService: CryptoService): ReceiverDistributionRecord {
val hashedDL = HashedDistributionList.deserialize(cryptoService.decrypt(this.distributionList))
fun toReceiverDistributionRecord(encryptionService: EncryptionService): ReceiverDistributionRecord {
val hashedDL = HashedDistributionList.decrypt(this.distributionList, encryptionService)
return ReceiverDistributionRecord(
SecureHash.parse(this.txId),
this.senderPartyId,
@ -139,32 +136,45 @@ class DBTransactionStorageLedgerRecovery(private val database: CordaPersistence,
}
}
override fun addSenderTransactionRecoveryMetadata(id: SecureHash, metadata: TransactionMetadata): ByteArray {
override fun addSenderTransactionRecoveryMetadata(txId: SecureHash, metadata: TransactionMetadata): ByteArray {
return database.transaction {
val senderRecordingTimestamp = clock.instant()
metadata.distributionList.peersToStatesToRecord.forEach { (peer, _) ->
val senderDistributionRecord = DBSenderDistributionRecord(PersistentKey(Key(senderRecordingTimestamp)),
id.toString(),
for (peer in metadata.distributionList.peersToStatesToRecord.keys) {
val senderDistributionRecord = DBSenderDistributionRecord(
PersistentKey(Key(senderRecordingTimestamp)),
txId.toString(),
partyInfoCache.getPartyIdByCordaX500Name(peer),
metadata.distributionList.senderStatesToRecord)
metadata.distributionList.senderStatesToRecord
)
session.save(senderDistributionRecord)
}
val hashedPeersToStatesToRecord = metadata.distributionList.peersToStatesToRecord.map { (peer, statesToRecord) ->
partyInfoCache.getPartyIdByCordaX500Name(peer) to statesToRecord }.toMap()
val hashedDistributionList = HashedDistributionList(metadata.distributionList.senderStatesToRecord, hashedPeersToStatesToRecord, senderRecordingTimestamp)
cryptoService.encrypt(hashedDistributionList.serialize())
val hashedPeersToStatesToRecord = metadata.distributionList.peersToStatesToRecord.mapKeys { (peer) ->
partyInfoCache.getPartyIdByCordaX500Name(peer)
}
val hashedDistributionList = HashedDistributionList(
metadata.distributionList.senderStatesToRecord,
hashedPeersToStatesToRecord,
HashedDistributionList.PublicHeader(senderRecordingTimestamp)
)
hashedDistributionList.encrypt(encryptionService)
}
}
override fun addReceiverTransactionRecoveryMetadata(id: SecureHash, sender: CordaX500Name, receiver: CordaX500Name, receiverStatesToRecord: StatesToRecord, encryptedDistributionList: ByteArray) {
val senderRecordedTimestamp = HashedDistributionList.deserialize(cryptoService.decrypt(encryptedDistributionList)).senderRecordedTimestamp
override fun addReceiverTransactionRecoveryMetadata(txId: SecureHash,
sender: CordaX500Name,
receiver: CordaX500Name,
receiverStatesToRecord: StatesToRecord,
encryptedDistributionList: ByteArray) {
val publicHeader = HashedDistributionList.PublicHeader.unauthenticatedDeserialise(encryptedDistributionList, encryptionService)
database.transaction {
val receiverDistributionRecord =
DBReceiverDistributionRecord(Key(senderRecordedTimestamp),
id,
partyInfoCache.getPartyIdByCordaX500Name(sender),
encryptedDistributionList,
receiverStatesToRecord)
val receiverDistributionRecord = DBReceiverDistributionRecord(
Key(publicHeader.senderRecordedTimestamp),
txId,
partyInfoCache.getPartyIdByCordaX500Name(sender),
encryptedDistributionList,
receiverStatesToRecord
)
session.save(receiverDistributionRecord)
}
}
@ -235,8 +245,9 @@ class DBTransactionStorageLedgerRecovery(private val database: CordaPersistence,
}
criteriaQuery.orderBy(orderCriteria)
}
val results = session.createQuery(criteriaQuery).stream()
results.map { it.toSenderDistributionRecord() }.toList()
session.createQuery(criteriaQuery).stream().use { results ->
results.map { it.toSenderDistributionRecord() }.toList()
}
}
}
@ -273,21 +284,13 @@ class DBTransactionStorageLedgerRecovery(private val database: CordaPersistence,
}
criteriaQuery.orderBy(orderCriteria)
}
val results = session.createQuery(criteriaQuery).stream()
results.map { it.toReceiverDistributionRecord(cryptoService) }.toList()
session.createQuery(criteriaQuery).stream().use { results ->
results.map { it.toReceiverDistributionRecord(encryptionService) }.toList()
}
}
}
}
// TO DO: https://r3-cev.atlassian.net/browse/ENT-9876
private fun CryptoService.decrypt(bytes: ByteArray): ByteArray {
return bytes
}
// TO DO: https://r3-cev.atlassian.net/browse/ENT-9876
fun CryptoService.encrypt(bytes: ByteArray): ByteArray {
return bytes
}
@CordaSerializable
open class DistributionRecord(
@ -318,46 +321,3 @@ data class ReceiverDistributionRecord(
enum class DistributionRecordType {
SENDER, RECEIVER, ALL
}
@CordaSerializable
data class HashedDistributionList(
val senderStatesToRecord: StatesToRecord,
val peerHashToStatesToRecord: Map<Long, StatesToRecord>,
val senderRecordedTimestamp: Instant
) {
fun serialize(): ByteArray {
val baos = ByteArrayOutputStream()
val out = DataOutputStream(baos)
out.use {
out.writeByte(SERIALIZER_VERSION_ID)
out.writeByte(senderStatesToRecord.ordinal)
out.writeInt(peerHashToStatesToRecord.size)
for(entry in peerHashToStatesToRecord) {
out.writeLong(entry.key)
out.writeByte(entry.value.ordinal)
}
out.writeLong(senderRecordedTimestamp.toEpochMilli())
out.flush()
return baos.toByteArray()
}
}
companion object {
const val SERIALIZER_VERSION_ID = 1
fun deserialize(bytes: ByteArray): HashedDistributionList {
val input = DataInputStream(ByteArrayInputStream(bytes))
input.use {
assert(input.readByte().toInt() == SERIALIZER_VERSION_ID) { "Serialization version conflict." }
val senderStatesToRecord = StatesToRecord.values()[input.readByte().toInt()]
val numPeerHashToStatesToRecords = input.readInt()
val peerHashToStatesToRecord = mutableMapOf<Long, StatesToRecord>()
repeat (numPeerHashToStatesToRecords) {
peerHashToStatesToRecord[input.readLong()] = StatesToRecord.values()[input.readByte().toInt()]
}
val senderRecordedTimestamp = Instant.ofEpochMilli(input.readLong())
return HashedDistributionList(senderStatesToRecord, peerHashToStatesToRecord, senderRecordedTimestamp)
}
}
}
}

View File

@ -0,0 +1,104 @@
package net.corda.node.services.persistence
import net.corda.core.node.StatesToRecord
import net.corda.core.serialization.CordaSerializable
import net.corda.node.services.EncryptionService
import java.io.ByteArrayOutputStream
import java.io.DataInputStream
import java.io.DataOutputStream
import java.nio.ByteBuffer
import java.time.Instant
@Suppress("TooGenericExceptionCaught")
@CordaSerializable
data class HashedDistributionList(
val senderStatesToRecord: StatesToRecord,
val peerHashToStatesToRecord: Map<Long, StatesToRecord>,
val publicHeader: PublicHeader
) {
/**
* Encrypt this hashed distribution list using the given [EncryptionService]. The [publicHeader] is not encrypted but is instead
* authenticated so that it is tamperproof.
*
* The same [EncryptionService] instance needs to be used with [decrypt] for decryption.
*/
fun encrypt(encryptionService: EncryptionService): ByteArray {
val baos = ByteArrayOutputStream()
val out = DataOutputStream(baos)
out.writeByte(senderStatesToRecord.ordinal)
out.writeInt(peerHashToStatesToRecord.size)
for (entry in peerHashToStatesToRecord) {
out.writeLong(entry.key)
out.writeByte(entry.value.ordinal)
}
return encryptionService.encrypt(baos.toByteArray(), publicHeader.serialise())
}
@CordaSerializable
data class PublicHeader(
val senderRecordedTimestamp: Instant
) {
fun serialise(): ByteArray {
val buffer = ByteBuffer.allocate(1 + java.lang.Long.BYTES)
buffer.put(VERSION_TAG.toByte())
buffer.putLong(senderRecordedTimestamp.toEpochMilli())
return buffer.array()
}
companion object {
/**
* Deserialise a [PublicHeader] from the given [encryptedBytes]. The bytes is expected is to be a valid encrypted blob that can
* be decrypted by [HashedDistributionList.decrypt] using the same [EncryptionService].
*
* Because this method does not actually decrypt the bytes, the header returned is not authenticated and any modifications to it
* will not be detected. That can only be done by the encrypting party with [HashedDistributionList.decrypt].
*/
fun unauthenticatedDeserialise(encryptedBytes: ByteArray, encryptionService: EncryptionService): PublicHeader {
val additionalData = encryptionService.extractUnauthenticatedAdditionalData(encryptedBytes)
requireNotNull(additionalData) { "Missing additional data field" }
return deserialise(additionalData!!)
}
fun deserialise(bytes: ByteArray): PublicHeader {
val buffer = ByteBuffer.wrap(bytes)
try {
val version = buffer.get().toInt()
require(version == VERSION_TAG) { "Unknown distribution list format $version" }
val senderRecordedTimestamp = Instant.ofEpochMilli(buffer.getLong())
return PublicHeader(senderRecordedTimestamp)
} catch (e: Exception) {
throw IllegalArgumentException("Corrupt or not a distribution list header", e)
}
}
}
}
companion object {
// The version tag is serialised in the header, even though it is separate from the encrypted main body of the distribution list.
// This is because the header and the dist list are cryptographically coupled and we want to avoid declaring the version field twice.
private const val VERSION_TAG = 1
private val statesToRecordValues = StatesToRecord.values() // Cache the enum values since .values() returns a new array each time.
/**
* Decrypt a [HashedDistributionList] from the given [encryptedBytes] using the same [EncryptionService] that was used in [encrypt].
*/
fun decrypt(encryptedBytes: ByteArray, encryptionService: EncryptionService): HashedDistributionList {
val (plaintext, authenticatedAdditionalData) = encryptionService.decrypt(encryptedBytes)
requireNotNull(authenticatedAdditionalData) { "Missing authenticated header" }
val publicHeader = PublicHeader.deserialise(authenticatedAdditionalData!!)
val input = DataInputStream(plaintext.inputStream())
try {
val senderStatesToRecord = statesToRecordValues[input.readByte().toInt()]
val numPeerHashToStatesToRecords = input.readInt()
val peerHashToStatesToRecord = mutableMapOf<Long, StatesToRecord>()
repeat(numPeerHashToStatesToRecords) {
peerHashToStatesToRecord[input.readLong()] = statesToRecordValues[input.readByte().toInt()]
}
return HashedDistributionList(senderStatesToRecord, peerHashToStatesToRecord, publicHeader)
} catch (e: Exception) {
throw IllegalArgumentException("Corrupt or not a distribution list", e)
}
}
}
}

View File

@ -16,6 +16,7 @@ import net.corda.node.services.keys.BasicHSMKeyManagementService
import net.corda.node.services.messaging.P2PMessageDeduplicator
import net.corda.node.services.network.PersistentNetworkMapCache
import net.corda.node.services.persistence.DBCheckpointStorage
import net.corda.node.services.persistence.AesDbEncryptionService
import net.corda.node.services.persistence.DBTransactionStorageLedgerRecovery
import net.corda.node.services.persistence.DBTransactionStorage
import net.corda.node.services.persistence.NodeAttachmentService
@ -30,7 +31,7 @@ import net.corda.node.services.vault.VaultSchemaV1
* TODO: support plugins for schema version upgrading or custom mapping not supported by original [QueryableState].
* TODO: create whitelisted tables when a CorDapp is first installed
*/
class NodeSchemaService(private val extraSchemas: Set<MappedSchema> = emptySet()) : SchemaService, SingletonSerializeAsToken() {
class NodeSchemaService(extraSchemas: Set<MappedSchema> = emptySet()) : SchemaService, SingletonSerializeAsToken() {
// Core Entities used by a Node
object NodeCore
@ -55,7 +56,8 @@ class NodeSchemaService(private val extraSchemas: Set<MappedSchema> = emptySet()
PersistentNetworkMapCache.PersistentPartyToPublicKeyHash::class.java,
DBTransactionStorageLedgerRecovery.DBSenderDistributionRecord::class.java,
DBTransactionStorageLedgerRecovery.DBReceiverDistributionRecord::class.java,
DBTransactionStorageLedgerRecovery.DBRecoveryPartyInfo::class.java
DBTransactionStorageLedgerRecovery.DBRecoveryPartyInfo::class.java,
AesDbEncryptionService.EncryptionKeyRecord::class.java
)) {
override val migrationResource = "node-core.changelog-master"
}

View File

@ -31,6 +31,7 @@
<include file="migration/node-core.changelog-v23.xml"/>
<include file="migration/node-core.changelog-v24.xml"/>
<include file="migration/node-core.changelog-v25.xml"/>
<include file="migration/node-core.changelog-v26.xml"/>
<!-- This must run after node-core.changelog-init.xml, to prevent database columns being created twice. -->
<include file="migration/vault-schema.changelog-v9.xml"/>

View File

@ -0,0 +1,28 @@
<?xml version="1.1" encoding="UTF-8" standalone="no"?>
<databaseChangeLog xmlns="http://www.liquibase.org/xml/ns/dbchangelog"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://www.liquibase.org/xml/ns/dbchangelog http://www.liquibase.org/xml/ns/dbchangelog/dbchangelog-3.5.xsd"
logicalFilePath="migration/node-services.changelog-init.xml">
<changeSet author="R3.Corda" id="create_aes_encryption_keys_table">
<createTable tableName="node_aes_encryption_keys">
<column name="key_id" type="VARCHAR(36)">
<constraints nullable="false"/>
</column>
<column name="key_material" type="VARBINARY(512)">
<constraints nullable="false"/>
</column>
</createTable>
</changeSet>
<changeSet author="R3.Corda" id="node_aes_encryption_keys_pkey">
<addPrimaryKey constraintName="node_aes_encryption_keys_pkey" tableName="node_aes_encryption_keys" columnNames="key_id"/>
</changeSet>
<changeSet author="R3.Corda" id="node_aes_encryption_keys_idx">
<createIndex indexName="node_aes_encryption_keys_idx" tableName="node_aes_encryption_keys">
<column name="key_id"/>
</createIndex>
</changeSet>
</databaseChangeLog>

View File

@ -810,15 +810,19 @@ class TwoPartyTradeFlowTests(private val anonymous: Boolean) {
return true
}
override fun addSenderTransactionRecoveryMetadata(id: SecureHash, metadata: TransactionMetadata): ByteArray? {
override fun addSenderTransactionRecoveryMetadata(txId: SecureHash, metadata: TransactionMetadata): ByteArray? {
return database.transaction {
delegate.addSenderTransactionRecoveryMetadata(id, metadata)
delegate.addSenderTransactionRecoveryMetadata(txId, metadata)
}
}
override fun addReceiverTransactionRecoveryMetadata(id: SecureHash, sender: CordaX500Name, receiver: CordaX500Name, receiverStatesToRecord: StatesToRecord, encryptedDistributionList: ByteArray) {
override fun addReceiverTransactionRecoveryMetadata(txId: SecureHash,
sender: CordaX500Name,
receiver: CordaX500Name,
receiverStatesToRecord: StatesToRecord,
encryptedDistributionList: ByteArray) {
database.transaction {
delegate.addReceiverTransactionRecoveryMetadata(id, sender, receiver, receiverStatesToRecord, encryptedDistributionList)
delegate.addReceiverTransactionRecoveryMetadata(txId, sender, receiver, receiverStatesToRecord, encryptedDistributionList)
}
}

View File

@ -0,0 +1,134 @@
package net.corda.node.services.persistence
import net.corda.node.services.persistence.AesDbEncryptionService.EncryptionKeyRecord
import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.nodeapi.internal.persistence.DatabaseConfig
import net.corda.testing.core.TestIdentity
import net.corda.testing.internal.configureDatabase
import net.corda.testing.node.MockServices
import org.assertj.core.api.Assertions.assertThat
import org.assertj.core.api.Assertions.assertThatExceptionOfType
import org.assertj.core.api.Assertions.assertThatIllegalArgumentException
import org.junit.After
import org.junit.Before
import org.junit.Test
import java.nio.ByteBuffer
import java.security.GeneralSecurityException
import java.util.UUID
class AesDbEncryptionServiceTest {
private val identity = TestIdentity.fresh("me").party
private lateinit var database: CordaPersistence
private lateinit var encryptionService: AesDbEncryptionService
@Before
fun setUp() {
val dataSourceProps = MockServices.makeTestDataSourceProperties()
database = configureDatabase(dataSourceProps, DatabaseConfig(), { null }, { null })
encryptionService = AesDbEncryptionService(database)
encryptionService.start(identity)
}
@After
fun cleanUp() {
database.close()
}
@Test(timeout = 300_000)
fun `same instance can decrypt ciphertext`() {
val ciphertext = encryptionService.encrypt("Hello World".toByteArray())
val (plaintext, authenticatedData) = encryptionService.decrypt(ciphertext)
assertThat(String(plaintext)).isEqualTo("Hello World")
assertThat(authenticatedData).isNull()
}
@Test(timeout = 300_000)
fun `encypting twice produces different ciphertext`() {
val plaintext = "Hello".toByteArray()
assertThat(encryptionService.encrypt(plaintext)).isNotEqualTo(encryptionService.encrypt(plaintext))
}
@Test(timeout = 300_000)
fun `ciphertext can be decrypted after restart`() {
val ciphertext = encryptionService.encrypt("Hello World".toByteArray())
encryptionService = AesDbEncryptionService(database)
encryptionService.start(identity)
val plaintext = encryptionService.decrypt(ciphertext).plaintext
assertThat(String(plaintext)).isEqualTo("Hello World")
}
@Test(timeout = 300_000)
fun `encrypting with authenticated data`() {
val ciphertext = encryptionService.encrypt("Hello World".toByteArray(), "Additional data".toByteArray())
val (plaintext, authenticatedData) = encryptionService.decrypt(ciphertext)
assertThat(String(plaintext)).isEqualTo("Hello World")
assertThat(authenticatedData?.let { String(it) }).isEqualTo("Additional data")
}
@Test(timeout = 300_000)
fun extractUnauthenticatedAdditionalData() {
val ciphertext = encryptionService.encrypt("Hello World".toByteArray(), "Additional data".toByteArray())
val additionalData = encryptionService.extractUnauthenticatedAdditionalData(ciphertext)
assertThat(additionalData?.let { String(it) }).isEqualTo("Additional data")
}
@Test(timeout = 300_000)
fun `ciphertext cannot be decrypted if the authenticated data is modified`() {
val ciphertext = ByteBuffer.wrap(encryptionService.encrypt("Hello World".toByteArray(), "1234".toByteArray()))
ciphertext.position(21)
ciphertext.put("4321".toByteArray()) // Use same length for the modified AAD
assertThatExceptionOfType(GeneralSecurityException::class.java).isThrownBy {
encryptionService.decrypt(ciphertext.array())
}
}
@Test(timeout = 300_000)
fun `ciphertext cannot be decrypted if the key used is deleted`() {
val ciphertext = encryptionService.encrypt("Hello World".toByteArray())
val keyId = ByteBuffer.wrap(ciphertext).getKeyId()
val deletedCount = database.transaction {
session.createQuery("DELETE FROM ${EncryptionKeyRecord::class.java.name} k WHERE k.keyId = :keyId")
.setParameter("keyId", keyId)
.executeUpdate()
}
assertThat(deletedCount).isEqualTo(1)
encryptionService = AesDbEncryptionService(database)
encryptionService.start(identity)
assertThatIllegalArgumentException().isThrownBy {
encryptionService.decrypt(ciphertext)
}
}
@Test(timeout = 300_000)
fun `ciphertext cannot be decrypted if forced to use a different key`() {
val ciphertext = ByteBuffer.wrap(encryptionService.encrypt("Hello World".toByteArray()))
val keyId = ciphertext.getKeyId()
val anotherKeyId = database.transaction {
session.createQuery("SELECT keyId FROM ${EncryptionKeyRecord::class.java.name} k WHERE k.keyId != :keyId", UUID::class.java)
.setParameter("keyId", keyId)
.setMaxResults(1)
.singleResult
}
ciphertext.putKeyId(anotherKeyId)
encryptionService = AesDbEncryptionService(database)
encryptionService.start(identity)
assertThatExceptionOfType(GeneralSecurityException::class.java).isThrownBy {
encryptionService.decrypt(ciphertext.array())
}
}
private fun ByteBuffer.getKeyId(): UUID {
position(1)
return getUUID()
}
private fun ByteBuffer.putKeyId(keyId: UUID) {
position(1)
putUUID(keyId)
}
}

View File

@ -24,7 +24,6 @@ import net.corda.node.services.network.PersistentPartyInfoCache
import net.corda.node.services.persistence.DBTransactionStorage.TransactionStatus.IN_FLIGHT
import net.corda.node.services.persistence.DBTransactionStorage.TransactionStatus.VERIFIED
import net.corda.nodeapi.internal.DEV_ROOT_CA
import net.corda.nodeapi.internal.cryptoservice.CryptoService
import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.nodeapi.internal.persistence.DatabaseConfig
import net.corda.testing.core.ALICE_NAME
@ -38,7 +37,8 @@ import net.corda.testing.internal.TestingNamedCacheFactory
import net.corda.testing.internal.configureDatabase
import net.corda.testing.internal.createWireTransaction
import net.corda.testing.node.MockServices.Companion.makeTestDataSourceProperties
import net.corda.testing.node.internal.MockCryptoService
import net.corda.testing.node.internal.MockEncryptionService
import org.assertj.core.api.Assertions.assertThat
import org.junit.After
import org.junit.Before
import org.junit.Rule
@ -67,6 +67,8 @@ class DBTransactionStorageLedgerRecoveryTests {
private lateinit var transactionRecovery: DBTransactionStorageLedgerRecovery
private lateinit var partyInfoCache: PersistentPartyInfoCache
private val encryptionService = MockEncryptionService()
@Before
fun setUp() {
val dataSourceProps = makeTestDataSourceProperties()
@ -278,17 +280,21 @@ class DBTransactionStorageLedgerRecoveryTests {
@Test(timeout = 300_000)
fun `test lightweight serialization and deserialization of hashed distribution list payload`() {
val dl = HashedDistributionList(ALL_VISIBLE,
mapOf(BOB.name.hashCode().toLong() to NONE, CHARLIE_NAME.hashCode().toLong() to ONLY_RELEVANT), now())
assertEquals(dl, dl.serialize().let { HashedDistributionList.deserialize(it) })
val hashedDistList = HashedDistributionList(
ALL_VISIBLE,
mapOf(BOB.name.hashCode().toLong() to NONE, CHARLIE_NAME.hashCode().toLong() to ONLY_RELEVANT),
HashedDistributionList.PublicHeader(now())
)
val roundtrip = HashedDistributionList.decrypt(hashedDistList.encrypt(encryptionService), encryptionService)
assertThat(roundtrip).isEqualTo(hashedDistList)
}
private fun readTransactionFromDB(id: SecureHash): DBTransactionStorage.DBTransaction {
val fromDb = database.transaction {
session.createQuery(
"from ${DBTransactionStorage.DBTransaction::class.java.name} where tx_id = :transactionId",
"from ${DBTransactionStorage.DBTransaction::class.java.name} where txId = :transactionId",
DBTransactionStorage.DBTransaction::class.java
).setParameter("transactionId", id.toString()).resultList.map { it }
).setParameter("transactionId", id.toString()).resultList
}
assertEquals(1, fromDb.size)
return fromDb[0]
@ -298,7 +304,7 @@ class DBTransactionStorageLedgerRecoveryTests {
return database.transaction {
if (id != null)
session.createQuery(
"from ${DBTransactionStorageLedgerRecovery.DBSenderDistributionRecord::class.java.name} where tx_id = :transactionId",
"from ${DBTransactionStorageLedgerRecovery.DBSenderDistributionRecord::class.java.name} where txId = :transactionId",
DBTransactionStorageLedgerRecovery.DBSenderDistributionRecord::class.java
).setParameter("transactionId", id.toString()).resultList.map { it.toSenderDistributionRecord() }
else
@ -312,17 +318,15 @@ class DBTransactionStorageLedgerRecoveryTests {
private fun readReceiverDistributionRecordFromDB(id: SecureHash): ReceiverDistributionRecord {
val fromDb = database.transaction {
session.createQuery(
"from ${DBTransactionStorageLedgerRecovery.DBReceiverDistributionRecord::class.java.name} where tx_id = :transactionId",
"from ${DBTransactionStorageLedgerRecovery.DBReceiverDistributionRecord::class.java.name} where txId = :transactionId",
DBTransactionStorageLedgerRecovery.DBReceiverDistributionRecord::class.java
).setParameter("transactionId", id.toString()).resultList.map { it }
).setParameter("transactionId", id.toString()).resultList
}
assertEquals(1, fromDb.size)
return fromDb[0].toReceiverDistributionRecord(MockCryptoService(emptyMap()))
return fromDb[0].toReceiverDistributionRecord(encryptionService)
}
private fun newTransactionRecovery(cacheSizeBytesOverride: Long? = null, clock: CordaClock = SimpleClock(Clock.systemUTC()),
cryptoService: CryptoService = MockCryptoService(emptyMap())) {
private fun newTransactionRecovery(cacheSizeBytesOverride: Long? = null, clock: CordaClock = SimpleClock(Clock.systemUTC())) {
val networkMapCache = PersistentNetworkMapCache(TestingNamedCacheFactory(), database, InMemoryIdentityService(trustRoot = DEV_ROOT_CA.certificate))
val alice = createNodeInfo(listOf(ALICE))
val bob = createNodeInfo(listOf(BOB))
@ -330,8 +334,13 @@ class DBTransactionStorageLedgerRecoveryTests {
networkMapCache.addOrUpdateNodes(listOf(alice, bob, charlie))
partyInfoCache = PersistentPartyInfoCache(networkMapCache, TestingNamedCacheFactory(), database)
partyInfoCache.start()
transactionRecovery = DBTransactionStorageLedgerRecovery(database, TestingNamedCacheFactory(cacheSizeBytesOverride
?: 1024), clock, cryptoService, partyInfoCache)
transactionRecovery = DBTransactionStorageLedgerRecovery(
database,
TestingNamedCacheFactory(cacheSizeBytesOverride ?: 1024),
clock,
encryptionService,
partyInfoCache
)
}
private var portCounter = 1000
@ -370,10 +379,13 @@ class DBTransactionStorageLedgerRecoveryTests {
private fun notarySig(txId: SecureHash) =
DUMMY_NOTARY.keyPair.sign(SignableData(txId, SignatureMetadata(1, Crypto.findSignatureScheme(DUMMY_NOTARY.publicKey).schemeNumberID)))
private fun DistributionList.toWire(cryptoService: CryptoService = MockCryptoService(emptyMap())): ByteArray {
val hashedPeersToStatesToRecord = this.peersToStatesToRecord.map { (peer, statesToRecord) ->
partyInfoCache.getPartyIdByCordaX500Name(peer) to statesToRecord }.toMap()
val hashedDistributionList = HashedDistributionList(this.senderStatesToRecord, hashedPeersToStatesToRecord, now())
return cryptoService.encrypt(hashedDistributionList.serialize())
private fun DistributionList.toWire(): ByteArray {
val hashedPeersToStatesToRecord = this.peersToStatesToRecord.mapKeys { (peer) -> partyInfoCache.getPartyIdByCordaX500Name(peer) }
val hashedDistributionList = HashedDistributionList(
this.senderStatesToRecord,
hashedPeersToStatesToRecord,
HashedDistributionList.PublicHeader(now())
)
return hashedDistributionList.encrypt(encryptionService)
}
}

View File

@ -0,0 +1,39 @@
package net.corda.testing.node.internal
import net.corda.core.internal.copyBytes
import net.corda.node.services.EncryptionService
import net.corda.nodeapi.internal.crypto.AesEncryption
import java.nio.ByteBuffer
import javax.crypto.SecretKey
class MockEncryptionService(private val aesKey: SecretKey = AesEncryption.randomKey()) : EncryptionService {
override fun encrypt(plaintext: ByteArray, additionalData: ByteArray?): ByteArray {
val ciphertext = AesEncryption.encrypt(aesKey, plaintext, additionalData)
val buffer = ByteBuffer.allocate(Integer.BYTES + (additionalData?.size ?: 0) + ciphertext.size)
if (additionalData != null) {
buffer.putInt(additionalData.size)
buffer.put(additionalData)
} else {
buffer.putInt(0)
}
buffer.put(ciphertext)
return buffer.array()
}
override fun decrypt(ciphertext: ByteArray): EncryptionService.PlaintextAndAAD {
val buffer = ByteBuffer.wrap(ciphertext)
val additionalData = buffer.getAdditionaData()
val plaintext = AesEncryption.decrypt(aesKey, buffer.copyBytes(), additionalData)
// Only now is the additional data authenticated
return EncryptionService.PlaintextAndAAD(plaintext, additionalData)
}
override fun extractUnauthenticatedAdditionalData(ciphertext: ByteArray): ByteArray? {
return ByteBuffer.wrap(ciphertext).getAdditionaData()
}
private fun ByteBuffer.getAdditionaData(): ByteArray? {
val additionalDataSize = getInt()
return if (additionalDataSize > 0) ByteArray(additionalDataSize).also { get(it) } else null
}
}

View File

@ -61,9 +61,13 @@ open class MockTransactionStorage : WritableTransactionStorage, SingletonSeriali
return txns.putIfAbsent(transaction.id, TxHolder(transaction, status = TransactionStatus.IN_FLIGHT)) == null
}
override fun addSenderTransactionRecoveryMetadata(id: SecureHash, metadata: TransactionMetadata): ByteArray? { return null }
override fun addSenderTransactionRecoveryMetadata(txId: SecureHash, metadata: TransactionMetadata): ByteArray? { return null }
override fun addReceiverTransactionRecoveryMetadata(id: SecureHash, sender: CordaX500Name, receiver: CordaX500Name, receiverStatesToRecord: StatesToRecord, encryptedDistributionList: ByteArray) { }
override fun addReceiverTransactionRecoveryMetadata(txId: SecureHash,
sender: CordaX500Name,
receiver: CordaX500Name,
receiverStatesToRecord: StatesToRecord,
encryptedDistributionList: ByteArray) { }
override fun removeUnnotarisedTransaction(id: SecureHash): Boolean {
return txns.remove(id) != null