mirror of
https://github.com/tahoe-lafs/tahoe-lafs.git
synced 2025-01-30 08:04:08 +00:00
Secret header parsing.
This commit is contained in:
parent
bc8889f32f
commit
b32374c8bc
@ -13,8 +13,12 @@ if PY2:
|
||||
# fmt: off
|
||||
from future.builtins import filter, map, zip, ascii, chr, hex, input, next, oct, open, pow, round, super, bytes, dict, list, object, range, str, max, min # noqa: F401
|
||||
# fmt: on
|
||||
else:
|
||||
from typing import Dict, List, Set
|
||||
|
||||
from functools import wraps
|
||||
from enum import Enum
|
||||
from base64 import b64decode
|
||||
|
||||
from klein import Klein
|
||||
from twisted.web import http
|
||||
@ -26,6 +30,40 @@ from .server import StorageServer
|
||||
from .http_client import swissnum_auth_header
|
||||
|
||||
|
||||
class Secrets(Enum):
|
||||
"""Different kinds of secrets the client may send."""
|
||||
LEASE_RENEW = "lease-renew-secret"
|
||||
LEASE_CANCEL = "lease-cancel-secret"
|
||||
UPLOAD = "upload-secret"
|
||||
|
||||
|
||||
class ClientSecretsException(Exception):
|
||||
"""The client did not send the appropriate secrets."""
|
||||
|
||||
|
||||
def _extract_secrets(header_values, required_secrets): # type: (List[str], Set[Secrets]) -> Dict[Secrets, bytes]
|
||||
"""
|
||||
Given list of values of ``X-Tahoe-Authorization`` headers, and required
|
||||
secrets, return dictionary mapping secrets to decoded values.
|
||||
|
||||
If too few secrets were given, or too many, a ``ClientSecretsException`` is
|
||||
raised.
|
||||
"""
|
||||
key_to_enum = {e.value: e for e in Secrets}
|
||||
result = {}
|
||||
try:
|
||||
for header_value in header_values:
|
||||
key, value = header_value.strip().split(" ", 1)
|
||||
result[key_to_enum[key]] = b64decode(value)
|
||||
except (ValueError, KeyError) as e:
|
||||
raise ClientSecretsException("Bad header value(s): {}".format(header_values))
|
||||
if result.keys() != required_secrets:
|
||||
raise ClientSecretsException(
|
||||
"Expected {} secrets, got {}".format(required_secrets, result.keys())
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def _authorization_decorator(f):
|
||||
"""
|
||||
Check the ``Authorization`` header, and (TODO: in later revision of code)
|
||||
|
@ -15,6 +15,7 @@ if PY2:
|
||||
# fmt: on
|
||||
|
||||
from unittest import SkipTest
|
||||
from base64 import b64encode
|
||||
|
||||
from twisted.trial.unittest import TestCase
|
||||
from twisted.internet.defer import inlineCallbacks
|
||||
@ -23,10 +24,80 @@ from treq.testing import StubTreq
|
||||
from hyperlink import DecodedURL
|
||||
|
||||
from ..storage.server import StorageServer
|
||||
from ..storage.http_server import HTTPServer
|
||||
from ..storage.http_server import HTTPServer, _extract_secrets, Secrets, ClientSecretsException
|
||||
from ..storage.http_client import StorageClient, ClientException
|
||||
|
||||
|
||||
class ExtractSecretsTests(TestCase):
|
||||
"""
|
||||
Tests for ``_extract_secrets``.
|
||||
"""
|
||||
def test_extract_secrets(self):
|
||||
"""
|
||||
``_extract_secrets()`` returns a dictionary with the extracted secrets
|
||||
if the input secrets match the required secrets.
|
||||
"""
|
||||
secret1 = b"\xFF\x11ZEBRa"
|
||||
secret2 = b"\x34\xF2lalalalalala"
|
||||
lease_secret = "lease-renew-secret " + str(b64encode(secret1), "ascii").strip()
|
||||
upload_secret = "upload-secret " + str(b64encode(secret2), "ascii").strip()
|
||||
|
||||
# No secrets needed, none given:
|
||||
self.assertEqual(_extract_secrets([], set()), {})
|
||||
|
||||
# One secret:
|
||||
self.assertEqual(
|
||||
_extract_secrets([lease_secret],
|
||||
{Secrets.LEASE_RENEW}),
|
||||
{Secrets.LEASE_RENEW: secret1}
|
||||
)
|
||||
|
||||
# Two secrets:
|
||||
self.assertEqual(
|
||||
_extract_secrets([upload_secret, lease_secret],
|
||||
{Secrets.LEASE_RENEW, Secrets.UPLOAD}),
|
||||
{Secrets.LEASE_RENEW: secret1, Secrets.UPLOAD: secret2}
|
||||
)
|
||||
|
||||
def test_wrong_number_of_secrets(self):
|
||||
"""
|
||||
If the wrong number of secrets are passed to ``_extract_secrets``, a
|
||||
``ClientSecretsException`` is raised.
|
||||
"""
|
||||
secret1 = b"\xFF\x11ZEBRa"
|
||||
lease_secret = "lease-renew-secret " + str(b64encode(secret1), "ascii").strip()
|
||||
|
||||
# Missing secret:
|
||||
with self.assertRaises(ClientSecretsException):
|
||||
_extract_secrets([], {Secrets.LEASE_RENEW})
|
||||
|
||||
# Wrong secret:
|
||||
with self.assertRaises(ClientSecretsException):
|
||||
_extract_secrets([lease_secret], {Secrets.UPLOAD})
|
||||
|
||||
# Extra secret:
|
||||
with self.assertRaises(ClientSecretsException):
|
||||
_extract_secrets([lease_secret], {})
|
||||
|
||||
def test_bad_secrets(self):
|
||||
"""
|
||||
Bad inputs to ``_extract_secrets`` result in
|
||||
``ClientSecretsException``.
|
||||
"""
|
||||
|
||||
# Missing value.
|
||||
with self.assertRaises(ClientSecretsException):
|
||||
_extract_secrets(["lease-renew-secret"], {Secrets.LEASE_RENEW})
|
||||
|
||||
# Garbage prefix
|
||||
with self.assertRaises(ClientSecretsException):
|
||||
_extract_secrets(["FOO eA=="], {})
|
||||
|
||||
# Not base64.
|
||||
with self.assertRaises(ClientSecretsException):
|
||||
_extract_secrets(["lease-renew-secret x"], {Secrets.LEASE_RENEW})
|
||||
|
||||
|
||||
class HTTPTests(TestCase):
|
||||
"""
|
||||
Tests of HTTP client talking to the HTTP server.
|
||||
|
Loading…
x
Reference in New Issue
Block a user