Secret header parsing.

This commit is contained in:
Itamar Turner-Trauring 2021-12-16 10:39:58 -05:00
parent bc8889f32f
commit b32374c8bc
2 changed files with 110 additions and 1 deletions

View File

@ -13,8 +13,12 @@ if PY2:
# fmt: off
from future.builtins import filter, map, zip, ascii, chr, hex, input, next, oct, open, pow, round, super, bytes, dict, list, object, range, str, max, min # noqa: F401
# fmt: on
else:
from typing import Dict, List, Set
from functools import wraps
from enum import Enum
from base64 import b64decode
from klein import Klein
from twisted.web import http
@ -26,6 +30,40 @@ from .server import StorageServer
from .http_client import swissnum_auth_header
class Secrets(Enum):
"""Different kinds of secrets the client may send."""
LEASE_RENEW = "lease-renew-secret"
LEASE_CANCEL = "lease-cancel-secret"
UPLOAD = "upload-secret"
class ClientSecretsException(Exception):
"""The client did not send the appropriate secrets."""
def _extract_secrets(header_values, required_secrets): # type: (List[str], Set[Secrets]) -> Dict[Secrets, bytes]
"""
Given list of values of ``X-Tahoe-Authorization`` headers, and required
secrets, return dictionary mapping secrets to decoded values.
If too few secrets were given, or too many, a ``ClientSecretsException`` is
raised.
"""
key_to_enum = {e.value: e for e in Secrets}
result = {}
try:
for header_value in header_values:
key, value = header_value.strip().split(" ", 1)
result[key_to_enum[key]] = b64decode(value)
except (ValueError, KeyError) as e:
raise ClientSecretsException("Bad header value(s): {}".format(header_values))
if result.keys() != required_secrets:
raise ClientSecretsException(
"Expected {} secrets, got {}".format(required_secrets, result.keys())
)
return result
def _authorization_decorator(f):
"""
Check the ``Authorization`` header, and (TODO: in later revision of code)

View File

@ -15,6 +15,7 @@ if PY2:
# fmt: on
from unittest import SkipTest
from base64 import b64encode
from twisted.trial.unittest import TestCase
from twisted.internet.defer import inlineCallbacks
@ -23,10 +24,80 @@ from treq.testing import StubTreq
from hyperlink import DecodedURL
from ..storage.server import StorageServer
from ..storage.http_server import HTTPServer
from ..storage.http_server import HTTPServer, _extract_secrets, Secrets, ClientSecretsException
from ..storage.http_client import StorageClient, ClientException
class ExtractSecretsTests(TestCase):
"""
Tests for ``_extract_secrets``.
"""
def test_extract_secrets(self):
"""
``_extract_secrets()`` returns a dictionary with the extracted secrets
if the input secrets match the required secrets.
"""
secret1 = b"\xFF\x11ZEBRa"
secret2 = b"\x34\xF2lalalalalala"
lease_secret = "lease-renew-secret " + str(b64encode(secret1), "ascii").strip()
upload_secret = "upload-secret " + str(b64encode(secret2), "ascii").strip()
# No secrets needed, none given:
self.assertEqual(_extract_secrets([], set()), {})
# One secret:
self.assertEqual(
_extract_secrets([lease_secret],
{Secrets.LEASE_RENEW}),
{Secrets.LEASE_RENEW: secret1}
)
# Two secrets:
self.assertEqual(
_extract_secrets([upload_secret, lease_secret],
{Secrets.LEASE_RENEW, Secrets.UPLOAD}),
{Secrets.LEASE_RENEW: secret1, Secrets.UPLOAD: secret2}
)
def test_wrong_number_of_secrets(self):
"""
If the wrong number of secrets are passed to ``_extract_secrets``, a
``ClientSecretsException`` is raised.
"""
secret1 = b"\xFF\x11ZEBRa"
lease_secret = "lease-renew-secret " + str(b64encode(secret1), "ascii").strip()
# Missing secret:
with self.assertRaises(ClientSecretsException):
_extract_secrets([], {Secrets.LEASE_RENEW})
# Wrong secret:
with self.assertRaises(ClientSecretsException):
_extract_secrets([lease_secret], {Secrets.UPLOAD})
# Extra secret:
with self.assertRaises(ClientSecretsException):
_extract_secrets([lease_secret], {})
def test_bad_secrets(self):
"""
Bad inputs to ``_extract_secrets`` result in
``ClientSecretsException``.
"""
# Missing value.
with self.assertRaises(ClientSecretsException):
_extract_secrets(["lease-renew-secret"], {Secrets.LEASE_RENEW})
# Garbage prefix
with self.assertRaises(ClientSecretsException):
_extract_secrets(["FOO eA=="], {})
# Not base64.
with self.assertRaises(ClientSecretsException):
_extract_secrets(["lease-renew-secret x"], {Secrets.LEASE_RENEW})
class HTTPTests(TestCase):
"""
Tests of HTTP client talking to the HTTP server.