Add type annotations to _authorization_decorator.

This commit is contained in:
Itamar Turner-Trauring 2023-07-28 11:28:13 -04:00
parent 849f4ed2a5
commit 2b7f3d1707
2 changed files with 36 additions and 11 deletions

View File

@ -4,7 +4,8 @@ HTTP server for storage.
from __future__ import annotations from __future__ import annotations
from typing import Any, Callable, Union, cast, Optional from typing import Any, Callable, Union, cast, Optional, TypeVar, Sequence
from typing_extensions import ParamSpec, Concatenate
from functools import wraps from functools import wraps
from base64 import b64decode from base64 import b64decode
import binascii import binascii
@ -27,6 +28,7 @@ from twisted.internet.defer import Deferred
from twisted.internet.ssl import CertificateOptions, Certificate, PrivateCertificate from twisted.internet.ssl import CertificateOptions, Certificate, PrivateCertificate
from twisted.internet.interfaces import IReactorFromThreads from twisted.internet.interfaces import IReactorFromThreads
from twisted.web.server import Site, Request from twisted.web.server import Site, Request
from twisted.web.iweb import IRequest
from twisted.protocols.tls import TLSMemoryBIOFactory from twisted.protocols.tls import TLSMemoryBIOFactory
from twisted.python.filepath import FilePath from twisted.python.filepath import FilePath
@ -68,7 +70,7 @@ class ClientSecretsException(Exception):
def _extract_secrets( def _extract_secrets(
header_values: list[str], required_secrets: set[Secrets] header_values: Sequence[str], required_secrets: set[Secrets]
) -> dict[Secrets, bytes]: ) -> dict[Secrets, bytes]:
""" """
Given list of values of ``X-Tahoe-Authorization`` headers, and required Given list of values of ``X-Tahoe-Authorization`` headers, and required
@ -102,18 +104,32 @@ def _extract_secrets(
return result return result
def _authorization_decorator(required_secrets): P = ParamSpec("P")
T = TypeVar("T")
def _authorization_decorator(
required_secrets: set[Secrets],
) -> Callable[
[Callable[Concatenate[BaseApp, Request, dict[Secrets, bytes], P], T]],
Callable[Concatenate[BaseApp, Request, P], T],
]:
""" """
1. Check the ``Authorization`` header matches server swissnum. 1. Check the ``Authorization`` header matches server swissnum.
2. Extract ``X-Tahoe-Authorization`` headers and pass them in. 2. Extract ``X-Tahoe-Authorization`` headers and pass them in.
3. Log the request and response. 3. Log the request and response.
""" """
def decorator(f): def decorator(
f: Callable[Concatenate[BaseApp, Request, dict[Secrets, bytes], P], T]
) -> Callable[Concatenate[BaseApp, Request, P], T]:
@wraps(f) @wraps(f)
def route(self, request, *args, **kwargs): def route(
# Don't set text/html content type by default: self: BaseApp, request: Request, *args: P.args, **kwargs: P.kwargs
request.defaultContentType = None ) -> T:
# Don't set text/html content type by default.
# None is actually supported, see https://github.com/twisted/twisted/issues/11902
request.defaultContentType = None # type: ignore[assignment]
with start_action( with start_action(
action_type="allmydata:storage:http-server:handle-request", action_type="allmydata:storage:http-server:handle-request",
@ -584,7 +600,13 @@ async def read_encoded(
return cbor2.load(request.content) return cbor2.load(request.content)
class HTTPServer(object): class BaseApp:
"""Base class for ``HTTPServer`` and testing equivalent."""
_swissnum: bytes
class HTTPServer(BaseApp):
""" """
A HTTP interface to the storage server. A HTTP interface to the storage server.
""" """
@ -641,7 +663,6 @@ class HTTPServer(object):
# https://tahoe-lafs.org/trac/tahoe-lafs/ticket/3861 # https://tahoe-lafs.org/trac/tahoe-lafs/ticket/3861
raise _HTTPError(http.NOT_ACCEPTABLE) raise _HTTPError(http.NOT_ACCEPTABLE)
##### Generic APIs ##### ##### Generic APIs #####
@_authorized_route(_app, set(), "/storage/v1/version", methods=["GET"]) @_authorized_route(_app, set(), "/storage/v1/version", methods=["GET"])
@ -874,7 +895,10 @@ class HTTPServer(object):
async def mutable_read_test_write(self, request, authorization, storage_index): async def mutable_read_test_write(self, request, authorization, storage_index):
"""Read/test/write combined operation for mutables.""" """Read/test/write combined operation for mutables."""
rtw_request = await read_encoded( rtw_request = await read_encoded(
self._reactor, request, _SCHEMAS["mutable_read_test_write"], max_size=2**48 self._reactor,
request,
_SCHEMAS["mutable_read_test_write"],
max_size=2**48,
) )
secrets = ( secrets = (
authorization[Secrets.WRITE_ENABLER], authorization[Secrets.WRITE_ENABLER],

View File

@ -62,6 +62,7 @@ from ..storage.http_server import (
_add_error_handling, _add_error_handling,
read_encoded, read_encoded,
_SCHEMAS as SERVER_SCHEMAS, _SCHEMAS as SERVER_SCHEMAS,
BaseApp,
) )
from ..storage.http_client import ( from ..storage.http_client import (
StorageClient, StorageClient,
@ -257,7 +258,7 @@ def gen_bytes(length: int) -> bytes:
return result return result
class TestApp(object): class TestApp(BaseApp):
"""HTTP API for testing purposes.""" """HTTP API for testing purposes."""
clock: IReactorTime clock: IReactorTime