Refactor transaction serialization caching (#1078)

* Cache deserialized rather than serialized WireTransaction. Prevent repeated deserialization when adding signatures to the SignedTransaction.

* Added a test to check that stx copying and signature collection still works properly after (de)serialization
This commit is contained in:
Andrius Dagys 2017-07-24 14:48:39 +01:00 committed by GitHub
parent 3e199e51fc
commit d2eb5507f9
14 changed files with 71 additions and 52 deletions

View File

@ -60,7 +60,7 @@ object DefaultKryoCustomizer {
instantiatorStrategy = CustomInstantiatorStrategy() instantiatorStrategy = CustomInstantiatorStrategy()
register(Arrays.asList("").javaClass, ArraysAsListSerializer()) register(Arrays.asList("").javaClass, ArraysAsListSerializer())
register(SignedTransaction::class.java, ImmutableClassSerializer(SignedTransaction::class)) register(SignedTransaction::class.java, SignedTransactionSerializer)
register(WireTransaction::class.java, WireTransactionSerializer) register(WireTransaction::class.java, WireTransactionSerializer)
register(SerializedBytes::class.java, SerializedBytesSerializer) register(SerializedBytes::class.java, SerializedBytesSerializer)

View File

@ -6,12 +6,10 @@ import com.esotericsoftware.kryo.io.Output
import com.esotericsoftware.kryo.util.MapReferenceResolver import com.esotericsoftware.kryo.util.MapReferenceResolver
import com.google.common.annotations.VisibleForTesting import com.google.common.annotations.VisibleForTesting
import net.corda.core.contracts.* import net.corda.core.contracts.*
import net.corda.core.crypto.Crypto import net.corda.core.crypto.*
import net.corda.core.crypto.MetaData
import net.corda.core.crypto.SecureHash
import net.corda.core.crypto.SignatureType
import net.corda.core.crypto.composite.CompositeKey import net.corda.core.crypto.composite.CompositeKey
import net.corda.core.identity.Party import net.corda.core.identity.Party
import net.corda.core.transactions.SignedTransaction
import net.corda.core.transactions.WireTransaction import net.corda.core.transactions.WireTransaction
import net.i2p.crypto.eddsa.EdDSAPrivateKey import net.i2p.crypto.eddsa.EdDSAPrivateKey
import net.i2p.crypto.eddsa.EdDSAPublicKey import net.i2p.crypto.eddsa.EdDSAPublicKey
@ -277,11 +275,27 @@ object WireTransactionSerializer : Serializer<WireTransaction>() {
} }
} }
@ThreadSafe
object SignedTransactionSerializer : Serializer<SignedTransaction>() {
override fun write(kryo: Kryo, output: Output, obj: SignedTransaction) {
kryo.writeClassAndObject(output, obj.txBits)
kryo.writeClassAndObject(output, obj.sigs)
}
@Suppress("UNCHECKED_CAST")
override fun read(kryo: Kryo, input: Input, type: Class<SignedTransaction>): SignedTransaction {
return SignedTransaction(
kryo.readClassAndObject(input) as SerializedBytes<WireTransaction>,
kryo.readClassAndObject(input) as List<DigitalSignature.WithKey>
)
}
}
/** For serialising an ed25519 private key */ /** For serialising an ed25519 private key */
@ThreadSafe @ThreadSafe
object Ed25519PrivateKeySerializer : Serializer<EdDSAPrivateKey>() { object Ed25519PrivateKeySerializer : Serializer<EdDSAPrivateKey>() {
override fun write(kryo: Kryo, output: Output, obj: EdDSAPrivateKey) { override fun write(kryo: Kryo, output: Output, obj: EdDSAPrivateKey) {
check(obj.params == Crypto.EDDSA_ED25519_SHA512.algSpec ) check(obj.params == Crypto.EDDSA_ED25519_SHA512.algSpec)
output.writeBytesWithLength(obj.seed) output.writeBytesWithLength(obj.seed)
} }

View File

@ -135,7 +135,3 @@ class SerializedBytes<T : Any>(bytes: ByteArray) : OpaqueBytes(bytes) {
// It's OK to use lazy here because SerializedBytes is configured to use the ImmutableClassSerializer. // It's OK to use lazy here because SerializedBytes is configured to use the ImmutableClassSerializer.
val hash: SecureHash by lazy { bytes.sha256() } val hash: SecureHash by lazy { bytes.sha256() }
} }
// The more specific deserialize version results in the bytes being cached, which is faster.
@JvmName("SerializedBytesWireTransaction")
fun SerializedBytes<WireTransaction>.deserialize(serializationFactory: SerializationFactory = SERIALIZATION_FACTORY, context: SerializationContext = P2P_CONTEXT): WireTransaction = WireTransaction.deserialize(this, serializationFactory, context)

View File

@ -10,6 +10,8 @@ import net.corda.core.crypto.isFulfilledBy
import net.corda.core.node.ServiceHub import net.corda.core.node.ServiceHub
import net.corda.core.serialization.CordaSerializable import net.corda.core.serialization.CordaSerializable
import net.corda.core.serialization.SerializedBytes import net.corda.core.serialization.SerializedBytes
import net.corda.core.serialization.deserialize
import net.corda.core.serialization.serialize
import net.corda.core.utilities.NonEmptySet import net.corda.core.utilities.NonEmptySet
import net.corda.core.utilities.toNonEmptySet import net.corda.core.utilities.toNonEmptySet
import java.security.PublicKey import java.security.PublicKey
@ -34,20 +36,21 @@ data class SignedTransaction(val txBits: SerializedBytes<WireTransaction>,
val sigs: List<DigitalSignature.WithKey> val sigs: List<DigitalSignature.WithKey>
) : NamedByHash { ) : NamedByHash {
// DOCEND 1 // DOCEND 1
constructor(wtx: WireTransaction, sigs: List<DigitalSignature.WithKey>) : this(wtx.serialize(), sigs) {
cachedTransaction = wtx
}
init { init {
require(sigs.isNotEmpty()) require(sigs.isNotEmpty())
} }
// TODO: This needs to be reworked to ensure that the inner WireTransaction is only ever deserialised sandboxed. /** Cache the deserialized form of the transaction. This is useful when building a transaction or collecting signatures. */
@Volatile @Transient private var cachedTransaction: WireTransaction? = null
/** Lazily calculated access to the deserialised/hashed transaction data. */ /** Lazily calculated access to the deserialised/hashed transaction data. */
val tx: WireTransaction by lazy { WireTransaction.deserialize(txBits) } val tx: WireTransaction get() = cachedTransaction ?: txBits.deserialize().apply { cachedTransaction = this }
/** /** The id of the contained [WireTransaction]. */
* The Merkle root of the inner [WireTransaction]. Note that this is _not_ the same as the simple hash of
* [txBits], which would not use the Merkle tree structure. If the difference isn't clear, please consult
* the user guide section "Transaction tear-offs" to learn more about Merkle trees.
*/
override val id: SecureHash get() = tx.id override val id: SecureHash get() = tx.id
@CordaSerializable @CordaSerializable
@ -87,7 +90,6 @@ data class SignedTransaction(val txBits: SerializedBytes<WireTransaction>,
val needed = getMissingSignatures() - allowedToBeMissing val needed = getMissingSignatures() - allowedToBeMissing
if (needed.isNotEmpty()) if (needed.isNotEmpty())
throw SignaturesMissingException(needed.toNonEmptySet(), getMissingKeyDescriptions(needed), id) throw SignaturesMissingException(needed.toNonEmptySet(), getMissingKeyDescriptions(needed), id)
check(tx.id == id)
return tx return tx
} }
@ -131,10 +133,21 @@ data class SignedTransaction(val txBits: SerializedBytes<WireTransaction>,
} }
/** Returns the same transaction but with an additional (unchecked) signature. */ /** Returns the same transaction but with an additional (unchecked) signature. */
fun withAdditionalSignature(sig: DigitalSignature.WithKey) = copy(sigs = sigs + sig) fun withAdditionalSignature(sig: DigitalSignature.WithKey) = copyWithCache(listOf(sig))
/** Returns the same transaction but with an additional (unchecked) signatures. */ /** Returns the same transaction but with an additional (unchecked) signatures. */
fun withAdditionalSignatures(sigList: Iterable<DigitalSignature.WithKey>) = copy(sigs = sigs + sigList) fun withAdditionalSignatures(sigList: Iterable<DigitalSignature.WithKey>) = copyWithCache(sigList)
/**
* Creates a copy of the SignedTransaction that includes the provided [sigList]. Also propagates the [cachedTransaction]
* so the contained transaction does not need to be deserialized again.
*/
private fun copyWithCache(sigList: Iterable<DigitalSignature.WithKey>): SignedTransaction {
val cached = cachedTransaction
return copy(sigs = sigs + sigList).apply {
cachedTransaction = cached
}
}
/** Alias for [withAdditionalSignature] to let you use Kotlin operator overloading. */ /** Alias for [withAdditionalSignature] to let you use Kotlin operator overloading. */
operator fun plus(sig: DigitalSignature.WithKey) = withAdditionalSignature(sig) operator fun plus(sig: DigitalSignature.WithKey) = withAdditionalSignature(sig)

View File

@ -178,7 +178,7 @@ open class TransactionBuilder(
throw IllegalStateException("Missing signatures on the transaction for the public keys: ${missing.joinToString()}") throw IllegalStateException("Missing signatures on the transaction for the public keys: ${missing.joinToString()}")
} }
val wtx = toWireTransaction() val wtx = toWireTransaction()
return SignedTransaction(wtx.serialize(), ArrayList(currentSigs)) return SignedTransaction(wtx, ArrayList(currentSigs))
} }
/** /**

View File

@ -8,18 +8,14 @@ import net.corda.core.crypto.keys
import net.corda.core.identity.Party import net.corda.core.identity.Party
import net.corda.core.internal.Emoji import net.corda.core.internal.Emoji
import net.corda.core.node.ServicesForResolution import net.corda.core.node.ServicesForResolution
import net.corda.core.serialization.*
import net.corda.core.serialization.SerializationDefaults.P2P_CONTEXT
import net.corda.core.serialization.SerializationDefaults.SERIALIZATION_FACTORY
import java.security.PublicKey import java.security.PublicKey
import java.security.SignatureException import java.security.SignatureException
import java.util.function.Predicate import java.util.function.Predicate
/** /**
* A transaction ready for serialisation, without any signatures attached. A WireTransaction is usually wrapped * A transaction ready for serialisation, without any signatures attached. A WireTransaction is usually wrapped
* by a [SignedTransaction] that carries the signatures over this payload. The hash of the wire transaction is * by a [SignedTransaction] that carries the signatures over this payload.
* the identity of the transaction, that is, it's possible for two [SignedTransaction]s with different sets of * The identity of the transaction is the Merkle tree root of its components (see [MerkleTree]).
* signatures to have the same identity hash.
*/ */
class WireTransaction( class WireTransaction(
/** Pointers to the input states on the ledger, identified by (tx identity hash, output index). */ /** Pointers to the input states on the ledger, identified by (tx identity hash, output index). */
@ -38,20 +34,9 @@ class WireTransaction(
checkInvariants() checkInvariants()
} }
// Cache the serialised form of the transaction and its hash to give us fast access to it. /** The transaction id is represented by the root hash of Merkle tree over the transaction components. */
@Volatile @Transient private var cachedBytes: SerializedBytes<WireTransaction>? = null
val serialized: SerializedBytes<WireTransaction> get() = cachedBytes ?: serialize().apply { cachedBytes = this }
override val id: SecureHash by lazy { merkleTree.hash } override val id: SecureHash by lazy { merkleTree.hash }
companion object {
fun deserialize(data: SerializedBytes<WireTransaction>, serializationFactory: SerializationFactory = SERIALIZATION_FACTORY, context: SerializationContext = P2P_CONTEXT): WireTransaction {
val wtx = data.deserialize<WireTransaction>(serializationFactory, context)
wtx.cachedBytes = data
return wtx
}
}
/** /**
* Looks up identities and attachments from storage to generate a [LedgerTransaction]. A transaction is expected to * Looks up identities and attachments from storage to generate a [LedgerTransaction]. A transaction is expected to
* have been fully resolved using the resolution flow by this point. * have been fully resolved using the resolution flow by this point.

View File

@ -19,8 +19,7 @@ import kotlin.test.assertFailsWith
class TransactionTests : TestDependencyInjectionBase() { class TransactionTests : TestDependencyInjectionBase() {
private fun makeSigned(wtx: WireTransaction, vararg keys: KeyPair): SignedTransaction { private fun makeSigned(wtx: WireTransaction, vararg keys: KeyPair): SignedTransaction {
val bytes: SerializedBytes<WireTransaction> = wtx.serialized return SignedTransaction(wtx, keys.map { it.sign(wtx.id.bytes) })
return SignedTransaction(bytes, keys.map { it.sign(wtx.id.bytes) })
} }
@Test @Test

View File

@ -6,6 +6,7 @@ import net.corda.contracts.asset.Cash
import net.corda.core.contracts.* import net.corda.core.contracts.*
import net.corda.core.crypto.SecureHash.Companion.zeroHash import net.corda.core.crypto.SecureHash.Companion.zeroHash
import net.corda.core.identity.Party import net.corda.core.identity.Party
import net.corda.core.serialization.deserialize
import net.corda.core.serialization.serialize import net.corda.core.serialization.serialize
import net.corda.core.transactions.WireTransaction import net.corda.core.transactions.WireTransaction
import net.corda.testing.* import net.corda.testing.*
@ -110,7 +111,7 @@ class PartialMerkleTreeTest : TestDependencyInjectionBase() {
val mt = testTx.buildFilteredTransaction(Predicate(::filtering)) val mt = testTx.buildFilteredTransaction(Predicate(::filtering))
val leaves = mt.filteredLeaves val leaves = mt.filteredLeaves
val d = WireTransaction.deserialize(testTx.serialized) val d = testTx.serialize().deserialize()
assertEquals(testTx.id, d.id) assertEquals(testTx.id, d.id)
assertEquals(1, leaves.commands.size) assertEquals(1, leaves.commands.size)
assertEquals(1, leaves.outputs.size) assertEquals(1, leaves.outputs.size)

View File

@ -104,4 +104,18 @@ class TransactionSerializationTests : TestDependencyInjectionBase() {
val stx = notaryServices.addSignature(ptx) val stx = notaryServices.addSignature(ptx)
assertEquals(TEST_TX_TIME, stx.tx.timeWindow?.midpoint) assertEquals(TEST_TX_TIME, stx.tx.timeWindow?.midpoint)
} }
@Test
fun storeAndLoadWhenSigning() {
val ptx = megaCorpServices.signInitialTransaction(tx)
ptx.verifySignaturesExcept(notaryServices.key.public)
val stored = ptx.serialize()
val loaded = stored.deserialize()
assertEquals(loaded, ptx)
val final = notaryServices.addSignature(loaded)
final.verifyRequiredSignatures()
}
} }

View File

@ -84,7 +84,7 @@ class SignedTransactionGenerator : Generator<SignedTransaction>(SignedTransactio
override fun generate(random: SourceOfRandomness, status: GenerationStatus): SignedTransaction { override fun generate(random: SourceOfRandomness, status: GenerationStatus): SignedTransaction {
val wireTransaction = WiredTransactionGenerator().generate(random, status) val wireTransaction = WiredTransactionGenerator().generate(random, status)
return SignedTransaction( return SignedTransaction(
txBits = wireTransaction.serialized, wtx = wireTransaction,
sigs = listOf(NullSignature) sigs = listOf(NullSignature)
) )
} }

View File

@ -605,7 +605,6 @@ class TwoPartyTradeFlowTests {
vararg extraSigningNodes: AbstractNode): Map<SecureHash, SignedTransaction> { vararg extraSigningNodes: AbstractNode): Map<SecureHash, SignedTransaction> {
val signed = wtxToSign.map { val signed = wtxToSign.map {
val bits = it.serialize()
val id = it.id val id = it.id
val sigs = mutableListOf<DigitalSignature.WithKey>() val sigs = mutableListOf<DigitalSignature.WithKey>()
sigs.add(node.services.keyManagementService.sign(id.bytes, node.services.legalIdentityKey)) sigs.add(node.services.keyManagementService.sign(id.bytes, node.services.legalIdentityKey))
@ -613,7 +612,7 @@ class TwoPartyTradeFlowTests {
extraSigningNodes.forEach { currentNode -> extraSigningNodes.forEach { currentNode ->
sigs.add(currentNode.services.keyManagementService.sign(id.bytes, currentNode.info.legalIdentity.owningKey)) sigs.add(currentNode.services.keyManagementService.sign(id.bytes, currentNode.info.legalIdentity.owningKey))
} }
SignedTransaction(bits, sigs) SignedTransaction(it, sigs)
} }
return node.database.transaction { return node.database.transaction {
node.services.recordTransactions(signed) node.services.recordTransactions(signed)

View File

@ -214,6 +214,6 @@ class RequeryConfigurationTest : TestDependencyInjectionBase() {
type = TransactionType.General, type = TransactionType.General,
timeWindow = null timeWindow = null
) )
return SignedTransaction(wtx.serialized, listOf(DigitalSignature.WithKey(NullPublicKey, ByteArray(1)))) return SignedTransaction(wtx, listOf(DigitalSignature.WithKey(NullPublicKey, ByteArray(1))))
} }
} }

View File

@ -152,6 +152,6 @@ class DBTransactionStorageTests : TestDependencyInjectionBase() {
type = TransactionType.General, type = TransactionType.General,
timeWindow = null timeWindow = null
) )
return SignedTransaction(wtx.serialized, listOf(DigitalSignature.WithKey(NullPublicKey, ByteArray(1)))) return SignedTransaction(wtx, listOf(DigitalSignature.WithKey(NullPublicKey, ByteArray(1))))
} }
} }

View File

@ -284,7 +284,7 @@ data class TestLedgerDSLInterpreter private constructor(
override fun verifies(): EnforceVerifyOrFail { override fun verifies(): EnforceVerifyOrFail {
try { try {
val usedInputs = mutableSetOf<StateRef>() val usedInputs = mutableSetOf<StateRef>()
services.recordTransactions(transactionsUnverified.map { SignedTransaction(it.serialized, listOf(NullSignature)) }) services.recordTransactions(transactionsUnverified.map { SignedTransaction(it, listOf(NullSignature)) })
for ((_, value) in transactionWithLocations) { for ((_, value) in transactionWithLocations) {
val wtx = value.transaction val wtx = value.transaction
val ltx = wtx.toLedgerTransaction(services) val ltx = wtx.toLedgerTransaction(services)
@ -296,7 +296,7 @@ data class TestLedgerDSLInterpreter private constructor(
throw DoubleSpentInputs(txIds) throw DoubleSpentInputs(txIds)
} }
usedInputs.addAll(wtx.inputs) usedInputs.addAll(wtx.inputs)
services.recordTransactions(SignedTransaction(wtx.serialized, listOf(NullSignature))) services.recordTransactions(SignedTransaction(wtx, listOf(NullSignature)))
} }
return EnforceVerifyOrFail.Token return EnforceVerifyOrFail.Token
} catch (exception: TransactionVerificationException) { } catch (exception: TransactionVerificationException) {
@ -330,8 +330,6 @@ data class TestLedgerDSLInterpreter private constructor(
*/ */
fun signAll(transactionsToSign: List<WireTransaction>, extraKeys: List<KeyPair>) = transactionsToSign.map { wtx -> fun signAll(transactionsToSign: List<WireTransaction>, extraKeys: List<KeyPair>) = transactionsToSign.map { wtx ->
check(wtx.mustSign.isNotEmpty()) check(wtx.mustSign.isNotEmpty())
val bits = wtx.serialize()
require(bits == wtx.serialized)
val signatures = ArrayList<DigitalSignature.WithKey>() val signatures = ArrayList<DigitalSignature.WithKey>()
val keyLookup = HashMap<PublicKey, KeyPair>() val keyLookup = HashMap<PublicKey, KeyPair>()
@ -342,7 +340,7 @@ fun signAll(transactionsToSign: List<WireTransaction>, extraKeys: List<KeyPair>)
val key = keyLookup[it] ?: throw IllegalArgumentException("Missing required key for ${it.toStringShort()}") val key = keyLookup[it] ?: throw IllegalArgumentException("Missing required key for ${it.toStringShort()}")
signatures += key.sign(wtx.id) signatures += key.sign(wtx.id)
} }
SignedTransaction(bits, signatures) SignedTransaction(wtx, signatures)
} }
/** /**