This commit is contained in:
Itamar Turner-Trauring 2022-11-07 11:19:00 -05:00
parent 26a9377d4c
commit c4772482ef
2 changed files with 78 additions and 6 deletions

View File

@ -20,8 +20,13 @@ from twisted.web.http_headers import Headers
from twisted.web import http from twisted.web import http
from twisted.web.iweb import IPolicyForHTTPS from twisted.web.iweb import IPolicyForHTTPS
from twisted.internet.defer import inlineCallbacks, returnValue, fail, Deferred, succeed from twisted.internet.defer import inlineCallbacks, returnValue, fail, Deferred, succeed
from twisted.internet.interfaces import IOpenSSLClientConnectionCreator from twisted.internet.interfaces import (
IOpenSSLClientConnectionCreator,
IReactorTime,
IDelayedCall,
)
from twisted.internet.ssl import CertificateOptions from twisted.internet.ssl import CertificateOptions
from twisted.internet import reactor
from twisted.web.client import Agent, HTTPConnectionPool from twisted.web.client import Agent, HTTPConnectionPool
from zope.interface import implementer from zope.interface import implementer
from hyperlink import DecodedURL from hyperlink import DecodedURL
@ -124,16 +129,20 @@ class _LengthLimitedCollector:
""" """
remaining_length: int remaining_length: int
timeout_on_silence: IDelayedCall
f: BytesIO = field(factory=BytesIO) f: BytesIO = field(factory=BytesIO)
def __call__(self, data: bytes): def __call__(self, data: bytes):
self.timeout_on_silence.reset(60)
self.remaining_length -= len(data) self.remaining_length -= len(data)
if self.remaining_length < 0: if self.remaining_length < 0:
raise ValueError("Response length was too long") raise ValueError("Response length was too long")
self.f.write(data) self.f.write(data)
def limited_content(response, max_length: int = 30 * 1024 * 1024) -> Deferred[BinaryIO]: def limited_content(
response, max_length: int = 30 * 1024 * 1024, clock: IReactorTime = reactor
) -> Deferred[BinaryIO]:
""" """
Like ``treq.content()``, but limit data read from the response to a set Like ``treq.content()``, but limit data read from the response to a set
length. If the response is longer than the max allowed length, the result length. If the response is longer than the max allowed length, the result
@ -142,11 +151,16 @@ def limited_content(response, max_length: int = 30 * 1024 * 1024) -> Deferred[Bi
A potentially useful future improvement would be using a temporary file to A potentially useful future improvement would be using a temporary file to
store the content; since filesystem buffering means that would use memory store the content; since filesystem buffering means that would use memory
for small responses and disk for large responses. for small responses and disk for large responses.
This will time out if no data is received for 60 seconds; so long as a
trickle of data continues to arrive, it will continue to run.
""" """
collector = _LengthLimitedCollector(max_length) d = succeed(None)
timeout = clock.callLater(60, d.cancel)
collector = _LengthLimitedCollector(max_length, timeout)
# Make really sure everything gets called in Deferred context, treq might # Make really sure everything gets called in Deferred context, treq might
# call collector directly... # call collector directly...
d = succeed(None)
d.addCallback(lambda _: treq.collect(response, collector)) d.addCallback(lambda _: treq.collect(response, collector))
def done(_): def done(_):
@ -307,6 +321,8 @@ class StorageClient(object):
reactor, reactor,
_StorageClientHTTPSPolicy(expected_spki_hash=certificate_hash), _StorageClientHTTPSPolicy(expected_spki_hash=certificate_hash),
pool=HTTPConnectionPool(reactor, persistent=persistent), pool=HTTPConnectionPool(reactor, persistent=persistent),
# TCP-level connection timeout
connectTimeout=5,
) )
) )
@ -337,6 +353,7 @@ class StorageClient(object):
write_enabler_secret=None, write_enabler_secret=None,
headers=None, headers=None,
message_to_serialize=None, message_to_serialize=None,
timeout: Union[int, float] = 60,
**kwargs, **kwargs,
): ):
""" """
@ -376,7 +393,9 @@ class StorageClient(object):
kwargs["data"] = dumps(message_to_serialize) kwargs["data"] = dumps(message_to_serialize)
headers.addRawHeader("Content-Type", CBOR_MIME_TYPE) headers.addRawHeader("Content-Type", CBOR_MIME_TYPE)
return self._treq.request(method, url, headers=headers, **kwargs) return self._treq.request(
method, url, headers=headers, timeout=timeout, **kwargs
)
class StorageClientGeneral(object): class StorageClientGeneral(object):
@ -461,6 +480,9 @@ def read_share_chunk(
share_type, _encode_si(storage_index), share_number share_type, _encode_si(storage_index), share_number
) )
) )
# The default timeout is for getting the response, so it doesn't include
# the time it takes to download the body... so we will will deal with that
# later.
response = yield client.request( response = yield client.request(
"GET", "GET",
url, url,
@ -469,6 +491,7 @@ def read_share_chunk(
# but Range constructor does that the conversion for us. # but Range constructor does that the conversion for us.
{"range": [Range("bytes", [(offset, offset + length)]).to_header()]} {"range": [Range("bytes", [(offset, offset + length)]).to_header()]}
), ),
unbuffered=True, # Don't buffer the response in memory.
) )
if response.code == http.NO_CONTENT: if response.code == http.NO_CONTENT:

View File

@ -31,6 +31,8 @@ from klein import Klein
from hyperlink import DecodedURL from hyperlink import DecodedURL
from collections_extended import RangeMap from collections_extended import RangeMap
from twisted.internet.task import Clock, Cooperator from twisted.internet.task import Clock, Cooperator
from twisted.internet.interfaces import IReactorTime
from twisted.internet.defer import CancelledError, Deferred
from twisted.web import http from twisted.web import http
from twisted.web.http_headers import Headers from twisted.web.http_headers import Headers
from werkzeug import routing from werkzeug import routing
@ -245,6 +247,7 @@ def gen_bytes(length: int) -> bytes:
class TestApp(object): class TestApp(object):
"""HTTP API for testing purposes.""" """HTTP API for testing purposes."""
clock: IReactorTime
_app = Klein() _app = Klein()
_swissnum = SWISSNUM_FOR_TEST # Match what the test client is using _swissnum = SWISSNUM_FOR_TEST # Match what the test client is using
@ -266,6 +269,17 @@ class TestApp(object):
"""Return bytes to the given length using ``gen_bytes()``.""" """Return bytes to the given length using ``gen_bytes()``."""
return gen_bytes(length) return gen_bytes(length)
@_authorized_route(_app, set(), "/slowly_never_finish_result", methods=["GET"])
def slowly_never_finish_result(self, request, authorization):
"""
Send data immediately, after 59 seconds, after another 59 seconds, and then
never again, without finishing the response.
"""
request.write(b"a")
self.clock.callLater(59, request.write, b"b")
self.clock.callLater(59 + 59, request.write, b"c")
return Deferred()
def result_of(d): def result_of(d):
""" """
@ -299,6 +313,10 @@ class CustomHTTPServerTests(SyncTestCase):
SWISSNUM_FOR_TEST, SWISSNUM_FOR_TEST,
treq=StubTreq(self._http_server._app.resource()), treq=StubTreq(self._http_server._app.resource()),
) )
# We're using a Treq private API to get the reactor, alas, but only in
# a test, so not going to worry about it too much. This would be fixed
# if https://github.com/twisted/treq/issues/226 were ever fixed.
self._http_server.clock = self.client._treq._agent._memoryReactor
def test_authorization_enforcement(self): def test_authorization_enforcement(self):
""" """
@ -367,6 +385,35 @@ class CustomHTTPServerTests(SyncTestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
result_of(limited_content(response, too_short)) result_of(limited_content(response, too_short))
def test_limited_content_silence_causes_timeout(self):
"""
``http_client.limited_content() times out if it receives no data for 60
seconds.
"""
response = result_of(
self.client.request(
"GET",
"http://127.0.0.1/slowly_never_finish_result",
)
)
body_deferred = limited_content(response, 4, self._http_server.clock)
result = []
error = []
body_deferred.addCallbacks(result.append, error.append)
for i in range(59 + 59 + 60):
self.assertEqual((result, error), ([], []))
self._http_server.clock.advance(1)
# Push data between in-memory client and in-memory server:
self.client._treq._agent.flush()
# After 59 (second write) + 59 (third write) + 60 seconds (quiescent
# timeout) the limited_content() response times out.
self.assertTrue(error)
with self.assertRaises(CancelledError):
error[0].raiseException()
class HttpTestFixture(Fixture): class HttpTestFixture(Fixture):
""" """
@ -1441,7 +1488,9 @@ class SharedImmutableMutableTestsMixin:
self.http.client.request( self.http.client.request(
"GET", "GET",
self.http.client.relative_url( self.http.client.relative_url(
"/storage/v1/{}/{}/1".format(self.KIND, _encode_si(storage_index)) "/storage/v1/{}/{}/1".format(
self.KIND, _encode_si(storage_index)
)
), ),
headers=headers, headers=headers,
) )