mirror of
https://github.com/tahoe-lafs/tahoe-lafs.git
synced 2024-12-19 04:57:54 +00:00
Refactor HTTP client creation to be more centralized.
This commit is contained in:
parent
940600e0ed
commit
5af0ead5b9
@ -16,6 +16,7 @@ from typing import (
|
||||
Set,
|
||||
Dict,
|
||||
Callable,
|
||||
ClassVar,
|
||||
)
|
||||
from base64 import b64encode
|
||||
from io import BytesIO
|
||||
@ -60,6 +61,15 @@ from .http_common import (
|
||||
from .common import si_b2a, si_to_human_readable
|
||||
from ..util.hashutil import timing_safe_compare
|
||||
from ..util.deferredutil import async_to_deferred
|
||||
from ..util.tor_provider import _Provider as TorProvider
|
||||
|
||||
try:
|
||||
from txtorcon import Tor # type: ignore
|
||||
except ImportError:
|
||||
|
||||
class Tor:
|
||||
pass
|
||||
|
||||
|
||||
_OPENSSL = Binding().lib
|
||||
|
||||
@ -302,18 +312,30 @@ class _StorageClientHTTPSPolicy:
|
||||
)
|
||||
|
||||
|
||||
@define(hash=True)
|
||||
class StorageClient(object):
|
||||
@define
|
||||
class StorageClientFactory:
|
||||
"""
|
||||
Low-level HTTP client that talks to the HTTP storage server.
|
||||
Create ``StorageClient`` instances, using appropriate
|
||||
``twisted.web.iweb.IAgent`` for different connection methods: normal TCP,
|
||||
Tor, and eventually I2P.
|
||||
|
||||
There is some caching involved since there might be shared setup work, e.g.
|
||||
connecting to the local Tor service only needs to happen once.
|
||||
"""
|
||||
|
||||
# If set, we're doing unit testing and we should call this with
|
||||
# HTTPConnectionPool we create.
|
||||
TEST_MODE_REGISTER_HTTP_POOL = None
|
||||
_default_connection_handlers: dict[str, str]
|
||||
_tor_provider: Optional[TorProvider]
|
||||
# Cache the Tor instance created by the provider, if relevant.
|
||||
_tor_instance: Optional[Tor] = None
|
||||
|
||||
# If set, we're doing unit testing and we should call this with any
|
||||
# HTTPConnectionPool that gets passed/created to ``create_agent()``.
|
||||
TEST_MODE_REGISTER_HTTP_POOL = ClassVar[
|
||||
Optional[Callable[[HTTPConnectionPool], None]]
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def start_test_mode(cls, callback):
|
||||
def start_test_mode(cls, callback: Callable[[HTTPConnectionPool], None]) -> None:
|
||||
"""Switch to testing mode.
|
||||
|
||||
In testing mode we register the pool with test system using the given
|
||||
@ -328,66 +350,84 @@ class StorageClient(object):
|
||||
"""Stop testing mode."""
|
||||
cls.TEST_MODE_REGISTER_HTTP_POOL = None
|
||||
|
||||
# The URL is a HTTPS URL ("https://..."). To construct from a NURL, use
|
||||
# ``StorageClient.from_nurl()``.
|
||||
_base_url: DecodedURL
|
||||
_swissnum: bytes
|
||||
_treq: Union[treq, StubTreq, HTTPClient]
|
||||
_pool: Optional[HTTPConnectionPool]
|
||||
_clock: IReactorTime
|
||||
|
||||
@classmethod
|
||||
def from_nurl(
|
||||
cls,
|
||||
async def _create_agent(
|
||||
self,
|
||||
nurl: DecodedURL,
|
||||
reactor,
|
||||
# TODO default_connection_handlers should really be a class, not a dict
|
||||
# of strings...
|
||||
default_connection_handlers: dict[str, str],
|
||||
pool: Optional[HTTPConnectionPool] = None,
|
||||
agent_factory: Optional[
|
||||
Callable[[object, IPolicyForHTTPS, HTTPConnectionPool], IAgent]
|
||||
] = None,
|
||||
) -> StorageClient:
|
||||
"""
|
||||
Create a ``StorageClient`` for the given NURL.
|
||||
"""
|
||||
# Safety check: if we're using normal TCP connections, we better not be
|
||||
# configured for Tor or I2P.
|
||||
if agent_factory is None:
|
||||
assert default_connection_handlers["tcp"] == "tcp"
|
||||
reactor: object,
|
||||
tls_context_factory: IPolicyForHTTPS,
|
||||
pool: HTTPConnectionPool,
|
||||
) -> IAgent:
|
||||
"""Create a new ``IAgent``, possibly using Tor."""
|
||||
if self.TEST_MODE_REGISTER_HTTP_POOL is not None:
|
||||
self.TEST_MODE_REGISTER_HTTP_POOL(pool)
|
||||
|
||||
# TODO default_connection_handlers should really be an object, not a
|
||||
# dict, so we can ask "is this using Tor" without poking at a
|
||||
# dictionary with arbitrary strings... See
|
||||
# https://tahoe-lafs.org/trac/tahoe-lafs/ticket/4032
|
||||
handler = self._default_connection_handlers["tcp"]
|
||||
|
||||
if handler == "tcp":
|
||||
return Agent(reactor, tls_context_factory, pool=pool)
|
||||
if handler == "tor": # TODO or nurl.scheme == "pb+tor":
|
||||
assert self._tor_provider is not None
|
||||
if self._tor_instance is None:
|
||||
self._tor_instance = await self._tor_provider.get_tor_instance(reactor)
|
||||
return self._tor_instance.web_agent(
|
||||
pool=pool, tls_context_factory=tls_context_factory
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported tcp connection handler: {handler}")
|
||||
|
||||
async def create_storage_client(
|
||||
self,
|
||||
nurl: DecodedURL,
|
||||
reactor: IReactorTime,
|
||||
pool: Optional[HTTPConnectionPool] = None,
|
||||
) -> StorageClient:
|
||||
"""Create a new ``StorageClient`` for the given NURL."""
|
||||
assert nurl.fragment == "v=1"
|
||||
assert nurl.scheme == "pb"
|
||||
swissnum = nurl.path[0].encode("ascii")
|
||||
certificate_hash = nurl.user.encode("ascii")
|
||||
assert nurl.scheme in ("pb", "pb+tor")
|
||||
if pool is None:
|
||||
pool = HTTPConnectionPool(reactor)
|
||||
pool.maxPersistentPerHost = 10
|
||||
|
||||
if cls.TEST_MODE_REGISTER_HTTP_POOL is not None:
|
||||
cls.TEST_MODE_REGISTER_HTTP_POOL(pool)
|
||||
|
||||
def default_agent_factory(
|
||||
reactor: object,
|
||||
tls_context_factory: IPolicyForHTTPS,
|
||||
pool: HTTPConnectionPool,
|
||||
) -> IAgent:
|
||||
return Agent(reactor, tls_context_factory, pool=pool)
|
||||
|
||||
if agent_factory is None:
|
||||
agent_factory = default_agent_factory
|
||||
|
||||
treq_client = HTTPClient(
|
||||
agent_factory(
|
||||
reactor,
|
||||
_StorageClientHTTPSPolicy(expected_spki_hash=certificate_hash),
|
||||
pool,
|
||||
)
|
||||
certificate_hash = nurl.user.encode("ascii")
|
||||
agent = await self._create_agent(
|
||||
nurl,
|
||||
reactor,
|
||||
_StorageClientHTTPSPolicy(expected_spki_hash=certificate_hash),
|
||||
pool,
|
||||
)
|
||||
treq_client = HTTPClient(agent)
|
||||
https_url = DecodedURL().replace(scheme="https", host=nurl.host, port=nurl.port)
|
||||
swissnum = nurl.path[0].encode("ascii")
|
||||
return StorageClient(
|
||||
https_url,
|
||||
swissnum,
|
||||
treq_client,
|
||||
pool,
|
||||
reactor,
|
||||
self.TEST_MODE_REGISTER_HTTP_POOL is not None,
|
||||
)
|
||||
|
||||
https_url = DecodedURL().replace(scheme="https", host=nurl.host, port=nurl.port)
|
||||
return cls(https_url, swissnum, treq_client, pool, reactor)
|
||||
|
||||
@define(hash=True)
|
||||
class StorageClient(object):
|
||||
"""
|
||||
Low-level HTTP client that talks to the HTTP storage server.
|
||||
|
||||
Create using a ``StorageClientFactory`` instance.
|
||||
"""
|
||||
|
||||
# The URL should be a HTTPS URL ("https://...")
|
||||
_base_url: DecodedURL
|
||||
_swissnum: bytes
|
||||
_treq: Union[treq, StubTreq, HTTPClient]
|
||||
_pool: HTTPConnectionPool
|
||||
_clock: IReactorTime
|
||||
# Are we running unit tests?
|
||||
_test_mode: bool
|
||||
|
||||
def relative_url(self, path: str) -> DecodedURL:
|
||||
"""Get a URL relative to the base URL."""
|
||||
@ -495,12 +535,11 @@ class StorageClient(object):
|
||||
method, url, headers=headers, timeout=timeout, **kwargs
|
||||
)
|
||||
|
||||
if self.TEST_MODE_REGISTER_HTTP_POOL is not None:
|
||||
if response.code != 404:
|
||||
# We're doing API queries, HTML is never correct except in 404, but
|
||||
# it's the default for Twisted's web server so make sure nothing
|
||||
# unexpected happened.
|
||||
assert get_content_type(response.headers) != "text/html"
|
||||
if self._test_mode and response.code != 404:
|
||||
# We're doing API queries, HTML is never correct except in 404, but
|
||||
# it's the default for Twisted's web server so make sure nothing
|
||||
# unexpected happened.
|
||||
assert get_content_type(response.headers) != "text/html"
|
||||
|
||||
return response
|
||||
|
||||
@ -529,8 +568,7 @@ class StorageClient(object):
|
||||
|
||||
def shutdown(self) -> Deferred:
|
||||
"""Shutdown any connections."""
|
||||
if self._pool is not None:
|
||||
return self._pool.closeCachedConnections()
|
||||
return self._pool.closeCachedConnections()
|
||||
|
||||
|
||||
@define(hash=True)
|
||||
|
@ -89,7 +89,8 @@ from allmydata.util.deferredutil import async_to_deferred, race
|
||||
from allmydata.storage.http_client import (
|
||||
StorageClient, StorageClientImmutables, StorageClientGeneral,
|
||||
ClientException as HTTPClientException, StorageClientMutables,
|
||||
ReadVector, TestWriteVectors, WriteVector, TestVector, ClientException
|
||||
ReadVector, TestWriteVectors, WriteVector, TestVector, ClientException,
|
||||
StorageClientFactory
|
||||
)
|
||||
from .node import _Config
|
||||
|
||||
@ -1068,8 +1069,9 @@ class HTTPNativeStorageServer(service.MultiService):
|
||||
self._on_status_changed = ObserverList()
|
||||
self._reactor = reactor
|
||||
self._grid_manager_verifier = grid_manager_verifier
|
||||
self._tor_provider = tor_provider
|
||||
self._default_connection_handlers = default_connection_handlers
|
||||
self._storage_client_factory = StorageClientFactory(
|
||||
default_connection_handlers, tor_provider
|
||||
)
|
||||
|
||||
furl = announcement["anonymous-storage-FURL"].encode("utf-8")
|
||||
(
|
||||
@ -1232,26 +1234,6 @@ class HTTPNativeStorageServer(service.MultiService):
|
||||
self._connecting_deferred = connecting
|
||||
return connecting
|
||||
|
||||
async def _agent_factory(self) -> Optional[Callable[[object, IPolicyForHTTPS, HTTPConnectionPool],IAgent]]:
|
||||
"""Return a factory for ``twisted.web.iweb.IAgent``."""
|
||||
# TODO default_connection_handlers should really be an object, not a
|
||||
# dict, so we can ask "is this using Tor" without poking at a
|
||||
# dictionary with arbitrary strings... See
|
||||
# https://tahoe-lafs.org/trac/tahoe-lafs/ticket/4032
|
||||
handler = self._default_connection_handlers["tcp"]
|
||||
if handler == "tcp":
|
||||
return None
|
||||
if handler == "tor":
|
||||
assert self._tor_provider is not None
|
||||
tor_instance = await self._tor_provider.get_tor_instance(self._reactor)
|
||||
|
||||
def agent_factory(reactor: object, tls_context_factory: IPolicyForHTTPS, pool: HTTPConnectionPool) -> IAgent:
|
||||
assert reactor == self._reactor
|
||||
return tor_instance.web_agent(pool=pool, tls_context_factory=tls_context_factory)
|
||||
return agent_factory
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported tcp connection handler: {handler}")
|
||||
|
||||
@async_to_deferred
|
||||
async def _pick_server_and_get_version(self):
|
||||
"""
|
||||
@ -1270,28 +1252,24 @@ class HTTPNativeStorageServer(service.MultiService):
|
||||
# version() calls before we are live talking to a server, it could only
|
||||
# be one. See https://tahoe-lafs.org/trac/tahoe-lafs/ticket/3992
|
||||
|
||||
agent_factory = await self._agent_factory()
|
||||
|
||||
def request(reactor, nurl: DecodedURL):
|
||||
@async_to_deferred
|
||||
async def request(reactor, nurl: DecodedURL):
|
||||
# Since we're just using this one off to check if the NURL
|
||||
# works, no need for persistent pool or other fanciness.
|
||||
pool = HTTPConnectionPool(reactor, persistent=False)
|
||||
pool.retryAutomatically = False
|
||||
return StorageClientGeneral(
|
||||
StorageClient.from_nurl(
|
||||
nurl, reactor, self._default_connection_handlers,
|
||||
pool=pool, agent_factory=agent_factory)
|
||||
).get_version()
|
||||
storage_client = await self._storage_client_factory.create_storage_client(
|
||||
nurl, reactor, pool
|
||||
)
|
||||
return await StorageClientGeneral(storage_client).get_version()
|
||||
|
||||
nurl = await _pick_a_http_server(reactor, self._nurls, request)
|
||||
|
||||
# If we've gotten this far, we've found a working NURL.
|
||||
self._istorage_server = _HTTPStorageServer.from_http_client(
|
||||
StorageClient.from_nurl(
|
||||
nurl, reactor, self._default_connection_handlers,
|
||||
agent_factory=agent_factory
|
||||
)
|
||||
storage_client = await self._storage_client_factory.create_storage_client(
|
||||
nurl, reactor, None
|
||||
)
|
||||
self._istorage_server = _HTTPStorageServer.from_http_client(storage_client)
|
||||
return self._istorage_server
|
||||
|
||||
try:
|
||||
|
@ -686,8 +686,8 @@ class SystemTestMixin(pollmixin.PollMixin, testutil.StallMixin):
|
||||
|
||||
def setUp(self):
|
||||
self._http_client_pools = []
|
||||
http_client.StorageClient.start_test_mode(self._got_new_http_connection_pool)
|
||||
self.addCleanup(http_client.StorageClient.stop_test_mode)
|
||||
http_client.StorageClientFactory.start_test_mode(self._got_new_http_connection_pool)
|
||||
self.addCleanup(http_client.StorageClientFactory.stop_test_mode)
|
||||
self.port_assigner = SameProcessStreamEndpointAssigner()
|
||||
self.port_assigner.setUp()
|
||||
self.addCleanup(self.port_assigner.tearDown)
|
||||
|
@ -58,6 +58,7 @@ from ..storage.http_server import (
|
||||
)
|
||||
from ..storage.http_client import (
|
||||
StorageClient,
|
||||
StorageClientFactory,
|
||||
ClientException,
|
||||
StorageClientImmutables,
|
||||
ImmutableCreateResult,
|
||||
@ -323,10 +324,10 @@ class CustomHTTPServerTests(SyncTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(CustomHTTPServerTests, self).setUp()
|
||||
StorageClient.start_test_mode(
|
||||
StorageClientFactory.start_test_mode(
|
||||
lambda pool: self.addCleanup(pool.closeCachedConnections)
|
||||
)
|
||||
self.addCleanup(StorageClient.stop_test_mode)
|
||||
self.addCleanup(StorageClientFactory.stop_test_mode)
|
||||
# Could be a fixture, but will only be used in this test class so not
|
||||
# going to bother:
|
||||
self._http_server = TestApp()
|
||||
@ -341,6 +342,7 @@ class CustomHTTPServerTests(SyncTestCase):
|
||||
# fixed if https://github.com/twisted/treq/issues/226 were ever
|
||||
# fixed.
|
||||
clock=treq._agent._memoryReactor,
|
||||
test_mode=True,
|
||||
)
|
||||
self._http_server.clock = self.client._clock
|
||||
|
||||
@ -529,10 +531,10 @@ class HttpTestFixture(Fixture):
|
||||
"""
|
||||
|
||||
def _setUp(self):
|
||||
StorageClient.start_test_mode(
|
||||
StorageClientFactory.start_test_mode(
|
||||
lambda pool: self.addCleanup(pool.closeCachedConnections)
|
||||
)
|
||||
self.addCleanup(StorageClient.stop_test_mode)
|
||||
self.addCleanup(StorageClientFactory.stop_test_mode)
|
||||
self.clock = Reactor()
|
||||
self.tempdir = self.useFixture(TempDir())
|
||||
# The global Cooperator used by Twisted (a) used by pull producers in
|
||||
@ -558,6 +560,7 @@ class HttpTestFixture(Fixture):
|
||||
treq=self.treq,
|
||||
pool=None,
|
||||
clock=self.clock,
|
||||
test_mode=True,
|
||||
)
|
||||
|
||||
def result_of_with_flush(self, d):
|
||||
@ -671,6 +674,7 @@ class GenericHTTPAPITests(SyncTestCase):
|
||||
treq=StubTreq(self.http.http_server.get_resource()),
|
||||
pool=None,
|
||||
clock=self.http.clock,
|
||||
test_mode=True,
|
||||
)
|
||||
)
|
||||
with assert_fails_with_http_code(self, http.UNAUTHORIZED):
|
||||
|
Loading…
Reference in New Issue
Block a user