fix some errors reported by mypy

This commit is contained in:
Jean-Paul Calderone 2023-01-06 17:12:59 -05:00
parent e829b891b3
commit 2dc6466ef5
6 changed files with 20 additions and 14 deletions

View File

@ -13,7 +13,7 @@ on any of their methods.
from __future__ import annotations from __future__ import annotations
from typing import TypeVar from typing_extensions import TypeAlias
from cryptography.exceptions import InvalidSignature from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.backends import default_backend from cryptography.hazmat.backends import default_backend
@ -24,8 +24,8 @@ from cryptography.hazmat.primitives.serialization import load_der_private_key, l
from allmydata.crypto.error import BadSignature from allmydata.crypto.error import BadSignature
PublicKey = TypeVar("PublicKey", bound=rsa.RSAPublicKey) PublicKey: TypeAlias = rsa.RSAPublicKey
PrivateKey = TypeVar("PrivateKey", bound=rsa.RSAPrivateKey) PrivateKey: TypeAlias = rsa.RSAPrivateKey
# This is the value that was used by `pycryptopp`, and we must continue to use it for # This is the value that was used by `pycryptopp`, and we must continue to use it for
# both backwards compatibility and interoperability. # both backwards compatibility and interoperability.

View File

@ -66,7 +66,7 @@ class UnknownVersionError(BadShareError):
"""The share we received was of a version we don't recognize.""" """The share we received was of a version we don't recognize."""
def encrypt_privkey(writekey: bytes, privkey: rsa.PrivateKey) -> bytes: def encrypt_privkey(writekey: bytes, privkey: bytes) -> bytes:
""" """
For SSK, encrypt a private ("signature") key using the writekey. For SSK, encrypt a private ("signature") key using the writekey.
""" """

View File

@ -3,7 +3,7 @@ Tests for the ``tahoe put`` CLI tool.
""" """
from __future__ import annotations from __future__ import annotations
from typing import Callable, Awaitable, TypeVar from typing import Callable, Awaitable, TypeVar, Any
import os.path import os.path
from twisted.trial import unittest from twisted.trial import unittest
from twisted.python import usage from twisted.python import usage
@ -242,7 +242,7 @@ class Put(GridTestMixin, CLITestMixin, unittest.TestCase):
async def _test_mutable_specified_key( async def _test_mutable_specified_key(
self, self,
run: Callable[[Callable[..., T], FilePath, FilePath], Awaitable[T]], run: Callable[[Any, FilePath, FilePath], Awaitable[tuple[int, bytes, bytes]]],
) -> None: ) -> None:
""" """
A helper for testing mutable creation. A helper for testing mutable creation.

View File

@ -635,7 +635,7 @@ class FakeMutableFileNode(object): # type: ignore # incomplete implementation
keypair: tuple[rsa.PublicKey, rsa.PrivateKey] | None keypair: tuple[rsa.PublicKey, rsa.PrivateKey] | None
): ):
self.all_contents = all_contents self.all_contents = all_contents
self.file_types = {} # storage index => MDMF_VERSION or SDMF_VERSION self.file_types: dict[bytes, int] = {} # storage index => MDMF_VERSION or SDMF_VERSION
self.init_from_cap(make_mutable_file_cap(keypair)) self.init_from_cap(make_mutable_file_cap(keypair))
self._k = default_encoding_parameters['k'] self._k = default_encoding_parameters['k']
self._segsize = default_encoding_parameters['max_segment_size'] self._segsize = default_encoding_parameters['max_segment_size']

View File

@ -90,6 +90,7 @@ class FakeNodeMaker(NodeMaker):
'happy': 7, 'happy': 7,
'max_segment_size':128*1024 # 1024=KiB 'max_segment_size':128*1024 # 1024=KiB
} }
all_contents: dict[bytes, object]
def _create_lit(self, cap): def _create_lit(self, cap):
return FakeCHKFileNode(cap, self.all_contents) return FakeCHKFileNode(cap, self.all_contents)
def _create_immutable(self, cap): def _create_immutable(self, cap):

View File

@ -6,7 +6,7 @@ from __future__ import annotations
from six import ensure_str from six import ensure_str
try: try:
from typing import Optional, Union, Tuple, Any from typing import Optional, Union, Tuple, Any, TypeVar
except ImportError: except ImportError:
pass pass
@ -706,8 +706,9 @@ def url_for_string(req, url_string):
) )
return url return url
T = TypeVar("T")
def get_arg(req, argname, default=None, multiple=False): # type: (IRequest, Union[bytes,str], Any, bool) -> Union[bytes,Tuple[bytes],Any] def get_arg(req: IRequest, argname: str | bytes, default: T = None, multiple: bool = False) -> Union[bytes, tuple[bytes, ...], T]:
"""Extract an argument from either the query args (req.args) or the form """Extract an argument from either the query args (req.args) or the form
body fields (req.fields). If multiple=False, this returns a single value body fields (req.fields). If multiple=False, this returns a single value
(or the default, which defaults to None), and the query args take (or the default, which defaults to None), and the query args take
@ -719,13 +720,17 @@ def get_arg(req, argname, default=None, multiple=False): # type: (IRequest, Uni
:return: Either bytes or tuple of bytes. :return: Either bytes or tuple of bytes.
""" """
if isinstance(argname, str): if isinstance(argname, str):
argname = argname.encode("utf-8") argname_bytes = argname.encode("utf-8")
else:
argname_bytes = argname
if isinstance(default, str): if isinstance(default, str):
default = default.encode("utf-8") default = default.encode("utf-8")
results = [] results = []
if argname in req.args: if argname_bytes in req.args:
results.extend(req.args[argname]) results.extend(req.args[argname_bytes])
argname_unicode = str(argname, "utf-8") argname_unicode = str(argname_bytes, "utf-8")
if req.fields and argname_unicode in req.fields: if req.fields and argname_unicode in req.fields:
value = req.fields[argname_unicode].value value = req.fields[argname_unicode].value
if isinstance(value, str): if isinstance(value, str):
@ -832,7 +837,7 @@ def get_keypair(request: IRequest) -> tuple[PublicKey, PrivateKey] | None:
Load a keypair from a urlsafe-base64-encoded RSA private key in the Load a keypair from a urlsafe-base64-encoded RSA private key in the
**private-key** argument of the given request, if there is one. **private-key** argument of the given request, if there is one.
""" """
privkey_der = get_arg(request, "private-key", None) privkey_der = get_arg(request, "private-key", default=None, multiple=False)
if privkey_der is None: if privkey_der is None:
return None return None
privkey, pubkey = create_signing_keypair_from_string(urlsafe_b64decode(privkey_der)) privkey, pubkey = create_signing_keypair_from_string(urlsafe_b64decode(privkey_der))