Add streaming to CBOR results.

This commit is contained in:
Itamar Turner-Trauring 2022-06-29 11:26:25 -04:00
parent efe9575d28
commit 520456bdc0

View File

@ -3,16 +3,16 @@ HTTP server for storage.
"""
from __future__ import annotations
from typing import Dict, List, Set, Tuple, Any, Callable
from typing import Dict, List, Set, Tuple, Any, Callable
from functools import wraps
from base64 import b64decode
import binascii
from tempfile import TemporaryFile
from zope.interface import implementer
from klein import Klein
from twisted.web import http
from twisted.web.server import NOT_DONE_YET
from twisted.internet.interfaces import (
IListeningPort,
IStreamServerEndpoint,
@ -37,7 +37,7 @@ from cryptography.x509 import load_pem_x509_certificate
# TODO Make sure to use pure Python versions?
from cbor2 import dumps, loads
from cbor2 import dump, loads
from pycddl import Schema, ValidationError as CDDLValidationError
from .server import StorageServer
from .http_common import (
@ -279,6 +279,10 @@ _SCHEMAS = {
}
# Callabale that takes offset and length, returns the data at that range.
ReadData = Callable[[int, int], bytes]
@implementer(IPullProducer)
@define
class _ReadAllProducer:
@ -288,10 +292,20 @@ class _ReadAllProducer:
"""
request: Request
read_data: Callable[[int, int], bytes]
result: Deferred
read_data: ReadData
result: Deferred = field(factory=Deferred)
start: int = field(default=0)
@classmethod
def produce_to(cls, request: Request, read_data: ReadData) -> Deferred:
"""
Create and register the producer, returning ``Deferred`` that should be
returned from a HTTP server endpoint.
"""
producer = cls(request, read_data)
request.registerProducer(producer, False)
return producer.result
def resumeProducing(self):
data = self.read_data(self.start, 65536)
if not data:
@ -319,7 +333,7 @@ class _ReadRangeProducer:
"""
request: Request
read_data: Callable[[int, int], bytes]
read_data: ReadData
result: Deferred
start: int
remaining: int
@ -356,7 +370,7 @@ class _ReadRangeProducer:
pass
def read_range(request: Request, read_data: Callable[[int, int], bytes]) -> None:
def read_range(request: Request, read_data: ReadData) -> None:
"""
Read an optional ``Range`` header, reads data appropriately via the given
callable, writes the data to the request.
@ -381,11 +395,7 @@ def read_range(request: Request, read_data: Callable[[int, int], bytes]) -> None
return b""
if request.getHeader("range") is None:
d = Deferred()
request.registerProducer(
_ReadAllProducer(request, read_data_with_error_handling, d), False
)
return d
return _ReadAllProducer.produce_to(request, read_data_with_error_handling)
range_header = parse_range_header(request.getHeader("range"))
if (
@ -459,9 +469,14 @@ class HTTPServer(object):
accept = parse_accept_header(accept_headers[0])
if accept.best == CBOR_MIME_TYPE:
request.setHeader("Content-Type", CBOR_MIME_TYPE)
# TODO if data is big, maybe want to use a temporary file eventually...
# https://tahoe-lafs.org/trac/tahoe-lafs/ticket/3872
return dumps(data)
f = TemporaryFile()
dump(data, f)
def read_data(offset: int, length: int) -> bytes:
f.seek(offset)
return f.read(length)
return _ReadAllProducer.produce_to(request, read_data)
else:
# TODO Might want to optionally send JSON someday:
# https://tahoe-lafs.org/trac/tahoe-lafs/ticket/3861