Merge remote-tracking branch 'origin/master' into 3939-faster-http-protocol

This commit is contained in:
Itamar Turner-Trauring 2022-11-28 11:07:01 -05:00
commit 06b57cd835
6 changed files with 196 additions and 52 deletions

View File

@ -163,7 +163,9 @@ jobs:
matrix:
os:
- windows-latest
- ubuntu-latest
# 22.04 has some issue with Tor at the moment:
# https://tahoe-lafs.org/trac/tahoe-lafs/ticket/3943
- ubuntu-20.04
python-version:
- 3.7
- 3.9
@ -175,7 +177,7 @@ jobs:
steps:
- name: Install Tor [Ubuntu]
if: matrix.os == 'ubuntu-latest'
if: ${{ contains(matrix.os, 'ubuntu') }}
run: sudo apt install tor
# TODO: See https://tahoe-lafs.org/trac/tahoe-lafs/ticket/3744.

0
newsfragments/3940.minor Normal file
View File

View File

@ -20,7 +20,11 @@ from twisted.web.http_headers import Headers
from twisted.web import http
from twisted.web.iweb import IPolicyForHTTPS
from twisted.internet.defer import inlineCallbacks, returnValue, fail, Deferred, succeed
from twisted.internet.interfaces import IOpenSSLClientConnectionCreator, IReactorTime
from twisted.internet.interfaces import (
IOpenSSLClientConnectionCreator,
IReactorTime,
IDelayedCall,
)
from twisted.internet.ssl import CertificateOptions
from twisted.web.client import Agent, HTTPConnectionPool
from zope.interface import implementer
@ -124,16 +128,22 @@ class _LengthLimitedCollector:
"""
remaining_length: int
timeout_on_silence: IDelayedCall
f: BytesIO = field(factory=BytesIO)
def __call__(self, data: bytes):
self.timeout_on_silence.reset(60)
self.remaining_length -= len(data)
if self.remaining_length < 0:
raise ValueError("Response length was too long")
self.f.write(data)
def limited_content(response, max_length: int = 30 * 1024 * 1024) -> Deferred[BinaryIO]:
def limited_content(
response,
clock: IReactorTime,
max_length: int = 30 * 1024 * 1024,
) -> Deferred[BinaryIO]:
"""
Like ``treq.content()``, but limit data read from the response to a set
length. If the response is longer than the max allowed length, the result
@ -142,39 +152,29 @@ def limited_content(response, max_length: int = 30 * 1024 * 1024) -> Deferred[Bi
A potentially useful future improvement would be using a temporary file to
store the content; since filesystem buffering means that would use memory
for small responses and disk for large responses.
This will time out if no data is received for 60 seconds; so long as a
trickle of data continues to arrive, it will continue to run.
"""
collector = _LengthLimitedCollector(max_length)
d = succeed(None)
timeout = clock.callLater(60, d.cancel)
collector = _LengthLimitedCollector(max_length, timeout)
# Make really sure everything gets called in Deferred context, treq might
# call collector directly...
d = succeed(None)
d.addCallback(lambda _: treq.collect(response, collector))
def done(_):
timeout.cancel()
collector.f.seek(0)
return collector.f
d.addCallback(done)
return d
def failed(f):
if timeout.active():
timeout.cancel()
return f
def _decode_cbor(response, schema: Schema):
"""Given HTTP response, return decoded CBOR body."""
def got_content(f: BinaryIO):
data = f.read()
schema.validate_cbor(data)
return loads(data)
if response.code > 199 and response.code < 300:
content_type = get_content_type(response.headers)
if content_type == CBOR_MIME_TYPE:
return limited_content(response).addCallback(got_content)
else:
raise ClientException(-1, "Server didn't send CBOR")
else:
return treq.content(response).addCallback(
lambda data: fail(ClientException(response.code, response.phrase, data))
)
return d.addCallbacks(done, failed)
@define
@ -362,6 +362,7 @@ class StorageClient(object):
write_enabler_secret=None,
headers=None,
message_to_serialize=None,
timeout: float = 60,
**kwargs,
):
"""
@ -370,6 +371,8 @@ class StorageClient(object):
If ``message_to_serialize`` is set, it will be serialized (by default
with CBOR) and set as the request body.
Default timeout is 60 seconds.
"""
headers = self._get_headers(headers)
@ -401,7 +404,28 @@ class StorageClient(object):
kwargs["data"] = dumps(message_to_serialize)
headers.addRawHeader("Content-Type", CBOR_MIME_TYPE)
return self._treq.request(method, url, headers=headers, **kwargs)
return self._treq.request(
method, url, headers=headers, timeout=timeout, **kwargs
)
def decode_cbor(self, response, schema: Schema):
"""Given HTTP response, return decoded CBOR body."""
def got_content(f: BinaryIO):
data = f.read()
schema.validate_cbor(data)
return loads(data)
if response.code > 199 and response.code < 300:
content_type = get_content_type(response.headers)
if content_type == CBOR_MIME_TYPE:
return limited_content(response, self._clock).addCallback(got_content)
else:
raise ClientException(-1, "Server didn't send CBOR")
else:
return treq.content(response).addCallback(
lambda data: fail(ClientException(response.code, response.phrase, data))
)
@define(hash=True)
@ -419,7 +443,9 @@ class StorageClientGeneral(object):
"""
url = self._client.relative_url("/storage/v1/version")
response = yield self._client.request("GET", url)
decoded_response = yield _decode_cbor(response, _SCHEMAS["get_version"])
decoded_response = yield self._client.decode_cbor(
response, _SCHEMAS["get_version"]
)
returnValue(decoded_response)
@inlineCallbacks
@ -486,6 +512,9 @@ def read_share_chunk(
share_type, _encode_si(storage_index), share_number
)
)
# The default 60 second timeout is for getting the response, so it doesn't
# include the time it takes to download the body... so we will will deal
# with that later, via limited_content().
response = yield client.request(
"GET",
url,
@ -494,6 +523,7 @@ def read_share_chunk(
# but Range constructor does that the conversion for us.
{"range": [Range("bytes", [(offset, offset + length)]).to_header()]}
),
unbuffered=True, # Don't buffer the response in memory.
)
if response.code == http.NO_CONTENT:
@ -516,7 +546,7 @@ def read_share_chunk(
raise ValueError("Server sent more than we asked for?!")
# It might also send less than we asked for. That's (probably) OK, e.g.
# if we went past the end of the file.
body = yield limited_content(response, supposed_length)
body = yield limited_content(response, client._clock, supposed_length)
body.seek(0, SEEK_END)
actual_length = body.tell()
if actual_length != supposed_length:
@ -603,7 +633,9 @@ class StorageClientImmutables(object):
upload_secret=upload_secret,
message_to_serialize=message,
)
decoded_response = yield _decode_cbor(response, _SCHEMAS["allocate_buckets"])
decoded_response = yield self._client.decode_cbor(
response, _SCHEMAS["allocate_buckets"]
)
returnValue(
ImmutableCreateResult(
already_have=decoded_response["already-have"],
@ -679,7 +711,9 @@ class StorageClientImmutables(object):
raise ClientException(
response.code,
)
body = yield _decode_cbor(response, _SCHEMAS["immutable_write_share_chunk"])
body = yield self._client.decode_cbor(
response, _SCHEMAS["immutable_write_share_chunk"]
)
remaining = RangeMap()
for chunk in body["required"]:
remaining.set(True, chunk["begin"], chunk["end"])
@ -708,7 +742,7 @@ class StorageClientImmutables(object):
url,
)
if response.code == http.OK:
body = yield _decode_cbor(response, _SCHEMAS["list_shares"])
body = yield self._client.decode_cbor(response, _SCHEMAS["list_shares"])
returnValue(set(body))
else:
raise ClientException(response.code)
@ -825,7 +859,9 @@ class StorageClientMutables:
message_to_serialize=message,
)
if response.code == http.OK:
result = await _decode_cbor(response, _SCHEMAS["mutable_read_test_write"])
result = await self._client.decode_cbor(
response, _SCHEMAS["mutable_read_test_write"]
)
return ReadTestWriteResult(success=result["success"], reads=result["data"])
else:
raise ClientException(response.code, (await response.content()))
@ -854,7 +890,9 @@ class StorageClientMutables:
)
response = await self._client.request("GET", url)
if response.code == http.OK:
return await _decode_cbor(response, _SCHEMAS["mutable_list_shares"])
return await self._client.decode_cbor(
response, _SCHEMAS["mutable_list_shares"]
)
else:
raise ClientException(response.code)

View File

@ -20,6 +20,7 @@ from foolscap.api import flushEventualQueue
from allmydata import client
from allmydata.introducer.server import create_introducer
from allmydata.util import fileutil, log, pollmixin
from allmydata.util.deferredutil import async_to_deferred
from allmydata.storage import http_client
from allmydata.storage_client import (
NativeStorageServer,
@ -639,6 +640,40 @@ def _render_section_values(values):
))
@async_to_deferred
async def spin_until_cleanup_done(value=None, timeout=10):
"""
At the end of the test, spin until the reactor has no more DelayedCalls
and file descriptors (or equivalents) registered. This prevents dirty
reactor errors, while also not hard-coding a fixed amount of time, so it
can finish faster on faster computers.
There is also a timeout: if it takes more than 10 seconds (by default) for
the remaining reactor state to clean itself up, the presumption is that it
will never get cleaned up and the spinning stops.
Make sure to run as last thing in tearDown.
"""
def num_fds():
if hasattr(reactor, "handles"):
# IOCP!
return len(reactor.handles)
else:
# Normal reactor; having internal readers still registered is fine,
# that's not our code.
return len(
set(reactor.getReaders()) - set(reactor._internalReaders)
) + len(reactor.getWriters())
for i in range(timeout * 1000):
# There's a single DelayedCall for AsynchronousDeferredRunTest's
# timeout...
if (len(reactor.getDelayedCalls()) < 2 and num_fds() == 0):
break
await deferLater(reactor, 0.001)
return value
class SystemTestMixin(pollmixin.PollMixin, testutil.StallMixin):
# If set to True, use Foolscap for storage protocol. If set to False, HTTP
@ -685,7 +720,7 @@ class SystemTestMixin(pollmixin.PollMixin, testutil.StallMixin):
d = self.sparent.stopService()
d.addBoth(flush_but_dont_ignore)
d.addBoth(lambda x: self.close_idle_http_connections().addCallback(lambda _: x))
d.addBoth(lambda x: deferLater(reactor, 2, lambda: x))
d.addBoth(spin_until_cleanup_done)
return d
def getdir(self, subdir):

View File

@ -31,6 +31,8 @@ from klein import Klein
from hyperlink import DecodedURL
from collections_extended import RangeMap
from twisted.internet.task import Clock, Cooperator
from twisted.internet.interfaces import IReactorTime
from twisted.internet.defer import CancelledError, Deferred
from twisted.web import http
from twisted.web.http_headers import Headers
from werkzeug import routing
@ -245,6 +247,7 @@ def gen_bytes(length: int) -> bytes:
class TestApp(object):
"""HTTP API for testing purposes."""
clock: IReactorTime
_app = Klein()
_swissnum = SWISSNUM_FOR_TEST # Match what the test client is using
@ -266,6 +269,25 @@ class TestApp(object):
"""Return bytes to the given length using ``gen_bytes()``."""
return gen_bytes(length)
@_authorized_route(_app, set(), "/slowly_never_finish_result", methods=["GET"])
def slowly_never_finish_result(self, request, authorization):
"""
Send data immediately, after 59 seconds, after another 59 seconds, and then
never again, without finishing the response.
"""
request.write(b"a")
self.clock.callLater(59, request.write, b"b")
self.clock.callLater(59 + 59, request.write, b"c")
return Deferred()
@_authorized_route(_app, set(), "/die_unfinished", methods=["GET"])
def die(self, request, authorization):
"""
Dies half-way.
"""
request.transport.loseConnection()
return Deferred()
def result_of(d):
"""
@ -298,12 +320,18 @@ class CustomHTTPServerTests(SyncTestCase):
# Could be a fixture, but will only be used in this test class so not
# going to bother:
self._http_server = TestApp()
treq = StubTreq(self._http_server._app.resource())
self.client = StorageClient(
DecodedURL.from_text("http://127.0.0.1"),
SWISSNUM_FOR_TEST,
treq=StubTreq(self._http_server._app.resource()),
clock=Clock(),
treq=treq,
# We're using a Treq private API to get the reactor, alas, but only
# in a test, so not going to worry about it too much. This would be
# fixed if https://github.com/twisted/treq/issues/226 were ever
# fixed.
clock=treq._agent._memoryReactor,
)
self._http_server.clock = self.client._clock
def test_authorization_enforcement(self):
"""
@ -351,7 +379,9 @@ class CustomHTTPServerTests(SyncTestCase):
)
self.assertEqual(
result_of(limited_content(response, at_least_length)).read(),
result_of(
limited_content(response, self._http_server.clock, at_least_length)
).read(),
gen_bytes(length),
)
@ -370,7 +400,52 @@ class CustomHTTPServerTests(SyncTestCase):
)
with self.assertRaises(ValueError):
result_of(limited_content(response, too_short))
result_of(limited_content(response, self._http_server.clock, too_short))
def test_limited_content_silence_causes_timeout(self):
"""
``http_client.limited_content() times out if it receives no data for 60
seconds.
"""
response = result_of(
self.client.request(
"GET",
"http://127.0.0.1/slowly_never_finish_result",
)
)
body_deferred = limited_content(response, self._http_server.clock, 4)
result = []
error = []
body_deferred.addCallbacks(result.append, error.append)
for i in range(59 + 59 + 60):
self.assertEqual((result, error), ([], []))
self._http_server.clock.advance(1)
# Push data between in-memory client and in-memory server:
self.client._treq._agent.flush()
# After 59 (second write) + 59 (third write) + 60 seconds (quiescent
# timeout) the limited_content() response times out.
self.assertTrue(error)
with self.assertRaises(CancelledError):
error[0].raiseException()
def test_limited_content_cancels_timeout_on_failed_response(self):
"""
If the response fails somehow, the timeout is still cancelled.
"""
response = result_of(
self.client.request(
"GET",
"http://127.0.0.1/die",
)
)
d = limited_content(response, self._http_server.clock, 4)
with self.assertRaises(ValueError):
result_of(d)
self.assertEqual(len(self._http_server.clock.getDelayedCalls()), 0)
class HttpTestFixture(Fixture):

View File

@ -12,7 +12,7 @@ from cryptography import x509
from twisted.internet.endpoints import serverFromString
from twisted.internet import reactor
from twisted.internet.task import deferLater
from twisted.internet.defer import maybeDeferred
from twisted.web.server import Site
from twisted.web.static import Data
from twisted.web.client import Agent, HTTPConnectionPool, ResponseNeverReceived
@ -30,6 +30,7 @@ from ..storage.http_common import get_spki_hash
from ..storage.http_client import _StorageClientHTTPSPolicy
from ..storage.http_server import _TLSEndpointWrapper
from ..util.deferredutil import async_to_deferred
from .common_system import spin_until_cleanup_done
class HTTPSNurlTests(SyncTestCase):
@ -87,6 +88,10 @@ class PinningHTTPSValidation(AsyncTestCase):
self.addCleanup(self._port_assigner.tearDown)
return AsyncTestCase.setUp(self)
def tearDown(self):
d = maybeDeferred(AsyncTestCase.tearDown, self)
return d.addCallback(lambda _: spin_until_cleanup_done())
@asynccontextmanager
async def listen(self, private_key_path: FilePath, cert_path: FilePath):
"""
@ -107,9 +112,6 @@ class PinningHTTPSValidation(AsyncTestCase):
yield f"https://127.0.0.1:{listening_port.getHost().port}/"
finally:
await listening_port.stopListening()
# Make sure all server connections are closed :( No idea why this
# is necessary when it's not for IStorageServer HTTPS tests.
await deferLater(reactor, 0.01)
def request(self, url: str, expected_certificate: x509.Certificate):
"""
@ -144,10 +146,6 @@ class PinningHTTPSValidation(AsyncTestCase):
response = await self.request(url, certificate)
self.assertEqual(await response.content(), b"YOYODYNE")
# We keep getting TLSMemoryBIOProtocol being left around, so try harder
# to wait for it to finish.
await deferLater(reactor, 0.01)
@async_to_deferred
async def test_server_certificate_has_wrong_hash(self):
"""
@ -202,10 +200,6 @@ class PinningHTTPSValidation(AsyncTestCase):
response = await self.request(url, certificate)
self.assertEqual(await response.content(), b"YOYODYNE")
# We keep getting TLSMemoryBIOProtocol being left around, so try harder
# to wait for it to finish.
await deferLater(reactor, 0.001)
# A potential attack to test is a private key that doesn't match the
# certificate... but OpenSSL (quite rightly) won't let you listen with that
# so I don't know how to test that! See