mirror of
https://github.com/corda/corda.git
synced 2025-04-12 21:53:17 +00:00
Testing: rework TransactionGroupDSL to work with WireTransactions instead of LedgerTransactions and simplify original hash/serialised bits tracking.
This commit is contained in:
parent
fc000ec03c
commit
ea18e239d9
@ -57,17 +57,30 @@ import java.util.*
|
||||
data class WireTransaction(val inputs: List<StateRef>,
|
||||
val outputs: List<ContractState>,
|
||||
val commands: List<Command>) {
|
||||
fun toLedgerTransaction(identityService: IdentityService, originalHash: SecureHash): LedgerTransaction {
|
||||
|
||||
// Cache the serialised form of the transaction and its hash to give us fast access to it.
|
||||
@Volatile @Transient private var cachedBits: SerializedBytes<WireTransaction>? = null
|
||||
val serialized: SerializedBytes<WireTransaction> get() = cachedBits ?: serialize().apply { cachedBits = this }
|
||||
val id: SecureHash get() = serialized.hash
|
||||
companion object {
|
||||
fun deserialize(bits: SerializedBytes<WireTransaction>): WireTransaction {
|
||||
val wtx = bits.deserialize()
|
||||
wtx.cachedBits = bits
|
||||
return wtx
|
||||
}
|
||||
}
|
||||
|
||||
fun toLedgerTransaction(identityService: IdentityService): LedgerTransaction {
|
||||
val authenticatedArgs = commands.map {
|
||||
val institutions = it.pubkeys.mapNotNull { pk -> identityService.partyFromKey(pk) }
|
||||
AuthenticatedObject(it.pubkeys, institutions, it.data)
|
||||
}
|
||||
return LedgerTransaction(inputs, outputs, authenticatedArgs, originalHash)
|
||||
return LedgerTransaction(inputs, outputs, authenticatedArgs, id)
|
||||
}
|
||||
|
||||
/** Serialises and returns this transaction as a [SignedWireTransaction] with no signatures attached. */
|
||||
fun toSignedTransaction(withSigs: List<DigitalSignature.WithKey> = emptyList()): SignedWireTransaction {
|
||||
return SignedWireTransaction(serialize(), withSigs)
|
||||
fun toSignedTransaction(withSigs: List<DigitalSignature.WithKey>): SignedWireTransaction {
|
||||
return SignedWireTransaction(serialized, withSigs)
|
||||
}
|
||||
|
||||
override fun toString(): String {
|
||||
@ -85,7 +98,7 @@ data class SignedWireTransaction(val txBits: SerializedBytes<WireTransaction>, v
|
||||
init { check(sigs.isNotEmpty()) }
|
||||
|
||||
/** Lazily calculated access to the deserialised/hashed transaction data. */
|
||||
val tx: WireTransaction by lazy { txBits.deserialize() }
|
||||
val tx: WireTransaction by lazy { WireTransaction.deserialize(txBits) }
|
||||
|
||||
/** A transaction ID is the hash of the [WireTransaction]. Thus adding or removing a signature does not change it. */
|
||||
val id: SecureHash get() = txBits.hash
|
||||
@ -124,7 +137,7 @@ data class SignedWireTransaction(val txBits: SerializedBytes<WireTransaction>, v
|
||||
*/
|
||||
fun verifyToLedgerTransaction(identityService: IdentityService): LedgerTransaction {
|
||||
verify()
|
||||
return tx.toLedgerTransaction(identityService, id)
|
||||
return tx.toLedgerTransaction(identityService)
|
||||
}
|
||||
|
||||
/** Returns the same transaction but with an additional (unchecked) signature */
|
||||
@ -279,13 +292,6 @@ data class LedgerTransaction(
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
fun <T : ContractState> outRef(index: Int) = StateAndRef(outputs[index] as T, StateRef(hash, index))
|
||||
|
||||
fun <T : ContractState> outRef(state: T): StateAndRef<T> {
|
||||
val i = outputs.indexOf(state)
|
||||
if (i == -1)
|
||||
throw IllegalArgumentException("State not found in this transaction")
|
||||
return outRef(i)
|
||||
}
|
||||
|
||||
fun toWireTransaction(): WireTransaction {
|
||||
val wtx = WireTransaction(inputs, outputs, commands.map { Command(it.value, it.signers) })
|
||||
check(wtx.serialize().hash == hash)
|
||||
|
@ -74,8 +74,8 @@ class TransactionGroupTests {
|
||||
val e = assertFailsWith(TransactionConflictException::class) {
|
||||
verify()
|
||||
}
|
||||
assertEquals(StateRef(t.hash, 0), e.conflictRef)
|
||||
assertEquals(setOf(conflict1, conflict2), setOf(e.tx1, e.tx2))
|
||||
assertEquals(StateRef(t.id, 0), e.conflictRef)
|
||||
assertEquals(setOf(conflict1, conflict2), setOf(e.tx1.toWireTransaction(), e.tx2.toWireTransaction()))
|
||||
}
|
||||
}
|
||||
|
||||
@ -97,9 +97,11 @@ class TransactionGroupTests {
|
||||
// We have to do this manually without the DSL because transactionGroup { } won't let us create a tx that
|
||||
// points nowhere.
|
||||
val ref = StateRef(SecureHash.randomSHA256(), 0)
|
||||
tg.txns.add(LedgerTransaction(
|
||||
listOf(ref), listOf(A_THOUSAND_POUNDS), listOf(AuthenticatedObject(listOf(BOB), emptyList(), Cash.Commands.Move())), SecureHash.randomSHA256())
|
||||
)
|
||||
tg.txns += TransactionBuilder().apply {
|
||||
addInputState(ref)
|
||||
addOutputState(A_THOUSAND_POUNDS)
|
||||
addCommand(Cash.Commands.Move(), BOB)
|
||||
}.toWireTransaction()
|
||||
|
||||
val e = assertFailsWith(TransactionResolutionException::class) {
|
||||
tg.verify()
|
||||
|
@ -88,7 +88,7 @@ class TwoPartyTradeProtocolTests : TestWithInMemoryNetwork() {
|
||||
|
||||
assertEquals(aliceResult.get(), bobResult.get())
|
||||
|
||||
txns.add(aliceResult.get().second)
|
||||
txns.add(aliceResult.get().first)
|
||||
verify()
|
||||
}
|
||||
}
|
||||
@ -179,7 +179,7 @@ class TwoPartyTradeProtocolTests : TestWithInMemoryNetwork() {
|
||||
|
||||
// Bob is now finished and has the same transaction as Alice.
|
||||
val tx = bobFuture.get()
|
||||
txns.add(tx.second)
|
||||
txns.add(tx.first)
|
||||
verify()
|
||||
|
||||
assertTrue(smm.stateMachines.isEmpty())
|
||||
|
@ -12,10 +12,7 @@ package core.testutils
|
||||
|
||||
import contracts.*
|
||||
import core.*
|
||||
import core.crypto.DummyPublicKey
|
||||
import core.crypto.NullPublicKey
|
||||
import core.crypto.SecureHash
|
||||
import core.crypto.generateKeyPair
|
||||
import core.crypto.*
|
||||
import core.serialization.serialize
|
||||
import core.visualiser.GraphVisualiser
|
||||
import java.security.PublicKey
|
||||
@ -44,7 +41,7 @@ val BOB = BOB_KEY.public
|
||||
val MEGA_CORP = Party("MegaCorp", MEGA_CORP_PUBKEY)
|
||||
val MINI_CORP = Party("MiniCorp", MINI_CORP_PUBKEY)
|
||||
|
||||
val ALL_KEYS = listOf(MEGA_CORP_KEY, MINI_CORP_KEY, ALICE_KEY, BOB_KEY)
|
||||
val ALL_TEST_KEYS = listOf(MEGA_CORP_KEY, MINI_CORP_KEY, ALICE_KEY, BOB_KEY)
|
||||
|
||||
val TEST_KEYS_TO_CORP_MAP: Map<PublicKey, Party> = mapOf(
|
||||
MEGA_CORP_PUBKEY to MEGA_CORP,
|
||||
@ -208,21 +205,14 @@ open class TransactionForTest : AbstractTransactionForTest() {
|
||||
fun transaction(body: TransactionForTest.() -> Unit) = TransactionForTest().apply { body() }
|
||||
|
||||
class TransactionGroupDSL<T : ContractState>(private val stateType: Class<T>) {
|
||||
open inner class LedgerTransactionDSL : AbstractTransactionForTest() {
|
||||
open inner class WireTransactionDSL : AbstractTransactionForTest() {
|
||||
private val inStates = ArrayList<StateRef>()
|
||||
|
||||
fun input(label: String) {
|
||||
inStates.add(label.outputRef)
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts to a [LedgerTransaction] with the test institution map, and just assigns a random hash
|
||||
* (i.e. pretend it was signed)
|
||||
*/
|
||||
fun toLedgerTransaction(): LedgerTransaction {
|
||||
val wtx = WireTransaction(inStates, outStates.map { it.state }, commands)
|
||||
return wtx.toLedgerTransaction(MockIdentityService, wtx.serialize().hash)
|
||||
}
|
||||
fun toWireTransaction() = WireTransaction(inStates, outStates.map { it.state }, commands)
|
||||
}
|
||||
|
||||
val String.output: T get() = labelToOutputs[this] ?: throw IllegalArgumentException("State with label '$this' was not found")
|
||||
@ -230,23 +220,23 @@ class TransactionGroupDSL<T : ContractState>(private val stateType: Class<T>) {
|
||||
|
||||
fun <C : ContractState> lookup(label: String) = StateAndRef(label.output as C, label.outputRef)
|
||||
|
||||
private inner class InternalLedgerTransactionDSL : LedgerTransactionDSL() {
|
||||
fun finaliseAndInsertLabels(): LedgerTransaction {
|
||||
val ltx = toLedgerTransaction()
|
||||
private inner class InternalWireTransactionDSL : WireTransactionDSL() {
|
||||
fun finaliseAndInsertLabels(): WireTransaction {
|
||||
val wtx = toWireTransaction()
|
||||
for ((index, labelledState) in outStates.withIndex()) {
|
||||
if (labelledState.label != null) {
|
||||
labelToRefs[labelledState.label] = StateRef(ltx.hash, index)
|
||||
labelToRefs[labelledState.label] = StateRef(wtx.id, index)
|
||||
if (stateType.isInstance(labelledState.state)) {
|
||||
labelToOutputs[labelledState.label] = labelledState.state as T
|
||||
}
|
||||
outputsToLabels[labelledState.state] = labelledState.label
|
||||
}
|
||||
}
|
||||
return ltx
|
||||
return wtx
|
||||
}
|
||||
}
|
||||
|
||||
private val rootTxns = ArrayList<LedgerTransaction>()
|
||||
private val rootTxns = ArrayList<WireTransaction>()
|
||||
private val labelToRefs = HashMap<String, StateRef>()
|
||||
private val labelToOutputs = HashMap<String, T>()
|
||||
private val outputsToLabels = HashMap<ContractState, String>()
|
||||
@ -257,42 +247,45 @@ class TransactionGroupDSL<T : ContractState>(private val stateType: Class<T>) {
|
||||
fun transaction(vararg outputStates: LabeledOutput) {
|
||||
val outs = outputStates.map { it.state }
|
||||
val wtx = WireTransaction(emptyList(), outs, emptyList())
|
||||
val ltx = wtx.toLedgerTransaction(MockIdentityService, SecureHash.randomSHA256())
|
||||
for ((index, state) in outputStates.withIndex()) {
|
||||
val label = state.label!!
|
||||
labelToRefs[label] = StateRef(ltx.hash, index)
|
||||
labelToRefs[label] = StateRef(wtx.id, index)
|
||||
outputsToLabels[state.state] = label
|
||||
labelToOutputs[label] = state.state as T
|
||||
}
|
||||
rootTxns.add(ltx)
|
||||
rootTxns.add(wtx)
|
||||
}
|
||||
|
||||
@Deprecated("Does not nest ", level = DeprecationLevel.ERROR)
|
||||
fun roots(body: Roots.() -> Unit) {}
|
||||
@Deprecated("Use the vararg form of transaction inside roots", level = DeprecationLevel.ERROR)
|
||||
fun transaction(body: LedgerTransactionDSL.() -> Unit) {}
|
||||
fun transaction(body: WireTransactionDSL.() -> Unit) {}
|
||||
}
|
||||
fun roots(body: Roots.() -> Unit) = Roots().apply { body() }
|
||||
|
||||
val txns = ArrayList<LedgerTransaction>()
|
||||
private val txnToLabelMap = HashMap<LedgerTransaction, String>()
|
||||
val txns = ArrayList<WireTransaction>()
|
||||
private val txnToLabelMap = HashMap<SecureHash, String>()
|
||||
|
||||
fun transaction(label: String? = null, body: LedgerTransactionDSL.() -> Unit): LedgerTransaction {
|
||||
val forTest = InternalLedgerTransactionDSL()
|
||||
fun transaction(label: String? = null, body: WireTransactionDSL.() -> Unit): WireTransaction {
|
||||
val forTest = InternalWireTransactionDSL()
|
||||
forTest.body()
|
||||
val ltx = forTest.finaliseAndInsertLabels()
|
||||
txns.add(ltx)
|
||||
val wtx = forTest.finaliseAndInsertLabels()
|
||||
txns.add(wtx)
|
||||
if (label != null)
|
||||
txnToLabelMap[ltx] = label
|
||||
return ltx
|
||||
txnToLabelMap[wtx.id] = label
|
||||
return wtx
|
||||
}
|
||||
|
||||
fun labelForTransaction(ltx: LedgerTransaction): String? = txnToLabelMap[ltx]
|
||||
fun labelForTransaction(tx: WireTransaction): String? = txnToLabelMap[tx.id]
|
||||
fun labelForTransaction(tx: LedgerTransaction): String? = txnToLabelMap[tx.hash]
|
||||
|
||||
@Deprecated("Does not nest ", level = DeprecationLevel.ERROR)
|
||||
fun transactionGroup(body: TransactionGroupDSL<T>.() -> Unit) {}
|
||||
|
||||
fun toTransactionGroup() = TransactionGroup(txns.map { it }.toSet(), rootTxns.toSet())
|
||||
fun toTransactionGroup() = TransactionGroup(
|
||||
txns.map { it.toLedgerTransaction(MockIdentityService) }.toSet(),
|
||||
rootTxns.map { it.toLedgerTransaction(MockIdentityService) }.toSet()
|
||||
)
|
||||
|
||||
class Failed(val index: Int, cause: Throwable) : Exception("Transaction $index didn't verify", cause)
|
||||
|
||||
@ -302,8 +295,8 @@ class TransactionGroupDSL<T : ContractState>(private val stateType: Class<T>) {
|
||||
group.verify(MockContractFactory)
|
||||
} catch (e: TransactionVerificationException) {
|
||||
// Let the developer know the index of the transaction that failed.
|
||||
val ltx: LedgerTransaction = txns.find { it.hash == e.tx.origHash }!!
|
||||
throw Failed(txns.indexOf(ltx) + 1, e)
|
||||
val wtx: WireTransaction = txns.find { it.id == e.tx.origHash }!!
|
||||
throw Failed(txns.indexOf(wtx) + 1, e)
|
||||
}
|
||||
}
|
||||
|
||||
@ -323,7 +316,17 @@ class TransactionGroupDSL<T : ContractState>(private val stateType: Class<T>) {
|
||||
}
|
||||
|
||||
fun signAll(): List<SignedWireTransaction> {
|
||||
return txns.map { it.toSignedTransaction(andSignWithKeys = ALL_KEYS, allowUnusedKeys = true) }
|
||||
return txns.map { wtx ->
|
||||
val allPubKeys = wtx.commands.flatMap { it.pubkeys }.toSet()
|
||||
val bits = wtx.serialize()
|
||||
require(bits == wtx.serialized)
|
||||
val sigs = ArrayList<DigitalSignature.WithKey>()
|
||||
for (key in ALL_TEST_KEYS) {
|
||||
if (allPubKeys.contains(key.public))
|
||||
sigs += key.signWithECDSA(bits)
|
||||
}
|
||||
wtx.toSignedTransaction(sigs)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user