From fef332754b28672659cd42ce1b5a47d677b3ef41 Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Mon, 14 Mar 2022 11:09:40 -0400 Subject: [PATCH] Switch to shared utility so server can use it too. --- src/allmydata/storage/http_client.py | 7 ++----- src/allmydata/storage/http_common.py | 21 +++++++++++++++------ src/allmydata/test/test_storage_http.py | 20 ++++++++++++++++++++ 3 files changed, 37 insertions(+), 11 deletions(-) diff --git a/src/allmydata/storage/http_client.py b/src/allmydata/storage/http_client.py index 2790f1f7e..ace97508c 100644 --- a/src/allmydata/storage/http_client.py +++ b/src/allmydata/storage/http_client.py @@ -13,14 +13,13 @@ import attr from cbor2 import loads, dumps from collections_extended import RangeMap from werkzeug.datastructures import Range, ContentRange -from werkzeug.http import parse_options_header from twisted.web.http_headers import Headers from twisted.web import http from twisted.internet.defer import inlineCallbacks, returnValue, fail, Deferred from hyperlink import DecodedURL import treq -from .http_common import swissnum_auth_header, Secrets +from .http_common import swissnum_auth_header, Secrets, get_content_type from .common import si_b2a @@ -40,9 +39,7 @@ class ClientException(Exception): def _decode_cbor(response): """Given HTTP response, return decoded CBOR body.""" if response.code > 199 and response.code < 300: - content_type = parse_options_header( - (response.headers.getRawHeaders("content-type") or [None])[0] - )[0] + content_type = get_content_type(response.headers) if content_type == "application/cbor": # TODO limit memory usage # https://tahoe-lafs.org/trac/tahoe-lafs/ticket/3872 diff --git a/src/allmydata/storage/http_common.py b/src/allmydata/storage/http_common.py index af4224bd0..fdf637180 100644 --- a/src/allmydata/storage/http_common.py +++ b/src/allmydata/storage/http_common.py @@ -1,15 +1,24 @@ """ Common HTTP infrastructure for the storge server. """ -from future.utils import PY2 - -if PY2: - # fmt: off - from future.builtins import filter, map, zip, ascii, chr, hex, input, next, oct, open, pow, round, super, bytes, dict, list, object, range, str, max, min # noqa: F401 - # fmt: on from enum import Enum from base64 import b64encode +from typing import Optional + +from werkzeug.http import parse_options_header +from twisted.web.http_headers import Headers + + +def get_content_type(headers: Headers) -> Optional[str]: + """ + Get the content type from the HTTP ``Content-Type`` header. + + Returns ``None`` if no content-type was set. + """ + values = headers.getRawHeaders("content-type") or [None] + content_type = parse_options_header(values[0])[0] or None + return content_type def swissnum_auth_header(swissnum): # type: (bytes) -> bytes diff --git a/src/allmydata/test/test_storage_http.py b/src/allmydata/test/test_storage_http.py index c20012a9b..af90d58a9 100644 --- a/src/allmydata/test/test_storage_http.py +++ b/src/allmydata/test/test_storage_http.py @@ -49,9 +49,29 @@ from ..storage.http_client import ( StorageClientGeneral, _encode_si, ) +from ..storage.http_common import get_content_type from ..storage.common import si_b2a +class HTTPUtilities(SyncTestCase): + """Tests for HTTP common utilities.""" + + def test_get_content_type(self): + """``get_content_type()`` extracts the content-type from the header.""" + + def assert_header_values_result(values, expected_content_type): + headers = Headers() + if values: + headers.setRawHeaders("Content-Type", values) + content_type = get_content_type(headers) + self.assertEqual(content_type, expected_content_type) + + assert_header_values_result(["text/html"], "text/html") + assert_header_values_result([], None) + assert_header_values_result(["text/plain", "application/json"], "text/plain") + assert_header_values_result(["text/html;encoding=utf-8"], "text/html") + + def _post_process(params): secret_types, secrets = params secrets = {t: s for (t, s) in zip(secret_types, secrets)}