Merge pull request #1323 from tahoe-lafs/4052.more-type-checking-gbs

More type check for http storage server

Fixes ticket:4052
This commit is contained in:
Itamar Turner-Trauring 2023-08-02 14:34:08 -04:00 committed by GitHub
commit e67ef7ad16
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 192 additions and 65 deletions

0
newsfragments/4052.minor Normal file
View File

View File

@ -41,6 +41,7 @@ from twisted.internet.interfaces import (
IDelayedCall,
)
from twisted.internet.ssl import CertificateOptions
from twisted.protocols.tls import TLSMemoryBIOProtocol
from twisted.web.client import Agent, HTTPConnectionPool
from zope.interface import implementer
from hyperlink import DecodedURL
@ -72,7 +73,7 @@ except ImportError:
pass
def _encode_si(si): # type: (bytes) -> str
def _encode_si(si: bytes) -> str:
"""Encode the storage index into Unicode string."""
return str(si_b2a(si), "ascii")
@ -80,9 +81,13 @@ def _encode_si(si): # type: (bytes) -> str
class ClientException(Exception):
"""An unexpected response code from the server."""
def __init__(self, code, *additional_args):
Exception.__init__(self, code, *additional_args)
def __init__(
self, code: int, message: Optional[str] = None, body: Optional[bytes] = None
):
Exception.__init__(self, code, message, body)
self.code = code
self.message = message
self.body = body
register_exception_extractor(ClientException, lambda e: {"response_code": e.code})
@ -93,7 +98,7 @@ register_exception_extractor(ClientException, lambda e: {"response_code": e.code
# Tags are of the form #6.nnn, where the number is documented at
# https://www.iana.org/assignments/cbor-tags/cbor-tags.xhtml. Notably, #6.258
# indicates a set.
_SCHEMAS = {
_SCHEMAS: Mapping[str, Schema] = {
"get_version": Schema(
# Note that the single-quoted (`'`) string keys in this schema
# represent *byte* strings - per the CDDL specification. Text strings
@ -155,7 +160,7 @@ class _LengthLimitedCollector:
timeout_on_silence: IDelayedCall
f: BytesIO = field(factory=BytesIO)
def __call__(self, data: bytes):
def __call__(self, data: bytes) -> None:
self.timeout_on_silence.reset(60)
self.remaining_length -= len(data)
if self.remaining_length < 0:
@ -164,7 +169,7 @@ class _LengthLimitedCollector:
def limited_content(
response,
response: IResponse,
clock: IReactorTime,
max_length: int = 30 * 1024 * 1024,
) -> Deferred[BinaryIO]:
@ -300,11 +305,13 @@ class _StorageClientHTTPSPolicy:
expected_spki_hash: bytes
# IPolicyForHTTPS
def creatorForNetloc(self, hostname, port):
def creatorForNetloc(self, hostname: str, port: int) -> _StorageClientHTTPSPolicy:
return self
# IOpenSSLClientConnectionCreator
def clientConnectionForTLS(self, tlsProtocol):
def clientConnectionForTLS(
self, tlsProtocol: TLSMemoryBIOProtocol
) -> SSL.Connection:
return SSL.Connection(
_TLSContextFactory(self.expected_spki_hash).getContext(), None
)
@ -344,7 +351,7 @@ class StorageClientFactory:
cls.TEST_MODE_REGISTER_HTTP_POOL = callback
@classmethod
def stop_test_mode(cls):
def stop_test_mode(cls) -> None:
"""Stop testing mode."""
cls.TEST_MODE_REGISTER_HTTP_POOL = None
@ -437,7 +444,7 @@ class StorageClient(object):
"""Get a URL relative to the base URL."""
return self._base_url.click(path)
def _get_headers(self, headers): # type: (Optional[Headers]) -> Headers
def _get_headers(self, headers: Optional[Headers]) -> Headers:
"""Return the basic headers to be used by default."""
if headers is None:
headers = Headers()
@ -565,7 +572,7 @@ class StorageClient(object):
).read()
raise ClientException(response.code, response.phrase, data)
def shutdown(self) -> Deferred:
def shutdown(self) -> Deferred[object]:
"""Shutdown any connections."""
return self._pool.closeCachedConnections()

View File

@ -4,7 +4,18 @@ HTTP server for storage.
from __future__ import annotations
from typing import Any, Callable, Union, cast, Optional
from typing import (
Any,
Callable,
Union,
cast,
Optional,
TypeVar,
Sequence,
Protocol,
Dict,
)
from typing_extensions import ParamSpec, Concatenate
from functools import wraps
from base64 import b64decode
import binascii
@ -15,20 +26,24 @@ import mmap
from eliot import start_action
from cryptography.x509 import Certificate as CryptoCertificate
from zope.interface import implementer
from klein import Klein
from klein import Klein, KleinRenderable
from klein.resource import KleinResource
from twisted.web import http
from twisted.internet.interfaces import (
IListeningPort,
IStreamServerEndpoint,
IPullProducer,
IProtocolFactory,
)
from twisted.internet.address import IPv4Address, IPv6Address
from twisted.internet.defer import Deferred
from twisted.internet.ssl import CertificateOptions, Certificate, PrivateCertificate
from twisted.internet.interfaces import IReactorFromThreads
from twisted.web.server import Site, Request
from twisted.web.iweb import IRequest
from twisted.protocols.tls import TLSMemoryBIOFactory
from twisted.python.filepath import FilePath
from twisted.python.failure import Failure
from attrs import define, field, Factory
from werkzeug.http import (
@ -68,7 +83,7 @@ class ClientSecretsException(Exception):
def _extract_secrets(
header_values: list[str], required_secrets: set[Secrets]
header_values: Sequence[str], required_secrets: set[Secrets]
) -> dict[Secrets, bytes]:
"""
Given list of values of ``X-Tahoe-Authorization`` headers, and required
@ -102,18 +117,43 @@ def _extract_secrets(
return result
def _authorization_decorator(required_secrets):
class BaseApp(Protocol):
"""Protocol for ``HTTPServer`` and testing equivalent."""
_swissnum: bytes
P = ParamSpec("P")
T = TypeVar("T")
SecretsDict = Dict[Secrets, bytes]
App = TypeVar("App", bound=BaseApp)
def _authorization_decorator(
required_secrets: set[Secrets],
) -> Callable[
[Callable[Concatenate[App, Request, SecretsDict, P], T]],
Callable[Concatenate[App, Request, P], T],
]:
"""
1. Check the ``Authorization`` header matches server swissnum.
2. Extract ``X-Tahoe-Authorization`` headers and pass them in.
3. Log the request and response.
"""
def decorator(f):
def decorator(
f: Callable[Concatenate[App, Request, SecretsDict, P], T]
) -> Callable[Concatenate[App, Request, P], T]:
@wraps(f)
def route(self, request, *args, **kwargs):
# Don't set text/html content type by default:
request.defaultContentType = None
def route(
self: App,
request: Request,
*args: P.args,
**kwargs: P.kwargs,
) -> T:
# Don't set text/html content type by default.
# None is actually supported, see https://github.com/twisted/twisted/issues/11902
request.defaultContentType = None # type: ignore[assignment]
with start_action(
action_type="allmydata:storage:http-server:handle-request",
@ -163,7 +203,22 @@ def _authorization_decorator(required_secrets):
return decorator
def _authorized_route(app, required_secrets, *route_args, **route_kwargs):
def _authorized_route(
klein_app: Klein,
required_secrets: set[Secrets],
url: str,
*route_args: Any,
branch: bool = False,
**route_kwargs: Any,
) -> Callable[
[
Callable[
Concatenate[App, Request, SecretsDict, P],
KleinRenderable,
]
],
Callable[..., KleinRenderable],
]:
"""
Like Klein's @route, but with additional support for checking the
``Authorization`` header as well as ``X-Tahoe-Authorization`` headers. The
@ -173,12 +228,23 @@ def _authorized_route(app, required_secrets, *route_args, **route_kwargs):
:param required_secrets: Set of required ``Secret`` types.
"""
def decorator(f):
@app.route(*route_args, **route_kwargs)
def decorator(
f: Callable[
Concatenate[App, Request, SecretsDict, P],
KleinRenderable,
]
) -> Callable[..., KleinRenderable]:
@klein_app.route(url, *route_args, branch=branch, **route_kwargs) # type: ignore[arg-type]
@_authorization_decorator(required_secrets)
@wraps(f)
def handle_route(*args, **kwargs):
return f(*args, **kwargs)
def handle_route(
app: App,
request: Request,
secrets: SecretsDict,
*args: P.args,
**kwargs: P.kwargs,
) -> KleinRenderable:
return f(app, request, secrets, *args, **kwargs)
return handle_route
@ -234,7 +300,7 @@ class UploadsInProgress(object):
except (KeyError, IndexError):
raise _HTTPError(http.NOT_FOUND)
def remove_write_bucket(self, bucket: BucketWriter):
def remove_write_bucket(self, bucket: BucketWriter) -> None:
"""Stop tracking the given ``BucketWriter``."""
try:
storage_index, share_number = self._bucketwriters.pop(bucket)
@ -250,7 +316,7 @@ class UploadsInProgress(object):
def validate_upload_secret(
self, storage_index: bytes, share_number: int, upload_secret: bytes
):
) -> None:
"""
Raise an unauthorized-HTTP-response exception if the given
storage_index+share_number have a different upload secret than the
@ -272,7 +338,7 @@ class StorageIndexConverter(BaseConverter):
regex = "[" + str(rfc3548_alphabet, "ascii") + "]{26}"
def to_python(self, value):
def to_python(self, value: str) -> bytes:
try:
return si_a2b(value.encode("ascii"))
except (AssertionError, binascii.Error, ValueError):
@ -351,7 +417,7 @@ class _ReadAllProducer:
start: int = field(default=0)
@classmethod
def produce_to(cls, request: Request, read_data: ReadData) -> Deferred:
def produce_to(cls, request: Request, read_data: ReadData) -> Deferred[bytes]:
"""
Create and register the producer, returning ``Deferred`` that should be
returned from a HTTP server endpoint.
@ -360,7 +426,7 @@ class _ReadAllProducer:
request.registerProducer(producer, False)
return producer.result
def resumeProducing(self):
def resumeProducing(self) -> None:
data = self.read_data(self.start, 65536)
if not data:
self.request.unregisterProducer()
@ -371,10 +437,10 @@ class _ReadAllProducer:
self.request.write(data)
self.start += len(data)
def pauseProducing(self):
def pauseProducing(self) -> None:
pass
def stopProducing(self):
def stopProducing(self) -> None:
pass
@ -392,7 +458,7 @@ class _ReadRangeProducer:
start: int
remaining: int
def resumeProducing(self):
def resumeProducing(self) -> None:
if self.result is None or self.request is None:
return
@ -429,10 +495,10 @@ class _ReadRangeProducer:
if self.remaining == 0:
self.stopProducing()
def pauseProducing(self):
def pauseProducing(self) -> None:
pass
def stopProducing(self):
def stopProducing(self) -> None:
if self.request is not None:
self.request.unregisterProducer()
self.request = None
@ -511,12 +577,13 @@ def read_range(
return d
def _add_error_handling(app: Klein):
def _add_error_handling(app: Klein) -> None:
"""Add exception handlers to a Klein app."""
@app.handle_errors(_HTTPError)
def _http_error(_, request, failure):
def _http_error(self: Any, request: IRequest, failure: Failure) -> KleinRenderable:
"""Handle ``_HTTPError`` exceptions."""
assert isinstance(failure.value, _HTTPError)
request.setResponseCode(failure.value.code)
if failure.value.body is not None:
return failure.value.body
@ -524,7 +591,9 @@ def _add_error_handling(app: Klein):
return b""
@app.handle_errors(CDDLValidationError)
def _cddl_validation_error(_, request, failure):
def _cddl_validation_error(
self: Any, request: IRequest, failure: Failure
) -> KleinRenderable:
"""Handle CDDL validation errors."""
request.setResponseCode(http.BAD_REQUEST)
return str(failure.value).encode("utf-8")
@ -584,7 +653,7 @@ async def read_encoded(
return cbor2.load(request.content)
class HTTPServer(object):
class HTTPServer(BaseApp):
"""
A HTTP interface to the storage server.
"""
@ -611,11 +680,11 @@ class HTTPServer(object):
self._uploads.remove_write_bucket
)
def get_resource(self):
def get_resource(self) -> KleinResource:
"""Return twisted.web ``Resource`` for this object."""
return self._app.resource()
def _send_encoded(self, request, data):
def _send_encoded(self, request: Request, data: object) -> Deferred[bytes]:
"""
Return encoded data suitable for writing as the HTTP body response, by
default using CBOR.
@ -641,11 +710,10 @@ class HTTPServer(object):
# https://tahoe-lafs.org/trac/tahoe-lafs/ticket/3861
raise _HTTPError(http.NOT_ACCEPTABLE)
##### Generic APIs #####
@_authorized_route(_app, set(), "/storage/v1/version", methods=["GET"])
def version(self, request, authorization):
def version(self, request: Request, authorization: SecretsDict) -> KleinRenderable:
"""Return version information."""
return self._send_encoded(request, self._get_version())
@ -677,7 +745,9 @@ class HTTPServer(object):
methods=["POST"],
)
@async_to_deferred
async def allocate_buckets(self, request, authorization, storage_index):
async def allocate_buckets(
self, request: Request, authorization: SecretsDict, storage_index: bytes
) -> KleinRenderable:
"""Allocate buckets."""
upload_secret = authorization[Secrets.UPLOAD]
# It's just a list of up to ~256 shares, shouldn't use many bytes.
@ -716,7 +786,13 @@ class HTTPServer(object):
"/storage/v1/immutable/<storage_index:storage_index>/<int(signed=False):share_number>/abort",
methods=["PUT"],
)
def abort_share_upload(self, request, authorization, storage_index, share_number):
def abort_share_upload(
self,
request: Request,
authorization: SecretsDict,
storage_index: bytes,
share_number: int,
) -> KleinRenderable:
"""Abort an in-progress immutable share upload."""
try:
bucket = self._uploads.get_write_bucket(
@ -747,7 +823,13 @@ class HTTPServer(object):
"/storage/v1/immutable/<storage_index:storage_index>/<int(signed=False):share_number>",
methods=["PATCH"],
)
def write_share_data(self, request, authorization, storage_index, share_number):
def write_share_data(
self,
request: Request,
authorization: SecretsDict,
storage_index: bytes,
share_number: int,
) -> KleinRenderable:
"""Write data to an in-progress immutable upload."""
content_range = parse_content_range_header(request.getHeader("content-range"))
if content_range is None or content_range.units != "bytes":
@ -757,14 +839,17 @@ class HTTPServer(object):
bucket = self._uploads.get_write_bucket(
storage_index, share_number, authorization[Secrets.UPLOAD]
)
offset = content_range.start
remaining = content_range.stop - content_range.start
offset = content_range.start or 0
# We don't support an unspecified stop for the range:
assert content_range.stop is not None
# Missing body makes no sense:
assert request.content is not None
remaining = content_range.stop - offset
finished = False
while remaining > 0:
data = request.content.read(min(remaining, 65536))
assert data, "uploaded data length doesn't match range"
try:
finished = bucket.write(offset, data)
except ConflictingWriteError:
@ -790,7 +875,9 @@ class HTTPServer(object):
"/storage/v1/immutable/<storage_index:storage_index>/shares",
methods=["GET"],
)
def list_shares(self, request, authorization, storage_index):
def list_shares(
self, request: Request, authorization: SecretsDict, storage_index: bytes
) -> KleinRenderable:
"""
List shares for the given storage index.
"""
@ -803,7 +890,13 @@ class HTTPServer(object):
"/storage/v1/immutable/<storage_index:storage_index>/<int(signed=False):share_number>",
methods=["GET"],
)
def read_share_chunk(self, request, authorization, storage_index, share_number):
def read_share_chunk(
self,
request: Request,
authorization: SecretsDict,
storage_index: bytes,
share_number: int,
) -> KleinRenderable:
"""Read a chunk for an already uploaded immutable."""
request.setHeader("content-type", "application/octet-stream")
try:
@ -820,7 +913,9 @@ class HTTPServer(object):
"/storage/v1/lease/<storage_index:storage_index>",
methods=["PUT"],
)
def add_or_renew_lease(self, request, authorization, storage_index):
def add_or_renew_lease(
self, request: Request, authorization: SecretsDict, storage_index: bytes
) -> KleinRenderable:
"""Update the lease for an immutable or mutable share."""
if not list(self._storage_server.get_shares(storage_index)):
raise _HTTPError(http.NOT_FOUND)
@ -843,8 +938,12 @@ class HTTPServer(object):
)
@async_to_deferred
async def advise_corrupt_share_immutable(
self, request, authorization, storage_index, share_number
):
self,
request: Request,
authorization: SecretsDict,
storage_index: bytes,
share_number: int,
) -> KleinRenderable:
"""Indicate that given share is corrupt, with a text reason."""
try:
bucket = self._storage_server.get_buckets(storage_index)[share_number]
@ -871,10 +970,15 @@ class HTTPServer(object):
methods=["POST"],
)
@async_to_deferred
async def mutable_read_test_write(self, request, authorization, storage_index):
async def mutable_read_test_write(
self, request: Request, authorization: SecretsDict, storage_index: bytes
) -> KleinRenderable:
"""Read/test/write combined operation for mutables."""
rtw_request = await read_encoded(
self._reactor, request, _SCHEMAS["mutable_read_test_write"], max_size=2**48
self._reactor,
request,
_SCHEMAS["mutable_read_test_write"],
max_size=2**48,
)
secrets = (
authorization[Secrets.WRITE_ENABLER],
@ -910,7 +1014,13 @@ class HTTPServer(object):
"/storage/v1/mutable/<storage_index:storage_index>/<int(signed=False):share_number>",
methods=["GET"],
)
def read_mutable_chunk(self, request, authorization, storage_index, share_number):
def read_mutable_chunk(
self,
request: Request,
authorization: SecretsDict,
storage_index: bytes,
share_number: int,
) -> KleinRenderable:
"""Read a chunk from a mutable."""
request.setHeader("content-type", "application/octet-stream")
@ -950,8 +1060,12 @@ class HTTPServer(object):
)
@async_to_deferred
async def advise_corrupt_share_mutable(
self, request, authorization, storage_index, share_number
):
self,
request: Request,
authorization: SecretsDict,
storage_index: bytes,
share_number: int,
) -> KleinRenderable:
"""Indicate that given share is corrupt, with a text reason."""
if share_number not in {
shnum for (shnum, _) in self._storage_server.get_shares(storage_index)
@ -983,7 +1097,10 @@ class _TLSEndpointWrapper(object):
@classmethod
def from_paths(
cls, endpoint, private_key_path: FilePath, cert_path: FilePath
cls: type[_TLSEndpointWrapper],
endpoint: IStreamServerEndpoint,
private_key_path: FilePath,
cert_path: FilePath,
) -> "_TLSEndpointWrapper":
"""
Create an endpoint with the given private key and certificate paths on
@ -998,7 +1115,7 @@ class _TLSEndpointWrapper(object):
)
return cls(endpoint=endpoint, context_factory=certificate_options)
def listen(self, factory):
def listen(self, factory: IProtocolFactory) -> Deferred[IListeningPort]:
return self.endpoint.listen(
TLSMemoryBIOFactory(self.context_factory, False, factory)
)

View File

@ -1428,7 +1428,7 @@ class _FakeRemoteReference(object):
result = yield getattr(self.local_object, action)(*args, **kwargs)
defer.returnValue(result)
except HTTPClientException as e:
raise RemoteException(e.args)
raise RemoteException((e.code, e.message, e.body))
@attr.s

View File

@ -62,6 +62,7 @@ from ..storage.http_server import (
_add_error_handling,
read_encoded,
_SCHEMAS as SERVER_SCHEMAS,
BaseApp,
)
from ..storage.http_client import (
StorageClient,
@ -257,7 +258,7 @@ def gen_bytes(length: int) -> bytes:
return result
class TestApp(object):
class TestApp(BaseApp):
"""HTTP API for testing purposes."""
clock: IReactorTime
@ -265,7 +266,7 @@ class TestApp(object):
_add_error_handling(_app)
_swissnum = SWISSNUM_FOR_TEST # Match what the test client is using
@_authorized_route(_app, {}, "/noop", methods=["GET"])
@_authorized_route(_app, set(), "/noop", methods=["GET"])
def noop(self, request, authorization):
return "noop"

View File

@ -109,9 +109,11 @@ class PinningHTTPSValidation(AsyncTestCase):
root.isLeaf = True
listening_port = await endpoint.listen(Site(root))
try:
yield f"https://127.0.0.1:{listening_port.getHost().port}/"
yield f"https://127.0.0.1:{listening_port.getHost().port}/" # type: ignore[attr-defined]
finally:
await listening_port.stopListening()
result = listening_port.stopListening()
if result is not None:
await result
def request(self, url: str, expected_certificate: x509.Certificate):
"""