Merge pull request #1185 from tahoe-lafs/3881-cbor-refactor

CBOR refactor for HTTP storage protocol

Fixes ticket:3881
This commit is contained in:
Itamar Turner-Trauring 2022-03-18 10:36:29 -04:00 committed by GitHub
commit c632aa1de1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 131 additions and 58 deletions

0
newsfragments/3881.minor Normal file
View File

View File

@ -2,27 +2,8 @@
HTTP client that talks to the HTTP storage server. HTTP client that talks to the HTTP storage server.
""" """
from __future__ import absolute_import from typing import Union, Set, Optional
from __future__ import division from treq.testing import StubTreq
from __future__ import print_function
from __future__ import unicode_literals
from future.utils import PY2
if PY2:
# fmt: off
from future.builtins import filter, map, zip, ascii, chr, hex, input, next, oct, open, pow, round, super, bytes, dict, list, object, range, str, max, min # noqa: F401
# fmt: on
from collections import defaultdict
Optional = Set = defaultdict(
lambda: None
) # some garbage to just make this module import
else:
# typing module not available in Python 2, and we only do type checking in
# Python 3 anyway.
from typing import Union, Set, Optional
from treq.testing import StubTreq
from base64 import b64encode from base64 import b64encode
@ -38,7 +19,7 @@ from twisted.internet.defer import inlineCallbacks, returnValue, fail, Deferred
from hyperlink import DecodedURL from hyperlink import DecodedURL
import treq import treq
from .http_common import swissnum_auth_header, Secrets from .http_common import swissnum_auth_header, Secrets, get_content_type, CBOR_MIME_TYPE
from .common import si_b2a from .common import si_b2a
@ -58,8 +39,15 @@ class ClientException(Exception):
def _decode_cbor(response): def _decode_cbor(response):
"""Given HTTP response, return decoded CBOR body.""" """Given HTTP response, return decoded CBOR body."""
if response.code > 199 and response.code < 300: if response.code > 199 and response.code < 300:
return treq.content(response).addCallback(loads) content_type = get_content_type(response.headers)
return fail(ClientException(response.code, response.phrase)) if content_type == CBOR_MIME_TYPE:
# TODO limit memory usage
# https://tahoe-lafs.org/trac/tahoe-lafs/ticket/3872
return treq.content(response).addCallback(loads)
else:
raise ClientException(-1, "Server didn't send CBOR")
else:
return fail(ClientException(response.code, response.phrase))
@attr.s @attr.s
@ -104,13 +92,19 @@ class StorageClient(object):
lease_cancel_secret=None, lease_cancel_secret=None,
upload_secret=None, upload_secret=None,
headers=None, headers=None,
message_to_serialize=None,
**kwargs **kwargs
): ):
""" """
Like ``treq.request()``, but with optional secrets that get translated Like ``treq.request()``, but with optional secrets that get translated
into corresponding HTTP headers. into corresponding HTTP headers.
If ``message_to_serialize`` is set, it will be serialized (by default
with CBOR) and set as the request body.
""" """
headers = self._get_headers(headers) headers = self._get_headers(headers)
# Add secrets:
for secret, value in [ for secret, value in [
(Secrets.LEASE_RENEW, lease_renew_secret), (Secrets.LEASE_RENEW, lease_renew_secret),
(Secrets.LEASE_CANCEL, lease_cancel_secret), (Secrets.LEASE_CANCEL, lease_cancel_secret),
@ -122,6 +116,21 @@ class StorageClient(object):
"X-Tahoe-Authorization", "X-Tahoe-Authorization",
b"%s %s" % (secret.value.encode("ascii"), b64encode(value).strip()), b"%s %s" % (secret.value.encode("ascii"), b64encode(value).strip()),
) )
# Note we can accept CBOR:
headers.addRawHeader("Accept", CBOR_MIME_TYPE)
# If there's a request message, serialize it and set the Content-Type
# header:
if message_to_serialize is not None:
if "data" in kwargs:
raise TypeError(
"Can't use both `message_to_serialize` and `data` "
"as keyword arguments at the same time"
)
kwargs["data"] = dumps(message_to_serialize)
headers.addRawHeader("Content-Type", CBOR_MIME_TYPE)
return self._treq.request(method, url, headers=headers, **kwargs) return self._treq.request(method, url, headers=headers, **kwargs)
@ -188,17 +197,15 @@ class StorageClientImmutables(object):
storage index failed the result will fire with an exception. storage index failed the result will fire with an exception.
""" """
url = self._client.relative_url("/v1/immutable/" + _encode_si(storage_index)) url = self._client.relative_url("/v1/immutable/" + _encode_si(storage_index))
message = dumps( message = {"share-numbers": share_numbers, "allocated-size": allocated_size}
{"share-numbers": share_numbers, "allocated-size": allocated_size}
)
response = yield self._client.request( response = yield self._client.request(
"POST", "POST",
url, url,
lease_renew_secret=lease_renew_secret, lease_renew_secret=lease_renew_secret,
lease_cancel_secret=lease_cancel_secret, lease_cancel_secret=lease_cancel_secret,
upload_secret=upload_secret, upload_secret=upload_secret,
data=message, message_to_serialize=message,
headers=Headers({"content-type": ["application/cbor"]}),
) )
decoded_response = yield _decode_cbor(response) decoded_response = yield _decode_cbor(response)
returnValue( returnValue(
@ -369,13 +376,8 @@ class StorageClientImmutables(object):
_encode_si(storage_index), share_number _encode_si(storage_index), share_number
) )
) )
message = dumps({"reason": reason}) message = {"reason": reason}
response = yield self._client.request( response = yield self._client.request("POST", url, message_to_serialize=message)
"POST",
url,
data=message,
headers=Headers({"content-type": ["application/cbor"]}),
)
if response.code == http.OK: if response.code == http.OK:
return return
else: else:

View File

@ -1,15 +1,26 @@
""" """
Common HTTP infrastructure for the storge server. Common HTTP infrastructure for the storge server.
""" """
from future.utils import PY2
if PY2:
# fmt: off
from future.builtins import filter, map, zip, ascii, chr, hex, input, next, oct, open, pow, round, super, bytes, dict, list, object, range, str, max, min # noqa: F401
# fmt: on
from enum import Enum from enum import Enum
from base64 import b64encode from base64 import b64encode
from typing import Optional
from werkzeug.http import parse_options_header
from twisted.web.http_headers import Headers
CBOR_MIME_TYPE = "application/cbor"
def get_content_type(headers: Headers) -> Optional[str]:
"""
Get the content type from the HTTP ``Content-Type`` header.
Returns ``None`` if no content-type was set.
"""
values = headers.getRawHeaders("content-type") or [None]
content_type = parse_options_header(values[0])[0] or None
return content_type
def swissnum_auth_header(swissnum): # type: (bytes) -> bytes def swissnum_auth_header(swissnum): # type: (bytes) -> bytes

View File

@ -2,7 +2,7 @@
HTTP server for storage. HTTP server for storage.
""" """
from typing import Dict, List, Set, Tuple from typing import Dict, List, Set, Tuple, Any
from functools import wraps from functools import wraps
from base64 import b64decode from base64 import b64decode
@ -11,7 +11,11 @@ import binascii
from klein import Klein from klein import Klein
from twisted.web import http from twisted.web import http
import attr import attr
from werkzeug.http import parse_range_header, parse_content_range_header from werkzeug.http import (
parse_range_header,
parse_content_range_header,
parse_accept_header,
)
from werkzeug.routing import BaseConverter, ValidationError from werkzeug.routing import BaseConverter, ValidationError
from werkzeug.datastructures import ContentRange from werkzeug.datastructures import ContentRange
@ -19,7 +23,7 @@ from werkzeug.datastructures import ContentRange
from cbor2 import dumps, loads from cbor2 import dumps, loads
from .server import StorageServer from .server import StorageServer
from .http_common import swissnum_auth_header, Secrets from .http_common import swissnum_auth_header, Secrets, get_content_type, CBOR_MIME_TYPE
from .common import si_a2b from .common import si_a2b
from .immutable import BucketWriter, ConflictingWriteError from .immutable import BucketWriter, ConflictingWriteError
from ..util.hashutil import timing_safe_compare from ..util.hashutil import timing_safe_compare
@ -243,20 +247,45 @@ class HTTPServer(object):
"""Return twisted.web ``Resource`` for this object.""" """Return twisted.web ``Resource`` for this object."""
return self._app.resource() return self._app.resource()
def _cbor(self, request, data): def _send_encoded(self, request, data):
"""Return CBOR-encoded data.""" """
# TODO Might want to optionally send JSON someday, based on Accept Return encoded data suitable for writing as the HTTP body response, by
# headers, see https://tahoe-lafs.org/trac/tahoe-lafs/ticket/3861 default using CBOR.
request.setHeader("Content-Type", "application/cbor")
# TODO if data is big, maybe want to use a temporary file eventually... Also sets the appropriate ``Content-Type`` header on the response.
return dumps(data) """
accept_headers = request.requestHeaders.getRawHeaders("accept") or [
CBOR_MIME_TYPE
]
accept = parse_accept_header(accept_headers[0])
if accept.best == CBOR_MIME_TYPE:
request.setHeader("Content-Type", CBOR_MIME_TYPE)
# TODO if data is big, maybe want to use a temporary file eventually...
# https://tahoe-lafs.org/trac/tahoe-lafs/ticket/3872
return dumps(data)
else:
# TODO Might want to optionally send JSON someday:
# https://tahoe-lafs.org/trac/tahoe-lafs/ticket/3861
raise _HTTPError(http.NOT_ACCEPTABLE)
def _read_encoded(self, request) -> Any:
"""
Read encoded request body data, decoding it with CBOR by default.
"""
content_type = get_content_type(request.requestHeaders)
if content_type == CBOR_MIME_TYPE:
# TODO limit memory usage, client could send arbitrarily large data...
# https://tahoe-lafs.org/trac/tahoe-lafs/ticket/3872
return loads(request.content.read())
else:
raise _HTTPError(http.UNSUPPORTED_MEDIA_TYPE)
##### Generic APIs ##### ##### Generic APIs #####
@_authorized_route(_app, set(), "/v1/version", methods=["GET"]) @_authorized_route(_app, set(), "/v1/version", methods=["GET"])
def version(self, request, authorization): def version(self, request, authorization):
"""Return version information.""" """Return version information."""
return self._cbor(request, self._storage_server.get_version()) return self._send_encoded(request, self._storage_server.get_version())
##### Immutable APIs ##### ##### Immutable APIs #####
@ -269,7 +298,7 @@ class HTTPServer(object):
def allocate_buckets(self, request, authorization, storage_index): def allocate_buckets(self, request, authorization, storage_index):
"""Allocate buckets.""" """Allocate buckets."""
upload_secret = authorization[Secrets.UPLOAD] upload_secret = authorization[Secrets.UPLOAD]
info = loads(request.content.read()) info = self._read_encoded(request)
# We do NOT validate the upload secret for existing bucket uploads. # We do NOT validate the upload secret for existing bucket uploads.
# Another upload may be happening in parallel, with a different upload # Another upload may be happening in parallel, with a different upload
@ -291,7 +320,7 @@ class HTTPServer(object):
storage_index, share_number, upload_secret, bucket storage_index, share_number, upload_secret, bucket
) )
return self._cbor( return self._send_encoded(
request, request,
{ {
"already-have": set(already_got), "already-have": set(already_got),
@ -367,7 +396,7 @@ class HTTPServer(object):
required = [] required = []
for start, end, _ in bucket.required_ranges().ranges(): for start, end, _ in bucket.required_ranges().ranges():
required.append({"begin": start, "end": end}) required.append({"begin": start, "end": end})
return self._cbor(request, {"required": required}) return self._send_encoded(request, {"required": required})
@_authorized_route( @_authorized_route(
_app, _app,
@ -380,7 +409,7 @@ class HTTPServer(object):
List shares for the given storage index. List shares for the given storage index.
""" """
share_numbers = list(self._storage_server.get_buckets(storage_index).keys()) share_numbers = list(self._storage_server.get_buckets(storage_index).keys())
return self._cbor(request, share_numbers) return self._send_encoded(request, share_numbers)
@_authorized_route( @_authorized_route(
_app, _app,
@ -469,6 +498,6 @@ class HTTPServer(object):
except KeyError: except KeyError:
raise _HTTPError(http.NOT_FOUND) raise _HTTPError(http.NOT_FOUND)
info = loads(request.content.read()) info = self._read_encoded(request)
bucket.advise_corrupt_share(info["reason"].encode("utf-8")) bucket.advise_corrupt_share(info["reason"].encode("utf-8"))
return b"" return b""

View File

@ -49,9 +49,29 @@ from ..storage.http_client import (
StorageClientGeneral, StorageClientGeneral,
_encode_si, _encode_si,
) )
from ..storage.http_common import get_content_type
from ..storage.common import si_b2a from ..storage.common import si_b2a
class HTTPUtilities(SyncTestCase):
"""Tests for HTTP common utilities."""
def test_get_content_type(self):
"""``get_content_type()`` extracts the content-type from the header."""
def assert_header_values_result(values, expected_content_type):
headers = Headers()
if values:
headers.setRawHeaders("Content-Type", values)
content_type = get_content_type(headers)
self.assertEqual(content_type, expected_content_type)
assert_header_values_result(["text/html"], "text/html")
assert_header_values_result([], None)
assert_header_values_result(["text/plain", "application/json"], "text/plain")
assert_header_values_result(["text/html;encoding=utf-8"], "text/html")
def _post_process(params): def _post_process(params):
secret_types, secrets = params secret_types, secrets = params
secrets = {t: s for (t, s) in zip(secret_types, secrets)} secrets = {t: s for (t, s) in zip(secret_types, secrets)}
@ -358,6 +378,17 @@ class GenericHTTPAPITests(SyncTestCase):
with assert_fails_with_http_code(self, http.UNAUTHORIZED): with assert_fails_with_http_code(self, http.UNAUTHORIZED):
result_of(client.get_version()) result_of(client.get_version())
def test_unsupported_mime_type(self):
"""
The client can request mime types other than CBOR, and if they are
unsupported a NOT ACCEPTABLE (406) error will be returned.
"""
client = StorageClientGeneral(
StorageClientWithHeadersOverride(self.http.client, {"accept": "image/gif"})
)
with assert_fails_with_http_code(self, http.NOT_ACCEPTABLE):
result_of(client.get_version())
def test_version(self): def test_version(self):
""" """
The client can return the version. The client can return the version.