Refactor HTTP client creation to be more centralized.

This commit is contained in:
Itamar Turner-Trauring 2023-06-06 10:58:16 -04:00
parent 940600e0ed
commit 5af0ead5b9
4 changed files with 128 additions and 108 deletions

View File

@ -16,6 +16,7 @@ from typing import (
Set,
Dict,
Callable,
ClassVar,
)
from base64 import b64encode
from io import BytesIO
@ -60,6 +61,15 @@ from .http_common import (
from .common import si_b2a, si_to_human_readable
from ..util.hashutil import timing_safe_compare
from ..util.deferredutil import async_to_deferred
from ..util.tor_provider import _Provider as TorProvider
try:
from txtorcon import Tor # type: ignore
except ImportError:
class Tor:
pass
_OPENSSL = Binding().lib
@ -302,18 +312,30 @@ class _StorageClientHTTPSPolicy:
)
@define(hash=True)
class StorageClient(object):
@define
class StorageClientFactory:
"""
Low-level HTTP client that talks to the HTTP storage server.
Create ``StorageClient`` instances, using appropriate
``twisted.web.iweb.IAgent`` for different connection methods: normal TCP,
Tor, and eventually I2P.
There is some caching involved since there might be shared setup work, e.g.
connecting to the local Tor service only needs to happen once.
"""
# If set, we're doing unit testing and we should call this with
# HTTPConnectionPool we create.
TEST_MODE_REGISTER_HTTP_POOL = None
_default_connection_handlers: dict[str, str]
_tor_provider: Optional[TorProvider]
# Cache the Tor instance created by the provider, if relevant.
_tor_instance: Optional[Tor] = None
# If set, we're doing unit testing and we should call this with any
# HTTPConnectionPool that gets passed/created to ``create_agent()``.
TEST_MODE_REGISTER_HTTP_POOL = ClassVar[
Optional[Callable[[HTTPConnectionPool], None]]
]
@classmethod
def start_test_mode(cls, callback):
def start_test_mode(cls, callback: Callable[[HTTPConnectionPool], None]) -> None:
"""Switch to testing mode.
In testing mode we register the pool with test system using the given
@ -328,66 +350,84 @@ class StorageClient(object):
"""Stop testing mode."""
cls.TEST_MODE_REGISTER_HTTP_POOL = None
# The URL is a HTTPS URL ("https://..."). To construct from a NURL, use
# ``StorageClient.from_nurl()``.
_base_url: DecodedURL
_swissnum: bytes
_treq: Union[treq, StubTreq, HTTPClient]
_pool: Optional[HTTPConnectionPool]
_clock: IReactorTime
@classmethod
def from_nurl(
cls,
async def _create_agent(
self,
nurl: DecodedURL,
reactor,
# TODO default_connection_handlers should really be a class, not a dict
# of strings...
default_connection_handlers: dict[str, str],
pool: Optional[HTTPConnectionPool] = None,
agent_factory: Optional[
Callable[[object, IPolicyForHTTPS, HTTPConnectionPool], IAgent]
] = None,
) -> StorageClient:
"""
Create a ``StorageClient`` for the given NURL.
"""
# Safety check: if we're using normal TCP connections, we better not be
# configured for Tor or I2P.
if agent_factory is None:
assert default_connection_handlers["tcp"] == "tcp"
reactor: object,
tls_context_factory: IPolicyForHTTPS,
pool: HTTPConnectionPool,
) -> IAgent:
"""Create a new ``IAgent``, possibly using Tor."""
if self.TEST_MODE_REGISTER_HTTP_POOL is not None:
self.TEST_MODE_REGISTER_HTTP_POOL(pool)
# TODO default_connection_handlers should really be an object, not a
# dict, so we can ask "is this using Tor" without poking at a
# dictionary with arbitrary strings... See
# https://tahoe-lafs.org/trac/tahoe-lafs/ticket/4032
handler = self._default_connection_handlers["tcp"]
if handler == "tcp":
return Agent(reactor, tls_context_factory, pool=pool)
if handler == "tor": # TODO or nurl.scheme == "pb+tor":
assert self._tor_provider is not None
if self._tor_instance is None:
self._tor_instance = await self._tor_provider.get_tor_instance(reactor)
return self._tor_instance.web_agent(
pool=pool, tls_context_factory=tls_context_factory
)
else:
raise RuntimeError(f"Unsupported tcp connection handler: {handler}")
async def create_storage_client(
self,
nurl: DecodedURL,
reactor: IReactorTime,
pool: Optional[HTTPConnectionPool] = None,
) -> StorageClient:
"""Create a new ``StorageClient`` for the given NURL."""
assert nurl.fragment == "v=1"
assert nurl.scheme == "pb"
swissnum = nurl.path[0].encode("ascii")
certificate_hash = nurl.user.encode("ascii")
assert nurl.scheme in ("pb", "pb+tor")
if pool is None:
pool = HTTPConnectionPool(reactor)
pool.maxPersistentPerHost = 10
if cls.TEST_MODE_REGISTER_HTTP_POOL is not None:
cls.TEST_MODE_REGISTER_HTTP_POOL(pool)
def default_agent_factory(
reactor: object,
tls_context_factory: IPolicyForHTTPS,
pool: HTTPConnectionPool,
) -> IAgent:
return Agent(reactor, tls_context_factory, pool=pool)
if agent_factory is None:
agent_factory = default_agent_factory
treq_client = HTTPClient(
agent_factory(
reactor,
_StorageClientHTTPSPolicy(expected_spki_hash=certificate_hash),
pool,
)
certificate_hash = nurl.user.encode("ascii")
agent = await self._create_agent(
nurl,
reactor,
_StorageClientHTTPSPolicy(expected_spki_hash=certificate_hash),
pool,
)
treq_client = HTTPClient(agent)
https_url = DecodedURL().replace(scheme="https", host=nurl.host, port=nurl.port)
swissnum = nurl.path[0].encode("ascii")
return StorageClient(
https_url,
swissnum,
treq_client,
pool,
reactor,
self.TEST_MODE_REGISTER_HTTP_POOL is not None,
)
https_url = DecodedURL().replace(scheme="https", host=nurl.host, port=nurl.port)
return cls(https_url, swissnum, treq_client, pool, reactor)
@define(hash=True)
class StorageClient(object):
"""
Low-level HTTP client that talks to the HTTP storage server.
Create using a ``StorageClientFactory`` instance.
"""
# The URL should be a HTTPS URL ("https://...")
_base_url: DecodedURL
_swissnum: bytes
_treq: Union[treq, StubTreq, HTTPClient]
_pool: HTTPConnectionPool
_clock: IReactorTime
# Are we running unit tests?
_test_mode: bool
def relative_url(self, path: str) -> DecodedURL:
"""Get a URL relative to the base URL."""
@ -495,12 +535,11 @@ class StorageClient(object):
method, url, headers=headers, timeout=timeout, **kwargs
)
if self.TEST_MODE_REGISTER_HTTP_POOL is not None:
if response.code != 404:
# We're doing API queries, HTML is never correct except in 404, but
# it's the default for Twisted's web server so make sure nothing
# unexpected happened.
assert get_content_type(response.headers) != "text/html"
if self._test_mode and response.code != 404:
# We're doing API queries, HTML is never correct except in 404, but
# it's the default for Twisted's web server so make sure nothing
# unexpected happened.
assert get_content_type(response.headers) != "text/html"
return response
@ -529,8 +568,7 @@ class StorageClient(object):
def shutdown(self) -> Deferred:
"""Shutdown any connections."""
if self._pool is not None:
return self._pool.closeCachedConnections()
return self._pool.closeCachedConnections()
@define(hash=True)

View File

@ -89,7 +89,8 @@ from allmydata.util.deferredutil import async_to_deferred, race
from allmydata.storage.http_client import (
StorageClient, StorageClientImmutables, StorageClientGeneral,
ClientException as HTTPClientException, StorageClientMutables,
ReadVector, TestWriteVectors, WriteVector, TestVector, ClientException
ReadVector, TestWriteVectors, WriteVector, TestVector, ClientException,
StorageClientFactory
)
from .node import _Config
@ -1068,8 +1069,9 @@ class HTTPNativeStorageServer(service.MultiService):
self._on_status_changed = ObserverList()
self._reactor = reactor
self._grid_manager_verifier = grid_manager_verifier
self._tor_provider = tor_provider
self._default_connection_handlers = default_connection_handlers
self._storage_client_factory = StorageClientFactory(
default_connection_handlers, tor_provider
)
furl = announcement["anonymous-storage-FURL"].encode("utf-8")
(
@ -1232,26 +1234,6 @@ class HTTPNativeStorageServer(service.MultiService):
self._connecting_deferred = connecting
return connecting
async def _agent_factory(self) -> Optional[Callable[[object, IPolicyForHTTPS, HTTPConnectionPool],IAgent]]:
"""Return a factory for ``twisted.web.iweb.IAgent``."""
# TODO default_connection_handlers should really be an object, not a
# dict, so we can ask "is this using Tor" without poking at a
# dictionary with arbitrary strings... See
# https://tahoe-lafs.org/trac/tahoe-lafs/ticket/4032
handler = self._default_connection_handlers["tcp"]
if handler == "tcp":
return None
if handler == "tor":
assert self._tor_provider is not None
tor_instance = await self._tor_provider.get_tor_instance(self._reactor)
def agent_factory(reactor: object, tls_context_factory: IPolicyForHTTPS, pool: HTTPConnectionPool) -> IAgent:
assert reactor == self._reactor
return tor_instance.web_agent(pool=pool, tls_context_factory=tls_context_factory)
return agent_factory
else:
raise RuntimeError(f"Unsupported tcp connection handler: {handler}")
@async_to_deferred
async def _pick_server_and_get_version(self):
"""
@ -1270,28 +1252,24 @@ class HTTPNativeStorageServer(service.MultiService):
# version() calls before we are live talking to a server, it could only
# be one. See https://tahoe-lafs.org/trac/tahoe-lafs/ticket/3992
agent_factory = await self._agent_factory()
def request(reactor, nurl: DecodedURL):
@async_to_deferred
async def request(reactor, nurl: DecodedURL):
# Since we're just using this one off to check if the NURL
# works, no need for persistent pool or other fanciness.
pool = HTTPConnectionPool(reactor, persistent=False)
pool.retryAutomatically = False
return StorageClientGeneral(
StorageClient.from_nurl(
nurl, reactor, self._default_connection_handlers,
pool=pool, agent_factory=agent_factory)
).get_version()
storage_client = await self._storage_client_factory.create_storage_client(
nurl, reactor, pool
)
return await StorageClientGeneral(storage_client).get_version()
nurl = await _pick_a_http_server(reactor, self._nurls, request)
# If we've gotten this far, we've found a working NURL.
self._istorage_server = _HTTPStorageServer.from_http_client(
StorageClient.from_nurl(
nurl, reactor, self._default_connection_handlers,
agent_factory=agent_factory
)
storage_client = await self._storage_client_factory.create_storage_client(
nurl, reactor, None
)
self._istorage_server = _HTTPStorageServer.from_http_client(storage_client)
return self._istorage_server
try:

View File

@ -686,8 +686,8 @@ class SystemTestMixin(pollmixin.PollMixin, testutil.StallMixin):
def setUp(self):
self._http_client_pools = []
http_client.StorageClient.start_test_mode(self._got_new_http_connection_pool)
self.addCleanup(http_client.StorageClient.stop_test_mode)
http_client.StorageClientFactory.start_test_mode(self._got_new_http_connection_pool)
self.addCleanup(http_client.StorageClientFactory.stop_test_mode)
self.port_assigner = SameProcessStreamEndpointAssigner()
self.port_assigner.setUp()
self.addCleanup(self.port_assigner.tearDown)

View File

@ -58,6 +58,7 @@ from ..storage.http_server import (
)
from ..storage.http_client import (
StorageClient,
StorageClientFactory,
ClientException,
StorageClientImmutables,
ImmutableCreateResult,
@ -323,10 +324,10 @@ class CustomHTTPServerTests(SyncTestCase):
def setUp(self):
super(CustomHTTPServerTests, self).setUp()
StorageClient.start_test_mode(
StorageClientFactory.start_test_mode(
lambda pool: self.addCleanup(pool.closeCachedConnections)
)
self.addCleanup(StorageClient.stop_test_mode)
self.addCleanup(StorageClientFactory.stop_test_mode)
# Could be a fixture, but will only be used in this test class so not
# going to bother:
self._http_server = TestApp()
@ -341,6 +342,7 @@ class CustomHTTPServerTests(SyncTestCase):
# fixed if https://github.com/twisted/treq/issues/226 were ever
# fixed.
clock=treq._agent._memoryReactor,
test_mode=True,
)
self._http_server.clock = self.client._clock
@ -529,10 +531,10 @@ class HttpTestFixture(Fixture):
"""
def _setUp(self):
StorageClient.start_test_mode(
StorageClientFactory.start_test_mode(
lambda pool: self.addCleanup(pool.closeCachedConnections)
)
self.addCleanup(StorageClient.stop_test_mode)
self.addCleanup(StorageClientFactory.stop_test_mode)
self.clock = Reactor()
self.tempdir = self.useFixture(TempDir())
# The global Cooperator used by Twisted (a) used by pull producers in
@ -558,6 +560,7 @@ class HttpTestFixture(Fixture):
treq=self.treq,
pool=None,
clock=self.clock,
test_mode=True,
)
def result_of_with_flush(self, d):
@ -671,6 +674,7 @@ class GenericHTTPAPITests(SyncTestCase):
treq=StubTreq(self.http.http_server.get_resource()),
pool=None,
clock=self.http.clock,
test_mode=True,
)
)
with assert_fails_with_http_code(self, http.UNAUTHORIZED):