diff --git a/src/allmydata/storage/http_client.py b/src/allmydata/storage/http_client.py index daadebb28..b8ba1641a 100644 --- a/src/allmydata/storage/http_client.py +++ b/src/allmydata/storage/http_client.py @@ -4,7 +4,7 @@ HTTP client that talks to the HTTP storage server. from __future__ import annotations -from typing import Union, Optional, Sequence, Mapping +from typing import Union, Optional, Sequence, Mapping, BinaryIO from base64 import b64encode from io import BytesIO @@ -131,25 +131,35 @@ class _LengthLimitedCollector: self.f.write(data) -def limited_content(response, max_length: int = 30 * 1024 * 1024) -> Deferred: +def limited_content(response, max_length: int = 30 * 1024 * 1024) -> Deferred[BinaryIO]: """ Like ``treq.content()``, but limit data read from the response to a set length. If the response is longer than the max allowed length, the result fails with a ``ValueError``. + + A potentially useful future improvement would be using a temporary file to + store the content; since filesystem buffering means that would use memory + for small responses and disk for large responses. """ collector = _LengthLimitedCollector(max_length) # Make really sure everything gets called in Deferred context, treq might # call collector directly... d = succeed(None) d.addCallback(lambda _: treq.collect(response, collector)) - d.addCallback(lambda _: collector.f.getvalue()) + + def done(_): + collector.f.seek(0) + return collector.f + + d.addCallback(done) return d def _decode_cbor(response, schema: Schema): """Given HTTP response, return decoded CBOR body.""" - def got_content(data): + def got_content(f: BinaryIO): + data = f.read() schema.validate_cbor(data) return loads(data) diff --git a/src/allmydata/test/test_storage_http.py b/src/allmydata/test/test_storage_http.py index 885750441..419052282 100644 --- a/src/allmydata/test/test_storage_http.py +++ b/src/allmydata/test/test_storage_http.py @@ -346,7 +346,7 @@ class CustomHTTPServerTests(SyncTestCase): ) self.assertEqual( - result_of(limited_content(response, at_least_length)), + result_of(limited_content(response, at_least_length)).read(), gen_bytes(length), )