Ensure and test (and necessary refactor) that lack of content-type is same as

CBOR content-type, as per spec.
This commit is contained in:
Itamar Turner-Trauring 2023-07-25 15:31:30 -04:00
parent bf2451bbcd
commit 46d10a6281
2 changed files with 108 additions and 57 deletions

View File

@ -530,6 +530,60 @@ def _add_error_handling(app: Klein):
return str(failure.value).encode("utf-8") return str(failure.value).encode("utf-8")
async def read_encoded(
reactor, request, schema: Schema, max_size: int = 1024 * 1024
) -> Any:
"""
Read encoded request body data, decoding it with CBOR by default.
Somewhat arbitrarily, limit body size to 1MiB by default.
"""
content_type = get_content_type(request.requestHeaders)
if content_type is None:
content_type = CBOR_MIME_TYPE
if content_type != CBOR_MIME_TYPE:
raise _HTTPError(http.UNSUPPORTED_MEDIA_TYPE)
# Make sure it's not too large:
request.content.seek(0, SEEK_END)
size = request.content.tell()
if size > max_size:
raise _HTTPError(http.REQUEST_ENTITY_TOO_LARGE)
request.content.seek(0, SEEK_SET)
# We don't want to load the whole message into memory, cause it might
# be quite large. The CDDL validator takes a read-only bytes-like
# thing. Luckily, for large request bodies twisted.web will buffer the
# data in a file, so we can use mmap() to get a memory view. The CDDL
# validator will not make a copy, so it won't increase memory usage
# beyond that.
try:
fd = request.content.fileno()
except (ValueError, OSError):
fd = -1
if fd >= 0:
# It's a file, so we can use mmap() to save memory.
message = mmap.mmap(fd, 0, access=mmap.ACCESS_READ)
else:
message = request.content.read()
# Pycddl will release the GIL when validating larger documents, so
# let's take advantage of multiple CPUs:
if size > 10_000:
await defer_to_thread(reactor, schema.validate_cbor, message)
else:
schema.validate_cbor(message)
# The CBOR parser will allocate more memory, but at least we can feed
# it the file-like object, so that if it's large it won't be make two
# copies.
request.content.seek(SEEK_SET, 0)
# Typically deserialization to Python will not release the GIL, and
# indeed as of Jan 2023 cbor2 didn't have any code to release the GIL
# in the decode path. As such, running it in a different thread has no benefit.
return cbor2.load(request.content)
class HTTPServer(object): class HTTPServer(object):
""" """
A HTTP interface to the storage server. A HTTP interface to the storage server.
@ -587,56 +641,6 @@ class HTTPServer(object):
# https://tahoe-lafs.org/trac/tahoe-lafs/ticket/3861 # https://tahoe-lafs.org/trac/tahoe-lafs/ticket/3861
raise _HTTPError(http.NOT_ACCEPTABLE) raise _HTTPError(http.NOT_ACCEPTABLE)
async def _read_encoded(
self, request, schema: Schema, max_size: int = 1024 * 1024
) -> Any:
"""
Read encoded request body data, decoding it with CBOR by default.
Somewhat arbitrarily, limit body size to 1MiB by default.
"""
content_type = get_content_type(request.requestHeaders)
if content_type != CBOR_MIME_TYPE:
raise _HTTPError(http.UNSUPPORTED_MEDIA_TYPE)
# Make sure it's not too large:
request.content.seek(0, SEEK_END)
size = request.content.tell()
if size > max_size:
raise _HTTPError(http.REQUEST_ENTITY_TOO_LARGE)
request.content.seek(0, SEEK_SET)
# We don't want to load the whole message into memory, cause it might
# be quite large. The CDDL validator takes a read-only bytes-like
# thing. Luckily, for large request bodies twisted.web will buffer the
# data in a file, so we can use mmap() to get a memory view. The CDDL
# validator will not make a copy, so it won't increase memory usage
# beyond that.
try:
fd = request.content.fileno()
except (ValueError, OSError):
fd = -1
if fd >= 0:
# It's a file, so we can use mmap() to save memory.
message = mmap.mmap(fd, 0, access=mmap.ACCESS_READ)
else:
message = request.content.read()
# Pycddl will release the GIL when validating larger documents, so
# let's take advantage of multiple CPUs:
if size > 10_000:
await defer_to_thread(self._reactor, schema.validate_cbor, message)
else:
schema.validate_cbor(message)
# The CBOR parser will allocate more memory, but at least we can feed
# it the file-like object, so that if it's large it won't be make two
# copies.
request.content.seek(SEEK_SET, 0)
# Typically deserialization to Python will not release the GIL, and
# indeed as of Jan 2023 cbor2 didn't have any code to release the GIL
# in the decode path. As such, running it in a different thread has no benefit.
return cbor2.load(request.content)
##### Generic APIs ##### ##### Generic APIs #####
@ -677,8 +681,8 @@ class HTTPServer(object):
"""Allocate buckets.""" """Allocate buckets."""
upload_secret = authorization[Secrets.UPLOAD] upload_secret = authorization[Secrets.UPLOAD]
# It's just a list of up to ~256 shares, shouldn't use many bytes. # It's just a list of up to ~256 shares, shouldn't use many bytes.
info = await self._read_encoded( info = await read_encoded(
request, _SCHEMAS["allocate_buckets"], max_size=8192 self._reactor, request, _SCHEMAS["allocate_buckets"], max_size=8192
) )
# We do NOT validate the upload secret for existing bucket uploads. # We do NOT validate the upload secret for existing bucket uploads.
@ -849,7 +853,8 @@ class HTTPServer(object):
# The reason can be a string with explanation, so in theory it could be # The reason can be a string with explanation, so in theory it could be
# longish? # longish?
info = await self._read_encoded( info = await read_encoded(
self._reactor,
request, request,
_SCHEMAS["advise_corrupt_share"], _SCHEMAS["advise_corrupt_share"],
max_size=32768, max_size=32768,
@ -868,8 +873,8 @@ class HTTPServer(object):
@async_to_deferred @async_to_deferred
async def mutable_read_test_write(self, request, authorization, storage_index): async def mutable_read_test_write(self, request, authorization, storage_index):
"""Read/test/write combined operation for mutables.""" """Read/test/write combined operation for mutables."""
rtw_request = await self._read_encoded( rtw_request = await read_encoded(
request, _SCHEMAS["mutable_read_test_write"], max_size=2**48 self._reactor, request, _SCHEMAS["mutable_read_test_write"], max_size=2**48
) )
secrets = ( secrets = (
authorization[Secrets.WRITE_ENABLER], authorization[Secrets.WRITE_ENABLER],
@ -955,8 +960,8 @@ class HTTPServer(object):
# The reason can be a string with explanation, so in theory it could be # The reason can be a string with explanation, so in theory it could be
# longish? # longish?
info = await self._read_encoded( info = await read_encoded(
request, _SCHEMAS["advise_corrupt_share"], max_size=32768 self._reactor, request, _SCHEMAS["advise_corrupt_share"], max_size=32768
) )
self._storage_server.advise_corrupt_share( self._storage_server.advise_corrupt_share(
b"mutable", storage_index, share_number, info["reason"].encode("utf-8") b"mutable", storage_index, share_number, info["reason"].encode("utf-8")

View File

@ -42,6 +42,7 @@ from werkzeug.exceptions import NotFound as WNotFound
from testtools.matchers import Equals from testtools.matchers import Equals
from zope.interface import implementer from zope.interface import implementer
from ..util.deferredutil import async_to_deferred
from .common import SyncTestCase from .common import SyncTestCase
from ..storage.http_common import ( from ..storage.http_common import (
get_content_type, get_content_type,
@ -59,6 +60,8 @@ from ..storage.http_server import (
_authorized_route, _authorized_route,
StorageIndexConverter, StorageIndexConverter,
_add_error_handling, _add_error_handling,
read_encoded,
_SCHEMAS as SERVER_SCHEMAS,
) )
from ..storage.http_client import ( from ..storage.http_client import (
StorageClient, StorageClient,
@ -303,6 +306,14 @@ class TestApp(object):
request.transport.loseConnection() request.transport.loseConnection()
return Deferred() return Deferred()
@_authorized_route(_app, set(), "/read_body", methods=["POST"])
@async_to_deferred
async def read_body(self, request, authorization):
data = await read_encoded(
self.clock, request, SERVER_SCHEMAS["advise_corrupt_share"]
)
return data["reason"]
def result_of(d): def result_of(d):
""" """
@ -320,6 +331,7 @@ def result_of(d):
+ "This is probably a test design issue." + "This is probably a test design issue."
) )
class CustomHTTPServerTests(SyncTestCase): class CustomHTTPServerTests(SyncTestCase):
""" """
Tests that use a custom HTTP server. Tests that use a custom HTTP server.
@ -504,6 +516,40 @@ class CustomHTTPServerTests(SyncTestCase):
result_of(d) result_of(d)
self.assertEqual(len(self._http_server.clock.getDelayedCalls()), 0) self.assertEqual(len(self._http_server.clock.getDelayedCalls()), 0)
def test_request_with_no_content_type_same_as_cbor(self):
"""
If no ``Content-Type`` header is set when sending a body, it is assumed
to be CBOR.
"""
response = result_of(
self.client.request(
"POST",
DecodedURL.from_text("http://127.0.0.1/read_body"),
data=dumps({"reason": "test"}),
)
)
self.assertEqual(
result_of(limited_content(response, self._http_server.clock, 100)).read(),
b"test",
)
def test_request_with_wrong_content(self):
"""
If a non-CBOR ``Content-Type`` header is set when sending a body, the
server complains appropriatly.
"""
headers = Headers()
headers.setRawHeaders("content-type", ["some/value"])
response = result_of(
self.client.request(
"POST",
DecodedURL.from_text("http://127.0.0.1/read_body"),
data=dumps({"reason": "test"}),
headers=headers,
)
)
self.assertEqual(response.code, http.UNSUPPORTED_MEDIA_TYPE)
@implementer(IReactorFromThreads) @implementer(IReactorFromThreads)
class Reactor(Clock): class Reactor(Clock):