diff --git a/src/allmydata/crypto/rsa.py b/src/allmydata/crypto/rsa.py index 3554ec557..f1ba0d10a 100644 --- a/src/allmydata/crypto/rsa.py +++ b/src/allmydata/crypto/rsa.py @@ -14,6 +14,7 @@ on any of their methods. from __future__ import annotations from typing_extensions import TypeAlias +from typing import Callable from functools import partial @@ -70,22 +71,29 @@ def create_signing_keypair_from_string(private_key_der: bytes) -> tuple[PrivateK :returns: 2-tuple of (private_key, public_key) """ - load = partial( - load_der_private_key, + _load = partial( + load_der_public_key, private_key_der, password=None, backend=default_backend(), ) + def load_with_validation() -> PrivateKey: + return _load() + + def load_without_validation() -> PrivateKey: + return _load(unsafe_skip_rsa_key_validation=True) + + # Load it once without the potentially expensive OpenSSL validation + # checks. These have superlinear complexity. We *will* run them just + # below - but first we'll apply our own constant-time checks. + load: Callable[[], PrivateKey] = load_without_validation try: - # Load it once without the potentially expensive OpenSSL validation - # checks. These have superlinear complexity. We *will* run them just - # below - but first we'll apply our own constant-time checks. - unsafe_priv_key = load(unsafe_skip_rsa_key_validation=True) + unsafe_priv_key = load() except TypeError: # cryptography<39 does not support this parameter, so just load the # key with validation... - unsafe_priv_key = load() + unsafe_priv_key = load_without_validation() # But avoid *reloading* it since that will run the expensive # validation *again*. load = lambda: unsafe_priv_key diff --git a/src/allmydata/web/common.py b/src/allmydata/web/common.py index 25e9e51f3..8f81aec94 100644 --- a/src/allmydata/web/common.py +++ b/src/allmydata/web/common.py @@ -6,7 +6,7 @@ from __future__ import annotations from six import ensure_str try: - from typing import Optional, Union, Tuple, Any, TypeVar + from typing import Optional, Union, Tuple, Any, TypeVar, Literal, overload except ImportError: pass @@ -708,7 +708,13 @@ def url_for_string(req, url_string): T = TypeVar("T") -def get_arg(req: IRequest, argname: str | bytes, default: T = None, multiple: bool = False) -> Union[bytes, tuple[bytes, ...], T]: +@overload +def get_arg(req: IRequest, argname: str | bytes, default: T = None, *, multiple: Literal[False] = False) -> bytes: ... + +@overload +def get_arg(req: IRequest, argname: str | bytes, default: T = None, *, multiple: Literal[True]) -> tuple[bytes, ...]: ... + +def get_arg(req: IRequest, argname: str | bytes, default: T | None = None, *, multiple: bool = False) -> None | T | bytes | tuple[bytes, ...]: """Extract an argument from either the query args (req.args) or the form body fields (req.fields). If multiple=False, this returns a single value (or the default, which defaults to None), and the query args take @@ -724,9 +730,6 @@ def get_arg(req: IRequest, argname: str | bytes, default: T = None, multiple: bo else: argname_bytes = argname - if isinstance(default, str): - default = default.encode("utf-8") - results = [] if argname_bytes in req.args: results.extend(req.args[argname_bytes]) @@ -740,6 +743,9 @@ def get_arg(req: IRequest, argname: str | bytes, default: T = None, multiple: bo return tuple(results) if results: return results[0] + + if isinstance(default, str): + return default.encode("utf-8") return default