diff --git a/core/src/main/kotlin/net/corda/core/internal/ResolveTransactionsFlow.kt b/core/src/main/kotlin/net/corda/core/internal/ResolveTransactionsFlow.kt index d9a6f2c715..00c5a56d70 100644 --- a/core/src/main/kotlin/net/corda/core/internal/ResolveTransactionsFlow.kt +++ b/core/src/main/kotlin/net/corda/core/internal/ResolveTransactionsFlow.kt @@ -47,31 +47,11 @@ class ResolveTransactionsFlow(txHashesArg: Set, * Topologically sorts the given transactions such that dependencies are listed before dependers. */ @JvmStatic fun topologicalSort(transactions: Collection): List { - // Construct txhash -> dependent-txs map - val forwardGraph = HashMap>() - transactions.forEach { stx -> - stx.inputs.forEach { (txhash) -> - // Note that we use a LinkedHashSet here to make the traversal deterministic (as long as the input list is) - forwardGraph.getOrPut(txhash) { LinkedHashSet() }.add(stx) - } + val sort = TopologicalSort() + for (tx in transactions) { + sort.add(tx) } - - val visited = HashSet(transactions.size) - val result = ArrayList(transactions.size) - - fun visit(transaction: SignedTransaction) { - if (transaction.id !in visited) { - visited.add(transaction.id) - forwardGraph[transaction.id]?.forEach(::visit) - result.add(transaction) - } - } - - transactions.forEach(::visit) - - result.reverse() - require(result.size == transactions.size) - return result + return sort.complete() } } diff --git a/core/src/main/kotlin/net/corda/core/internal/TopologicalSort.kt b/core/src/main/kotlin/net/corda/core/internal/TopologicalSort.kt new file mode 100644 index 0000000000..b0fd2d0abf --- /dev/null +++ b/core/src/main/kotlin/net/corda/core/internal/TopologicalSort.kt @@ -0,0 +1,117 @@ +package net.corda.core.internal + +import net.corda.core.contracts.StateRef +import net.corda.core.crypto.SecureHash +import net.corda.core.transactions.SignedTransaction +import rx.Observable + +/** + * Provides a way to topologically sort SignedTransactions. This means that given any two transactions T1 and T2 in the + * list returned by [complete] if T1 is a dependency of T2 then T1 will occur earlier than T2. + */ +class TopologicalSort { + private val forwardGraph = HashMap>() + private val transactions = ArrayList() + + /** + * Add a transaction to the to-be-sorted set of transactions. + */ + fun add(stx: SignedTransaction) { + for (input in stx.inputs) { + // Note that we use a LinkedHashSet here to make the traversal deterministic (as long as the input list is) + forwardGraph.getOrPut(input.txhash) { LinkedHashSet() }.add(stx) + } + transactions.add(stx) + } + + /** + * Return the sorted list of signed transactions. + */ + fun complete(): List { + val visited = HashSet(transactions.size) + val result = ArrayList(transactions.size) + + fun visit(transaction: SignedTransaction) { + if (transaction.id !in visited) { + visited.add(transaction.id) + forwardGraph[transaction.id]?.forEach(::visit) + result.add(transaction) + } + } + + transactions.forEach(::visit) + return result.reversed() + } +} + +private fun getOutputStateRefs(stx: SignedTransaction): List { + return stx.coreTransaction.outputs.mapIndexed { i, _ -> StateRef(stx.id, i) } +} + +/** + * Topologically sort a SignedTransaction Observable on the fly by buffering transactions until all dependencies are met. + * @param initialUnspentRefs the list of unspent references that may be spent by transactions in the observable. This is + * the initial set of references the sort uses to decide whether to buffer transactions or not. For example if this + * is empty then the Observable should start with issue transactions that don't have inputs. + */ +fun Observable.topologicalSort(initialUnspentRefs: Collection): Observable { + data class State( + val unspentRefs: HashSet, + val bufferedTopologicalSort: TopologicalSort, + val bufferedInputs: HashSet, + val bufferedOutputs: HashSet + ) + + var state = State( + unspentRefs = HashSet(initialUnspentRefs), + bufferedTopologicalSort = TopologicalSort(), + bufferedInputs = HashSet(), + bufferedOutputs = HashSet() + ) + + return concatMapIterable { stx -> + val results = ArrayList() + if (state.unspentRefs.containsAll(stx.inputs)) { + // Dependencies are satisfied + state.unspentRefs.removeAll(stx.inputs) + state.unspentRefs.addAll(getOutputStateRefs(stx)) + results.add(stx) + } else { + // Dependencies are not satisfied, buffer + state.bufferedTopologicalSort.add(stx) + state.bufferedInputs.addAll(stx.inputs) + for (outputRef in getOutputStateRefs(stx)) { + if (!state.bufferedInputs.remove(outputRef)) { + state.bufferedOutputs.add(outputRef) + } + } + for (inputRef in stx.inputs) { + if (!state.bufferedOutputs.remove(inputRef)) { + state.bufferedInputs.add(inputRef) + } + } + } + if (state.unspentRefs.containsAll(state.bufferedInputs)) { + // Buffer satisfied + results.addAll(state.bufferedTopologicalSort.complete()) + state.unspentRefs.removeAll(state.bufferedInputs) + state.unspentRefs.addAll(state.bufferedOutputs) + state = State( + unspentRefs = state.unspentRefs, + bufferedTopologicalSort = TopologicalSort(), + bufferedInputs = HashSet(), + bufferedOutputs = HashSet() + ) + results + } else { + // Buffer not satisfied + state = State( + unspentRefs = state.unspentRefs, + bufferedTopologicalSort = state.bufferedTopologicalSort, + bufferedInputs = state.bufferedInputs, + bufferedOutputs = state.bufferedOutputs + ) + results + } + } +} diff --git a/core/src/test/kotlin/net/corda/core/internal/TopologicalSortTest.kt b/core/src/test/kotlin/net/corda/core/internal/TopologicalSortTest.kt new file mode 100644 index 0000000000..1c4f76cad9 --- /dev/null +++ b/core/src/test/kotlin/net/corda/core/internal/TopologicalSortTest.kt @@ -0,0 +1,106 @@ +package net.corda.core.internal + +import net.corda.client.mock.Generator +import net.corda.core.contracts.ContractState +import net.corda.core.contracts.StateRef +import net.corda.core.contracts.TransactionState +import net.corda.core.crypto.SecureHash +import net.corda.core.crypto.SignatureMetadata +import net.corda.core.crypto.TransactionSignature +import net.corda.core.crypto.sign +import net.corda.core.identity.AbstractParty +import net.corda.core.identity.Party +import net.corda.core.serialization.serialize +import net.corda.core.transactions.CoreTransaction +import net.corda.core.transactions.SignedTransaction +import net.corda.testing.core.SerializationEnvironmentRule +import net.corda.testing.core.TestIdentity +import org.junit.Rule +import org.junit.Test +import rx.Observable +import java.util.* + +class TopologicalSortTest { + class DummyTransaction( + override val id: SecureHash, + override val inputs: List, + val numberOfOutputs: Int, + override val notary: Party + ) : CoreTransaction() { + override val outputs: List> = (1..numberOfOutputs).map { + TransactionState(DummyState(), "", notary) + } + } + + class DummyState : ContractState { + override val participants: List = emptyList() + } + + @Rule + @JvmField + val testSerialization = SerializationEnvironmentRule() + + @Test + fun topologicalObservableSort() { + val testIdentity = TestIdentity.fresh("asd") + + val N = 10 + // generate random tx DAG + val ids = (1..N).map { SecureHash.sha256("$it") } + val forwardsGenerators = (0 until ids.size).map { i -> + Generator.sampleBernoulli(ids.subList(i + 1, ids.size), 0.8).map { outputs -> ids[i] to outputs } + } + val transactions = Generator.sequence(forwardsGenerators).map { forwardGraph -> + val backGraph = forwardGraph.flatMap { it.second.map { output -> it.first to output } }.fold(HashMap>()) { backGraph, edge -> + backGraph.getOrPut(edge.second) { HashSet() }.add(edge.first) + backGraph + } + val outrefCounts = HashMap() + val transactions = ArrayList() + for ((id, outputs) in forwardGraph) { + val inputs = (backGraph[id]?.toList() ?: emptyList()).map { inputTxId -> + val ref = outrefCounts.compute(inputTxId) { _, count -> + if (count == null) { + 0 + } else { + count + 1 + } + }!! + StateRef(inputTxId, ref) + } + val tx = DummyTransaction(id, inputs, outputs.size, testIdentity.party) + val bits = tx.serialize().bytes + val sig = TransactionSignature(testIdentity.keyPair.private.sign(bits).bytes, testIdentity.publicKey, SignatureMetadata(0, 0)) + val stx = SignedTransaction(tx, listOf(sig)) + transactions.add(stx) + } + transactions + } + + // Swap two random items + transactions.combine(Generator.intRange(0, N - 1), Generator.intRange(0, N - 2)) { txs, i, j -> + val k = 0 // if (i == j) i + 1 else j + val tmp = txs[i] + txs[i] = txs[k] + txs[k] = tmp + txs + } + + val random = SplittableRandom() + for (i in 1..100) { + val txs = transactions.generateOrFail(random) + val ordered = Observable.from(txs).topologicalSort(emptyList()).toList().toBlocking().first() + checkTopologicallyOrdered(ordered) + } + } + + fun checkTopologicallyOrdered(txs: List) { + val outputs = HashSet() + for (tx in txs) { + if (!outputs.containsAll(tx.inputs)) { + throw IllegalStateException("Transaction $tx's inputs ${tx.inputs} are not satisfied by $outputs") + } + outputs.addAll(tx.coreTransaction.outputs.mapIndexed { i, _ -> StateRef(tx.id, i) }) + } + } +} \ No newline at end of file