diff --git a/newsfragments/4005.minor b/newsfragments/4005.minor new file mode 100644 index 000000000..e69de29bb diff --git a/src/allmydata/storage/http_client.py b/src/allmydata/storage/http_client.py index ea142ed85..f786b8f30 100644 --- a/src/allmydata/storage/http_client.py +++ b/src/allmydata/storage/http_client.py @@ -5,7 +5,7 @@ HTTP client that talks to the HTTP storage server. from __future__ import annotations from eliot import start_action, register_exception_extractor -from typing import Union, Optional, Sequence, Mapping, BinaryIO +from typing import Union, Optional, Sequence, Mapping, BinaryIO, cast, TypedDict, Set from base64 import b64encode from io import BytesIO from os import SEEK_END @@ -20,7 +20,7 @@ from werkzeug.datastructures import Range, ContentRange from twisted.web.http_headers import Headers from twisted.web import http from twisted.web.iweb import IPolicyForHTTPS, IResponse -from twisted.internet.defer import inlineCallbacks, returnValue, fail, Deferred, succeed +from twisted.internet.defer import inlineCallbacks, Deferred, succeed from twisted.internet.interfaces import ( IOpenSSLClientConnectionCreator, IReactorTime, @@ -447,24 +447,28 @@ class StorageClient(object): method, url, headers=headers, timeout=timeout, **kwargs ) - def decode_cbor(self, response, schema: Schema): + async def decode_cbor(self, response, schema: Schema) -> object: """Given HTTP response, return decoded CBOR body.""" - - def got_content(f: BinaryIO): - data = f.read() - schema.validate_cbor(data) - return loads(data) - - if response.code > 199 and response.code < 300: - content_type = get_content_type(response.headers) - if content_type == CBOR_MIME_TYPE: - return limited_content(response, self._clock).addCallback(got_content) + with start_action(action_type="allmydata:storage:http-client:decode-cbor"): + if response.code > 199 and response.code < 300: + content_type = get_content_type(response.headers) + if content_type == CBOR_MIME_TYPE: + f = await limited_content(response, self._clock) + data = f.read() + schema.validate_cbor(data) + return loads(data) + else: + raise ClientException( + -1, + "Server didn't send CBOR, content type is {}".format( + content_type + ), + ) else: - raise ClientException(-1, "Server didn't send CBOR") - else: - return treq.content(response).addCallback( - lambda data: fail(ClientException(response.code, response.phrase, data)) - ) + data = ( + await limited_content(response, self._clock, max_length=10_000) + ).read() + raise ClientException(response.code, response.phrase, data) @define(hash=True) @@ -475,20 +479,24 @@ class StorageClientGeneral(object): _client: StorageClient - @inlineCallbacks - def get_version(self): + @async_to_deferred + async def get_version(self): """ Return the version metadata for the server. """ url = self._client.relative_url("/storage/v1/version") - response = yield self._client.request("GET", url) - decoded_response = yield self._client.decode_cbor( - response, _SCHEMAS["get_version"] + response = await self._client.request("GET", url) + decoded_response = cast( + Mapping[bytes, object], + await self._client.decode_cbor(response, _SCHEMAS["get_version"]), ) # Add some features we know are true because the HTTP API # specification requires them and because other parts of the storage # client implementation assumes they will be present. - decoded_response[b"http://allmydata.org/tahoe/protocols/storage/v1"].update( + cast( + Mapping[bytes, object], + decoded_response[b"http://allmydata.org/tahoe/protocols/storage/v1"], + ).update( { b"tolerates-immutable-read-overrun": True, b"delete-mutable-shares-with-zero-length-writev": True, @@ -496,7 +504,7 @@ class StorageClientGeneral(object): b"prevents-read-past-end-of-share-data": True, } ) - returnValue(decoded_response) + return decoded_response @inlineCallbacks def add_or_renew_lease( @@ -647,16 +655,16 @@ class StorageClientImmutables(object): _client: StorageClient - @inlineCallbacks - def create( + @async_to_deferred + async def create( self, - storage_index, - share_numbers, - allocated_size, - upload_secret, - lease_renew_secret, - lease_cancel_secret, - ): # type: (bytes, set[int], int, bytes, bytes, bytes) -> Deferred[ImmutableCreateResult] + storage_index: bytes, + share_numbers: set[int], + allocated_size: int, + upload_secret: bytes, + lease_renew_secret: bytes, + lease_cancel_secret: bytes, + ) -> ImmutableCreateResult: """ Create a new storage index for an immutable. @@ -675,7 +683,7 @@ class StorageClientImmutables(object): ) message = {"share-numbers": share_numbers, "allocated-size": allocated_size} - response = yield self._client.request( + response = await self._client.request( "POST", url, lease_renew_secret=lease_renew_secret, @@ -683,14 +691,13 @@ class StorageClientImmutables(object): upload_secret=upload_secret, message_to_serialize=message, ) - decoded_response = yield self._client.decode_cbor( - response, _SCHEMAS["allocate_buckets"] + decoded_response = cast( + Mapping[str, Set[int]], + await self._client.decode_cbor(response, _SCHEMAS["allocate_buckets"]), ) - returnValue( - ImmutableCreateResult( - already_have=decoded_response["already-have"], - allocated=decoded_response["allocated"], - ) + return ImmutableCreateResult( + already_have=decoded_response["already-have"], + allocated=decoded_response["allocated"], ) @inlineCallbacks @@ -716,10 +723,15 @@ class StorageClientImmutables(object): response.code, ) - @inlineCallbacks - def write_share_chunk( - self, storage_index, share_number, upload_secret, offset, data - ): # type: (bytes, int, bytes, int, bytes) -> Deferred[UploadProgress] + @async_to_deferred + async def write_share_chunk( + self, + storage_index: bytes, + share_number: int, + upload_secret: bytes, + offset: int, + data: bytes, + ) -> UploadProgress: """ Upload a chunk of data for a specific share. @@ -737,7 +749,7 @@ class StorageClientImmutables(object): _encode_si(storage_index), share_number ) ) - response = yield self._client.request( + response = await self._client.request( "PATCH", url, upload_secret=upload_secret, @@ -761,13 +773,16 @@ class StorageClientImmutables(object): raise ClientException( response.code, ) - body = yield self._client.decode_cbor( - response, _SCHEMAS["immutable_write_share_chunk"] + body = cast( + Mapping[str, Sequence[Mapping[str, int]]], + await self._client.decode_cbor( + response, _SCHEMAS["immutable_write_share_chunk"] + ), ) remaining = RangeMap() for chunk in body["required"]: remaining.set(True, chunk["begin"], chunk["end"]) - returnValue(UploadProgress(finished=finished, required=remaining)) + return UploadProgress(finished=finished, required=remaining) def read_share_chunk( self, storage_index, share_number, offset, length @@ -779,21 +794,23 @@ class StorageClientImmutables(object): self._client, "immutable", storage_index, share_number, offset, length ) - @inlineCallbacks - def list_shares(self, storage_index: bytes) -> Deferred[set[int]]: + @async_to_deferred + async def list_shares(self, storage_index: bytes) -> Set[int]: """ Return the set of shares for a given storage index. """ url = self._client.relative_url( "/storage/v1/immutable/{}/shares".format(_encode_si(storage_index)) ) - response = yield self._client.request( + response = await self._client.request( "GET", url, ) if response.code == http.OK: - body = yield self._client.decode_cbor(response, _SCHEMAS["list_shares"]) - returnValue(set(body)) + return cast( + Set[int], + await self._client.decode_cbor(response, _SCHEMAS["list_shares"]), + ) else: raise ClientException(response.code) @@ -863,6 +880,13 @@ class ReadTestWriteResult: reads: Mapping[int, Sequence[bytes]] +# Result type for mutable read/test/write HTTP response. Can't just use +# dict[int,list[bytes]] because on Python 3.8 that will error out. +MUTABLE_RTW = TypedDict( + "MUTABLE_RTW", {"success": bool, "data": Mapping[int, Sequence[bytes]]} +) + + @frozen class StorageClientMutables: """ @@ -909,8 +933,11 @@ class StorageClientMutables: message_to_serialize=message, ) if response.code == http.OK: - result = await self._client.decode_cbor( - response, _SCHEMAS["mutable_read_test_write"] + result = cast( + MUTABLE_RTW, + await self._client.decode_cbor( + response, _SCHEMAS["mutable_read_test_write"] + ), ) return ReadTestWriteResult(success=result["success"], reads=result["data"]) else: @@ -931,7 +958,7 @@ class StorageClientMutables: ) @async_to_deferred - async def list_shares(self, storage_index: bytes) -> set[int]: + async def list_shares(self, storage_index: bytes) -> Set[int]: """ List the share numbers for a given storage index. """ @@ -940,8 +967,11 @@ class StorageClientMutables: ) response = await self._client.request("GET", url) if response.code == http.OK: - return await self._client.decode_cbor( - response, _SCHEMAS["mutable_list_shares"] + return cast( + Set[int], + await self._client.decode_cbor( + response, _SCHEMAS["mutable_list_shares"] + ), ) else: raise ClientException(response.code) diff --git a/src/allmydata/test/test_storage_http.py b/src/allmydata/test/test_storage_http.py index ea93ad360..eca2be1c1 100644 --- a/src/allmydata/test/test_storage_http.py +++ b/src/allmydata/test/test_storage_http.py @@ -34,7 +34,7 @@ from hyperlink import DecodedURL from collections_extended import RangeMap from twisted.internet.task import Clock, Cooperator from twisted.internet.interfaces import IReactorTime, IReactorFromThreads -from twisted.internet.defer import CancelledError, Deferred +from twisted.internet.defer import CancelledError, Deferred, ensureDeferred from twisted.web import http from twisted.web.http_headers import Headers from werkzeug import routing @@ -520,6 +520,7 @@ class HttpTestFixture(Fixture): Like ``result_of``, but supports fake reactor and ``treq`` testing infrastructure necessary to support asynchronous HTTP server endpoints. """ + d = ensureDeferred(d) result = [] error = [] d.addCallbacks(result.append, error.append)