[wip] test form posts

This commit is contained in:
Jean-Paul Calderone 2020-10-21 06:59:59 -04:00
parent 0dcc3e13c0
commit 7f02128973
2 changed files with 158 additions and 4 deletions

View File

@ -0,0 +1,138 @@
"""
Tests for ``allmydata.webish``.
"""
from io import (
BytesIO,
)
from uuid import (
uuid4,
)
from testtools.matchers import (
AfterPreprocessing,
Equals,
)
from twisted.python.filepath import (
FilePath,
)
from twisted.web.test.requesthelper import (
DummyChannel,
)
from twisted.web.client import (
FileBodyProducer,
)
from twisted.internet.task import (
Cooperator,
)
from treq.multipart import (
MultiPartProducer,
)
from ..common import (
SyncTestCase,
)
from ...webish import (
TahoeLAFSRequest,
)
class TahoeLAFSRequestTests(SyncTestCase):
"""
Tests for ``TahoeLAFSRequest``.
"""
def _fields_test(self, method, request_headers, request_body, match_fields):
channel = DummyChannel()
request = TahoeLAFSRequest(
channel,
)
for (k, v) in request_headers.items():
request.requestHeaders.setRawHeaders(k, [v])
request.gotLength(len(request_body))
request.handleContentChunk(request_body)
request.requestReceived(method, b"/", b"HTTP/1.1")
# We don't really care what happened to the request. What we do care
# about is what the `fields` attribute is set to.
self.assertThat(
request.fields,
match_fields,
)
def test_no_form_fields(self):
"""
When a ``GET`` request is received, ``TahoeLAFSRequest.fields`` is None.
"""
self._fields_test(b"GET", {}, b"", Equals(None))
def test_form_fields(self):
"""
When a ``POST`` request is received, form fields are parsed into
``TahoeLAFSRequest.fields``.
"""
form_data, boundary = multipart_formdata([
[param(u"name", u"foo"),
body(u"bar"),
],
[param(u"name", u"baz"),
param(u"filename", u"quux"),
body(u"some file contents"),
],
])
self._fields_test(
b"POST",
{b"content-type": b"multipart/form-data; boundary={}".format(boundary)},
form_data.encode("ascii"),
AfterPreprocessing(
lambda fs: {
k: fs.getvalue(k)
for k
in fs.keys()
},
Equals({
b"foo": b"bar",
b"baz": b"some file contents",
}),
),
)
def param(name, value):
return u"; {}={}".format(name, value)
def body(value):
return u"\r\n\r\n{}".format(value)
def _field(field):
yield u"Content-Disposition: form-data"
for param in field:
yield param
def _multipart_formdata(fields):
for field in fields:
yield u"".join(_field(field)) + u"\r\n"
def multipart_formdata(fields):
"""
Serialize some simple fields into a multipart/form-data string.
:param fields: A list of lists of unicode strings to assemble into the
result. See ``param`` and ``body``.
:return unicode: The given fields combined into a multipart/form-data
string.
"""
boundary = str(uuid4())
parts = list(_multipart_formdata(fields))
parts.insert(0, u"")
return (
(u"--" + boundary + u"\r\n").join(parts),
boundary,
)

View File

@ -1,6 +1,11 @@
import re, time
from cgi import (
FieldStorage,
)
from twisted.application import service, strports, internet
from twisted.web import http, static
from twisted.web import http, server, static
from twisted.internet import defer
from twisted.internet.address import (
IPv4Address,
@ -25,7 +30,7 @@ from .web.storage_plugins import (
# surgery may induce a dependency upon a particular version of twisted.web
parse_qs = http.parse_qs
class MyRequest(appserver.NevowRequest, object):
class TahoeLAFSRequest(server.Request, object):
fields = None
_tahoe_request_had_error = None
@ -34,7 +39,7 @@ class MyRequest(appserver.NevowRequest, object):
This method is not intended for users.
"""
self.content.seek(0,0)
self.content.seek(0)
self.args = {}
self.stack = []
self.setHeader("Referrer-Policy", "no-referrer")
@ -53,6 +58,17 @@ class MyRequest(appserver.NevowRequest, object):
# See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Frame-Options
self.responseHeaders.setRawHeaders("X-Frame-Options", ["DENY"])
if self.method == 'POST':
self.fields = FieldStorage(
self.content,
{
name.lower(): value[-1]
for (name, value)
in self.requestHeaders.getAllRawHeaders()
},
environ={'REQUEST_METHOD': 'POST'})
self.content.seek(0)
# Argument processing.
## The original twisted.web.http.Request.requestReceived code parsed the
@ -176,7 +192,7 @@ class WebishServer(service.MultiService):
def buildServer(self, webport, nodeurl_path, staticdir):
self.webport = webport
self.site = site = appserver.NevowSite(self.root)
self.site.requestFactory = MyRequest
self.site.requestFactory = TahoeLAFSRequest
self.site.remember(MyExceptionHandler(), inevow.ICanHandleException)
self.staticdir = staticdir # so tests can check
if staticdir: