From 3d0b17bc1c197a73b212b6b5eac0ad3b3ee43297 Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Mon, 27 Feb 2023 11:37:18 -0500 Subject: [PATCH] Make cancellation more likely to happen. --- src/allmydata/storage_client.py | 88 ++++++++++++++--------- src/allmydata/test/test_storage_client.py | 28 ++++---- 2 files changed, 72 insertions(+), 44 deletions(-) diff --git a/src/allmydata/storage_client.py b/src/allmydata/storage_client.py index a2726fe09..549062d63 100644 --- a/src/allmydata/storage_client.py +++ b/src/allmydata/storage_client.py @@ -47,7 +47,7 @@ from zope.interface import ( ) from twisted.python.failure import Failure from twisted.web import http -from twisted.internet.task import LoopingCall, deferLater +from twisted.internet.task import LoopingCall from twisted.internet import defer, reactor from twisted.application import service from twisted.plugin import ( @@ -935,42 +935,52 @@ class NativeStorageServer(service.MultiService): self._reconnector.reset() -async def _pick_a_http_server( +def _pick_a_http_server( reactor, nurls: list[DecodedURL], request: Callable[[Any, DecodedURL], defer.Deferred[Any]] -) -> DecodedURL: - """Pick the first server we successfully send a request to.""" - while True: - result : defer.Deferred[Optional[DecodedURL]] = defer.Deferred() +) -> defer.Deferred[Optional[DecodedURL]]: + """Pick the first server we successfully send a request to. - def succeeded(nurl: DecodedURL, result=result): - # Only need the first successful NURL: - if result.called: - return - result.callback(nurl) + Fires with ``None`` if no server was found, or with the ``DecodedURL`` of + the first successfully-connected server. + """ - def failed(failure, failures=[], result=result): - # Logging errors breaks a bunch of tests, and it's not a _bug_ to - # have a failed connection, it's often expected and transient. More - # of a warning, really? - log.msg("Failed to connect to NURL: {}".format(failure)) - failures.append(None) - if len(failures) == len(nurls): - # All our potential NURLs failed... - result.callback(None) + to_cancel : list[defer.Deferred] = [] - for index, nurl in enumerate(nurls): - request(reactor, nurl).addCallback( - lambda _, nurl=nurl: nurl).addCallbacks(succeeded, failed) + def cancel(result: Optional[defer.Deferred]): + for d in to_cancel: + if not d.called: + d.cancel() + if result is not None: + result.errback(defer.CancelledError()) - first_nurl = await result - if first_nurl is None: - # Failed to connect to any of the NURLs, try again in a few - # seconds: - await deferLater(reactor, 5, lambda: None) - else: - return first_nurl + result : defer.Deferred[Optional[DecodedURL]] = defer.Deferred(canceller=cancel) + + def succeeded(nurl: DecodedURL, result=result): + # Only need the first successful NURL: + if result.called: + return + result.callback(nurl) + # No point in continuing other requests if we're connected: + cancel(None) + + def failed(failure, failures=[], result=result): + # Logging errors breaks a bunch of tests, and it's not a _bug_ to + # have a failed connection, it's often expected and transient. More + # of a warning, really? + log.msg("Failed to connect to NURL: {}".format(failure)) + failures.append(None) + if len(failures) == len(nurls): + # All our potential NURLs failed... + result.callback(None) + + for index, nurl in enumerate(nurls): + d = request(reactor, nurl) + to_cancel.append(d) + d.addCallback(lambda _, nurl=nurl: nurl).addCallbacks(succeeded, failed) + + return result @implementer(IServer) @@ -1117,8 +1127,22 @@ class HTTPNativeStorageServer(service.MultiService): StorageClient.from_nurl(nurl, reactor) ).get_version() - nurl = await _pick_a_http_server(reactor, self._nurls, request) - self._istorage_server = _HTTPStorageServer.from_http_client( + # LoopingCall.stop() doesn't cancel Deferreds, unfortunately: + # https://github.com/twisted/twisted/issues/11814 Thus we want + # store the Deferred so it gets cancelled. + picking = _pick_a_http_server(reactor, self._nurls, request) + self._connecting_deferred = picking + try: + nurl = await picking + finally: + self._connecting_deferred = None + + if nurl is None: + # We failed to find a server to connect to. Perhaps the next + # iteration of the loop will succeed. + return + else: + self._istorage_server = _HTTPStorageServer.from_http_client( StorageClient.from_nurl(nurl, reactor) ) diff --git a/src/allmydata/test/test_storage_client.py b/src/allmydata/test/test_storage_client.py index d7420b62f..a51e44a82 100644 --- a/src/allmydata/test/test_storage_client.py +++ b/src/allmydata/test/test_storage_client.py @@ -83,7 +83,6 @@ from allmydata.webish import ( WebishServer, ) from allmydata.util import base32, yamlutil -from allmydata.util.deferredutil import async_to_deferred from allmydata.storage_client import ( IFoolscapStorageServer, NativeStorageServer, @@ -741,7 +740,7 @@ storage: class PickHTTPServerTests(unittest.SynchronousTestCase): """Tests for ``_pick_a_http_server``.""" - def loop_until_result(self, url_to_results: dict[DecodedURL, list[tuple[float, Union[Exception, Any]]]]) -> Deferred[DecodedURL]: + def loop_until_result(self, url_to_results: dict[DecodedURL, list[tuple[float, Union[Exception, Any]]]]) -> tuple[int, DecodedURL]: """ Given mapping of URLs to list of (delay, result), return the URL of the first selected server. @@ -759,12 +758,15 @@ class PickHTTPServerTests(unittest.SynchronousTestCase): reactor.callLater(delay, add_result_value) return result - d = async_to_deferred(_pick_a_http_server)( - clock, list(url_to_results.keys()), request - ) - for i in range(1000): - clock.advance(0.1) - return d + iterations = 0 + while True: + iterations += 1 + d = _pick_a_http_server(clock, list(url_to_results.keys()), request) + for i in range(100): + clock.advance(0.1) + result = self.successResultOf(d) + if result is not None: + return iterations, result def test_first_successful_connect_is_picked(self): """ @@ -772,11 +774,12 @@ class PickHTTPServerTests(unittest.SynchronousTestCase): """ earliest_url = DecodedURL.from_text("http://a") latest_url = DecodedURL.from_text("http://b") - d = self.loop_until_result({ + iterations, result = self.loop_until_result({ latest_url: [(2, None)], earliest_url: [(1, None)] }) - self.assertEqual(self.successResultOf(d), earliest_url) + self.assertEqual(iterations, 1) + self.assertEqual(result, earliest_url) def test_failures_are_retried(self): """ @@ -785,10 +788,11 @@ class PickHTTPServerTests(unittest.SynchronousTestCase): """ eventually_good_url = DecodedURL.from_text("http://good") bad_url = DecodedURL.from_text("http://bad") - d = self.loop_until_result({ + iterations, result = self.loop_until_result({ eventually_good_url: [ (1, ZeroDivisionError()), (0.1, ZeroDivisionError()), (1, None) ], bad_url: [(0.1, RuntimeError()), (0.1, RuntimeError()), (0.1, RuntimeError())] }) - self.assertEqual(self.successResultOf(d), eventually_good_url) + self.assertEqual(iterations, 3) + self.assertEqual(result, eventually_good_url)