Refactor transaction serialization caching (#1078)

* Cache deserialized rather than serialized WireTransaction. Prevent repeated deserialization when adding signatures to the SignedTransaction.

* Added a test to check that stx copying and signature collection still works properly after (de)serialization
This commit is contained in:
Andrius Dagys 2017-07-24 14:48:39 +01:00 committed by GitHub
parent 3e199e51fc
commit d2eb5507f9
14 changed files with 71 additions and 52 deletions

View File

@ -60,7 +60,7 @@ object DefaultKryoCustomizer {
instantiatorStrategy = CustomInstantiatorStrategy()
register(Arrays.asList("").javaClass, ArraysAsListSerializer())
register(SignedTransaction::class.java, ImmutableClassSerializer(SignedTransaction::class))
register(SignedTransaction::class.java, SignedTransactionSerializer)
register(WireTransaction::class.java, WireTransactionSerializer)
register(SerializedBytes::class.java, SerializedBytesSerializer)

View File

@ -6,12 +6,10 @@ import com.esotericsoftware.kryo.io.Output
import com.esotericsoftware.kryo.util.MapReferenceResolver
import com.google.common.annotations.VisibleForTesting
import net.corda.core.contracts.*
import net.corda.core.crypto.Crypto
import net.corda.core.crypto.MetaData
import net.corda.core.crypto.SecureHash
import net.corda.core.crypto.SignatureType
import net.corda.core.crypto.*
import net.corda.core.crypto.composite.CompositeKey
import net.corda.core.identity.Party
import net.corda.core.transactions.SignedTransaction
import net.corda.core.transactions.WireTransaction
import net.i2p.crypto.eddsa.EdDSAPrivateKey
import net.i2p.crypto.eddsa.EdDSAPublicKey
@ -277,11 +275,27 @@ object WireTransactionSerializer : Serializer<WireTransaction>() {
}
}
@ThreadSafe
object SignedTransactionSerializer : Serializer<SignedTransaction>() {
override fun write(kryo: Kryo, output: Output, obj: SignedTransaction) {
kryo.writeClassAndObject(output, obj.txBits)
kryo.writeClassAndObject(output, obj.sigs)
}
@Suppress("UNCHECKED_CAST")
override fun read(kryo: Kryo, input: Input, type: Class<SignedTransaction>): SignedTransaction {
return SignedTransaction(
kryo.readClassAndObject(input) as SerializedBytes<WireTransaction>,
kryo.readClassAndObject(input) as List<DigitalSignature.WithKey>
)
}
}
/** For serialising an ed25519 private key */
@ThreadSafe
object Ed25519PrivateKeySerializer : Serializer<EdDSAPrivateKey>() {
override fun write(kryo: Kryo, output: Output, obj: EdDSAPrivateKey) {
check(obj.params == Crypto.EDDSA_ED25519_SHA512.algSpec )
check(obj.params == Crypto.EDDSA_ED25519_SHA512.algSpec)
output.writeBytesWithLength(obj.seed)
}

View File

@ -135,7 +135,3 @@ class SerializedBytes<T : Any>(bytes: ByteArray) : OpaqueBytes(bytes) {
// It's OK to use lazy here because SerializedBytes is configured to use the ImmutableClassSerializer.
val hash: SecureHash by lazy { bytes.sha256() }
}
// The more specific deserialize version results in the bytes being cached, which is faster.
@JvmName("SerializedBytesWireTransaction")
fun SerializedBytes<WireTransaction>.deserialize(serializationFactory: SerializationFactory = SERIALIZATION_FACTORY, context: SerializationContext = P2P_CONTEXT): WireTransaction = WireTransaction.deserialize(this, serializationFactory, context)

View File

@ -10,6 +10,8 @@ import net.corda.core.crypto.isFulfilledBy
import net.corda.core.node.ServiceHub
import net.corda.core.serialization.CordaSerializable
import net.corda.core.serialization.SerializedBytes
import net.corda.core.serialization.deserialize
import net.corda.core.serialization.serialize
import net.corda.core.utilities.NonEmptySet
import net.corda.core.utilities.toNonEmptySet
import java.security.PublicKey
@ -34,20 +36,21 @@ data class SignedTransaction(val txBits: SerializedBytes<WireTransaction>,
val sigs: List<DigitalSignature.WithKey>
) : NamedByHash {
// DOCEND 1
constructor(wtx: WireTransaction, sigs: List<DigitalSignature.WithKey>) : this(wtx.serialize(), sigs) {
cachedTransaction = wtx
}
init {
require(sigs.isNotEmpty())
}
// TODO: This needs to be reworked to ensure that the inner WireTransaction is only ever deserialised sandboxed.
/** Cache the deserialized form of the transaction. This is useful when building a transaction or collecting signatures. */
@Volatile @Transient private var cachedTransaction: WireTransaction? = null
/** Lazily calculated access to the deserialised/hashed transaction data. */
val tx: WireTransaction by lazy { WireTransaction.deserialize(txBits) }
val tx: WireTransaction get() = cachedTransaction ?: txBits.deserialize().apply { cachedTransaction = this }
/**
* The Merkle root of the inner [WireTransaction]. Note that this is _not_ the same as the simple hash of
* [txBits], which would not use the Merkle tree structure. If the difference isn't clear, please consult
* the user guide section "Transaction tear-offs" to learn more about Merkle trees.
*/
/** The id of the contained [WireTransaction]. */
override val id: SecureHash get() = tx.id
@CordaSerializable
@ -87,7 +90,6 @@ data class SignedTransaction(val txBits: SerializedBytes<WireTransaction>,
val needed = getMissingSignatures() - allowedToBeMissing
if (needed.isNotEmpty())
throw SignaturesMissingException(needed.toNonEmptySet(), getMissingKeyDescriptions(needed), id)
check(tx.id == id)
return tx
}
@ -131,10 +133,21 @@ data class SignedTransaction(val txBits: SerializedBytes<WireTransaction>,
}
/** Returns the same transaction but with an additional (unchecked) signature. */
fun withAdditionalSignature(sig: DigitalSignature.WithKey) = copy(sigs = sigs + sig)
fun withAdditionalSignature(sig: DigitalSignature.WithKey) = copyWithCache(listOf(sig))
/** Returns the same transaction but with an additional (unchecked) signatures. */
fun withAdditionalSignatures(sigList: Iterable<DigitalSignature.WithKey>) = copy(sigs = sigs + sigList)
fun withAdditionalSignatures(sigList: Iterable<DigitalSignature.WithKey>) = copyWithCache(sigList)
/**
* Creates a copy of the SignedTransaction that includes the provided [sigList]. Also propagates the [cachedTransaction]
* so the contained transaction does not need to be deserialized again.
*/
private fun copyWithCache(sigList: Iterable<DigitalSignature.WithKey>): SignedTransaction {
val cached = cachedTransaction
return copy(sigs = sigs + sigList).apply {
cachedTransaction = cached
}
}
/** Alias for [withAdditionalSignature] to let you use Kotlin operator overloading. */
operator fun plus(sig: DigitalSignature.WithKey) = withAdditionalSignature(sig)

View File

@ -178,7 +178,7 @@ open class TransactionBuilder(
throw IllegalStateException("Missing signatures on the transaction for the public keys: ${missing.joinToString()}")
}
val wtx = toWireTransaction()
return SignedTransaction(wtx.serialize(), ArrayList(currentSigs))
return SignedTransaction(wtx, ArrayList(currentSigs))
}
/**

View File

@ -8,18 +8,14 @@ import net.corda.core.crypto.keys
import net.corda.core.identity.Party
import net.corda.core.internal.Emoji
import net.corda.core.node.ServicesForResolution
import net.corda.core.serialization.*
import net.corda.core.serialization.SerializationDefaults.P2P_CONTEXT
import net.corda.core.serialization.SerializationDefaults.SERIALIZATION_FACTORY
import java.security.PublicKey
import java.security.SignatureException
import java.util.function.Predicate
/**
* A transaction ready for serialisation, without any signatures attached. A WireTransaction is usually wrapped
* by a [SignedTransaction] that carries the signatures over this payload. The hash of the wire transaction is
* the identity of the transaction, that is, it's possible for two [SignedTransaction]s with different sets of
* signatures to have the same identity hash.
* by a [SignedTransaction] that carries the signatures over this payload.
* The identity of the transaction is the Merkle tree root of its components (see [MerkleTree]).
*/
class WireTransaction(
/** Pointers to the input states on the ledger, identified by (tx identity hash, output index). */
@ -38,20 +34,9 @@ class WireTransaction(
checkInvariants()
}
// Cache the serialised form of the transaction and its hash to give us fast access to it.
@Volatile @Transient private var cachedBytes: SerializedBytes<WireTransaction>? = null
val serialized: SerializedBytes<WireTransaction> get() = cachedBytes ?: serialize().apply { cachedBytes = this }
/** The transaction id is represented by the root hash of Merkle tree over the transaction components. */
override val id: SecureHash by lazy { merkleTree.hash }
companion object {
fun deserialize(data: SerializedBytes<WireTransaction>, serializationFactory: SerializationFactory = SERIALIZATION_FACTORY, context: SerializationContext = P2P_CONTEXT): WireTransaction {
val wtx = data.deserialize<WireTransaction>(serializationFactory, context)
wtx.cachedBytes = data
return wtx
}
}
/**
* Looks up identities and attachments from storage to generate a [LedgerTransaction]. A transaction is expected to
* have been fully resolved using the resolution flow by this point.

View File

@ -19,8 +19,7 @@ import kotlin.test.assertFailsWith
class TransactionTests : TestDependencyInjectionBase() {
private fun makeSigned(wtx: WireTransaction, vararg keys: KeyPair): SignedTransaction {
val bytes: SerializedBytes<WireTransaction> = wtx.serialized
return SignedTransaction(bytes, keys.map { it.sign(wtx.id.bytes) })
return SignedTransaction(wtx, keys.map { it.sign(wtx.id.bytes) })
}
@Test

View File

@ -6,6 +6,7 @@ import net.corda.contracts.asset.Cash
import net.corda.core.contracts.*
import net.corda.core.crypto.SecureHash.Companion.zeroHash
import net.corda.core.identity.Party
import net.corda.core.serialization.deserialize
import net.corda.core.serialization.serialize
import net.corda.core.transactions.WireTransaction
import net.corda.testing.*
@ -110,7 +111,7 @@ class PartialMerkleTreeTest : TestDependencyInjectionBase() {
val mt = testTx.buildFilteredTransaction(Predicate(::filtering))
val leaves = mt.filteredLeaves
val d = WireTransaction.deserialize(testTx.serialized)
val d = testTx.serialize().deserialize()
assertEquals(testTx.id, d.id)
assertEquals(1, leaves.commands.size)
assertEquals(1, leaves.outputs.size)

View File

@ -104,4 +104,18 @@ class TransactionSerializationTests : TestDependencyInjectionBase() {
val stx = notaryServices.addSignature(ptx)
assertEquals(TEST_TX_TIME, stx.tx.timeWindow?.midpoint)
}
@Test
fun storeAndLoadWhenSigning() {
val ptx = megaCorpServices.signInitialTransaction(tx)
ptx.verifySignaturesExcept(notaryServices.key.public)
val stored = ptx.serialize()
val loaded = stored.deserialize()
assertEquals(loaded, ptx)
val final = notaryServices.addSignature(loaded)
final.verifyRequiredSignatures()
}
}

View File

@ -84,7 +84,7 @@ class SignedTransactionGenerator : Generator<SignedTransaction>(SignedTransactio
override fun generate(random: SourceOfRandomness, status: GenerationStatus): SignedTransaction {
val wireTransaction = WiredTransactionGenerator().generate(random, status)
return SignedTransaction(
txBits = wireTransaction.serialized,
wtx = wireTransaction,
sigs = listOf(NullSignature)
)
}

View File

@ -605,7 +605,6 @@ class TwoPartyTradeFlowTests {
vararg extraSigningNodes: AbstractNode): Map<SecureHash, SignedTransaction> {
val signed = wtxToSign.map {
val bits = it.serialize()
val id = it.id
val sigs = mutableListOf<DigitalSignature.WithKey>()
sigs.add(node.services.keyManagementService.sign(id.bytes, node.services.legalIdentityKey))
@ -613,7 +612,7 @@ class TwoPartyTradeFlowTests {
extraSigningNodes.forEach { currentNode ->
sigs.add(currentNode.services.keyManagementService.sign(id.bytes, currentNode.info.legalIdentity.owningKey))
}
SignedTransaction(bits, sigs)
SignedTransaction(it, sigs)
}
return node.database.transaction {
node.services.recordTransactions(signed)

View File

@ -214,6 +214,6 @@ class RequeryConfigurationTest : TestDependencyInjectionBase() {
type = TransactionType.General,
timeWindow = null
)
return SignedTransaction(wtx.serialized, listOf(DigitalSignature.WithKey(NullPublicKey, ByteArray(1))))
return SignedTransaction(wtx, listOf(DigitalSignature.WithKey(NullPublicKey, ByteArray(1))))
}
}

View File

@ -152,6 +152,6 @@ class DBTransactionStorageTests : TestDependencyInjectionBase() {
type = TransactionType.General,
timeWindow = null
)
return SignedTransaction(wtx.serialized, listOf(DigitalSignature.WithKey(NullPublicKey, ByteArray(1))))
return SignedTransaction(wtx, listOf(DigitalSignature.WithKey(NullPublicKey, ByteArray(1))))
}
}

View File

@ -284,7 +284,7 @@ data class TestLedgerDSLInterpreter private constructor(
override fun verifies(): EnforceVerifyOrFail {
try {
val usedInputs = mutableSetOf<StateRef>()
services.recordTransactions(transactionsUnverified.map { SignedTransaction(it.serialized, listOf(NullSignature)) })
services.recordTransactions(transactionsUnverified.map { SignedTransaction(it, listOf(NullSignature)) })
for ((_, value) in transactionWithLocations) {
val wtx = value.transaction
val ltx = wtx.toLedgerTransaction(services)
@ -296,7 +296,7 @@ data class TestLedgerDSLInterpreter private constructor(
throw DoubleSpentInputs(txIds)
}
usedInputs.addAll(wtx.inputs)
services.recordTransactions(SignedTransaction(wtx.serialized, listOf(NullSignature)))
services.recordTransactions(SignedTransaction(wtx, listOf(NullSignature)))
}
return EnforceVerifyOrFail.Token
} catch (exception: TransactionVerificationException) {
@ -330,8 +330,6 @@ data class TestLedgerDSLInterpreter private constructor(
*/
fun signAll(transactionsToSign: List<WireTransaction>, extraKeys: List<KeyPair>) = transactionsToSign.map { wtx ->
check(wtx.mustSign.isNotEmpty())
val bits = wtx.serialize()
require(bits == wtx.serialized)
val signatures = ArrayList<DigitalSignature.WithKey>()
val keyLookup = HashMap<PublicKey, KeyPair>()
@ -342,7 +340,7 @@ fun signAll(transactionsToSign: List<WireTransaction>, extraKeys: List<KeyPair>)
val key = keyLookup[it] ?: throw IllegalArgumentException("Missing required key for ${it.toStringShort()}")
signatures += key.sign(wtx.id)
}
SignedTransaction(bits, signatures)
SignedTransaction(wtx, signatures)
}
/**