Refactor to unify data structure logic.

This commit is contained in:
Itamar Turner-Trauring 2022-03-08 10:13:37 -05:00
parent 1d007cc573
commit 5203873995

View File

@ -138,9 +138,69 @@ class StorageIndexUploads(object):
# different upload secrets).
upload_secrets = attr.ib(factory=dict) # type: Dict[int,bytes]
def add_upload(self, share_number, upload_secret, bucket):
self.shares[share_number] = bucket
self.upload_secrets[share_number] = upload_secret
@attr.s
class UploadsInProgress(object):
"""
Keep track of uploads for storage indexes.
"""
# Map storage index to corresponding uploads-in-progress
_uploads = attr.ib(type=Dict[bytes, StorageIndexUploads], factory=dict)
def add_write_bucket(
self,
storage_index: bytes,
share_number: int,
upload_secret: bytes,
bucket: BucketWriter,
):
"""Add a new ``BucketWriter`` to be tracked.
TODO 3877 how does a timed-out BucketWriter get removed?!
"""
si_uploads = self._uploads.setdefault(storage_index, StorageIndexUploads())
si_uploads.shares[share_number] = bucket
si_uploads.upload_secrets[share_number] = upload_secret
def get_write_bucket(
self, storage_index: bytes, share_number: int, upload_secret: bytes
) -> BucketWriter:
"""Get the given in-progress immutable share upload."""
try:
# TODO 3877 check the upload secret matches given one
return self._uploads[storage_index].shares[share_number]
except (KeyError, IndexError):
raise _HTTPError(http.NOT_FOUND)
def remove_write_bucket(self, storage_index: bytes, share_number: int):
"""Stop tracking the given ``BucketWriter``."""
uploads_index = self._uploads[storage_index]
uploads_index.shares.pop(share_number)
uploads_index.upload_secrets.pop(share_number)
if not uploads_index.shares:
self._uploads.pop(storage_index)
def validate_upload_secret(
self, storage_index: bytes, share_number: int, upload_secret: bytes
):
"""
Raise an unauthorized-HTTP-response exception if the given
storage_index+share_number have a different upload secret than the
given one.
If the given upload doesn't exist at all, nothing happens.
"""
if storage_index in self._uploads:
try:
in_progress = self._uploads[storage_index]
except KeyError:
return
# For pre-existing upload, make sure password matches.
if share_number in in_progress.upload_secrets and not timing_safe_compare(
in_progress.upload_secrets[share_number], upload_secret
):
raise _HTTPError(http.UNAUTHORIZED)
class StorageIndexConverter(BaseConverter):
@ -155,6 +215,15 @@ class StorageIndexConverter(BaseConverter):
raise ValidationError("Invalid storage index")
class _HTTPError(Exception):
"""
Raise from ``HTTPServer`` endpoint to return the given HTTP response code.
"""
def __init__(self, code: int):
self.code = code
class HTTPServer(object):
"""
A HTTP interface to the storage server.
@ -163,13 +232,19 @@ class HTTPServer(object):
_app = Klein()
_app.url_map.converters["storage_index"] = StorageIndexConverter
@_app.handle_errors(_HTTPError)
def _http_error(self, request, failure):
"""Handle ``_HTTPError`` exceptions."""
request.setResponseCode(failure.value.code)
return b""
def __init__(
self, storage_server, swissnum
): # type: (StorageServer, bytes) -> None
self._storage_server = storage_server
self._swissnum = swissnum
# Maps storage index to StorageIndexUploads:
self._uploads = {} # type: Dict[bytes,StorageIndexUploads]
self._uploads = UploadsInProgress()
def get_resource(self):
"""Return twisted.web ``Resource`` for this object."""
@ -203,18 +278,10 @@ class HTTPServer(object):
upload_secret = authorization[Secrets.UPLOAD]
info = loads(request.content.read())
if storage_index in self._uploads:
for share_number in info["share-numbers"]:
in_progress = self._uploads[storage_index]
# For pre-existing upload, make sure password matches.
if (
share_number in in_progress.upload_secrets
and not timing_safe_compare(
in_progress.upload_secrets[share_number], upload_secret
)
):
request.setResponseCode(http.UNAUTHORIZED)
return b""
for share_number in info["share-numbers"]:
self._uploads.validate_upload_secret(
storage_index, share_number, upload_secret
)
already_got, sharenum_to_bucket = self._storage_server.allocate_buckets(
storage_index,
@ -223,9 +290,10 @@ class HTTPServer(object):
sharenums=info["share-numbers"],
allocated_size=info["allocated-size"],
)
uploads = self._uploads.setdefault(storage_index, StorageIndexUploads())
for share_number, bucket in sharenum_to_bucket.items():
uploads.add_upload(share_number, upload_secret, bucket)
self._uploads.add_write_bucket(
storage_index, share_number, upload_secret, bucket
)
return self._cbor(
request,
@ -250,14 +318,14 @@ class HTTPServer(object):
offset = content_range.start
# TODO 3877 test for checking upload secret
# TODO limit memory usage
# https://tahoe-lafs.org/trac/tahoe-lafs/ticket/3872
data = request.content.read(content_range.stop - content_range.start + 1)
try:
bucket = self._uploads[storage_index].shares[share_number]
except (KeyError, IndexError):
request.setResponseCode(http.NOT_FOUND)
return b""
bucket = self._uploads.get_write_bucket(
storage_index, share_number, authorization[Secrets.UPLOAD]
)
try:
finished = bucket.write(offset, data)