mirror of
https://github.com/tahoe-lafs/tahoe-lafs.git
synced 2024-12-22 22:32:23 +00:00
Fix some more mypy errors
This commit is contained in:
parent
01b14fe05c
commit
a806b2faba
@ -14,6 +14,7 @@ on any of their methods.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing_extensions import TypeAlias
|
||||
from typing import Callable
|
||||
|
||||
from functools import partial
|
||||
|
||||
@ -70,22 +71,29 @@ def create_signing_keypair_from_string(private_key_der: bytes) -> tuple[PrivateK
|
||||
|
||||
:returns: 2-tuple of (private_key, public_key)
|
||||
"""
|
||||
load = partial(
|
||||
load_der_private_key,
|
||||
_load = partial(
|
||||
load_der_public_key,
|
||||
private_key_der,
|
||||
password=None,
|
||||
backend=default_backend(),
|
||||
)
|
||||
|
||||
try:
|
||||
def load_with_validation() -> PrivateKey:
|
||||
return _load()
|
||||
|
||||
def load_without_validation() -> PrivateKey:
|
||||
return _load(unsafe_skip_rsa_key_validation=True)
|
||||
|
||||
# Load it once without the potentially expensive OpenSSL validation
|
||||
# checks. These have superlinear complexity. We *will* run them just
|
||||
# below - but first we'll apply our own constant-time checks.
|
||||
unsafe_priv_key = load(unsafe_skip_rsa_key_validation=True)
|
||||
load: Callable[[], PrivateKey] = load_without_validation
|
||||
try:
|
||||
unsafe_priv_key = load()
|
||||
except TypeError:
|
||||
# cryptography<39 does not support this parameter, so just load the
|
||||
# key with validation...
|
||||
unsafe_priv_key = load()
|
||||
unsafe_priv_key = load_without_validation()
|
||||
# But avoid *reloading* it since that will run the expensive
|
||||
# validation *again*.
|
||||
load = lambda: unsafe_priv_key
|
||||
|
@ -6,7 +6,7 @@ from __future__ import annotations
|
||||
from six import ensure_str
|
||||
|
||||
try:
|
||||
from typing import Optional, Union, Tuple, Any, TypeVar
|
||||
from typing import Optional, Union, Tuple, Any, TypeVar, Literal, overload
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
@ -708,7 +708,13 @@ def url_for_string(req, url_string):
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
def get_arg(req: IRequest, argname: str | bytes, default: T = None, multiple: bool = False) -> Union[bytes, tuple[bytes, ...], T]:
|
||||
@overload
|
||||
def get_arg(req: IRequest, argname: str | bytes, default: T = None, *, multiple: Literal[False] = False) -> bytes: ...
|
||||
|
||||
@overload
|
||||
def get_arg(req: IRequest, argname: str | bytes, default: T = None, *, multiple: Literal[True]) -> tuple[bytes, ...]: ...
|
||||
|
||||
def get_arg(req: IRequest, argname: str | bytes, default: T | None = None, *, multiple: bool = False) -> None | T | bytes | tuple[bytes, ...]:
|
||||
"""Extract an argument from either the query args (req.args) or the form
|
||||
body fields (req.fields). If multiple=False, this returns a single value
|
||||
(or the default, which defaults to None), and the query args take
|
||||
@ -724,9 +730,6 @@ def get_arg(req: IRequest, argname: str | bytes, default: T = None, multiple: bo
|
||||
else:
|
||||
argname_bytes = argname
|
||||
|
||||
if isinstance(default, str):
|
||||
default = default.encode("utf-8")
|
||||
|
||||
results = []
|
||||
if argname_bytes in req.args:
|
||||
results.extend(req.args[argname_bytes])
|
||||
@ -740,6 +743,9 @@ def get_arg(req: IRequest, argname: str | bytes, default: T = None, multiple: bo
|
||||
return tuple(results)
|
||||
if results:
|
||||
return results[0]
|
||||
|
||||
if isinstance(default, str):
|
||||
return default.encode("utf-8")
|
||||
return default
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user