fix some errors reported by mypy

This commit is contained in:
Jean-Paul Calderone 2023-01-06 17:12:59 -05:00
parent e829b891b3
commit 2dc6466ef5
6 changed files with 20 additions and 14 deletions

View File

@ -13,7 +13,7 @@ on any of their methods.
from __future__ import annotations
from typing import TypeVar
from typing_extensions import TypeAlias
from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.backends import default_backend
@ -24,8 +24,8 @@ from cryptography.hazmat.primitives.serialization import load_der_private_key, l
from allmydata.crypto.error import BadSignature
PublicKey = TypeVar("PublicKey", bound=rsa.RSAPublicKey)
PrivateKey = TypeVar("PrivateKey", bound=rsa.RSAPrivateKey)
PublicKey: TypeAlias = rsa.RSAPublicKey
PrivateKey: TypeAlias = rsa.RSAPrivateKey
# This is the value that was used by `pycryptopp`, and we must continue to use it for
# both backwards compatibility and interoperability.

View File

@ -66,7 +66,7 @@ class UnknownVersionError(BadShareError):
"""The share we received was of a version we don't recognize."""
def encrypt_privkey(writekey: bytes, privkey: rsa.PrivateKey) -> bytes:
def encrypt_privkey(writekey: bytes, privkey: bytes) -> bytes:
"""
For SSK, encrypt a private ("signature") key using the writekey.
"""

View File

@ -3,7 +3,7 @@ Tests for the ``tahoe put`` CLI tool.
"""
from __future__ import annotations
from typing import Callable, Awaitable, TypeVar
from typing import Callable, Awaitable, TypeVar, Any
import os.path
from twisted.trial import unittest
from twisted.python import usage
@ -242,7 +242,7 @@ class Put(GridTestMixin, CLITestMixin, unittest.TestCase):
async def _test_mutable_specified_key(
self,
run: Callable[[Callable[..., T], FilePath, FilePath], Awaitable[T]],
run: Callable[[Any, FilePath, FilePath], Awaitable[tuple[int, bytes, bytes]]],
) -> None:
"""
A helper for testing mutable creation.

View File

@ -635,7 +635,7 @@ class FakeMutableFileNode(object): # type: ignore # incomplete implementation
keypair: tuple[rsa.PublicKey, rsa.PrivateKey] | None
):
self.all_contents = all_contents
self.file_types = {} # storage index => MDMF_VERSION or SDMF_VERSION
self.file_types: dict[bytes, int] = {} # storage index => MDMF_VERSION or SDMF_VERSION
self.init_from_cap(make_mutable_file_cap(keypair))
self._k = default_encoding_parameters['k']
self._segsize = default_encoding_parameters['max_segment_size']

View File

@ -90,6 +90,7 @@ class FakeNodeMaker(NodeMaker):
'happy': 7,
'max_segment_size':128*1024 # 1024=KiB
}
all_contents: dict[bytes, object]
def _create_lit(self, cap):
return FakeCHKFileNode(cap, self.all_contents)
def _create_immutable(self, cap):

View File

@ -6,7 +6,7 @@ from __future__ import annotations
from six import ensure_str
try:
from typing import Optional, Union, Tuple, Any
from typing import Optional, Union, Tuple, Any, TypeVar
except ImportError:
pass
@ -706,8 +706,9 @@ def url_for_string(req, url_string):
)
return url
T = TypeVar("T")
def get_arg(req, argname, default=None, multiple=False): # type: (IRequest, Union[bytes,str], Any, bool) -> Union[bytes,Tuple[bytes],Any]
def get_arg(req: IRequest, argname: str | bytes, default: T = None, multiple: bool = False) -> Union[bytes, tuple[bytes, ...], T]:
"""Extract an argument from either the query args (req.args) or the form
body fields (req.fields). If multiple=False, this returns a single value
(or the default, which defaults to None), and the query args take
@ -719,13 +720,17 @@ def get_arg(req, argname, default=None, multiple=False): # type: (IRequest, Uni
:return: Either bytes or tuple of bytes.
"""
if isinstance(argname, str):
argname = argname.encode("utf-8")
argname_bytes = argname.encode("utf-8")
else:
argname_bytes = argname
if isinstance(default, str):
default = default.encode("utf-8")
results = []
if argname in req.args:
results.extend(req.args[argname])
argname_unicode = str(argname, "utf-8")
if argname_bytes in req.args:
results.extend(req.args[argname_bytes])
argname_unicode = str(argname_bytes, "utf-8")
if req.fields and argname_unicode in req.fields:
value = req.fields[argname_unicode].value
if isinstance(value, str):
@ -832,7 +837,7 @@ def get_keypair(request: IRequest) -> tuple[PublicKey, PrivateKey] | None:
Load a keypair from a urlsafe-base64-encoded RSA private key in the
**private-key** argument of the given request, if there is one.
"""
privkey_der = get_arg(request, "private-key", None)
privkey_der = get_arg(request, "private-key", default=None, multiple=False)
if privkey_der is None:
return None
privkey, pubkey = create_signing_keypair_from_string(urlsafe_b64decode(privkey_der))