Add signing of transaction merkle root hash.

This commit is contained in:
Katarzyna Streich
2016-11-04 17:56:42 +00:00
parent 2db2854a0b
commit 103817ec57
13 changed files with 34 additions and 32 deletions

View File

@ -45,7 +45,7 @@ class TransactionGraphSearch(val transactions: ReadOnlyTransactionStorage,
val unvisitedInputTxs: Map<SecureHash, SignedTransaction> = inputsLeadingToUnvisitedTx.map { it.txhash }.toHashSet().map { transactions.getTransaction(it) }.filterNotNull().associateBy { it.id } val unvisitedInputTxs: Map<SecureHash, SignedTransaction> = inputsLeadingToUnvisitedTx.map { it.txhash }.toHashSet().map { transactions.getTransaction(it) }.filterNotNull().associateBy { it.id }
val unvisitedInputTxsWithInputIndex: Iterable<Pair<SignedTransaction, Int>> = inputsLeadingToUnvisitedTx.filter { it.txhash in unvisitedInputTxs.keys }.map { Pair(unvisitedInputTxs[it.txhash]!!, it.index) } val unvisitedInputTxsWithInputIndex: Iterable<Pair<SignedTransaction, Int>> = inputsLeadingToUnvisitedTx.filter { it.txhash in unvisitedInputTxs.keys }.map { Pair(unvisitedInputTxs[it.txhash]!!, it.index) }
next += (unvisitedInputTxsWithInputIndex.filter { q.followInputsOfType == null || it.first.tx.outputs[it.second].data.javaClass == q.followInputsOfType } next += (unvisitedInputTxsWithInputIndex.filter { q.followInputsOfType == null || it.first.tx.outputs[it.second].data.javaClass == q.followInputsOfType }
.map { it.first }.filter { alreadyVisited.add(it.txBits.hash) }.map { it.tx }) .map { it.first }.filter { alreadyVisited.add(it.id) }.map { it.tx })
} }
return results return results

View File

@ -19,9 +19,12 @@ import java.util.*
* of a WireTransaction, therefore if you are storing data keyed by WT hash be aware that multiple different STs may * of a WireTransaction, therefore if you are storing data keyed by WT hash be aware that multiple different STs may
* map to the same key (and they could be different in important ways, like validity!). The signatures on a * map to the same key (and they could be different in important ways, like validity!). The signatures on a
* SignedTransaction might be invalid or missing: the type does not imply validity. * SignedTransaction might be invalid or missing: the type does not imply validity.
* A transaction ID should be the hash of the [WireTransaction] Merkle tree root. Thus adding or removing a signature does not change it.
*/ */
data class SignedTransaction(val txBits: SerializedBytes<WireTransaction>, data class SignedTransaction(val txBits: SerializedBytes<WireTransaction>,
val sigs: List<DigitalSignature.WithKey>) : NamedByHash { val sigs: List<DigitalSignature.WithKey>,
override val id: SecureHash
) : NamedByHash {
init { init {
require(sigs.isNotEmpty()) require(sigs.isNotEmpty())
} }
@ -31,9 +34,6 @@ data class SignedTransaction(val txBits: SerializedBytes<WireTransaction>,
/** Lazily calculated access to the deserialised/hashed transaction data. */ /** Lazily calculated access to the deserialised/hashed transaction data. */
val tx: WireTransaction by lazy { WireTransaction.deserialize(txBits) } val tx: WireTransaction by lazy { WireTransaction.deserialize(txBits) }
/** A transaction ID is the hash of the [WireTransaction]. Thus adding or removing a signature does not change it. */
override val id: SecureHash get() = tx.id
class SignaturesMissingException(val missing: Set<PublicKey>, val descriptions: List<String>, override val id: SecureHash) : NamedByHash, SignatureException() { class SignaturesMissingException(val missing: Set<PublicKey>, val descriptions: List<String>, override val id: SecureHash) : NamedByHash, SignatureException() {
override fun toString(): String { override fun toString(): String {
return "Missing signatures for $descriptions on transaction ${id.prefixChars()} for ${missing.toStringsShort()}" return "Missing signatures for $descriptions on transaction ${id.prefixChars()} for ${missing.toStringsShort()}"
@ -64,6 +64,7 @@ data class SignedTransaction(val txBits: SerializedBytes<WireTransaction>,
if (needed.isNotEmpty()) if (needed.isNotEmpty())
throw SignaturesMissingException(needed, getMissingKeyDescriptions(needed), id) throw SignaturesMissingException(needed, getMissingKeyDescriptions(needed), id)
} }
check(tx.id == id)
return tx return tx
} }
@ -77,8 +78,9 @@ data class SignedTransaction(val txBits: SerializedBytes<WireTransaction>,
*/ */
@Throws(SignatureException::class) @Throws(SignatureException::class)
fun checkSignaturesAreValid() { fun checkSignaturesAreValid() {
for (sig in sigs) for (sig in sigs) {
sig.verifyWithECDSA(txBits.bits) sig.verifyWithECDSA(id.bits)
}
} }
private fun getMissingSignatures(): Set<PublicKey> { private fun getMissingSignatures(): Set<PublicKey> {

View File

@ -99,7 +99,7 @@ open class TransactionBuilder(
fun signWith(key: KeyPair): TransactionBuilder { fun signWith(key: KeyPair): TransactionBuilder {
check(currentSigs.none { it.by == key.public }) { "This partial transaction was already signed by ${key.public}" } check(currentSigs.none { it.by == key.public }) { "This partial transaction was already signed by ${key.public}" }
val data = toWireTransaction().serialize() val data = toWireTransaction().id
addSignatureUnchecked(key.signWithECDSA(data.bits)) addSignatureUnchecked(key.signWithECDSA(data.bits))
return this return this
} }
@ -124,7 +124,7 @@ open class TransactionBuilder(
*/ */
fun checkSignature(sig: DigitalSignature.WithKey) { fun checkSignature(sig: DigitalSignature.WithKey) {
require(commands.any { it.signers.contains(sig.by) }) { "Signature key doesn't match any command" } require(commands.any { it.signers.contains(sig.by) }) { "Signature key doesn't match any command" }
sig.verifyWithECDSA(toWireTransaction().serialized) sig.verifyWithECDSA(toWireTransaction().id)
} }
/** Adds the signature directly to the transaction, without checking it for validity. */ /** Adds the signature directly to the transaction, without checking it for validity. */
@ -143,7 +143,8 @@ open class TransactionBuilder(
if (missing.isNotEmpty()) if (missing.isNotEmpty())
throw IllegalStateException("Missing signatures on the transaction for the public keys: ${missing.toStringsShort()}") throw IllegalStateException("Missing signatures on the transaction for the public keys: ${missing.toStringsShort()}")
} }
return SignedTransaction(toWireTransaction().serialize(), ArrayList(currentSigs)) val wtx = toWireTransaction()
return SignedTransaction(wtx.serialize(), ArrayList(currentSigs), wtx.id)
} }
open fun addInputState(stateAndRef: StateAndRef<*>) { open fun addInputState(stateAndRef: StateAndRef<*>) {

View File

@ -96,7 +96,7 @@ abstract class AbstractStateReplacementProtocol<T> {
if (it.sig == null) throw StateReplacementException(it.error!!) if (it.sig == null) throw StateReplacementException(it.error!!)
else { else {
check(it.sig.by == party.owningKey) { "Not signed by the required participant" } check(it.sig.by == party.owningKey) { "Not signed by the required participant" }
it.sig.verifyWithECDSA(stx.txBits) it.sig.verifyWithECDSA(stx.id)
it.sig it.sig
} }
} }
@ -154,7 +154,7 @@ abstract class AbstractStateReplacementProtocol<T> {
// TODO: This step should not be necessary, as signatures are re-checked in verifySignatures. // TODO: This step should not be necessary, as signatures are re-checked in verifySignatures.
val allSignatures = swapSignatures.unwrap { signatures -> val allSignatures = swapSignatures.unwrap { signatures ->
signatures.forEach { it.verifyWithECDSA(stx.txBits) } signatures.forEach { it.verifyWithECDSA(stx.id) }
signatures signatures
} }
@ -199,7 +199,7 @@ abstract class AbstractStateReplacementProtocol<T> {
private fun sign(stx: SignedTransaction): DigitalSignature.WithKey { private fun sign(stx: SignedTransaction): DigitalSignature.WithKey {
val myKey = serviceHub.legalIdentityKey val myKey = serviceHub.legalIdentityKey
return myKey.signWithECDSA(stx.txBits) return myKey.signWithECDSA(stx.id)
} }
} }

View File

@ -66,7 +66,7 @@ object NotaryProtocol {
progressTracker.currentStep = VALIDATING progressTracker.currentStep = VALIDATING
when (notaryResult) { when (notaryResult) {
is Result.Success -> { is Result.Success -> {
validateSignature(notaryResult.sig, stx.txBits) validateSignature(notaryResult.sig, stx.id.bits)
notaryResult.sig notaryResult.sig
} }
is Result.Error -> { is Result.Error -> {
@ -79,7 +79,7 @@ object NotaryProtocol {
} }
} }
private fun validateSignature(sig: DigitalSignature.LegallyIdentifiable, data: SerializedBytes<WireTransaction>) { private fun validateSignature(sig: DigitalSignature.LegallyIdentifiable, data: ByteArray) {
check(sig.signer == notaryParty) { "Notary result not signed by the correct service" } check(sig.signer == notaryParty) { "Notary result not signed by the correct service" }
sig.verifyWithECDSA(data) sig.verifyWithECDSA(data)
} }
@ -108,7 +108,7 @@ object NotaryProtocol {
beforeCommit(stx, reqIdentity) beforeCommit(stx, reqIdentity)
commitInputStates(wtx, reqIdentity) commitInputStates(wtx, reqIdentity)
val sig = sign(stx.txBits) val sig = sign(stx.id.bits)
Result.Success(sig) Result.Success(sig)
} catch(e: NotaryException) { } catch(e: NotaryException) {
Result.Error(e.error) Result.Error(e.error)
@ -140,12 +140,12 @@ object NotaryProtocol {
uniquenessProvider.commit(tx.inputs, tx.id, reqIdentity) uniquenessProvider.commit(tx.inputs, tx.id, reqIdentity)
} catch (e: UniquenessException) { } catch (e: UniquenessException) {
val conflictData = e.error.serialize() val conflictData = e.error.serialize()
val signedConflict = SignedData(conflictData, sign(conflictData)) val signedConflict = SignedData(conflictData, sign(conflictData.bits))
throw NotaryException(NotaryError.Conflict(tx, signedConflict)) throw NotaryException(NotaryError.Conflict(tx, signedConflict))
} }
} }
private fun <T : Any> sign(bits: SerializedBytes<T>): DigitalSignature.LegallyIdentifiable { private fun sign(bits: ByteArray): DigitalSignature.LegallyIdentifiable {
val myNodeInfo = serviceHub.myInfo val myNodeInfo = serviceHub.myInfo
val myIdentity = myNodeInfo.notaryIdentity val myIdentity = myNodeInfo.notaryIdentity
val mySigningKey = serviceHub.notaryIdentityKey val mySigningKey = serviceHub.notaryIdentityKey

View File

@ -170,7 +170,7 @@ object TwoPartyDealProtocol {
open fun computeOurSignature(partialTX: SignedTransaction): DigitalSignature.WithKey { open fun computeOurSignature(partialTX: SignedTransaction): DigitalSignature.WithKey {
progressTracker.currentStep = SIGNING progressTracker.currentStep = SIGNING
return myKeyPair.signWithECDSA(partialTX.txBits) return myKeyPair.signWithECDSA(partialTX.id)
} }
@Suspendable @Suspendable

View File

@ -34,7 +34,7 @@ class TransactionTests {
timestamp = null timestamp = null
) )
val bits: SerializedBytes<WireTransaction> = wtx.serialized val bits: SerializedBytes<WireTransaction> = wtx.serialized
fun make(vararg keys: KeyPair) = SignedTransaction(bits, keys.map { it.signWithECDSA(bits) }) fun make(vararg keys: KeyPair) = SignedTransaction(bits, keys.map { it.signWithECDSA(wtx.id.bits) }, wtx.id)
assertFailsWith<IllegalArgumentException> { make().verifySignatures() } assertFailsWith<IllegalArgumentException> { make().verifySignatures() }
assertEquals( assertEquals(

View File

@ -67,7 +67,7 @@ class TransactionSerializationTests {
signedTX.verifySignatures() signedTX.verifySignatures()
// Corrupt the data and ensure the signature catches the problem. // Corrupt the data and ensure the signature catches the problem.
signedTX.txBits.bits[5] = 0 signedTX.id.bits[5] = 0
assertFailsWith(SignatureException::class) { assertFailsWith(SignatureException::class) {
signedTX.verifySignatures() signedTX.verifySignatures()
} }

View File

@ -81,7 +81,8 @@ class SignedTransactionGenerator: Generator<SignedTransaction>(SignedTransaction
val wireTransaction = WiredTransactionGenerator().generate(random, status) val wireTransaction = WiredTransactionGenerator().generate(random, status)
return SignedTransaction( return SignedTransaction(
txBits = wireTransaction.serialized, txBits = wireTransaction.serialized,
sigs = listOf(NullSignature) sigs = listOf(NullSignature),
id = wireTransaction.id
) )
} }
} }

View File

@ -141,7 +141,7 @@ object TwoPartyTradeProtocol {
open fun calculateOurSignature(partialTX: SignedTransaction): DigitalSignature.WithKey { open fun calculateOurSignature(partialTX: SignedTransaction): DigitalSignature.WithKey {
progressTracker.currentStep = SIGNING progressTracker.currentStep = SIGNING
return myKeyPair.signWithECDSA(partialTX.txBits) return myKeyPair.signWithECDSA(partialTX.id)
} }
@Suspendable @Suspendable

View File

@ -54,7 +54,7 @@ class NotaryServiceTests {
val future = runNotaryClient(stx) val future = runNotaryClient(stx)
val signature = future.get() val signature = future.get()
signature.verifyWithECDSA(stx.txBits) signature.verifyWithECDSA(stx.id)
} }
@Test fun `should sign a unique transaction without a timestamp`() { @Test fun `should sign a unique transaction without a timestamp`() {
@ -67,7 +67,7 @@ class NotaryServiceTests {
val future = runNotaryClient(stx) val future = runNotaryClient(stx)
val signature = future.get() val signature = future.get()
signature.verifyWithECDSA(stx.txBits) signature.verifyWithECDSA(stx.id)
} }
@Test fun `should report error for transaction with an invalid timestamp`() { @Test fun `should report error for transaction with an invalid timestamp`() {

View File

@ -1,13 +1,11 @@
package com.r3corda.node.services.persistence package com.r3corda.node.services.persistence
import com.google.common.primitives.Ints
import com.google.common.util.concurrent.SettableFuture import com.google.common.util.concurrent.SettableFuture
import com.r3corda.core.contracts.StateRef import com.r3corda.core.contracts.StateRef
import com.r3corda.core.contracts.TransactionType import com.r3corda.core.contracts.TransactionType
import com.r3corda.core.crypto.DigitalSignature import com.r3corda.core.crypto.DigitalSignature
import com.r3corda.core.crypto.NullPublicKey import com.r3corda.core.crypto.NullPublicKey
import com.r3corda.core.crypto.SecureHash import com.r3corda.core.crypto.SecureHash
import com.r3corda.core.serialization.SerializedBytes
import com.r3corda.core.transactions.SignedTransaction import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.core.transactions.WireTransaction import com.r3corda.core.transactions.WireTransaction
import com.r3corda.core.utilities.DUMMY_NOTARY import com.r3corda.core.utilities.DUMMY_NOTARY
@ -144,6 +142,6 @@ class DBTransactionStorageTests {
type = TransactionType.General(), type = TransactionType.General(),
timestamp = null timestamp = null
) )
return SignedTransaction(wtx.serialized, listOf(DigitalSignature.WithKey(NullPublicKey, ByteArray(1)))) return SignedTransaction(wtx.serialized, listOf(DigitalSignature.WithKey(NullPublicKey, ByteArray(1))), wtx.id)
} }
} }

View File

@ -279,10 +279,10 @@ data class TestLedgerDSLInterpreter private constructor (
override fun verifies(): EnforceVerifyOrFail { override fun verifies(): EnforceVerifyOrFail {
try { try {
services.recordTransactions(transactionsUnverified.map { SignedTransaction(it.serialized, listOf(NullSignature)) }) services.recordTransactions(transactionsUnverified.map { SignedTransaction(it.serialized, listOf(NullSignature), it.id) })
for ((key, value) in transactionWithLocations) { for ((key, value) in transactionWithLocations) {
value.transaction.toLedgerTransaction(services).verify() value.transaction.toLedgerTransaction(services).verify()
services.recordTransactions(SignedTransaction(value.transaction.serialized, listOf(NullSignature))) services.recordTransactions(SignedTransaction(value.transaction.serialized, listOf(NullSignature), value.transaction.id))
} }
return EnforceVerifyOrFail.Token return EnforceVerifyOrFail.Token
} catch (exception: TransactionVerificationException) { } catch (exception: TransactionVerificationException) {
@ -326,9 +326,9 @@ fun signAll(transactionsToSign: List<WireTransaction>, extraKeys: List<KeyPair>)
} }
wtx.mustSign.forEach { wtx.mustSign.forEach {
val key = keyLookup[it] ?: throw IllegalArgumentException("Missing required key for ${it.toStringShort()}") val key = keyLookup[it] ?: throw IllegalArgumentException("Missing required key for ${it.toStringShort()}")
signatures += key.signWithECDSA(bits) signatures += key.signWithECDSA(wtx.id)
} }
SignedTransaction(bits, signatures) SignedTransaction(bits, signatures, wtx.id)
} }
/** /**