Fix some more mypy errors

This commit is contained in:
Jean-Paul Calderone 2023-01-06 18:11:47 -05:00
parent 01b14fe05c
commit a806b2faba
2 changed files with 26 additions and 12 deletions

View File

@ -14,6 +14,7 @@ on any of their methods.
from __future__ import annotations
from typing_extensions import TypeAlias
from typing import Callable
from functools import partial
@ -70,22 +71,29 @@ def create_signing_keypair_from_string(private_key_der: bytes) -> tuple[PrivateK
:returns: 2-tuple of (private_key, public_key)
"""
load = partial(
load_der_private_key,
_load = partial(
load_der_public_key,
private_key_der,
password=None,
backend=default_backend(),
)
def load_with_validation() -> PrivateKey:
return _load()
def load_without_validation() -> PrivateKey:
return _load(unsafe_skip_rsa_key_validation=True)
# Load it once without the potentially expensive OpenSSL validation
# checks. These have superlinear complexity. We *will* run them just
# below - but first we'll apply our own constant-time checks.
load: Callable[[], PrivateKey] = load_without_validation
try:
# Load it once without the potentially expensive OpenSSL validation
# checks. These have superlinear complexity. We *will* run them just
# below - but first we'll apply our own constant-time checks.
unsafe_priv_key = load(unsafe_skip_rsa_key_validation=True)
unsafe_priv_key = load()
except TypeError:
# cryptography<39 does not support this parameter, so just load the
# key with validation...
unsafe_priv_key = load()
unsafe_priv_key = load_without_validation()
# But avoid *reloading* it since that will run the expensive
# validation *again*.
load = lambda: unsafe_priv_key

View File

@ -6,7 +6,7 @@ from __future__ import annotations
from six import ensure_str
try:
from typing import Optional, Union, Tuple, Any, TypeVar
from typing import Optional, Union, Tuple, Any, TypeVar, Literal, overload
except ImportError:
pass
@ -708,7 +708,13 @@ def url_for_string(req, url_string):
T = TypeVar("T")
def get_arg(req: IRequest, argname: str | bytes, default: T = None, multiple: bool = False) -> Union[bytes, tuple[bytes, ...], T]:
@overload
def get_arg(req: IRequest, argname: str | bytes, default: T = None, *, multiple: Literal[False] = False) -> bytes: ...
@overload
def get_arg(req: IRequest, argname: str | bytes, default: T = None, *, multiple: Literal[True]) -> tuple[bytes, ...]: ...
def get_arg(req: IRequest, argname: str | bytes, default: T | None = None, *, multiple: bool = False) -> None | T | bytes | tuple[bytes, ...]:
"""Extract an argument from either the query args (req.args) or the form
body fields (req.fields). If multiple=False, this returns a single value
(or the default, which defaults to None), and the query args take
@ -724,9 +730,6 @@ def get_arg(req: IRequest, argname: str | bytes, default: T = None, multiple: bo
else:
argname_bytes = argname
if isinstance(default, str):
default = default.encode("utf-8")
results = []
if argname_bytes in req.args:
results.extend(req.args[argname_bytes])
@ -740,6 +743,9 @@ def get_arg(req: IRequest, argname: str | bytes, default: T = None, multiple: bo
return tuple(results)
if results:
return results[0]
if isinstance(default, str):
return default.encode("utf-8")
return default