diff --git a/integration/conftest.py b/integration/conftest.py index dc0107eea..33e7998c1 100644 --- a/integration/conftest.py +++ b/integration/conftest.py @@ -393,7 +393,7 @@ def alice( finalize=False, ) ) - await_client_ready(process) + pytest_twisted.blockon(await_client_ready(process)) # 1. Create a new RW directory cap: cli(process, "create-alias", "test") @@ -424,7 +424,7 @@ alice-key ssh-rsa {ssh_public_key} {rwcap} # 4. Restart the node with new SFTP config. pytest_twisted.blockon(process.restart_async(reactor, request)) - await_client_ready(process) + pytest_twisted.blockon(await_client_ready(process)) print(f"Alice pid: {process.transport.pid}") return process @@ -439,7 +439,7 @@ def bob(reactor, temp_dir, introducer_furl, flog_gatherer, storage_nodes, reques storage=False, ) ) - await_client_ready(process) + pytest_twisted.blockon(await_client_ready(process)) return process diff --git a/integration/test_get_put.py b/integration/test_get_put.py index 1b6c30072..f121d6284 100644 --- a/integration/test_get_put.py +++ b/integration/test_get_put.py @@ -4,7 +4,6 @@ and stdout. """ from subprocess import Popen, PIPE, check_output, check_call -import sys import pytest from pytest_twisted import ensureDeferred @@ -50,6 +49,7 @@ def test_put_from_stdin(alice, get_put_alias, tmpdir): assert read_bytes(tempfile) == DATA +@run_in_thread def test_get_to_stdout(alice, get_put_alias, tmpdir): """ It's possible to upload a file, and then download it to stdout. @@ -67,6 +67,7 @@ def test_get_to_stdout(alice, get_put_alias, tmpdir): assert p.wait() == 0 +@run_in_thread def test_large_file(alice, get_put_alias, tmp_path): """ It's possible to upload and download a larger file. @@ -85,10 +86,6 @@ def test_large_file(alice, get_put_alias, tmp_path): assert outfile.read_bytes() == tempfile.read_bytes() -@pytest.mark.skipif( - sys.platform.startswith("win"), - reason="reconfigure() has issues on Windows" -) @ensureDeferred async def test_upload_download_immutable_different_default_max_segment_size(alice, get_put_alias, tmpdir, request): """ diff --git a/integration/test_servers_of_happiness.py b/integration/test_servers_of_happiness.py index b9de0c075..c63642066 100644 --- a/integration/test_servers_of_happiness.py +++ b/integration/test_servers_of_happiness.py @@ -31,7 +31,7 @@ def test_upload_immutable(reactor, temp_dir, introducer_furl, flog_gatherer, sto happy=7, total=10, ) - util.await_client_ready(edna) + yield util.await_client_ready(edna) node_dir = join(temp_dir, 'edna') diff --git a/integration/test_tor.py b/integration/test_tor.py index c78fa8098..901858347 100644 --- a/integration/test_tor.py +++ b/integration/test_tor.py @@ -42,8 +42,8 @@ if PY2: def test_onion_service_storage(reactor, request, temp_dir, flog_gatherer, tor_network, tor_introducer_furl): carol = yield _create_anonymous_node(reactor, 'carol', 8008, request, temp_dir, flog_gatherer, tor_network, tor_introducer_furl) dave = yield _create_anonymous_node(reactor, 'dave', 8009, request, temp_dir, flog_gatherer, tor_network, tor_introducer_furl) - util.await_client_ready(carol, minimum_number_of_servers=2) - util.await_client_ready(dave, minimum_number_of_servers=2) + yield util.await_client_ready(carol, minimum_number_of_servers=2) + yield util.await_client_ready(dave, minimum_number_of_servers=2) # ensure both nodes are connected to "a grid" by uploading # something via carol, and retrieve it using dave. diff --git a/integration/test_web.py b/integration/test_web.py index 95a09a5f5..b3c4a8e5f 100644 --- a/integration/test_web.py +++ b/integration/test_web.py @@ -18,6 +18,7 @@ import allmydata.uri from allmydata.util import jsonbytes as json from . import util +from .util import run_in_thread import requests import html5lib @@ -25,6 +26,7 @@ from bs4 import BeautifulSoup from pytest_twisted import ensureDeferred +@run_in_thread def test_index(alice): """ we can download the index file @@ -32,6 +34,7 @@ def test_index(alice): util.web_get(alice, u"") +@run_in_thread def test_index_json(alice): """ we can download the index file as json @@ -41,6 +44,7 @@ def test_index_json(alice): json.loads(data) +@run_in_thread def test_upload_download(alice): """ upload a file, then download it via readcap @@ -70,6 +74,7 @@ def test_upload_download(alice): assert str(data, "utf-8") == FILE_CONTENTS +@run_in_thread def test_put(alice): """ use PUT to create a file @@ -89,6 +94,7 @@ def test_put(alice): assert cap.needed_shares == int(cfg.get_config("client", "shares.needed")) +@run_in_thread def test_helper_status(storage_nodes): """ successfully GET the /helper_status page @@ -101,6 +107,7 @@ def test_helper_status(storage_nodes): assert str(dom.h1.string) == u"Helper Status" +@run_in_thread def test_deep_stats(alice): """ create a directory, do deep-stats on it and prove the /operations/ @@ -417,6 +424,7 @@ async def test_directory_deep_check(reactor, request, alice): assert dom is not None, "Operation never completed" +@run_in_thread def test_storage_info(storage_nodes): """ retrieve and confirm /storage URI for one storage node @@ -428,6 +436,7 @@ def test_storage_info(storage_nodes): ) +@run_in_thread def test_storage_info_json(storage_nodes): """ retrieve and confirm /storage?t=json URI for one storage node @@ -442,6 +451,7 @@ def test_storage_info_json(storage_nodes): assert data[u"stats"][u"storage_server.reserved_space"] == 1000000000 +@run_in_thread def test_introducer_info(introducer): """ retrieve and confirm /introducer URI for the introducer @@ -460,6 +470,7 @@ def test_introducer_info(introducer): assert "subscription_summary" in data +@run_in_thread def test_mkdir_with_children(alice): """ create a directory using ?t=mkdir-with-children diff --git a/integration/util.py b/integration/util.py index 04c925abf..05fef8fed 100644 --- a/integration/util.py +++ b/integration/util.py @@ -430,6 +430,31 @@ class FileShouldVanishException(Exception): ) +def run_in_thread(f): + """Decorator for integration tests that runs code in a thread. + + Because we're using pytest_twisted, tests that rely on the reactor are + expected to return a Deferred and use async APIs so the reactor can run. + + In the case of the integration test suite, it launches nodes in the + background using Twisted APIs. The nodes stdout and stderr is read via + Twisted code. If the reactor doesn't run, reads don't happen, and + eventually the buffers fill up, and the nodes block when they try to flush + logs. + + We can switch to Twisted APIs (treq instead of requests etc.), but + sometimes it's easier or expedient to just have a blocking test. So this + decorator allows you to run the test in a thread, and the reactor can keep + running in the main thread. + + See https://tahoe-lafs.org/trac/tahoe-lafs/ticket/3597 for tracking bug. + """ + @wraps(f) + def test(*args, **kwargs): + return deferToThread(lambda: f(*args, **kwargs)) + return test + + def await_file_contents(path, contents, timeout=15, error_if=None): """ wait up to `timeout` seconds for the file at `path` (any path-like @@ -555,6 +580,7 @@ def web_post(tahoe, uri_fragment, **kwargs): return resp.content +@run_in_thread def await_client_ready(tahoe, timeout=10, liveness=60*2, minimum_number_of_servers=1): """ Uses the status API to wait for a client-type node (in `tahoe`, a @@ -622,30 +648,6 @@ def generate_ssh_key(path): f.write(s.encode("ascii")) -def run_in_thread(f): - """Decorator for integration tests that runs code in a thread. - - Because we're using pytest_twisted, tests that rely on the reactor are - expected to return a Deferred and use async APIs so the reactor can run. - - In the case of the integration test suite, it launches nodes in the - background using Twisted APIs. The nodes stdout and stderr is read via - Twisted code. If the reactor doesn't run, reads don't happen, and - eventually the buffers fill up, and the nodes block when they try to flush - logs. - - We can switch to Twisted APIs (treq instead of requests etc.), but - sometimes it's easier or expedient to just have a blocking test. So this - decorator allows you to run the test in a thread, and the reactor can keep - running in the main thread. - - See https://tahoe-lafs.org/trac/tahoe-lafs/ticket/3597 for tracking bug. - """ - @wraps(f) - def test(*args, **kwargs): - return deferToThread(lambda: f(*args, **kwargs)) - return test - @frozen class CHK: """ @@ -792,16 +794,11 @@ async def reconfigure(reactor, request, node: TahoeProcess, ) if changed: - # TODO reconfigure() seems to have issues on Windows. If you need to - # use it there, delete this assert and try to figure out what's going - # on... - assert not sys.platform.startswith("win") - # restart the node print(f"Restarting {node.node_dir} for ZFEC reconfiguration") await node.restart_async(reactor, request) print("Restarted. Waiting for ready state.") - await_client_ready(node) + await await_client_ready(node) print("Ready.") else: print("Config unchanged, not restarting.") diff --git a/newsfragments/3935.minor b/newsfragments/3935.minor new file mode 100644 index 000000000..e69de29bb diff --git a/newsfragments/3988.minor b/newsfragments/3988.minor new file mode 100644 index 000000000..e69de29bb diff --git a/src/allmydata/storage/http_client.py b/src/allmydata/storage/http_client.py index 90bda7fc0..1d798fecc 100644 --- a/src/allmydata/storage/http_client.py +++ b/src/allmydata/storage/http_client.py @@ -311,9 +311,7 @@ class StorageClient(object): @classmethod def from_nurl( - cls, - nurl: DecodedURL, - reactor, + cls, nurl: DecodedURL, reactor, pool: Optional[HTTPConnectionPool] = None ) -> StorageClient: """ Create a ``StorageClient`` for the given NURL. @@ -322,8 +320,9 @@ class StorageClient(object): assert nurl.scheme == "pb" swissnum = nurl.path[0].encode("ascii") certificate_hash = nurl.user.encode("ascii") - pool = HTTPConnectionPool(reactor) - pool.maxPersistentPerHost = 20 + if pool is None: + pool = HTTPConnectionPool(reactor) + pool.maxPersistentPerHost = 20 if cls.TEST_MODE_REGISTER_HTTP_POOL is not None: cls.TEST_MODE_REGISTER_HTTP_POOL(pool) diff --git a/src/allmydata/storage_client.py b/src/allmydata/storage_client.py index 837cc06d3..c88613803 100644 --- a/src/allmydata/storage_client.py +++ b/src/allmydata/storage_client.py @@ -33,8 +33,7 @@ Ported to Python 3. from __future__ import annotations from six import ensure_text - -from typing import Union, Any +from typing import Union, Callable, Any, Optional from os import urandom import re import time @@ -44,6 +43,7 @@ from configparser import NoSectionError import attr from hyperlink import DecodedURL +from twisted.web.client import HTTPConnectionPool from zope.interface import ( Attribute, Interface, @@ -83,7 +83,7 @@ from allmydata.util.observer import ObserverList from allmydata.util.rrefutil import add_version_to_remote_reference from allmydata.util.hashutil import permute_server_hash from allmydata.util.dictutil import BytesKeyDict, UnicodeKeyDict -from allmydata.util.deferredutil import async_to_deferred +from allmydata.util.deferredutil import async_to_deferred, race from allmydata.storage.http_client import ( StorageClient, StorageClientImmutables, StorageClientGeneral, ClientException as HTTPClientException, StorageClientMutables, @@ -1019,6 +1019,36 @@ class NativeStorageServer(service.MultiService): self._reconnector.reset() +@async_to_deferred +async def _pick_a_http_server( + reactor, + nurls: list[DecodedURL], + request: Callable[[Any, DecodedURL], defer.Deferred[Any]] +) -> Optional[DecodedURL]: + """Pick the first server we successfully send a request to. + + Fires with ``None`` if no server was found, or with the ``DecodedURL`` of + the first successfully-connected server. + """ + queries = race([ + request(reactor, nurl).addCallback(lambda _, nurl=nurl: nurl) + for nurl in nurls + ]) + + try: + _, nurl = await queries + return nurl + except Exception as e: + # Logging errors breaks a bunch of tests, and it's not a _bug_ to + # have a failed connection, it's often expected and transient. More + # of a warning, really? + log.msg( + "Failed to connect to a storage server advertised by NURL: {}".format( + e) + ) + return None + + @implementer(IServer) class HTTPNativeStorageServer(service.MultiService): """ @@ -1045,12 +1075,11 @@ class HTTPNativeStorageServer(service.MultiService): self._short_description, self._long_description ) = _parse_announcement(server_id, furl, announcement) - # TODO need some way to do equivalent of Happy Eyeballs for multiple NURLs? - # https://tahoe-lafs.org/trac/tahoe-lafs/ticket/3935 - nurl = DecodedURL.from_text(announcement[ANONYMOUS_STORAGE_NURLS][0]) - self._istorage_server = _HTTPStorageServer.from_http_client( - StorageClient.from_nurl(nurl, reactor) - ) + self._nurls = [ + DecodedURL.from_text(u) + for u in announcement[ANONYMOUS_STORAGE_NURLS] + ] + self._istorage_server = None self._connection_status = connection_status.ConnectionStatus.unstarted() self._version = None @@ -1167,7 +1196,46 @@ class HTTPNativeStorageServer(service.MultiService): def try_to_connect(self): self._connect() - def _connect(self): + @async_to_deferred + async def _connect(self): + if self._istorage_server is None: + # We haven't selected a server yet, so let's do so. + + # TODO This is somewhat inefficient on startup: it takes two successful + # version() calls before we are live talking to a server, it could only + # be one. See https://tahoe-lafs.org/trac/tahoe-lafs/ticket/3992 + + # TODO Another problem with this scheme is that while picking + # the HTTP server to talk to, we don't have connection status + # updates... https://tahoe-lafs.org/trac/tahoe-lafs/ticket/3978 + def request(reactor, nurl: DecodedURL): + # Since we're just using this one off to check if the NURL + # works, no need for persistent pool or other fanciness. + pool = HTTPConnectionPool(reactor, persistent=False) + pool.retryAutomatically = False + return StorageClientGeneral( + StorageClient.from_nurl(nurl, reactor, pool) + ).get_version() + + # LoopingCall.stop() doesn't cancel Deferreds, unfortunately: + # https://github.com/twisted/twisted/issues/11814 Thus we want + # store the Deferred so it gets cancelled. + picking = _pick_a_http_server(reactor, self._nurls, request) + self._connecting_deferred = picking + try: + nurl = await picking + finally: + self._connecting_deferred = None + + if nurl is None: + # We failed to find a server to connect to. Perhaps the next + # iteration of the loop will succeed. + return + else: + self._istorage_server = _HTTPStorageServer.from_http_client( + StorageClient.from_nurl(nurl, reactor) + ) + result = self._istorage_server.get_version() def remove_connecting_deferred(result): @@ -1181,6 +1249,14 @@ class HTTPNativeStorageServer(service.MultiService): self._failed_to_connect ) + # We don't want to do another iteration of the loop until this + # iteration has finished, so wait here: + try: + if self._connecting_deferred is not None: + await self._connecting_deferred + except Exception as e: + log.msg(f"Failed to connect to a HTTP storage server: {e}", level=log.CURIOUS) + def stopService(self): if self._connecting_deferred is not None: self._connecting_deferred.cancel() @@ -1354,7 +1430,7 @@ class _HTTPBucketWriter(object): return self.finished -def _ignore_404(failure: Failure) -> Union[Failure, None]: +def _ignore_404(failure: Failure) -> Optional[Failure]: """ Useful for advise_corrupt_share(), since it swallows unknown share numbers in Foolscap. diff --git a/src/allmydata/test/test_deferredutil.py b/src/allmydata/test/test_deferredutil.py index a37dfdd6f..34358d0c8 100644 --- a/src/allmydata/test/test_deferredutil.py +++ b/src/allmydata/test/test_deferredutil.py @@ -1,23 +1,18 @@ """ Tests for allmydata.util.deferredutil. - -Ported to Python 3. """ -from __future__ import unicode_literals -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from future.utils import PY2 -if PY2: - from builtins import filter, map, zip, ascii, chr, hex, input, next, oct, open, pow, round, super, bytes, dict, list, object, range, str, max, min # noqa: F401 +from __future__ import annotations from twisted.trial import unittest from twisted.internet import defer, reactor +from twisted.internet.defer import Deferred from twisted.python.failure import Failure +from hypothesis.strategies import integers +from hypothesis import given from allmydata.util import deferredutil +from allmydata.util.deferredutil import race, MultiFailure class DeferredUtilTests(unittest.TestCase, deferredutil.WaitForDelayedCallsMixin): @@ -157,3 +152,148 @@ class AsyncToDeferred(unittest.TestCase): result = f(1, 0) self.assertIsInstance(self.failureResultOf(result).value, ZeroDivisionError) + + + +def _setupRaceState(numDeferreds: int) -> tuple[list[int], list[Deferred[object]]]: + """ + Create a list of Deferreds and a corresponding list of integers + tracking how many times each Deferred has been cancelled. Without + additional steps the Deferreds will never fire. + """ + cancelledState = [0] * numDeferreds + + ds: list[Deferred[object]] = [] + for n in range(numDeferreds): + + def cancel(d: Deferred, n: int = n) -> None: + cancelledState[n] += 1 + + ds.append(Deferred(canceller=cancel)) + + return cancelledState, ds + + +class RaceTests(unittest.SynchronousTestCase): + """ + Tests for L{race}. + """ + + @given( + beforeWinner=integers(min_value=0, max_value=3), + afterWinner=integers(min_value=0, max_value=3), + ) + def test_success(self, beforeWinner: int, afterWinner: int) -> None: + """ + When one of the L{Deferred}s passed to L{race} fires successfully, + the L{Deferred} return by L{race} fires with the index of that + L{Deferred} and its result and cancels the rest of the L{Deferred}s. + @param beforeWinner: A randomly selected number of Deferreds to + appear before the "winning" Deferred in the list passed in. + @param beforeWinner: A randomly selected number of Deferreds to + appear after the "winning" Deferred in the list passed in. + """ + cancelledState, ds = _setupRaceState(beforeWinner + 1 + afterWinner) + + raceResult = race(ds) + expected = object() + ds[beforeWinner].callback(expected) + + # The result should be the index and result of the only Deferred that + # fired. + self.assertEqual( + self.successResultOf(raceResult), + (beforeWinner, expected), + ) + # All Deferreds except the winner should have been cancelled once. + expectedCancelledState = [1] * beforeWinner + [0] + [1] * afterWinner + self.assertEqual( + cancelledState, + expectedCancelledState, + ) + + @given( + beforeWinner=integers(min_value=0, max_value=3), + afterWinner=integers(min_value=0, max_value=3), + ) + def test_failure(self, beforeWinner: int, afterWinner: int) -> None: + """ + When all of the L{Deferred}s passed to L{race} fire with failures, + the L{Deferred} return by L{race} fires with L{MultiFailure} wrapping + all of their failures. + @param beforeWinner: A randomly selected number of Deferreds to + appear before the "winning" Deferred in the list passed in. + @param beforeWinner: A randomly selected number of Deferreds to + appear after the "winning" Deferred in the list passed in. + """ + cancelledState, ds = _setupRaceState(beforeWinner + 1 + afterWinner) + + failure = Failure(Exception("The test demands failures.")) + raceResult = race(ds) + for d in ds: + d.errback(failure) + + actualFailure = self.failureResultOf(raceResult, MultiFailure) + self.assertEqual( + actualFailure.value.failures, + [failure] * len(ds), + ) + self.assertEqual( + cancelledState, + [0] * len(ds), + ) + + @given( + beforeWinner=integers(min_value=0, max_value=3), + afterWinner=integers(min_value=0, max_value=3), + ) + def test_resultAfterCancel(self, beforeWinner: int, afterWinner: int) -> None: + """ + If one of the Deferreds fires after it was cancelled its result + goes nowhere. In particular, it does not cause any errors to be + logged. + """ + # Ensure we have a Deferred to win and at least one other Deferred + # that can ignore cancellation. + ds: list[Deferred[None]] = [ + Deferred() for n in range(beforeWinner + 2 + afterWinner) + ] + + raceResult = race(ds) + ds[beforeWinner].callback(None) + ds[beforeWinner + 1].callback(None) + + self.successResultOf(raceResult) + self.assertEqual(len(self.flushLoggedErrors()), 0) + + def test_resultFromCancel(self) -> None: + """ + If one of the input Deferreds has a cancel function that fires it + with success, nothing bad happens. + """ + winner: Deferred[object] = Deferred() + ds: list[Deferred[object]] = [ + winner, + Deferred(canceller=lambda d: d.callback(object())), + ] + expected = object() + raceResult = race(ds) + winner.callback(expected) + + self.assertEqual(self.successResultOf(raceResult), (0, expected)) + + @given( + numDeferreds=integers(min_value=1, max_value=3), + ) + def test_cancel(self, numDeferreds: int) -> None: + """ + If the result of L{race} is cancelled then all of the L{Deferred}s + passed in are cancelled. + """ + cancelledState, ds = _setupRaceState(numDeferreds) + + raceResult = race(ds) + raceResult.cancel() + + self.assertEqual(cancelledState, [1] * numDeferreds) + self.failureResultOf(raceResult, MultiFailure) diff --git a/src/allmydata/test/test_storage_client.py b/src/allmydata/test/test_storage_client.py index 109122da6..91668e7ca 100644 --- a/src/allmydata/test/test_storage_client.py +++ b/src/allmydata/test/test_storage_client.py @@ -1,22 +1,16 @@ """ -Ported from Python 3. +Tests for allmydata.storage_client. """ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -from __future__ import unicode_literals -from future.utils import PY2 -if PY2: - from future.builtins import filter, map, zip, ascii, chr, hex, input, next, oct, open, pow, round, super, bytes, dict, list, object, range, str, max, min # noqa: F401 - -from six import ensure_text +from __future__ import annotations from json import ( loads, ) - import hashlib +from typing import Union, Any, Optional + +from hyperlink import DecodedURL from fixtures import ( TempDir, ) @@ -60,6 +54,7 @@ from twisted.internet.defer import ( from twisted.python.filepath import ( FilePath, ) +from twisted.internet.task import Clock from foolscap.api import ( Tub, @@ -94,7 +89,8 @@ from allmydata.storage_client import ( StorageFarmBroker, _FoolscapStorage, _NullStorage, - ANONYMOUS_STORAGE_NURLS + _pick_a_http_server, + ANONYMOUS_STORAGE_NURLS, ) from ..storage.server import ( StorageServer, @@ -478,7 +474,7 @@ class StoragePluginWebPresence(AsyncTestCase): # config validation policy). "tub.port": tubport_endpoint, "tub.location": tubport_location, - "web.port": ensure_text(webport_endpoint), + "web.port": str(webport_endpoint), }, storage_plugin=self.storage_plugin, basedir=self.basedir, @@ -781,3 +777,56 @@ storage: StorageFarmBroker._should_we_use_http(node_config, announcement), expected_http_usage ) + + +class PickHTTPServerTests(unittest.SynchronousTestCase): + """Tests for ``_pick_a_http_server``.""" + + def pick_result(self, url_to_results: dict[DecodedURL, tuple[float, Union[Exception, Any]]]) -> Optional[DecodedURL]: + """ + Given mapping of URLs to (delay, result), return the URL of the + first selected server, or None. + """ + clock = Clock() + + def request(reactor, url): + delay, value = url_to_results[url] + result = Deferred() + def add_result_value(): + if isinstance(value, Exception): + result.errback(value) + else: + result.callback(value) + reactor.callLater(delay, add_result_value) + return result + + d = _pick_a_http_server(clock, list(url_to_results.keys()), request) + for i in range(100): + clock.advance(0.1) + return self.successResultOf(d) + + def test_first_successful_connect_is_picked(self): + """ + Given multiple good URLs, the first one that connects is chosen. + """ + earliest_url = DecodedURL.from_text("http://a") + latest_url = DecodedURL.from_text("http://b") + bad_url = DecodedURL.from_text("http://bad") + result = self.pick_result({ + latest_url: (2, None), + earliest_url: (1, None), + bad_url: (0.5, RuntimeError()), + }) + self.assertEqual(result, earliest_url) + + def test_failures_are_turned_into_none(self): + """ + If the requests all fail, ``_pick_a_http_server`` returns ``None``. + """ + eventually_good_url = DecodedURL.from_text("http://good") + bad_url = DecodedURL.from_text("http://bad") + result = self.pick_result({ + eventually_good_url: (1, ZeroDivisionError()), + bad_url: (0.1, RuntimeError()) + }) + self.assertEqual(result, None) diff --git a/src/allmydata/util/deferredutil.py b/src/allmydata/util/deferredutil.py index 782663e8b..83de411ce 100644 --- a/src/allmydata/util/deferredutil.py +++ b/src/allmydata/util/deferredutil.py @@ -1,15 +1,18 @@ """ Utilities for working with Twisted Deferreds. - -Ported to Python 3. """ +from __future__ import annotations + import time from functools import wraps from typing import ( Callable, Any, + Sequence, + TypeVar, + Optional, ) from foolscap.api import eventually @@ -17,6 +20,7 @@ from eliot.twisted import ( inline_callbacks, ) from twisted.internet import defer, reactor, error +from twisted.internet.defer import Deferred from twisted.python.failure import Failure from allmydata.util import log @@ -234,3 +238,95 @@ def async_to_deferred(f): return defer.Deferred.fromCoroutine(f(*args, **kwargs)) return not_async + + +class MultiFailure(Exception): + """ + More than one failure occurred. + """ + + def __init__(self, failures: Sequence[Failure]) -> None: + super(MultiFailure, self).__init__() + self.failures = failures + + +_T = TypeVar("_T") + +# Eventually this should be in Twisted upstream: +# https://github.com/twisted/twisted/pull/11818 +def race(ds: Sequence[Deferred[_T]]) -> Deferred[tuple[int, _T]]: + """ + Select the first available result from the sequence of Deferreds and + cancel the rest. + @return: A cancellable L{Deferred} that fires with the index and output of + the element of C{ds} to have a success result first, or that fires + with L{MultiFailure} holding a list of their failures if they all + fail. + """ + # Keep track of the Deferred for the action which completed first. When + # it completes, all of the other Deferreds will get cancelled but this one + # shouldn't be. Even though it "completed" it isn't really done - the + # caller will still be using it for something. If we cancelled it, + # cancellation could propagate down to them. + winner: Optional[Deferred] = None + + # The cancellation function for the Deferred this function returns. + def cancel(result: Deferred) -> None: + # If it is cancelled then we cancel all of the Deferreds for the + # individual actions because there is no longer the possibility of + # delivering any of their results anywhere. We don't have to fire + # `result` because the Deferred will do that for us. + for d in to_cancel: + d.cancel() + + # The Deferred that this function will return. It will fire with the + # index and output of the action that completes first, or None if all of + # the actions fail. If it is cancelled, all of the actions will be + # cancelled. + final_result: Deferred[tuple[int, _T]] = Deferred(canceller=cancel) + + # A callback for an individual action. + def succeeded(this_output: _T, this_index: int) -> None: + # If it is the first action to succeed then it becomes the "winner", + # its index/output become the externally visible result, and the rest + # of the action Deferreds get cancelled. If it is not the first + # action to succeed (because some action did not support + # cancellation), just ignore the result. It is uncommon for this + # callback to be entered twice. The only way it can happen is if one + # of the input Deferreds has a cancellation function that fires the + # Deferred with a success result. + nonlocal winner + if winner is None: + # This is the first success. Act on it. + winner = to_cancel[this_index] + + # Cancel the rest. + for d in to_cancel: + if d is not winner: + d.cancel() + + # Fire our Deferred + final_result.callback((this_index, this_output)) + + # Keep track of how many actions have failed. If they all fail we need to + # deliver failure notification on our externally visible result. + failure_state = [] + + def failed(failure: Failure, this_index: int) -> None: + failure_state.append((this_index, failure)) + if len(failure_state) == len(to_cancel): + # Every operation failed. + failure_state.sort() + failures = [f for (ignored, f) in failure_state] + final_result.errback(MultiFailure(failures)) + + # Copy the sequence of Deferreds so we know it doesn't get mutated out + # from under us. + to_cancel = list(ds) + for index, d in enumerate(ds): + # Propagate the position of this action as well as the argument to f + # to the success callback so we can cancel the right Deferreds and + # propagate the result outwards. + d.addCallbacks(succeeded, failed, callbackArgs=(index,), errbackArgs=(index,)) + + return final_result diff --git a/tox.ini b/tox.ini index 3e2dacbb2..382ba973e 100644 --- a/tox.ini +++ b/tox.ini @@ -10,6 +10,7 @@ python = 3.8: py38-coverage 3.9: py39-coverage 3.10: py310-coverage + 3.11: py311-coverage pypy-3.8: pypy38 pypy-3.9: pypy39 @@ -17,7 +18,7 @@ python = twisted = 1 [tox] -envlist = typechecks,codechecks,py{38,39,310}-{coverage},pypy27,pypy38,pypy39,integration +envlist = typechecks,codechecks,py{38,39,310,311}-{coverage},pypy27,pypy38,pypy39,integration minversion = 2.4 [testenv]