Testing: rework TransactionGroupDSL to work with WireTransactions instead of LedgerTransactions and simplify original hash/serialised bits tracking.

This commit is contained in:
Mike Hearn 2016-02-12 15:41:22 +01:00
parent fc000ec03c
commit ea18e239d9
4 changed files with 68 additions and 57 deletions

View File

@ -57,17 +57,30 @@ import java.util.*
data class WireTransaction(val inputs: List<StateRef>,
val outputs: List<ContractState>,
val commands: List<Command>) {
fun toLedgerTransaction(identityService: IdentityService, originalHash: SecureHash): LedgerTransaction {
// Cache the serialised form of the transaction and its hash to give us fast access to it.
@Volatile @Transient private var cachedBits: SerializedBytes<WireTransaction>? = null
val serialized: SerializedBytes<WireTransaction> get() = cachedBits ?: serialize().apply { cachedBits = this }
val id: SecureHash get() = serialized.hash
companion object {
fun deserialize(bits: SerializedBytes<WireTransaction>): WireTransaction {
val wtx = bits.deserialize()
wtx.cachedBits = bits
return wtx
}
}
fun toLedgerTransaction(identityService: IdentityService): LedgerTransaction {
val authenticatedArgs = commands.map {
val institutions = it.pubkeys.mapNotNull { pk -> identityService.partyFromKey(pk) }
AuthenticatedObject(it.pubkeys, institutions, it.data)
}
return LedgerTransaction(inputs, outputs, authenticatedArgs, originalHash)
return LedgerTransaction(inputs, outputs, authenticatedArgs, id)
}
/** Serialises and returns this transaction as a [SignedWireTransaction] with no signatures attached. */
fun toSignedTransaction(withSigs: List<DigitalSignature.WithKey> = emptyList()): SignedWireTransaction {
return SignedWireTransaction(serialize(), withSigs)
fun toSignedTransaction(withSigs: List<DigitalSignature.WithKey>): SignedWireTransaction {
return SignedWireTransaction(serialized, withSigs)
}
override fun toString(): String {
@ -85,7 +98,7 @@ data class SignedWireTransaction(val txBits: SerializedBytes<WireTransaction>, v
init { check(sigs.isNotEmpty()) }
/** Lazily calculated access to the deserialised/hashed transaction data. */
val tx: WireTransaction by lazy { txBits.deserialize() }
val tx: WireTransaction by lazy { WireTransaction.deserialize(txBits) }
/** A transaction ID is the hash of the [WireTransaction]. Thus adding or removing a signature does not change it. */
val id: SecureHash get() = txBits.hash
@ -124,7 +137,7 @@ data class SignedWireTransaction(val txBits: SerializedBytes<WireTransaction>, v
*/
fun verifyToLedgerTransaction(identityService: IdentityService): LedgerTransaction {
verify()
return tx.toLedgerTransaction(identityService, id)
return tx.toLedgerTransaction(identityService)
}
/** Returns the same transaction but with an additional (unchecked) signature */
@ -279,13 +292,6 @@ data class LedgerTransaction(
@Suppress("UNCHECKED_CAST")
fun <T : ContractState> outRef(index: Int) = StateAndRef(outputs[index] as T, StateRef(hash, index))
fun <T : ContractState> outRef(state: T): StateAndRef<T> {
val i = outputs.indexOf(state)
if (i == -1)
throw IllegalArgumentException("State not found in this transaction")
return outRef(i)
}
fun toWireTransaction(): WireTransaction {
val wtx = WireTransaction(inputs, outputs, commands.map { Command(it.value, it.signers) })
check(wtx.serialize().hash == hash)

View File

@ -74,8 +74,8 @@ class TransactionGroupTests {
val e = assertFailsWith(TransactionConflictException::class) {
verify()
}
assertEquals(StateRef(t.hash, 0), e.conflictRef)
assertEquals(setOf(conflict1, conflict2), setOf(e.tx1, e.tx2))
assertEquals(StateRef(t.id, 0), e.conflictRef)
assertEquals(setOf(conflict1, conflict2), setOf(e.tx1.toWireTransaction(), e.tx2.toWireTransaction()))
}
}
@ -97,9 +97,11 @@ class TransactionGroupTests {
// We have to do this manually without the DSL because transactionGroup { } won't let us create a tx that
// points nowhere.
val ref = StateRef(SecureHash.randomSHA256(), 0)
tg.txns.add(LedgerTransaction(
listOf(ref), listOf(A_THOUSAND_POUNDS), listOf(AuthenticatedObject(listOf(BOB), emptyList(), Cash.Commands.Move())), SecureHash.randomSHA256())
)
tg.txns += TransactionBuilder().apply {
addInputState(ref)
addOutputState(A_THOUSAND_POUNDS)
addCommand(Cash.Commands.Move(), BOB)
}.toWireTransaction()
val e = assertFailsWith(TransactionResolutionException::class) {
tg.verify()

View File

@ -88,7 +88,7 @@ class TwoPartyTradeProtocolTests : TestWithInMemoryNetwork() {
assertEquals(aliceResult.get(), bobResult.get())
txns.add(aliceResult.get().second)
txns.add(aliceResult.get().first)
verify()
}
}
@ -179,7 +179,7 @@ class TwoPartyTradeProtocolTests : TestWithInMemoryNetwork() {
// Bob is now finished and has the same transaction as Alice.
val tx = bobFuture.get()
txns.add(tx.second)
txns.add(tx.first)
verify()
assertTrue(smm.stateMachines.isEmpty())

View File

@ -12,10 +12,7 @@ package core.testutils
import contracts.*
import core.*
import core.crypto.DummyPublicKey
import core.crypto.NullPublicKey
import core.crypto.SecureHash
import core.crypto.generateKeyPair
import core.crypto.*
import core.serialization.serialize
import core.visualiser.GraphVisualiser
import java.security.PublicKey
@ -44,7 +41,7 @@ val BOB = BOB_KEY.public
val MEGA_CORP = Party("MegaCorp", MEGA_CORP_PUBKEY)
val MINI_CORP = Party("MiniCorp", MINI_CORP_PUBKEY)
val ALL_KEYS = listOf(MEGA_CORP_KEY, MINI_CORP_KEY, ALICE_KEY, BOB_KEY)
val ALL_TEST_KEYS = listOf(MEGA_CORP_KEY, MINI_CORP_KEY, ALICE_KEY, BOB_KEY)
val TEST_KEYS_TO_CORP_MAP: Map<PublicKey, Party> = mapOf(
MEGA_CORP_PUBKEY to MEGA_CORP,
@ -208,21 +205,14 @@ open class TransactionForTest : AbstractTransactionForTest() {
fun transaction(body: TransactionForTest.() -> Unit) = TransactionForTest().apply { body() }
class TransactionGroupDSL<T : ContractState>(private val stateType: Class<T>) {
open inner class LedgerTransactionDSL : AbstractTransactionForTest() {
open inner class WireTransactionDSL : AbstractTransactionForTest() {
private val inStates = ArrayList<StateRef>()
fun input(label: String) {
inStates.add(label.outputRef)
}
/**
* Converts to a [LedgerTransaction] with the test institution map, and just assigns a random hash
* (i.e. pretend it was signed)
*/
fun toLedgerTransaction(): LedgerTransaction {
val wtx = WireTransaction(inStates, outStates.map { it.state }, commands)
return wtx.toLedgerTransaction(MockIdentityService, wtx.serialize().hash)
}
fun toWireTransaction() = WireTransaction(inStates, outStates.map { it.state }, commands)
}
val String.output: T get() = labelToOutputs[this] ?: throw IllegalArgumentException("State with label '$this' was not found")
@ -230,23 +220,23 @@ class TransactionGroupDSL<T : ContractState>(private val stateType: Class<T>) {
fun <C : ContractState> lookup(label: String) = StateAndRef(label.output as C, label.outputRef)
private inner class InternalLedgerTransactionDSL : LedgerTransactionDSL() {
fun finaliseAndInsertLabels(): LedgerTransaction {
val ltx = toLedgerTransaction()
private inner class InternalWireTransactionDSL : WireTransactionDSL() {
fun finaliseAndInsertLabels(): WireTransaction {
val wtx = toWireTransaction()
for ((index, labelledState) in outStates.withIndex()) {
if (labelledState.label != null) {
labelToRefs[labelledState.label] = StateRef(ltx.hash, index)
labelToRefs[labelledState.label] = StateRef(wtx.id, index)
if (stateType.isInstance(labelledState.state)) {
labelToOutputs[labelledState.label] = labelledState.state as T
}
outputsToLabels[labelledState.state] = labelledState.label
}
}
return ltx
return wtx
}
}
private val rootTxns = ArrayList<LedgerTransaction>()
private val rootTxns = ArrayList<WireTransaction>()
private val labelToRefs = HashMap<String, StateRef>()
private val labelToOutputs = HashMap<String, T>()
private val outputsToLabels = HashMap<ContractState, String>()
@ -257,42 +247,45 @@ class TransactionGroupDSL<T : ContractState>(private val stateType: Class<T>) {
fun transaction(vararg outputStates: LabeledOutput) {
val outs = outputStates.map { it.state }
val wtx = WireTransaction(emptyList(), outs, emptyList())
val ltx = wtx.toLedgerTransaction(MockIdentityService, SecureHash.randomSHA256())
for ((index, state) in outputStates.withIndex()) {
val label = state.label!!
labelToRefs[label] = StateRef(ltx.hash, index)
labelToRefs[label] = StateRef(wtx.id, index)
outputsToLabels[state.state] = label
labelToOutputs[label] = state.state as T
}
rootTxns.add(ltx)
rootTxns.add(wtx)
}
@Deprecated("Does not nest ", level = DeprecationLevel.ERROR)
fun roots(body: Roots.() -> Unit) {}
@Deprecated("Use the vararg form of transaction inside roots", level = DeprecationLevel.ERROR)
fun transaction(body: LedgerTransactionDSL.() -> Unit) {}
fun transaction(body: WireTransactionDSL.() -> Unit) {}
}
fun roots(body: Roots.() -> Unit) = Roots().apply { body() }
val txns = ArrayList<LedgerTransaction>()
private val txnToLabelMap = HashMap<LedgerTransaction, String>()
val txns = ArrayList<WireTransaction>()
private val txnToLabelMap = HashMap<SecureHash, String>()
fun transaction(label: String? = null, body: LedgerTransactionDSL.() -> Unit): LedgerTransaction {
val forTest = InternalLedgerTransactionDSL()
fun transaction(label: String? = null, body: WireTransactionDSL.() -> Unit): WireTransaction {
val forTest = InternalWireTransactionDSL()
forTest.body()
val ltx = forTest.finaliseAndInsertLabels()
txns.add(ltx)
val wtx = forTest.finaliseAndInsertLabels()
txns.add(wtx)
if (label != null)
txnToLabelMap[ltx] = label
return ltx
txnToLabelMap[wtx.id] = label
return wtx
}
fun labelForTransaction(ltx: LedgerTransaction): String? = txnToLabelMap[ltx]
fun labelForTransaction(tx: WireTransaction): String? = txnToLabelMap[tx.id]
fun labelForTransaction(tx: LedgerTransaction): String? = txnToLabelMap[tx.hash]
@Deprecated("Does not nest ", level = DeprecationLevel.ERROR)
fun transactionGroup(body: TransactionGroupDSL<T>.() -> Unit) {}
fun toTransactionGroup() = TransactionGroup(txns.map { it }.toSet(), rootTxns.toSet())
fun toTransactionGroup() = TransactionGroup(
txns.map { it.toLedgerTransaction(MockIdentityService) }.toSet(),
rootTxns.map { it.toLedgerTransaction(MockIdentityService) }.toSet()
)
class Failed(val index: Int, cause: Throwable) : Exception("Transaction $index didn't verify", cause)
@ -302,8 +295,8 @@ class TransactionGroupDSL<T : ContractState>(private val stateType: Class<T>) {
group.verify(MockContractFactory)
} catch (e: TransactionVerificationException) {
// Let the developer know the index of the transaction that failed.
val ltx: LedgerTransaction = txns.find { it.hash == e.tx.origHash }!!
throw Failed(txns.indexOf(ltx) + 1, e)
val wtx: WireTransaction = txns.find { it.id == e.tx.origHash }!!
throw Failed(txns.indexOf(wtx) + 1, e)
}
}
@ -323,7 +316,17 @@ class TransactionGroupDSL<T : ContractState>(private val stateType: Class<T>) {
}
fun signAll(): List<SignedWireTransaction> {
return txns.map { it.toSignedTransaction(andSignWithKeys = ALL_KEYS, allowUnusedKeys = true) }
return txns.map { wtx ->
val allPubKeys = wtx.commands.flatMap { it.pubkeys }.toSet()
val bits = wtx.serialize()
require(bits == wtx.serialized)
val sigs = ArrayList<DigitalSignature.WithKey>()
for (key in ALL_TEST_KEYS) {
if (allPubKeys.contains(key.public))
sigs += key.signWithECDSA(bits)
}
wtx.toSignedTransaction(sigs)
}
}
}