Use guard and add some tests (integration failing)

This commit is contained in:
Jean-Paul Calderone 2019-03-22 13:47:32 -04:00
parent 7c79f69d03
commit 9de97dbdd5
7 changed files with 451 additions and 38 deletions

View File

@ -0,0 +1,135 @@
from __future__ import (
print_function,
unicode_literals,
absolute_import,
division,
)
from os.path import (
join,
)
from urlparse import (
urlsplit,
)
import attr
from twisted.internet.defer import (
Deferred,
)
from twisted.internet.endpoints import (
HostnameEndpoint,
)
from autobahn.twisted.websocket import (
WebSocketClientFactory,
WebSocketClientProtocol,
)
from allmydata.client import (
read_config,
)
from allmydata.web.private import (
SCHEME,
)
from allmydata.util.eliotutil import (
inline_callbacks,
)
import pytest_twisted
def _url_to_endpoint(reactor, url):
netloc = urlsplit(url).netloc
host, port = netloc.split(":")
return HostnameEndpoint(reactor, host, int(port))
class _StreamingLogClientProtocol(WebSocketClientProtocol):
def onOpen(self):
self.factory.on_open.callback(self)
def onMessage(self, payload, isBinary):
self.on_message.callback(payload)
def onClose(self, wasClean, code, reason):
self.on_close.callback(reason)
def _connect_client(reactor, api_auth_token, ws_url):
factory = WebSocketClientFactory(
url=ws_url,
headers={
"Authorization": "{} {}".format(SCHEME, api_auth_token),
}
)
factory.protocol = _StreamingLogClientProtocol
factory.on_open = Deferred()
endpoint = _url_to_endpoint(reactor, ws_url)
return endpoint.connect(factory)
def _race(left, right):
"""
Wait for the first result from either of two Deferreds.
Any result, success or failure, causes the return Deferred to fire. It
fires with either a Left or a Right instance depending on whether the left
or right argument fired first.
The Deferred that loses the race is cancelled and any result it eventually
produces is discarded.
"""
racing = [True]
def got_result(result, which):
if racing:
racing.pop()
loser = which.pick(left, right)
loser.cancel()
finished.callback(which(result))
finished = Deferred()
left.addBoth(got_result, Left)
right.addBoth(got_result, Right)
return finished
@attr.s
class Left(object):
value = attr.ib()
@classmethod
def pick(cls, left, right):
return left
@attr.s
class Right(object):
value = attr.ib()
@classmethod
def pick(cls, left, right):
return right
@inline_callbacks
def _test_streaming_logs(reactor, temp_dir, alice):
cfg = read_config(join(temp_dir, "alice"), "portnum")
node_url = cfg.get_config_from_file("node.url")
api_auth_token = cfg.get_private_config("api_auth_token")
ws_url = node_url.replace("http://", "ws://")
log_url = ws_url + "private/logs/v1"
client = yield _connect_client(reactor, api_auth_token, log_url)
client.on_close = Deferred()
client.on_message = Deferred()
result = yield _race(client.on_close, client.on_message)
assert result == Right("some payload")
@pytest_twisted.inlineCallbacks
def test_streaming_logs(reactor, temp_dir, alice):
yield _test_streaming_logs(reactor, temp_dir, alice)

View File

@ -0,0 +1,28 @@
import attr
from testtools.matchers import Mismatch
@attr.s
class _HasResponseCode(object):
match_expected_code = attr.ib()
def match(self, response):
actual_code = response.code
mismatch = self.match_expected_code.match(actual_code)
if mismatch is None:
return None
return Mismatch(
u"Response {} code: {}".format(
response,
mismatch.describe(),
),
mismatch.get_details(),
)
def has_response_code(match_expected_code):
"""
Match a Treq response with the given code.
:param int expected_code: The HTTP response code expected of the response.
"""
return _HasResponseCode(match_expected_code)

View File

@ -0,0 +1,51 @@
"""
Tests for ``allmydata.web.logs``.
"""
from __future__ import (
print_function,
unicode_literals,
absolute_import,
division,
)
from testtools.matchers import (
Always,
)
from testtools.twistedsupport import (
succeeded,
)
from treq.client import (
HTTPClient,
)
from treq.testing import (
RequestTraversalAgent,
)
from ..common import (
SyncTestCase,
)
from ...web.logs import (
create_log_resources,
)
class StreamingEliotLogsTests(SyncTestCase):
"""
Tests for the log streaming resources created by ``create_log_resources``.
"""
def setUp(self):
self.resource = create_log_resources()
self.agent = RequestTraversalAgent(self.resource)
self.client = HTTPClient(self.agent)
return super(StreamingEliotLogsTests, self).setUp()
def test_v1(self):
"""
There is a resource at *logs/v1*.
"""
self.assertThat(
self.client.head(b"http:///logs/v1"),
succeeded(Always()),
)

View File

@ -0,0 +1,110 @@
"""
Tests for ``allmydata.web.private``.
"""
from __future__ import (
print_function,
unicode_literals,
absolute_import,
division,
)
from testtools.matchers import (
Equals,
)
from testtools.twistedsupport import (
succeeded,
)
from twisted.web.http import (
UNAUTHORIZED,
NOT_FOUND,
)
from twisted.web.http_headers import (
Headers,
)
from treq.client import (
HTTPClient,
)
from treq.testing import (
RequestTraversalAgent,
)
from ..common import (
SyncTestCase,
)
from ...web.private import (
SCHEME,
create_private_tree,
)
from .matchers import (
has_response_code,
)
class PrivacyTests(SyncTestCase):
"""
Tests for the privacy features of the resources created by ``create_private_tree``.
"""
def setUp(self):
self.token = u"abcdef"
self.resource = create_private_tree(lambda: self.token)
self.agent = RequestTraversalAgent(self.resource)
self.client = HTTPClient(self.agent)
return super(PrivacyTests, self).setUp()
def _authorization(self, scheme, value):
return Headers({
u"authorization": [u"{} {}".format(scheme, value)],
})
def test_unauthorized(self):
"""
A request without an *Authorization* header receives an *Unauthorized* response.
"""
self.assertThat(
self.client.head(b"http:///foo/bar"),
succeeded(has_response_code(Equals(UNAUTHORIZED))),
)
def test_wrong_scheme(self):
"""
A request with an *Authorization* header not containing the Tahoe-LAFS
scheme receives an *Unauthorized* response.
"""
self.assertThat(
self.client.head(
b"http:///foo/bar",
headers=self._authorization(u"basic", self.token),
),
succeeded(has_response_code(Equals(UNAUTHORIZED))),
)
def test_wrong_token(self):
"""
A request with an *Authorization* header not containing the expected token
receives an *Unauthorized* response.
"""
self.assertThat(
self.client.head(
b"http:///foo/bar",
headers=self._authorization(SCHEME, u"foo bar"),
),
succeeded(has_response_code(Equals(UNAUTHORIZED))),
)
def test_authorized(self):
"""
A request with an *Authorization* header containing the expected scheme
and token does not receive an *Unauthorized* response.
"""
self.assertThat(
self.client.head(
b"http:///foo/bar",
headers=self._authorization(SCHEME, self.token),
),
# It's a made up URL so we don't get a 200, either, but a 404.
succeeded(has_response_code(Equals(NOT_FOUND))),
)

View File

@ -12,18 +12,12 @@ from autobahn.twisted.websocket import (
WebSocketServerFactory, WebSocketServerFactory,
WebSocketServerProtocol, WebSocketServerProtocol,
) )
from autobahn.websocket.types import ConnectionDeny
import eliot import eliot
from twisted.web.resource import ( from twisted.web.resource import (
Resource, Resource,
) )
from allmydata.util.hashutil import (
timing_safe_compare,
)
class TokenAuthenticatedWebSocketServerProtocol(WebSocketServerProtocol): class TokenAuthenticatedWebSocketServerProtocol(WebSocketServerProtocol):
""" """
@ -36,12 +30,6 @@ class TokenAuthenticatedWebSocketServerProtocol(WebSocketServerProtocol):
""" """
WebSocket callback WebSocket callback
""" """
if b'authorization' in req.headers:
auth = req.headers[b'authorization'].encode('ascii').split(b' ', 1)
if len(auth) == 2:
tag, token = auth
if tag == b"tahoe-lafs":
if timing_safe_compare(token, self.factory.tahoe_client.get_auth_token()):
# we don't care what WebSocket sub-protocol is # we don't care what WebSocket sub-protocol is
# negotiated, nor do we need to send headers to the # negotiated, nor do we need to send headers to the
# client, so we ask Autobahn to just allow this # client, so we ask Autobahn to just allow this
@ -49,13 +37,6 @@ class TokenAuthenticatedWebSocketServerProtocol(WebSocketServerProtocol):
# (headers, protocol) pair here instead if required. # (headers, protocol) pair here instead if required.
return None return None
# everything else -- i.e. no Authorization header, or it's
# wrong -- means we deny the websocket connection
raise ConnectionDeny(
code=ConnectionDeny.NOT_ACCEPTABLE,
reason=u"Invalid or missing token"
)
def _received_eliot_log(self, message): def _received_eliot_log(self, message):
""" """
While this WebSocket connection is open, this function is While this WebSocket connection is open, this function is
@ -81,20 +62,13 @@ class TokenAuthenticatedWebSocketServerProtocol(WebSocketServerProtocol):
pass pass
def create_log_streaming_resource(client): def create_log_streaming_resource():
"""
Create a new resource that accepts WebSocket connections if they
include a correct `Authorization: tahoe-lafs <api_auth_token>`
header (where `api_auth_token` matches the private configuration
value).
"""
factory = WebSocketServerFactory() factory = WebSocketServerFactory()
factory.tahoe_client = client
factory.protocol = TokenAuthenticatedWebSocketServerProtocol factory.protocol = TokenAuthenticatedWebSocketServerProtocol
return WebSocketResource(factory) return WebSocketResource(factory)
def create_log_resources(client): def create_log_resources():
logs = Resource() logs = Resource()
logs.putChild(b"v1", create_log_streaming_resource(client)) logs.putChild(b"v1", create_log_streaming_resource())
return logs return logs

View File

@ -6,15 +6,130 @@ from __future__ import (
division, division,
) )
import attr
from zope.interface import (
implementer,
)
from twisted.python.failure import (
Failure,
)
from twisted.internet.defer import (
succeed,
fail,
)
from twisted.cred.credentials import (
ICredentials,
)
from twisted.cred.portal import (
IRealm,
Portal,
)
from twisted.cred.checkers import (
ANONYMOUS,
)
from twisted.cred.error import (
UnauthorizedLogin,
)
from twisted.web.iweb import (
ICredentialFactory,
)
from twisted.web.resource import ( from twisted.web.resource import (
IResource,
Resource, Resource,
) )
from twisted.web.guard import (
HTTPAuthSessionWrapper,
)
from ..util.hashutil import (
timing_safe_compare,
)
from .logs import ( from .logs import (
create_log_resources, create_log_resources,
) )
def create_private_tree(client): SCHEME = b"tahoe-lafs"
class IToken(ICredentials):
def check(auth_token):
pass
@implementer(IToken)
@attr.s
class Token(object):
proposed_token = attr.ib(type=bytes)
def equals(self, valid_token):
return timing_safe_compare(
valid_token.encode("ascii"),
self.proposed_token,
)
@attr.s
class TokenChecker(object):
get_auth_token = attr.ib()
credentialInterfaces = [IToken]
def requestAvatarId(self, credentials):
if credentials.equals(self.get_auth_token()):
return succeed(ANONYMOUS)
return fail(Failure(UnauthorizedLogin()))
@implementer(ICredentialFactory)
@attr.s
class TokenCredentialFactory(object):
scheme = SCHEME
authentication_realm = b"tahoe-lafs"
def getChallenge(self, request):
return {b"realm": self.authentication_realm}
def decode(self, response, request):
return Token(response)
@implementer(IRealm)
@attr.s
class PrivateRealm(object):
_root = attr.ib()
def _logout(self):
pass
def requestAvatar(self, avatarId, mind, *interfaces):
if IResource in interfaces:
return (IResource, self._root, self._logout)
raise NotImplementedError(
"PrivateRealm supports IResource not {}".format(interfaces),
)
def _create_vulnerable_tree():
private = Resource() private = Resource()
private.putChild(b"logs", create_log_resources(client)) private.putChild(b"logs", create_log_resources())
return private return private
def _create_private_tree(get_auth_token, vulnerable):
realm = PrivateRealm(vulnerable)
portal = Portal(realm, [TokenChecker(get_auth_token)])
return HTTPAuthSessionWrapper(portal, [TokenCredentialFactory()])
def create_private_tree(get_auth_token):
"""
Create a new resource tree that only allows requests if they include a
correct `Authorization: tahoe-lafs <api_auth_token>` header (where
`api_auth_token` matches the private configuration value).
"""
return _create_private_tree(
get_auth_token,
_create_vulnerable_tree(),
)

View File

@ -173,7 +173,7 @@ class Root(MultiFormatPage):
# Handler for everything beneath "/private", an area of the resource # Handler for everything beneath "/private", an area of the resource
# hierarchy which is only accessible with the private per-node API # hierarchy which is only accessible with the private per-node API
# auth token. # auth token.
self.child_private = create_private_tree(client) self.child_private = create_private_tree(client.get_auth_token)
self.child_file = FileHandler(client) self.child_file = FileHandler(client)
self.child_named = FileHandler(client) self.child_named = FileHandler(client)