diff --git a/src/allmydata/storage/http_client.py b/src/allmydata/storage/http_client.py index 9b44d2a73..9f5d6cce2 100644 --- a/src/allmydata/storage/http_client.py +++ b/src/allmydata/storage/http_client.py @@ -56,6 +56,7 @@ from .http_common import ( get_content_type, CBOR_MIME_TYPE, get_spki_hash, + response_is_not_html, ) from ..interfaces import VersionMessage from .common import si_b2a, si_to_human_readable @@ -399,13 +400,17 @@ class StorageClientFactory: treq_client = HTTPClient(agent) https_url = DecodedURL().replace(scheme="https", host=nurl.host, port=nurl.port) swissnum = nurl.path[0].encode("ascii") + response_check = lambda _: None + if self.TEST_MODE_REGISTER_HTTP_POOL is not None: + response_check = response_is_not_html + return StorageClient( https_url, swissnum, treq_client, pool, reactor, - self.TEST_MODE_REGISTER_HTTP_POOL is not None, + response_check, ) @@ -424,7 +429,7 @@ class StorageClient(object): _pool: HTTPConnectionPool _clock: IReactorTime # Are we running unit tests? - _test_mode: bool + _analyze_response: Callable[[IResponse], None] = lambda _: None def relative_url(self, path: str) -> DecodedURL: """Get a URL relative to the base URL.""" @@ -531,12 +536,7 @@ class StorageClient(object): response = await self._treq.request( method, url, headers=headers, timeout=timeout, **kwargs ) - - if self._test_mode and response.code != 404: - # We're doing API queries, HTML is never correct except in 404, but - # it's the default for Twisted's web server so make sure nothing - # unexpected happened. - assert get_content_type(response.headers) != "text/html" + self._analyze_response(response) return response diff --git a/src/allmydata/storage/http_common.py b/src/allmydata/storage/http_common.py index f16a16785..650d905e9 100644 --- a/src/allmydata/storage/http_common.py +++ b/src/allmydata/storage/http_common.py @@ -12,6 +12,7 @@ from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat from werkzeug.http import parse_options_header from twisted.web.http_headers import Headers +from twisted.web.iweb import IResponse CBOR_MIME_TYPE = "application/cbor" @@ -27,6 +28,18 @@ def get_content_type(headers: Headers) -> Optional[str]: return content_type +def response_is_not_html(response: IResponse) -> None: + """ + During tests, this is registered so we can ensure the web server + doesn't give us text/html. + + HTML is never correct except in 404, but it's the default for + Twisted's web server so we assert nothing unexpected happened. + """ + if response.code != 404: + assert get_content_type(response.headers) != "text/html" + + def swissnum_auth_header(swissnum: bytes) -> bytes: """Return value for ``Authorization`` header.""" return b"Tahoe-LAFS " + b64encode(swissnum).strip() diff --git a/src/allmydata/test/test_storage_http.py b/src/allmydata/test/test_storage_http.py index 2b4023bc5..30f6a527d 100644 --- a/src/allmydata/test/test_storage_http.py +++ b/src/allmydata/test/test_storage_http.py @@ -43,7 +43,11 @@ from testtools.matchers import Equals from zope.interface import implementer from .common import SyncTestCase -from ..storage.http_common import get_content_type, CBOR_MIME_TYPE +from ..storage.http_common import ( + get_content_type, + CBOR_MIME_TYPE, + response_is_not_html, +) from ..storage.common import si_b2a from ..storage.lease import LeaseInfo from ..storage.server import StorageServer @@ -316,7 +320,6 @@ def result_of(d): + "This is probably a test design issue." ) - class CustomHTTPServerTests(SyncTestCase): """ Tests that use a custom HTTP server. @@ -342,7 +345,7 @@ class CustomHTTPServerTests(SyncTestCase): # fixed if https://github.com/twisted/treq/issues/226 were ever # fixed. clock=treq._agent._memoryReactor, - test_mode=True, + analyze_response=response_is_not_html, ) self._http_server.clock = self.client._clock @@ -560,7 +563,7 @@ class HttpTestFixture(Fixture): treq=self.treq, pool=None, clock=self.clock, - test_mode=True, + analyze_response=response_is_not_html, ) def result_of_with_flush(self, d): @@ -674,7 +677,7 @@ class GenericHTTPAPITests(SyncTestCase): treq=StubTreq(self.http.http_server.get_resource()), pool=None, clock=self.http.clock, - test_mode=True, + analyze_response=response_is_not_html, ) ) with assert_fails_with_http_code(self, http.UNAUTHORIZED):