Fix type annotations.

This commit is contained in:
Itamar Turner-Trauring 2023-04-12 17:00:31 -04:00
parent 2a7616e0be
commit 3997eaaf90

View File

@ -5,7 +5,7 @@ HTTP client that talks to the HTTP storage server.
from __future__ import annotations from __future__ import annotations
from eliot import start_action, register_exception_extractor from eliot import start_action, register_exception_extractor
from typing import Union, Optional, Sequence, Mapping, BinaryIO from typing import Union, Optional, Sequence, Mapping, BinaryIO, cast, TypedDict
from base64 import b64encode from base64 import b64encode
from io import BytesIO from io import BytesIO
from os import SEEK_END from os import SEEK_END
@ -486,13 +486,17 @@ class StorageClientGeneral(object):
""" """
url = self._client.relative_url("/storage/v1/version") url = self._client.relative_url("/storage/v1/version")
response = await self._client.request("GET", url) response = await self._client.request("GET", url)
decoded_response = await self._client.decode_cbor( decoded_response = cast(
response, _SCHEMAS["get_version"] dict[bytes, object],
await self._client.decode_cbor(response, _SCHEMAS["get_version"]),
) )
# Add some features we know are true because the HTTP API # Add some features we know are true because the HTTP API
# specification requires them and because other parts of the storage # specification requires them and because other parts of the storage
# client implementation assumes they will be present. # client implementation assumes they will be present.
decoded_response[b"http://allmydata.org/tahoe/protocols/storage/v1"].update( cast(
dict[bytes, object],
decoded_response[b"http://allmydata.org/tahoe/protocols/storage/v1"],
).update(
{ {
b"tolerates-immutable-read-overrun": True, b"tolerates-immutable-read-overrun": True,
b"delete-mutable-shares-with-zero-length-writev": True, b"delete-mutable-shares-with-zero-length-writev": True,
@ -687,8 +691,9 @@ class StorageClientImmutables(object):
upload_secret=upload_secret, upload_secret=upload_secret,
message_to_serialize=message, message_to_serialize=message,
) )
decoded_response = await self._client.decode_cbor( decoded_response = cast(
response, _SCHEMAS["allocate_buckets"] dict[str, set[int]],
await self._client.decode_cbor(response, _SCHEMAS["allocate_buckets"]),
) )
return ImmutableCreateResult( return ImmutableCreateResult(
already_have=decoded_response["already-have"], already_have=decoded_response["already-have"],
@ -763,8 +768,11 @@ class StorageClientImmutables(object):
raise ClientException( raise ClientException(
response.code, response.code,
) )
body = await self._client.decode_cbor( body = cast(
dict[str, list[dict[str, int]]],
await self._client.decode_cbor(
response, _SCHEMAS["immutable_write_share_chunk"] response, _SCHEMAS["immutable_write_share_chunk"]
),
) )
remaining = RangeMap() remaining = RangeMap()
for chunk in body["required"]: for chunk in body["required"]:
@ -794,7 +802,10 @@ class StorageClientImmutables(object):
url, url,
) )
if response.code == http.OK: if response.code == http.OK:
body = await self._client.decode_cbor(response, _SCHEMAS["list_shares"]) body = cast(
set[int],
await self._client.decode_cbor(response, _SCHEMAS["list_shares"]),
)
return set(body) return set(body)
else: else:
raise ClientException(response.code) raise ClientException(response.code)
@ -865,6 +876,12 @@ class ReadTestWriteResult:
reads: Mapping[int, Sequence[bytes]] reads: Mapping[int, Sequence[bytes]]
# Result type for mutable read/test/write HTTP response.
MUTABLE_RTW = TypedDict(
"MUTABLE_RTW", {"success": bool, "data": dict[int, list[bytes]]}
)
@frozen @frozen
class StorageClientMutables: class StorageClientMutables:
""" """
@ -911,8 +928,11 @@ class StorageClientMutables:
message_to_serialize=message, message_to_serialize=message,
) )
if response.code == http.OK: if response.code == http.OK:
result = await self._client.decode_cbor( result = cast(
MUTABLE_RTW,
await self._client.decode_cbor(
response, _SCHEMAS["mutable_read_test_write"] response, _SCHEMAS["mutable_read_test_write"]
),
) )
return ReadTestWriteResult(success=result["success"], reads=result["data"]) return ReadTestWriteResult(success=result["success"], reads=result["data"])
else: else:
@ -942,8 +962,11 @@ class StorageClientMutables:
) )
response = await self._client.request("GET", url) response = await self._client.request("GET", url)
if response.code == http.OK: if response.code == http.OK:
return await self._client.decode_cbor( return cast(
set[int],
await self._client.decode_cbor(
response, _SCHEMAS["mutable_list_shares"] response, _SCHEMAS["mutable_list_shares"]
),
) )
else: else:
raise ClientException(response.code) raise ClientException(response.code)