diff --git a/src/allmydata/storage/http_server.py b/src/allmydata/storage/http_server.py index 6297ef484..47722180d 100644 --- a/src/allmydata/storage/http_server.py +++ b/src/allmydata/storage/http_server.py @@ -13,8 +13,12 @@ if PY2: # fmt: off from future.builtins import filter, map, zip, ascii, chr, hex, input, next, oct, open, pow, round, super, bytes, dict, list, object, range, str, max, min # noqa: F401 # fmt: on +else: + from typing import Dict, List, Set from functools import wraps +from enum import Enum +from base64 import b64decode from klein import Klein from twisted.web import http @@ -26,6 +30,40 @@ from .server import StorageServer from .http_client import swissnum_auth_header +class Secrets(Enum): + """Different kinds of secrets the client may send.""" + LEASE_RENEW = "lease-renew-secret" + LEASE_CANCEL = "lease-cancel-secret" + UPLOAD = "upload-secret" + + +class ClientSecretsException(Exception): + """The client did not send the appropriate secrets.""" + + +def _extract_secrets(header_values, required_secrets): # type: (List[str], Set[Secrets]) -> Dict[Secrets, bytes] + """ + Given list of values of ``X-Tahoe-Authorization`` headers, and required + secrets, return dictionary mapping secrets to decoded values. + + If too few secrets were given, or too many, a ``ClientSecretsException`` is + raised. + """ + key_to_enum = {e.value: e for e in Secrets} + result = {} + try: + for header_value in header_values: + key, value = header_value.strip().split(" ", 1) + result[key_to_enum[key]] = b64decode(value) + except (ValueError, KeyError) as e: + raise ClientSecretsException("Bad header value(s): {}".format(header_values)) + if result.keys() != required_secrets: + raise ClientSecretsException( + "Expected {} secrets, got {}".format(required_secrets, result.keys()) + ) + return result + + def _authorization_decorator(f): """ Check the ``Authorization`` header, and (TODO: in later revision of code) diff --git a/src/allmydata/test/test_storage_http.py b/src/allmydata/test/test_storage_http.py index e413a0624..2a84d477f 100644 --- a/src/allmydata/test/test_storage_http.py +++ b/src/allmydata/test/test_storage_http.py @@ -15,6 +15,7 @@ if PY2: # fmt: on from unittest import SkipTest +from base64 import b64encode from twisted.trial.unittest import TestCase from twisted.internet.defer import inlineCallbacks @@ -23,10 +24,80 @@ from treq.testing import StubTreq from hyperlink import DecodedURL from ..storage.server import StorageServer -from ..storage.http_server import HTTPServer +from ..storage.http_server import HTTPServer, _extract_secrets, Secrets, ClientSecretsException from ..storage.http_client import StorageClient, ClientException +class ExtractSecretsTests(TestCase): + """ + Tests for ``_extract_secrets``. + """ + def test_extract_secrets(self): + """ + ``_extract_secrets()`` returns a dictionary with the extracted secrets + if the input secrets match the required secrets. + """ + secret1 = b"\xFF\x11ZEBRa" + secret2 = b"\x34\xF2lalalalalala" + lease_secret = "lease-renew-secret " + str(b64encode(secret1), "ascii").strip() + upload_secret = "upload-secret " + str(b64encode(secret2), "ascii").strip() + + # No secrets needed, none given: + self.assertEqual(_extract_secrets([], set()), {}) + + # One secret: + self.assertEqual( + _extract_secrets([lease_secret], + {Secrets.LEASE_RENEW}), + {Secrets.LEASE_RENEW: secret1} + ) + + # Two secrets: + self.assertEqual( + _extract_secrets([upload_secret, lease_secret], + {Secrets.LEASE_RENEW, Secrets.UPLOAD}), + {Secrets.LEASE_RENEW: secret1, Secrets.UPLOAD: secret2} + ) + + def test_wrong_number_of_secrets(self): + """ + If the wrong number of secrets are passed to ``_extract_secrets``, a + ``ClientSecretsException`` is raised. + """ + secret1 = b"\xFF\x11ZEBRa" + lease_secret = "lease-renew-secret " + str(b64encode(secret1), "ascii").strip() + + # Missing secret: + with self.assertRaises(ClientSecretsException): + _extract_secrets([], {Secrets.LEASE_RENEW}) + + # Wrong secret: + with self.assertRaises(ClientSecretsException): + _extract_secrets([lease_secret], {Secrets.UPLOAD}) + + # Extra secret: + with self.assertRaises(ClientSecretsException): + _extract_secrets([lease_secret], {}) + + def test_bad_secrets(self): + """ + Bad inputs to ``_extract_secrets`` result in + ``ClientSecretsException``. + """ + + # Missing value. + with self.assertRaises(ClientSecretsException): + _extract_secrets(["lease-renew-secret"], {Secrets.LEASE_RENEW}) + + # Garbage prefix + with self.assertRaises(ClientSecretsException): + _extract_secrets(["FOO eA=="], {}) + + # Not base64. + with self.assertRaises(ClientSecretsException): + _extract_secrets(["lease-renew-secret x"], {Secrets.LEASE_RENEW}) + + class HTTPTests(TestCase): """ Tests of HTTP client talking to the HTTP server.