more-generic testing hook

This commit is contained in:
meejah 2023-06-22 01:29:55 -06:00
parent 9cf69c5253
commit 122e0a73a9
2 changed files with 16 additions and 10 deletions

View File

@ -427,7 +427,7 @@ class StorageClient(object):
_pool: HTTPConnectionPool _pool: HTTPConnectionPool
_clock: IReactorTime _clock: IReactorTime
# Are we running unit tests? # Are we running unit tests?
_test_mode: bool _analyze_response: Callable[[IResponse], None] = lambda _: None
def relative_url(self, path: str) -> DecodedURL: def relative_url(self, path: str) -> DecodedURL:
"""Get a URL relative to the base URL.""" """Get a URL relative to the base URL."""
@ -534,12 +534,7 @@ class StorageClient(object):
response = await self._treq.request( response = await self._treq.request(
method, url, headers=headers, timeout=timeout, **kwargs method, url, headers=headers, timeout=timeout, **kwargs
) )
self._analyze_response(response)
if self._test_mode and response.code != 404:
# We're doing API queries, HTML is never correct except in 404, but
# it's the default for Twisted's web server so make sure nothing
# unexpected happened.
assert get_content_type(response.headers) != "text/html"
return response return response

View File

@ -316,6 +316,17 @@ def result_of(d):
+ "This is probably a test design issue." + "This is probably a test design issue."
) )
def response_is_not_html(response):
"""
During tests, this is registered so we can ensure the web server
doesn't give us text/html.
HTML is never correct except in 404, but it's the default for
Twisted's web server so we assert nothing unexpected happened.
"""
if response.code != 404:
assert get_content_type(response.headers) != "text/html"
class CustomHTTPServerTests(SyncTestCase): class CustomHTTPServerTests(SyncTestCase):
""" """
@ -342,7 +353,7 @@ class CustomHTTPServerTests(SyncTestCase):
# fixed if https://github.com/twisted/treq/issues/226 were ever # fixed if https://github.com/twisted/treq/issues/226 were ever
# fixed. # fixed.
clock=treq._agent._memoryReactor, clock=treq._agent._memoryReactor,
test_mode=True, analyze_response=response_is_not_html,
) )
self._http_server.clock = self.client._clock self._http_server.clock = self.client._clock
@ -560,7 +571,7 @@ class HttpTestFixture(Fixture):
treq=self.treq, treq=self.treq,
pool=None, pool=None,
clock=self.clock, clock=self.clock,
test_mode=True, analyze_response=response_is_not_html,
) )
def result_of_with_flush(self, d): def result_of_with_flush(self, d):
@ -674,7 +685,7 @@ class GenericHTTPAPITests(SyncTestCase):
treq=StubTreq(self.http.http_server.get_resource()), treq=StubTreq(self.http.http_server.get_resource()),
pool=None, pool=None,
clock=self.http.clock, clock=self.http.clock,
test_mode=True, analyze_response=response_is_not_html,
) )
) )
with assert_fails_with_http_code(self, http.UNAUTHORIZED): with assert_fails_with_http_code(self, http.UNAUTHORIZED):