Refactor to use race().

This commit is contained in:
Itamar Turner-Trauring
2023-03-08 14:36:37 -05:00
parent 75da037d67
commit 0093edcd93

View File

@ -82,7 +82,7 @@ from allmydata.util.observer import ObserverList
from allmydata.util.rrefutil import add_version_to_remote_reference
from allmydata.util.hashutil import permute_server_hash
from allmydata.util.dictutil import BytesKeyDict, UnicodeKeyDict
from allmydata.util.deferredutil import async_to_deferred
from allmydata.util.deferredutil import async_to_deferred, race
from allmydata.storage.http_client import (
StorageClient, StorageClientImmutables, StorageClientGeneral,
ClientException as HTTPClientException, StorageClientMutables,
@ -1017,42 +1017,23 @@ def _pick_a_http_server(
Fires with ``None`` if no server was found, or with the ``DecodedURL`` of
the first successfully-connected server.
"""
queries = race([
request(reactor, nurl).addCallback(lambda _, nurl=nurl: nurl)
for nurl in nurls
])
to_cancel : list[defer.Deferred] = []
def cancel(result: Optional[defer.Deferred]):
for d in to_cancel:
if not d.called:
d.cancel()
if result is not None:
result.errback(defer.CancelledError())
result : defer.Deferred[Optional[DecodedURL]] = defer.Deferred(canceller=cancel)
def succeeded(nurl: DecodedURL, result=result):
# Only need the first successful NURL:
if result.called:
return
result.callback(nurl)
# No point in continuing other requests if we're connected:
cancel(None)
def failed(failure, failures=[], result=result):
def failed(failure: Failure):
# Logging errors breaks a bunch of tests, and it's not a _bug_ to
# have a failed connection, it's often expected and transient. More
# of a warning, really?
log.msg("Failed to connect to NURL: {}".format(failure))
failures.append(None)
if len(failures) == len(nurls):
# All our potential NURLs failed...
result.callback(None)
return None
for index, nurl in enumerate(nurls):
d = request(reactor, nurl)
to_cancel.append(d)
d.addCallback(lambda _, nurl=nurl: nurl).addCallbacks(succeeded, failed)
def succeeded(result: tuple[int, DecodedURL]):
_, nurl = result
return nurl
return result
return queries.addCallbacks(succeeded, failed)
@implementer(IServer)