mirror of
https://github.com/corda/corda.git
synced 2024-12-19 04:57:58 +00:00
Observable toposort for transactions (#3027)
This commit is contained in:
parent
468c0c7404
commit
e2b4943bbb
@ -47,31 +47,11 @@ class ResolveTransactionsFlow(txHashesArg: Set<SecureHash>,
|
||||
* Topologically sorts the given transactions such that dependencies are listed before dependers. */
|
||||
@JvmStatic
|
||||
fun topologicalSort(transactions: Collection<SignedTransaction>): List<SignedTransaction> {
|
||||
// Construct txhash -> dependent-txs map
|
||||
val forwardGraph = HashMap<SecureHash, HashSet<SignedTransaction>>()
|
||||
transactions.forEach { stx ->
|
||||
stx.inputs.forEach { (txhash) ->
|
||||
// Note that we use a LinkedHashSet here to make the traversal deterministic (as long as the input list is)
|
||||
forwardGraph.getOrPut(txhash) { LinkedHashSet() }.add(stx)
|
||||
}
|
||||
val sort = TopologicalSort()
|
||||
for (tx in transactions) {
|
||||
sort.add(tx)
|
||||
}
|
||||
|
||||
val visited = HashSet<SecureHash>(transactions.size)
|
||||
val result = ArrayList<SignedTransaction>(transactions.size)
|
||||
|
||||
fun visit(transaction: SignedTransaction) {
|
||||
if (transaction.id !in visited) {
|
||||
visited.add(transaction.id)
|
||||
forwardGraph[transaction.id]?.forEach(::visit)
|
||||
result.add(transaction)
|
||||
}
|
||||
}
|
||||
|
||||
transactions.forEach(::visit)
|
||||
|
||||
result.reverse()
|
||||
require(result.size == transactions.size)
|
||||
return result
|
||||
return sort.complete()
|
||||
}
|
||||
}
|
||||
|
||||
|
117
core/src/main/kotlin/net/corda/core/internal/TopologicalSort.kt
Normal file
117
core/src/main/kotlin/net/corda/core/internal/TopologicalSort.kt
Normal file
@ -0,0 +1,117 @@
|
||||
package net.corda.core.internal
|
||||
|
||||
import net.corda.core.contracts.StateRef
|
||||
import net.corda.core.crypto.SecureHash
|
||||
import net.corda.core.transactions.SignedTransaction
|
||||
import rx.Observable
|
||||
|
||||
/**
|
||||
* Provides a way to topologically sort SignedTransactions. This means that given any two transactions T1 and T2 in the
|
||||
* list returned by [complete] if T1 is a dependency of T2 then T1 will occur earlier than T2.
|
||||
*/
|
||||
class TopologicalSort {
|
||||
private val forwardGraph = HashMap<SecureHash, LinkedHashSet<SignedTransaction>>()
|
||||
private val transactions = ArrayList<SignedTransaction>()
|
||||
|
||||
/**
|
||||
* Add a transaction to the to-be-sorted set of transactions.
|
||||
*/
|
||||
fun add(stx: SignedTransaction) {
|
||||
for (input in stx.inputs) {
|
||||
// Note that we use a LinkedHashSet here to make the traversal deterministic (as long as the input list is)
|
||||
forwardGraph.getOrPut(input.txhash) { LinkedHashSet() }.add(stx)
|
||||
}
|
||||
transactions.add(stx)
|
||||
}
|
||||
|
||||
/**
|
||||
* Return the sorted list of signed transactions.
|
||||
*/
|
||||
fun complete(): List<SignedTransaction> {
|
||||
val visited = HashSet<SecureHash>(transactions.size)
|
||||
val result = ArrayList<SignedTransaction>(transactions.size)
|
||||
|
||||
fun visit(transaction: SignedTransaction) {
|
||||
if (transaction.id !in visited) {
|
||||
visited.add(transaction.id)
|
||||
forwardGraph[transaction.id]?.forEach(::visit)
|
||||
result.add(transaction)
|
||||
}
|
||||
}
|
||||
|
||||
transactions.forEach(::visit)
|
||||
return result.reversed()
|
||||
}
|
||||
}
|
||||
|
||||
private fun getOutputStateRefs(stx: SignedTransaction): List<StateRef> {
|
||||
return stx.coreTransaction.outputs.mapIndexed { i, _ -> StateRef(stx.id, i) }
|
||||
}
|
||||
|
||||
/**
|
||||
* Topologically sort a SignedTransaction Observable on the fly by buffering transactions until all dependencies are met.
|
||||
* @param initialUnspentRefs the list of unspent references that may be spent by transactions in the observable. This is
|
||||
* the initial set of references the sort uses to decide whether to buffer transactions or not. For example if this
|
||||
* is empty then the Observable should start with issue transactions that don't have inputs.
|
||||
*/
|
||||
fun Observable<SignedTransaction>.topologicalSort(initialUnspentRefs: Collection<StateRef>): Observable<SignedTransaction> {
|
||||
data class State(
|
||||
val unspentRefs: HashSet<StateRef>,
|
||||
val bufferedTopologicalSort: TopologicalSort,
|
||||
val bufferedInputs: HashSet<StateRef>,
|
||||
val bufferedOutputs: HashSet<StateRef>
|
||||
)
|
||||
|
||||
var state = State(
|
||||
unspentRefs = HashSet(initialUnspentRefs),
|
||||
bufferedTopologicalSort = TopologicalSort(),
|
||||
bufferedInputs = HashSet(),
|
||||
bufferedOutputs = HashSet()
|
||||
)
|
||||
|
||||
return concatMapIterable { stx ->
|
||||
val results = ArrayList<SignedTransaction>()
|
||||
if (state.unspentRefs.containsAll(stx.inputs)) {
|
||||
// Dependencies are satisfied
|
||||
state.unspentRefs.removeAll(stx.inputs)
|
||||
state.unspentRefs.addAll(getOutputStateRefs(stx))
|
||||
results.add(stx)
|
||||
} else {
|
||||
// Dependencies are not satisfied, buffer
|
||||
state.bufferedTopologicalSort.add(stx)
|
||||
state.bufferedInputs.addAll(stx.inputs)
|
||||
for (outputRef in getOutputStateRefs(stx)) {
|
||||
if (!state.bufferedInputs.remove(outputRef)) {
|
||||
state.bufferedOutputs.add(outputRef)
|
||||
}
|
||||
}
|
||||
for (inputRef in stx.inputs) {
|
||||
if (!state.bufferedOutputs.remove(inputRef)) {
|
||||
state.bufferedInputs.add(inputRef)
|
||||
}
|
||||
}
|
||||
}
|
||||
if (state.unspentRefs.containsAll(state.bufferedInputs)) {
|
||||
// Buffer satisfied
|
||||
results.addAll(state.bufferedTopologicalSort.complete())
|
||||
state.unspentRefs.removeAll(state.bufferedInputs)
|
||||
state.unspentRefs.addAll(state.bufferedOutputs)
|
||||
state = State(
|
||||
unspentRefs = state.unspentRefs,
|
||||
bufferedTopologicalSort = TopologicalSort(),
|
||||
bufferedInputs = HashSet(),
|
||||
bufferedOutputs = HashSet()
|
||||
)
|
||||
results
|
||||
} else {
|
||||
// Buffer not satisfied
|
||||
state = State(
|
||||
unspentRefs = state.unspentRefs,
|
||||
bufferedTopologicalSort = state.bufferedTopologicalSort,
|
||||
bufferedInputs = state.bufferedInputs,
|
||||
bufferedOutputs = state.bufferedOutputs
|
||||
)
|
||||
results
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,106 @@
|
||||
package net.corda.core.internal
|
||||
|
||||
import net.corda.client.mock.Generator
|
||||
import net.corda.core.contracts.ContractState
|
||||
import net.corda.core.contracts.StateRef
|
||||
import net.corda.core.contracts.TransactionState
|
||||
import net.corda.core.crypto.SecureHash
|
||||
import net.corda.core.crypto.SignatureMetadata
|
||||
import net.corda.core.crypto.TransactionSignature
|
||||
import net.corda.core.crypto.sign
|
||||
import net.corda.core.identity.AbstractParty
|
||||
import net.corda.core.identity.Party
|
||||
import net.corda.core.serialization.serialize
|
||||
import net.corda.core.transactions.CoreTransaction
|
||||
import net.corda.core.transactions.SignedTransaction
|
||||
import net.corda.testing.core.SerializationEnvironmentRule
|
||||
import net.corda.testing.core.TestIdentity
|
||||
import org.junit.Rule
|
||||
import org.junit.Test
|
||||
import rx.Observable
|
||||
import java.util.*
|
||||
|
||||
class TopologicalSortTest {
|
||||
class DummyTransaction(
|
||||
override val id: SecureHash,
|
||||
override val inputs: List<StateRef>,
|
||||
val numberOfOutputs: Int,
|
||||
override val notary: Party
|
||||
) : CoreTransaction() {
|
||||
override val outputs: List<TransactionState<ContractState>> = (1..numberOfOutputs).map {
|
||||
TransactionState(DummyState(), "", notary)
|
||||
}
|
||||
}
|
||||
|
||||
class DummyState : ContractState {
|
||||
override val participants: List<AbstractParty> = emptyList()
|
||||
}
|
||||
|
||||
@Rule
|
||||
@JvmField
|
||||
val testSerialization = SerializationEnvironmentRule()
|
||||
|
||||
@Test
|
||||
fun topologicalObservableSort() {
|
||||
val testIdentity = TestIdentity.fresh("asd")
|
||||
|
||||
val N = 10
|
||||
// generate random tx DAG
|
||||
val ids = (1..N).map { SecureHash.sha256("$it") }
|
||||
val forwardsGenerators = (0 until ids.size).map { i ->
|
||||
Generator.sampleBernoulli(ids.subList(i + 1, ids.size), 0.8).map { outputs -> ids[i] to outputs }
|
||||
}
|
||||
val transactions = Generator.sequence(forwardsGenerators).map { forwardGraph ->
|
||||
val backGraph = forwardGraph.flatMap { it.second.map { output -> it.first to output } }.fold(HashMap<SecureHash, HashSet<SecureHash>>()) { backGraph, edge ->
|
||||
backGraph.getOrPut(edge.second) { HashSet() }.add(edge.first)
|
||||
backGraph
|
||||
}
|
||||
val outrefCounts = HashMap<SecureHash, Int>()
|
||||
val transactions = ArrayList<SignedTransaction>()
|
||||
for ((id, outputs) in forwardGraph) {
|
||||
val inputs = (backGraph[id]?.toList() ?: emptyList()).map { inputTxId ->
|
||||
val ref = outrefCounts.compute(inputTxId) { _, count ->
|
||||
if (count == null) {
|
||||
0
|
||||
} else {
|
||||
count + 1
|
||||
}
|
||||
}!!
|
||||
StateRef(inputTxId, ref)
|
||||
}
|
||||
val tx = DummyTransaction(id, inputs, outputs.size, testIdentity.party)
|
||||
val bits = tx.serialize().bytes
|
||||
val sig = TransactionSignature(testIdentity.keyPair.private.sign(bits).bytes, testIdentity.publicKey, SignatureMetadata(0, 0))
|
||||
val stx = SignedTransaction(tx, listOf(sig))
|
||||
transactions.add(stx)
|
||||
}
|
||||
transactions
|
||||
}
|
||||
|
||||
// Swap two random items
|
||||
transactions.combine(Generator.intRange(0, N - 1), Generator.intRange(0, N - 2)) { txs, i, j ->
|
||||
val k = 0 // if (i == j) i + 1 else j
|
||||
val tmp = txs[i]
|
||||
txs[i] = txs[k]
|
||||
txs[k] = tmp
|
||||
txs
|
||||
}
|
||||
|
||||
val random = SplittableRandom()
|
||||
for (i in 1..100) {
|
||||
val txs = transactions.generateOrFail(random)
|
||||
val ordered = Observable.from(txs).topologicalSort(emptyList()).toList().toBlocking().first()
|
||||
checkTopologicallyOrdered(ordered)
|
||||
}
|
||||
}
|
||||
|
||||
fun checkTopologicallyOrdered(txs: List<SignedTransaction>) {
|
||||
val outputs = HashSet<StateRef>()
|
||||
for (tx in txs) {
|
||||
if (!outputs.containsAll(tx.inputs)) {
|
||||
throw IllegalStateException("Transaction $tx's inputs ${tx.inputs} are not satisfied by $outputs")
|
||||
}
|
||||
outputs.addAll(tx.coreTransaction.outputs.mapIndexed { i, _ -> StateRef(tx.id, i) })
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user