This commit is contained in:
Itamar Turner-Trauring 2022-11-07 11:19:00 -05:00
parent 26a9377d4c
commit c4772482ef
2 changed files with 78 additions and 6 deletions

View File

@ -20,8 +20,13 @@ from twisted.web.http_headers import Headers
from twisted.web import http
from twisted.web.iweb import IPolicyForHTTPS
from twisted.internet.defer import inlineCallbacks, returnValue, fail, Deferred, succeed
from twisted.internet.interfaces import IOpenSSLClientConnectionCreator
from twisted.internet.interfaces import (
IOpenSSLClientConnectionCreator,
IReactorTime,
IDelayedCall,
)
from twisted.internet.ssl import CertificateOptions
from twisted.internet import reactor
from twisted.web.client import Agent, HTTPConnectionPool
from zope.interface import implementer
from hyperlink import DecodedURL
@ -124,16 +129,20 @@ class _LengthLimitedCollector:
"""
remaining_length: int
timeout_on_silence: IDelayedCall
f: BytesIO = field(factory=BytesIO)
def __call__(self, data: bytes):
self.timeout_on_silence.reset(60)
self.remaining_length -= len(data)
if self.remaining_length < 0:
raise ValueError("Response length was too long")
self.f.write(data)
def limited_content(response, max_length: int = 30 * 1024 * 1024) -> Deferred[BinaryIO]:
def limited_content(
response, max_length: int = 30 * 1024 * 1024, clock: IReactorTime = reactor
) -> Deferred[BinaryIO]:
"""
Like ``treq.content()``, but limit data read from the response to a set
length. If the response is longer than the max allowed length, the result
@ -142,11 +151,16 @@ def limited_content(response, max_length: int = 30 * 1024 * 1024) -> Deferred[Bi
A potentially useful future improvement would be using a temporary file to
store the content; since filesystem buffering means that would use memory
for small responses and disk for large responses.
This will time out if no data is received for 60 seconds; so long as a
trickle of data continues to arrive, it will continue to run.
"""
collector = _LengthLimitedCollector(max_length)
d = succeed(None)
timeout = clock.callLater(60, d.cancel)
collector = _LengthLimitedCollector(max_length, timeout)
# Make really sure everything gets called in Deferred context, treq might
# call collector directly...
d = succeed(None)
d.addCallback(lambda _: treq.collect(response, collector))
def done(_):
@ -307,6 +321,8 @@ class StorageClient(object):
reactor,
_StorageClientHTTPSPolicy(expected_spki_hash=certificate_hash),
pool=HTTPConnectionPool(reactor, persistent=persistent),
# TCP-level connection timeout
connectTimeout=5,
)
)
@ -337,6 +353,7 @@ class StorageClient(object):
write_enabler_secret=None,
headers=None,
message_to_serialize=None,
timeout: Union[int, float] = 60,
**kwargs,
):
"""
@ -376,7 +393,9 @@ class StorageClient(object):
kwargs["data"] = dumps(message_to_serialize)
headers.addRawHeader("Content-Type", CBOR_MIME_TYPE)
return self._treq.request(method, url, headers=headers, **kwargs)
return self._treq.request(
method, url, headers=headers, timeout=timeout, **kwargs
)
class StorageClientGeneral(object):
@ -461,6 +480,9 @@ def read_share_chunk(
share_type, _encode_si(storage_index), share_number
)
)
# The default timeout is for getting the response, so it doesn't include
# the time it takes to download the body... so we will will deal with that
# later.
response = yield client.request(
"GET",
url,
@ -469,6 +491,7 @@ def read_share_chunk(
# but Range constructor does that the conversion for us.
{"range": [Range("bytes", [(offset, offset + length)]).to_header()]}
),
unbuffered=True, # Don't buffer the response in memory.
)
if response.code == http.NO_CONTENT:

View File

@ -31,6 +31,8 @@ from klein import Klein
from hyperlink import DecodedURL
from collections_extended import RangeMap
from twisted.internet.task import Clock, Cooperator
from twisted.internet.interfaces import IReactorTime
from twisted.internet.defer import CancelledError, Deferred
from twisted.web import http
from twisted.web.http_headers import Headers
from werkzeug import routing
@ -245,6 +247,7 @@ def gen_bytes(length: int) -> bytes:
class TestApp(object):
"""HTTP API for testing purposes."""
clock: IReactorTime
_app = Klein()
_swissnum = SWISSNUM_FOR_TEST # Match what the test client is using
@ -266,6 +269,17 @@ class TestApp(object):
"""Return bytes to the given length using ``gen_bytes()``."""
return gen_bytes(length)
@_authorized_route(_app, set(), "/slowly_never_finish_result", methods=["GET"])
def slowly_never_finish_result(self, request, authorization):
"""
Send data immediately, after 59 seconds, after another 59 seconds, and then
never again, without finishing the response.
"""
request.write(b"a")
self.clock.callLater(59, request.write, b"b")
self.clock.callLater(59 + 59, request.write, b"c")
return Deferred()
def result_of(d):
"""
@ -299,6 +313,10 @@ class CustomHTTPServerTests(SyncTestCase):
SWISSNUM_FOR_TEST,
treq=StubTreq(self._http_server._app.resource()),
)
# We're using a Treq private API to get the reactor, alas, but only in
# a test, so not going to worry about it too much. This would be fixed
# if https://github.com/twisted/treq/issues/226 were ever fixed.
self._http_server.clock = self.client._treq._agent._memoryReactor
def test_authorization_enforcement(self):
"""
@ -367,6 +385,35 @@ class CustomHTTPServerTests(SyncTestCase):
with self.assertRaises(ValueError):
result_of(limited_content(response, too_short))
def test_limited_content_silence_causes_timeout(self):
"""
``http_client.limited_content() times out if it receives no data for 60
seconds.
"""
response = result_of(
self.client.request(
"GET",
"http://127.0.0.1/slowly_never_finish_result",
)
)
body_deferred = limited_content(response, 4, self._http_server.clock)
result = []
error = []
body_deferred.addCallbacks(result.append, error.append)
for i in range(59 + 59 + 60):
self.assertEqual((result, error), ([], []))
self._http_server.clock.advance(1)
# Push data between in-memory client and in-memory server:
self.client._treq._agent.flush()
# After 59 (second write) + 59 (third write) + 60 seconds (quiescent
# timeout) the limited_content() response times out.
self.assertTrue(error)
with self.assertRaises(CancelledError):
error[0].raiseException()
class HttpTestFixture(Fixture):
"""
@ -1441,7 +1488,9 @@ class SharedImmutableMutableTestsMixin:
self.http.client.request(
"GET",
self.http.client.relative_url(
"/storage/v1/{}/{}/1".format(self.KIND, _encode_si(storage_index))
"/storage/v1/{}/{}/1".format(
self.KIND, _encode_si(storage_index)
)
),
headers=headers,
)