From 75da037d673a3db8db50a9472758f89960591f2d Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Wed, 8 Mar 2023 14:25:04 -0500 Subject: [PATCH] Add race() implementation from https://github.com/twisted/twisted/pull/11818 --- src/allmydata/test/test_deferredutil.py | 160 ++++++++++++++++++++++-- src/allmydata/util/deferredutil.py | 100 ++++++++++++++- 2 files changed, 247 insertions(+), 13 deletions(-) diff --git a/src/allmydata/test/test_deferredutil.py b/src/allmydata/test/test_deferredutil.py index a37dfdd6f..47121b4cb 100644 --- a/src/allmydata/test/test_deferredutil.py +++ b/src/allmydata/test/test_deferredutil.py @@ -1,23 +1,16 @@ """ Tests for allmydata.util.deferredutil. - -Ported to Python 3. """ -from __future__ import unicode_literals -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from future.utils import PY2 -if PY2: - from builtins import filter, map, zip, ascii, chr, hex, input, next, oct, open, pow, round, super, bytes, dict, list, object, range, str, max, min # noqa: F401 - from twisted.trial import unittest from twisted.internet import defer, reactor +from twisted.internet.defer import Deferred from twisted.python.failure import Failure +from hypothesis.strategies import integers +from hypothesis import given from allmydata.util import deferredutil +from allmydata.util.deferredutil import race, MultiFailure class DeferredUtilTests(unittest.TestCase, deferredutil.WaitForDelayedCallsMixin): @@ -157,3 +150,148 @@ class AsyncToDeferred(unittest.TestCase): result = f(1, 0) self.assertIsInstance(self.failureResultOf(result).value, ZeroDivisionError) + + + +def _setupRaceState(numDeferreds: int) -> tuple[list[int], list[Deferred[object]]]: + """ + Create a list of Deferreds and a corresponding list of integers + tracking how many times each Deferred has been cancelled. Without + additional steps the Deferreds will never fire. + """ + cancelledState = [0] * numDeferreds + + ds: list[Deferred[object]] = [] + for n in range(numDeferreds): + + def cancel(d: Deferred, n: int = n) -> None: + cancelledState[n] += 1 + + ds.append(Deferred(canceller=cancel)) + + return cancelledState, ds + + +class RaceTests(unittest.SynchronousTestCase): + """ + Tests for L{race}. + """ + + @given( + beforeWinner=integers(min_value=0, max_value=3), + afterWinner=integers(min_value=0, max_value=3), + ) + def test_success(self, beforeWinner: int, afterWinner: int) -> None: + """ + When one of the L{Deferred}s passed to L{race} fires successfully, + the L{Deferred} return by L{race} fires with the index of that + L{Deferred} and its result and cancels the rest of the L{Deferred}s. + @param beforeWinner: A randomly selected number of Deferreds to + appear before the "winning" Deferred in the list passed in. + @param beforeWinner: A randomly selected number of Deferreds to + appear after the "winning" Deferred in the list passed in. + """ + cancelledState, ds = _setupRaceState(beforeWinner + 1 + afterWinner) + + raceResult = race(ds) + expected = object() + ds[beforeWinner].callback(expected) + + # The result should be the index and result of the only Deferred that + # fired. + self.assertEqual( + self.successResultOf(raceResult), + (beforeWinner, expected), + ) + # All Deferreds except the winner should have been cancelled once. + expectedCancelledState = [1] * beforeWinner + [0] + [1] * afterWinner + self.assertEqual( + cancelledState, + expectedCancelledState, + ) + + @given( + beforeWinner=integers(min_value=0, max_value=3), + afterWinner=integers(min_value=0, max_value=3), + ) + def test_failure(self, beforeWinner: int, afterWinner: int) -> None: + """ + When all of the L{Deferred}s passed to L{race} fire with failures, + the L{Deferred} return by L{race} fires with L{MultiFailure} wrapping + all of their failures. + @param beforeWinner: A randomly selected number of Deferreds to + appear before the "winning" Deferred in the list passed in. + @param beforeWinner: A randomly selected number of Deferreds to + appear after the "winning" Deferred in the list passed in. + """ + cancelledState, ds = _setupRaceState(beforeWinner + 1 + afterWinner) + + failure = Failure(Exception("The test demands failures.")) + raceResult = race(ds) + for d in ds: + d.errback(failure) + + actualFailure = self.failureResultOf(raceResult, MultiFailure) + self.assertEqual( + actualFailure.value.failures, + [failure] * len(ds), + ) + self.assertEqual( + cancelledState, + [0] * len(ds), + ) + + @given( + beforeWinner=integers(min_value=0, max_value=3), + afterWinner=integers(min_value=0, max_value=3), + ) + def test_resultAfterCancel(self, beforeWinner: int, afterWinner: int) -> None: + """ + If one of the Deferreds fires after it was cancelled its result + goes nowhere. In particular, it does not cause any errors to be + logged. + """ + # Ensure we have a Deferred to win and at least one other Deferred + # that can ignore cancellation. + ds: list[Deferred[None]] = [ + Deferred() for n in range(beforeWinner + 2 + afterWinner) + ] + + raceResult = race(ds) + ds[beforeWinner].callback(None) + ds[beforeWinner + 1].callback(None) + + self.successResultOf(raceResult) + self.assertEqual(len(self.flushLoggedErrors()), 0) + + def test_resultFromCancel(self) -> None: + """ + If one of the input Deferreds has a cancel function that fires it + with success, nothing bad happens. + """ + winner: Deferred[object] = Deferred() + ds: list[Deferred[object]] = [ + winner, + Deferred(canceller=lambda d: d.callback(object())), + ] + expected = object() + raceResult = race(ds) + winner.callback(expected) + + self.assertEqual(self.successResultOf(raceResult), (0, expected)) + + @given( + numDeferreds=integers(min_value=1, max_value=3), + ) + def test_cancel(self, numDeferreds: int) -> None: + """ + If the result of L{race} is cancelled then all of the L{Deferred}s + passed in are cancelled. + """ + cancelledState, ds = _setupRaceState(numDeferreds) + + raceResult = race(ds) + raceResult.cancel() + + self.assertEqual(cancelledState, [1] * numDeferreds) + self.failureResultOf(raceResult, MultiFailure) diff --git a/src/allmydata/util/deferredutil.py b/src/allmydata/util/deferredutil.py index 782663e8b..83de411ce 100644 --- a/src/allmydata/util/deferredutil.py +++ b/src/allmydata/util/deferredutil.py @@ -1,15 +1,18 @@ """ Utilities for working with Twisted Deferreds. - -Ported to Python 3. """ +from __future__ import annotations + import time from functools import wraps from typing import ( Callable, Any, + Sequence, + TypeVar, + Optional, ) from foolscap.api import eventually @@ -17,6 +20,7 @@ from eliot.twisted import ( inline_callbacks, ) from twisted.internet import defer, reactor, error +from twisted.internet.defer import Deferred from twisted.python.failure import Failure from allmydata.util import log @@ -234,3 +238,95 @@ def async_to_deferred(f): return defer.Deferred.fromCoroutine(f(*args, **kwargs)) return not_async + + +class MultiFailure(Exception): + """ + More than one failure occurred. + """ + + def __init__(self, failures: Sequence[Failure]) -> None: + super(MultiFailure, self).__init__() + self.failures = failures + + +_T = TypeVar("_T") + +# Eventually this should be in Twisted upstream: +# https://github.com/twisted/twisted/pull/11818 +def race(ds: Sequence[Deferred[_T]]) -> Deferred[tuple[int, _T]]: + """ + Select the first available result from the sequence of Deferreds and + cancel the rest. + @return: A cancellable L{Deferred} that fires with the index and output of + the element of C{ds} to have a success result first, or that fires + with L{MultiFailure} holding a list of their failures if they all + fail. + """ + # Keep track of the Deferred for the action which completed first. When + # it completes, all of the other Deferreds will get cancelled but this one + # shouldn't be. Even though it "completed" it isn't really done - the + # caller will still be using it for something. If we cancelled it, + # cancellation could propagate down to them. + winner: Optional[Deferred] = None + + # The cancellation function for the Deferred this function returns. + def cancel(result: Deferred) -> None: + # If it is cancelled then we cancel all of the Deferreds for the + # individual actions because there is no longer the possibility of + # delivering any of their results anywhere. We don't have to fire + # `result` because the Deferred will do that for us. + for d in to_cancel: + d.cancel() + + # The Deferred that this function will return. It will fire with the + # index and output of the action that completes first, or None if all of + # the actions fail. If it is cancelled, all of the actions will be + # cancelled. + final_result: Deferred[tuple[int, _T]] = Deferred(canceller=cancel) + + # A callback for an individual action. + def succeeded(this_output: _T, this_index: int) -> None: + # If it is the first action to succeed then it becomes the "winner", + # its index/output become the externally visible result, and the rest + # of the action Deferreds get cancelled. If it is not the first + # action to succeed (because some action did not support + # cancellation), just ignore the result. It is uncommon for this + # callback to be entered twice. The only way it can happen is if one + # of the input Deferreds has a cancellation function that fires the + # Deferred with a success result. + nonlocal winner + if winner is None: + # This is the first success. Act on it. + winner = to_cancel[this_index] + + # Cancel the rest. + for d in to_cancel: + if d is not winner: + d.cancel() + + # Fire our Deferred + final_result.callback((this_index, this_output)) + + # Keep track of how many actions have failed. If they all fail we need to + # deliver failure notification on our externally visible result. + failure_state = [] + + def failed(failure: Failure, this_index: int) -> None: + failure_state.append((this_index, failure)) + if len(failure_state) == len(to_cancel): + # Every operation failed. + failure_state.sort() + failures = [f for (ignored, f) in failure_state] + final_result.errback(MultiFailure(failures)) + + # Copy the sequence of Deferreds so we know it doesn't get mutated out + # from under us. + to_cancel = list(ds) + for index, d in enumerate(ds): + # Propagate the position of this action as well as the argument to f + # to the success callback so we can cancel the right Deferreds and + # propagate the result outwards. + d.addCallbacks(succeeded, failed, callbackArgs=(index,), errbackArgs=(index,)) + + return final_result