Add streaming to CBOR results.

This commit is contained in:
Itamar Turner-Trauring 2022-06-29 11:26:25 -04:00
parent efe9575d28
commit 520456bdc0

View File

@ -3,16 +3,16 @@ HTTP server for storage.
""" """
from __future__ import annotations from __future__ import annotations
from typing import Dict, List, Set, Tuple, Any, Callable
from typing import Dict, List, Set, Tuple, Any, Callable
from functools import wraps from functools import wraps
from base64 import b64decode from base64 import b64decode
import binascii import binascii
from tempfile import TemporaryFile
from zope.interface import implementer from zope.interface import implementer
from klein import Klein from klein import Klein
from twisted.web import http from twisted.web import http
from twisted.web.server import NOT_DONE_YET
from twisted.internet.interfaces import ( from twisted.internet.interfaces import (
IListeningPort, IListeningPort,
IStreamServerEndpoint, IStreamServerEndpoint,
@ -37,7 +37,7 @@ from cryptography.x509 import load_pem_x509_certificate
# TODO Make sure to use pure Python versions? # TODO Make sure to use pure Python versions?
from cbor2 import dumps, loads from cbor2 import dump, loads
from pycddl import Schema, ValidationError as CDDLValidationError from pycddl import Schema, ValidationError as CDDLValidationError
from .server import StorageServer from .server import StorageServer
from .http_common import ( from .http_common import (
@ -279,6 +279,10 @@ _SCHEMAS = {
} }
# Callabale that takes offset and length, returns the data at that range.
ReadData = Callable[[int, int], bytes]
@implementer(IPullProducer) @implementer(IPullProducer)
@define @define
class _ReadAllProducer: class _ReadAllProducer:
@ -288,10 +292,20 @@ class _ReadAllProducer:
""" """
request: Request request: Request
read_data: Callable[[int, int], bytes] read_data: ReadData
result: Deferred result: Deferred = field(factory=Deferred)
start: int = field(default=0) start: int = field(default=0)
@classmethod
def produce_to(cls, request: Request, read_data: ReadData) -> Deferred:
"""
Create and register the producer, returning ``Deferred`` that should be
returned from a HTTP server endpoint.
"""
producer = cls(request, read_data)
request.registerProducer(producer, False)
return producer.result
def resumeProducing(self): def resumeProducing(self):
data = self.read_data(self.start, 65536) data = self.read_data(self.start, 65536)
if not data: if not data:
@ -319,7 +333,7 @@ class _ReadRangeProducer:
""" """
request: Request request: Request
read_data: Callable[[int, int], bytes] read_data: ReadData
result: Deferred result: Deferred
start: int start: int
remaining: int remaining: int
@ -356,7 +370,7 @@ class _ReadRangeProducer:
pass pass
def read_range(request: Request, read_data: Callable[[int, int], bytes]) -> None: def read_range(request: Request, read_data: ReadData) -> None:
""" """
Read an optional ``Range`` header, reads data appropriately via the given Read an optional ``Range`` header, reads data appropriately via the given
callable, writes the data to the request. callable, writes the data to the request.
@ -381,11 +395,7 @@ def read_range(request: Request, read_data: Callable[[int, int], bytes]) -> None
return b"" return b""
if request.getHeader("range") is None: if request.getHeader("range") is None:
d = Deferred() return _ReadAllProducer.produce_to(request, read_data_with_error_handling)
request.registerProducer(
_ReadAllProducer(request, read_data_with_error_handling, d), False
)
return d
range_header = parse_range_header(request.getHeader("range")) range_header = parse_range_header(request.getHeader("range"))
if ( if (
@ -459,9 +469,14 @@ class HTTPServer(object):
accept = parse_accept_header(accept_headers[0]) accept = parse_accept_header(accept_headers[0])
if accept.best == CBOR_MIME_TYPE: if accept.best == CBOR_MIME_TYPE:
request.setHeader("Content-Type", CBOR_MIME_TYPE) request.setHeader("Content-Type", CBOR_MIME_TYPE)
# TODO if data is big, maybe want to use a temporary file eventually... f = TemporaryFile()
# https://tahoe-lafs.org/trac/tahoe-lafs/ticket/3872 dump(data, f)
return dumps(data)
def read_data(offset: int, length: int) -> bytes:
f.seek(offset)
return f.read(length)
return _ReadAllProducer.produce_to(request, read_data)
else: else:
# TODO Might want to optionally send JSON someday: # TODO Might want to optionally send JSON someday:
# https://tahoe-lafs.org/trac/tahoe-lafs/ticket/3861 # https://tahoe-lafs.org/trac/tahoe-lafs/ticket/3861