Observable toposort for transactions (#3027)

This commit is contained in:
Andras Slemmer 2018-06-06 12:58:23 +01:00 committed by Mike Hearn
parent 468c0c7404
commit e2b4943bbb
3 changed files with 227 additions and 24 deletions

View File

@ -47,31 +47,11 @@ class ResolveTransactionsFlow(txHashesArg: Set<SecureHash>,
* Topologically sorts the given transactions such that dependencies are listed before dependers. */
@JvmStatic
fun topologicalSort(transactions: Collection<SignedTransaction>): List<SignedTransaction> {
// Construct txhash -> dependent-txs map
val forwardGraph = HashMap<SecureHash, HashSet<SignedTransaction>>()
transactions.forEach { stx ->
stx.inputs.forEach { (txhash) ->
// Note that we use a LinkedHashSet here to make the traversal deterministic (as long as the input list is)
forwardGraph.getOrPut(txhash) { LinkedHashSet() }.add(stx)
}
val sort = TopologicalSort()
for (tx in transactions) {
sort.add(tx)
}
val visited = HashSet<SecureHash>(transactions.size)
val result = ArrayList<SignedTransaction>(transactions.size)
fun visit(transaction: SignedTransaction) {
if (transaction.id !in visited) {
visited.add(transaction.id)
forwardGraph[transaction.id]?.forEach(::visit)
result.add(transaction)
}
}
transactions.forEach(::visit)
result.reverse()
require(result.size == transactions.size)
return result
return sort.complete()
}
}

View File

@ -0,0 +1,117 @@
package net.corda.core.internal
import net.corda.core.contracts.StateRef
import net.corda.core.crypto.SecureHash
import net.corda.core.transactions.SignedTransaction
import rx.Observable
/**
* Provides a way to topologically sort SignedTransactions. This means that given any two transactions T1 and T2 in the
* list returned by [complete] if T1 is a dependency of T2 then T1 will occur earlier than T2.
*/
class TopologicalSort {
private val forwardGraph = HashMap<SecureHash, LinkedHashSet<SignedTransaction>>()
private val transactions = ArrayList<SignedTransaction>()
/**
* Add a transaction to the to-be-sorted set of transactions.
*/
fun add(stx: SignedTransaction) {
for (input in stx.inputs) {
// Note that we use a LinkedHashSet here to make the traversal deterministic (as long as the input list is)
forwardGraph.getOrPut(input.txhash) { LinkedHashSet() }.add(stx)
}
transactions.add(stx)
}
/**
* Return the sorted list of signed transactions.
*/
fun complete(): List<SignedTransaction> {
val visited = HashSet<SecureHash>(transactions.size)
val result = ArrayList<SignedTransaction>(transactions.size)
fun visit(transaction: SignedTransaction) {
if (transaction.id !in visited) {
visited.add(transaction.id)
forwardGraph[transaction.id]?.forEach(::visit)
result.add(transaction)
}
}
transactions.forEach(::visit)
return result.reversed()
}
}
private fun getOutputStateRefs(stx: SignedTransaction): List<StateRef> {
return stx.coreTransaction.outputs.mapIndexed { i, _ -> StateRef(stx.id, i) }
}
/**
* Topologically sort a SignedTransaction Observable on the fly by buffering transactions until all dependencies are met.
* @param initialUnspentRefs the list of unspent references that may be spent by transactions in the observable. This is
* the initial set of references the sort uses to decide whether to buffer transactions or not. For example if this
* is empty then the Observable should start with issue transactions that don't have inputs.
*/
fun Observable<SignedTransaction>.topologicalSort(initialUnspentRefs: Collection<StateRef>): Observable<SignedTransaction> {
data class State(
val unspentRefs: HashSet<StateRef>,
val bufferedTopologicalSort: TopologicalSort,
val bufferedInputs: HashSet<StateRef>,
val bufferedOutputs: HashSet<StateRef>
)
var state = State(
unspentRefs = HashSet(initialUnspentRefs),
bufferedTopologicalSort = TopologicalSort(),
bufferedInputs = HashSet(),
bufferedOutputs = HashSet()
)
return concatMapIterable { stx ->
val results = ArrayList<SignedTransaction>()
if (state.unspentRefs.containsAll(stx.inputs)) {
// Dependencies are satisfied
state.unspentRefs.removeAll(stx.inputs)
state.unspentRefs.addAll(getOutputStateRefs(stx))
results.add(stx)
} else {
// Dependencies are not satisfied, buffer
state.bufferedTopologicalSort.add(stx)
state.bufferedInputs.addAll(stx.inputs)
for (outputRef in getOutputStateRefs(stx)) {
if (!state.bufferedInputs.remove(outputRef)) {
state.bufferedOutputs.add(outputRef)
}
}
for (inputRef in stx.inputs) {
if (!state.bufferedOutputs.remove(inputRef)) {
state.bufferedInputs.add(inputRef)
}
}
}
if (state.unspentRefs.containsAll(state.bufferedInputs)) {
// Buffer satisfied
results.addAll(state.bufferedTopologicalSort.complete())
state.unspentRefs.removeAll(state.bufferedInputs)
state.unspentRefs.addAll(state.bufferedOutputs)
state = State(
unspentRefs = state.unspentRefs,
bufferedTopologicalSort = TopologicalSort(),
bufferedInputs = HashSet(),
bufferedOutputs = HashSet()
)
results
} else {
// Buffer not satisfied
state = State(
unspentRefs = state.unspentRefs,
bufferedTopologicalSort = state.bufferedTopologicalSort,
bufferedInputs = state.bufferedInputs,
bufferedOutputs = state.bufferedOutputs
)
results
}
}
}

View File

@ -0,0 +1,106 @@
package net.corda.core.internal
import net.corda.client.mock.Generator
import net.corda.core.contracts.ContractState
import net.corda.core.contracts.StateRef
import net.corda.core.contracts.TransactionState
import net.corda.core.crypto.SecureHash
import net.corda.core.crypto.SignatureMetadata
import net.corda.core.crypto.TransactionSignature
import net.corda.core.crypto.sign
import net.corda.core.identity.AbstractParty
import net.corda.core.identity.Party
import net.corda.core.serialization.serialize
import net.corda.core.transactions.CoreTransaction
import net.corda.core.transactions.SignedTransaction
import net.corda.testing.core.SerializationEnvironmentRule
import net.corda.testing.core.TestIdentity
import org.junit.Rule
import org.junit.Test
import rx.Observable
import java.util.*
class TopologicalSortTest {
class DummyTransaction(
override val id: SecureHash,
override val inputs: List<StateRef>,
val numberOfOutputs: Int,
override val notary: Party
) : CoreTransaction() {
override val outputs: List<TransactionState<ContractState>> = (1..numberOfOutputs).map {
TransactionState(DummyState(), "", notary)
}
}
class DummyState : ContractState {
override val participants: List<AbstractParty> = emptyList()
}
@Rule
@JvmField
val testSerialization = SerializationEnvironmentRule()
@Test
fun topologicalObservableSort() {
val testIdentity = TestIdentity.fresh("asd")
val N = 10
// generate random tx DAG
val ids = (1..N).map { SecureHash.sha256("$it") }
val forwardsGenerators = (0 until ids.size).map { i ->
Generator.sampleBernoulli(ids.subList(i + 1, ids.size), 0.8).map { outputs -> ids[i] to outputs }
}
val transactions = Generator.sequence(forwardsGenerators).map { forwardGraph ->
val backGraph = forwardGraph.flatMap { it.second.map { output -> it.first to output } }.fold(HashMap<SecureHash, HashSet<SecureHash>>()) { backGraph, edge ->
backGraph.getOrPut(edge.second) { HashSet() }.add(edge.first)
backGraph
}
val outrefCounts = HashMap<SecureHash, Int>()
val transactions = ArrayList<SignedTransaction>()
for ((id, outputs) in forwardGraph) {
val inputs = (backGraph[id]?.toList() ?: emptyList()).map { inputTxId ->
val ref = outrefCounts.compute(inputTxId) { _, count ->
if (count == null) {
0
} else {
count + 1
}
}!!
StateRef(inputTxId, ref)
}
val tx = DummyTransaction(id, inputs, outputs.size, testIdentity.party)
val bits = tx.serialize().bytes
val sig = TransactionSignature(testIdentity.keyPair.private.sign(bits).bytes, testIdentity.publicKey, SignatureMetadata(0, 0))
val stx = SignedTransaction(tx, listOf(sig))
transactions.add(stx)
}
transactions
}
// Swap two random items
transactions.combine(Generator.intRange(0, N - 1), Generator.intRange(0, N - 2)) { txs, i, j ->
val k = 0 // if (i == j) i + 1 else j
val tmp = txs[i]
txs[i] = txs[k]
txs[k] = tmp
txs
}
val random = SplittableRandom()
for (i in 1..100) {
val txs = transactions.generateOrFail(random)
val ordered = Observable.from(txs).topologicalSort(emptyList()).toList().toBlocking().first()
checkTopologicallyOrdered(ordered)
}
}
fun checkTopologicallyOrdered(txs: List<SignedTransaction>) {
val outputs = HashSet<StateRef>()
for (tx in txs) {
if (!outputs.containsAll(tx.inputs)) {
throw IllegalStateException("Transaction $tx's inputs ${tx.inputs} are not satisfied by $outputs")
}
outputs.addAll(tx.coreTransaction.outputs.mapIndexed { i, _ -> StateRef(tx.id, i) })
}
}
}