mirror of
https://github.com/corda/corda.git
synced 2025-06-02 23:50:54 +00:00
Observable toposort for transactions (#3027)
This commit is contained in:
parent
468c0c7404
commit
e2b4943bbb
@ -47,31 +47,11 @@ class ResolveTransactionsFlow(txHashesArg: Set<SecureHash>,
|
|||||||
* Topologically sorts the given transactions such that dependencies are listed before dependers. */
|
* Topologically sorts the given transactions such that dependencies are listed before dependers. */
|
||||||
@JvmStatic
|
@JvmStatic
|
||||||
fun topologicalSort(transactions: Collection<SignedTransaction>): List<SignedTransaction> {
|
fun topologicalSort(transactions: Collection<SignedTransaction>): List<SignedTransaction> {
|
||||||
// Construct txhash -> dependent-txs map
|
val sort = TopologicalSort()
|
||||||
val forwardGraph = HashMap<SecureHash, HashSet<SignedTransaction>>()
|
for (tx in transactions) {
|
||||||
transactions.forEach { stx ->
|
sort.add(tx)
|
||||||
stx.inputs.forEach { (txhash) ->
|
|
||||||
// Note that we use a LinkedHashSet here to make the traversal deterministic (as long as the input list is)
|
|
||||||
forwardGraph.getOrPut(txhash) { LinkedHashSet() }.add(stx)
|
|
||||||
}
|
}
|
||||||
}
|
return sort.complete()
|
||||||
|
|
||||||
val visited = HashSet<SecureHash>(transactions.size)
|
|
||||||
val result = ArrayList<SignedTransaction>(transactions.size)
|
|
||||||
|
|
||||||
fun visit(transaction: SignedTransaction) {
|
|
||||||
if (transaction.id !in visited) {
|
|
||||||
visited.add(transaction.id)
|
|
||||||
forwardGraph[transaction.id]?.forEach(::visit)
|
|
||||||
result.add(transaction)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
transactions.forEach(::visit)
|
|
||||||
|
|
||||||
result.reverse()
|
|
||||||
require(result.size == transactions.size)
|
|
||||||
return result
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
117
core/src/main/kotlin/net/corda/core/internal/TopologicalSort.kt
Normal file
117
core/src/main/kotlin/net/corda/core/internal/TopologicalSort.kt
Normal file
@ -0,0 +1,117 @@
|
|||||||
|
package net.corda.core.internal
|
||||||
|
|
||||||
|
import net.corda.core.contracts.StateRef
|
||||||
|
import net.corda.core.crypto.SecureHash
|
||||||
|
import net.corda.core.transactions.SignedTransaction
|
||||||
|
import rx.Observable
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Provides a way to topologically sort SignedTransactions. This means that given any two transactions T1 and T2 in the
|
||||||
|
* list returned by [complete] if T1 is a dependency of T2 then T1 will occur earlier than T2.
|
||||||
|
*/
|
||||||
|
class TopologicalSort {
|
||||||
|
private val forwardGraph = HashMap<SecureHash, LinkedHashSet<SignedTransaction>>()
|
||||||
|
private val transactions = ArrayList<SignedTransaction>()
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Add a transaction to the to-be-sorted set of transactions.
|
||||||
|
*/
|
||||||
|
fun add(stx: SignedTransaction) {
|
||||||
|
for (input in stx.inputs) {
|
||||||
|
// Note that we use a LinkedHashSet here to make the traversal deterministic (as long as the input list is)
|
||||||
|
forwardGraph.getOrPut(input.txhash) { LinkedHashSet() }.add(stx)
|
||||||
|
}
|
||||||
|
transactions.add(stx)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Return the sorted list of signed transactions.
|
||||||
|
*/
|
||||||
|
fun complete(): List<SignedTransaction> {
|
||||||
|
val visited = HashSet<SecureHash>(transactions.size)
|
||||||
|
val result = ArrayList<SignedTransaction>(transactions.size)
|
||||||
|
|
||||||
|
fun visit(transaction: SignedTransaction) {
|
||||||
|
if (transaction.id !in visited) {
|
||||||
|
visited.add(transaction.id)
|
||||||
|
forwardGraph[transaction.id]?.forEach(::visit)
|
||||||
|
result.add(transaction)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
transactions.forEach(::visit)
|
||||||
|
return result.reversed()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun getOutputStateRefs(stx: SignedTransaction): List<StateRef> {
|
||||||
|
return stx.coreTransaction.outputs.mapIndexed { i, _ -> StateRef(stx.id, i) }
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Topologically sort a SignedTransaction Observable on the fly by buffering transactions until all dependencies are met.
|
||||||
|
* @param initialUnspentRefs the list of unspent references that may be spent by transactions in the observable. This is
|
||||||
|
* the initial set of references the sort uses to decide whether to buffer transactions or not. For example if this
|
||||||
|
* is empty then the Observable should start with issue transactions that don't have inputs.
|
||||||
|
*/
|
||||||
|
fun Observable<SignedTransaction>.topologicalSort(initialUnspentRefs: Collection<StateRef>): Observable<SignedTransaction> {
|
||||||
|
data class State(
|
||||||
|
val unspentRefs: HashSet<StateRef>,
|
||||||
|
val bufferedTopologicalSort: TopologicalSort,
|
||||||
|
val bufferedInputs: HashSet<StateRef>,
|
||||||
|
val bufferedOutputs: HashSet<StateRef>
|
||||||
|
)
|
||||||
|
|
||||||
|
var state = State(
|
||||||
|
unspentRefs = HashSet(initialUnspentRefs),
|
||||||
|
bufferedTopologicalSort = TopologicalSort(),
|
||||||
|
bufferedInputs = HashSet(),
|
||||||
|
bufferedOutputs = HashSet()
|
||||||
|
)
|
||||||
|
|
||||||
|
return concatMapIterable { stx ->
|
||||||
|
val results = ArrayList<SignedTransaction>()
|
||||||
|
if (state.unspentRefs.containsAll(stx.inputs)) {
|
||||||
|
// Dependencies are satisfied
|
||||||
|
state.unspentRefs.removeAll(stx.inputs)
|
||||||
|
state.unspentRefs.addAll(getOutputStateRefs(stx))
|
||||||
|
results.add(stx)
|
||||||
|
} else {
|
||||||
|
// Dependencies are not satisfied, buffer
|
||||||
|
state.bufferedTopologicalSort.add(stx)
|
||||||
|
state.bufferedInputs.addAll(stx.inputs)
|
||||||
|
for (outputRef in getOutputStateRefs(stx)) {
|
||||||
|
if (!state.bufferedInputs.remove(outputRef)) {
|
||||||
|
state.bufferedOutputs.add(outputRef)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (inputRef in stx.inputs) {
|
||||||
|
if (!state.bufferedOutputs.remove(inputRef)) {
|
||||||
|
state.bufferedInputs.add(inputRef)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (state.unspentRefs.containsAll(state.bufferedInputs)) {
|
||||||
|
// Buffer satisfied
|
||||||
|
results.addAll(state.bufferedTopologicalSort.complete())
|
||||||
|
state.unspentRefs.removeAll(state.bufferedInputs)
|
||||||
|
state.unspentRefs.addAll(state.bufferedOutputs)
|
||||||
|
state = State(
|
||||||
|
unspentRefs = state.unspentRefs,
|
||||||
|
bufferedTopologicalSort = TopologicalSort(),
|
||||||
|
bufferedInputs = HashSet(),
|
||||||
|
bufferedOutputs = HashSet()
|
||||||
|
)
|
||||||
|
results
|
||||||
|
} else {
|
||||||
|
// Buffer not satisfied
|
||||||
|
state = State(
|
||||||
|
unspentRefs = state.unspentRefs,
|
||||||
|
bufferedTopologicalSort = state.bufferedTopologicalSort,
|
||||||
|
bufferedInputs = state.bufferedInputs,
|
||||||
|
bufferedOutputs = state.bufferedOutputs
|
||||||
|
)
|
||||||
|
results
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,106 @@
|
|||||||
|
package net.corda.core.internal
|
||||||
|
|
||||||
|
import net.corda.client.mock.Generator
|
||||||
|
import net.corda.core.contracts.ContractState
|
||||||
|
import net.corda.core.contracts.StateRef
|
||||||
|
import net.corda.core.contracts.TransactionState
|
||||||
|
import net.corda.core.crypto.SecureHash
|
||||||
|
import net.corda.core.crypto.SignatureMetadata
|
||||||
|
import net.corda.core.crypto.TransactionSignature
|
||||||
|
import net.corda.core.crypto.sign
|
||||||
|
import net.corda.core.identity.AbstractParty
|
||||||
|
import net.corda.core.identity.Party
|
||||||
|
import net.corda.core.serialization.serialize
|
||||||
|
import net.corda.core.transactions.CoreTransaction
|
||||||
|
import net.corda.core.transactions.SignedTransaction
|
||||||
|
import net.corda.testing.core.SerializationEnvironmentRule
|
||||||
|
import net.corda.testing.core.TestIdentity
|
||||||
|
import org.junit.Rule
|
||||||
|
import org.junit.Test
|
||||||
|
import rx.Observable
|
||||||
|
import java.util.*
|
||||||
|
|
||||||
|
class TopologicalSortTest {
|
||||||
|
class DummyTransaction(
|
||||||
|
override val id: SecureHash,
|
||||||
|
override val inputs: List<StateRef>,
|
||||||
|
val numberOfOutputs: Int,
|
||||||
|
override val notary: Party
|
||||||
|
) : CoreTransaction() {
|
||||||
|
override val outputs: List<TransactionState<ContractState>> = (1..numberOfOutputs).map {
|
||||||
|
TransactionState(DummyState(), "", notary)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
class DummyState : ContractState {
|
||||||
|
override val participants: List<AbstractParty> = emptyList()
|
||||||
|
}
|
||||||
|
|
||||||
|
@Rule
|
||||||
|
@JvmField
|
||||||
|
val testSerialization = SerializationEnvironmentRule()
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun topologicalObservableSort() {
|
||||||
|
val testIdentity = TestIdentity.fresh("asd")
|
||||||
|
|
||||||
|
val N = 10
|
||||||
|
// generate random tx DAG
|
||||||
|
val ids = (1..N).map { SecureHash.sha256("$it") }
|
||||||
|
val forwardsGenerators = (0 until ids.size).map { i ->
|
||||||
|
Generator.sampleBernoulli(ids.subList(i + 1, ids.size), 0.8).map { outputs -> ids[i] to outputs }
|
||||||
|
}
|
||||||
|
val transactions = Generator.sequence(forwardsGenerators).map { forwardGraph ->
|
||||||
|
val backGraph = forwardGraph.flatMap { it.second.map { output -> it.first to output } }.fold(HashMap<SecureHash, HashSet<SecureHash>>()) { backGraph, edge ->
|
||||||
|
backGraph.getOrPut(edge.second) { HashSet() }.add(edge.first)
|
||||||
|
backGraph
|
||||||
|
}
|
||||||
|
val outrefCounts = HashMap<SecureHash, Int>()
|
||||||
|
val transactions = ArrayList<SignedTransaction>()
|
||||||
|
for ((id, outputs) in forwardGraph) {
|
||||||
|
val inputs = (backGraph[id]?.toList() ?: emptyList()).map { inputTxId ->
|
||||||
|
val ref = outrefCounts.compute(inputTxId) { _, count ->
|
||||||
|
if (count == null) {
|
||||||
|
0
|
||||||
|
} else {
|
||||||
|
count + 1
|
||||||
|
}
|
||||||
|
}!!
|
||||||
|
StateRef(inputTxId, ref)
|
||||||
|
}
|
||||||
|
val tx = DummyTransaction(id, inputs, outputs.size, testIdentity.party)
|
||||||
|
val bits = tx.serialize().bytes
|
||||||
|
val sig = TransactionSignature(testIdentity.keyPair.private.sign(bits).bytes, testIdentity.publicKey, SignatureMetadata(0, 0))
|
||||||
|
val stx = SignedTransaction(tx, listOf(sig))
|
||||||
|
transactions.add(stx)
|
||||||
|
}
|
||||||
|
transactions
|
||||||
|
}
|
||||||
|
|
||||||
|
// Swap two random items
|
||||||
|
transactions.combine(Generator.intRange(0, N - 1), Generator.intRange(0, N - 2)) { txs, i, j ->
|
||||||
|
val k = 0 // if (i == j) i + 1 else j
|
||||||
|
val tmp = txs[i]
|
||||||
|
txs[i] = txs[k]
|
||||||
|
txs[k] = tmp
|
||||||
|
txs
|
||||||
|
}
|
||||||
|
|
||||||
|
val random = SplittableRandom()
|
||||||
|
for (i in 1..100) {
|
||||||
|
val txs = transactions.generateOrFail(random)
|
||||||
|
val ordered = Observable.from(txs).topologicalSort(emptyList()).toList().toBlocking().first()
|
||||||
|
checkTopologicallyOrdered(ordered)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fun checkTopologicallyOrdered(txs: List<SignedTransaction>) {
|
||||||
|
val outputs = HashSet<StateRef>()
|
||||||
|
for (tx in txs) {
|
||||||
|
if (!outputs.containsAll(tx.inputs)) {
|
||||||
|
throw IllegalStateException("Transaction $tx's inputs ${tx.inputs} are not satisfied by $outputs")
|
||||||
|
}
|
||||||
|
outputs.addAll(tx.coreTransaction.outputs.mapIndexed { i, _ -> StateRef(tx.id, i) })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user