mirror of
https://github.com/tahoe-lafs/tahoe-lafs.git
synced 2024-12-19 04:57:54 +00:00
Use guard and add some tests (integration failing)
This commit is contained in:
parent
7c79f69d03
commit
9de97dbdd5
135
integration/test_streaming_logs.py
Normal file
135
integration/test_streaming_logs.py
Normal file
@ -0,0 +1,135 @@
|
||||
from __future__ import (
|
||||
print_function,
|
||||
unicode_literals,
|
||||
absolute_import,
|
||||
division,
|
||||
)
|
||||
|
||||
from os.path import (
|
||||
join,
|
||||
)
|
||||
from urlparse import (
|
||||
urlsplit,
|
||||
)
|
||||
|
||||
import attr
|
||||
|
||||
from twisted.internet.defer import (
|
||||
Deferred,
|
||||
)
|
||||
from twisted.internet.endpoints import (
|
||||
HostnameEndpoint,
|
||||
)
|
||||
|
||||
from autobahn.twisted.websocket import (
|
||||
WebSocketClientFactory,
|
||||
WebSocketClientProtocol,
|
||||
)
|
||||
|
||||
from allmydata.client import (
|
||||
read_config,
|
||||
)
|
||||
from allmydata.web.private import (
|
||||
SCHEME,
|
||||
)
|
||||
from allmydata.util.eliotutil import (
|
||||
inline_callbacks,
|
||||
)
|
||||
|
||||
import pytest_twisted
|
||||
|
||||
def _url_to_endpoint(reactor, url):
|
||||
netloc = urlsplit(url).netloc
|
||||
host, port = netloc.split(":")
|
||||
return HostnameEndpoint(reactor, host, int(port))
|
||||
|
||||
|
||||
class _StreamingLogClientProtocol(WebSocketClientProtocol):
|
||||
def onOpen(self):
|
||||
self.factory.on_open.callback(self)
|
||||
|
||||
def onMessage(self, payload, isBinary):
|
||||
self.on_message.callback(payload)
|
||||
|
||||
def onClose(self, wasClean, code, reason):
|
||||
self.on_close.callback(reason)
|
||||
|
||||
|
||||
def _connect_client(reactor, api_auth_token, ws_url):
|
||||
factory = WebSocketClientFactory(
|
||||
url=ws_url,
|
||||
headers={
|
||||
"Authorization": "{} {}".format(SCHEME, api_auth_token),
|
||||
}
|
||||
)
|
||||
factory.protocol = _StreamingLogClientProtocol
|
||||
factory.on_open = Deferred()
|
||||
|
||||
endpoint = _url_to_endpoint(reactor, ws_url)
|
||||
return endpoint.connect(factory)
|
||||
|
||||
|
||||
def _race(left, right):
|
||||
"""
|
||||
Wait for the first result from either of two Deferreds.
|
||||
|
||||
Any result, success or failure, causes the return Deferred to fire. It
|
||||
fires with either a Left or a Right instance depending on whether the left
|
||||
or right argument fired first.
|
||||
|
||||
The Deferred that loses the race is cancelled and any result it eventually
|
||||
produces is discarded.
|
||||
"""
|
||||
racing = [True]
|
||||
def got_result(result, which):
|
||||
if racing:
|
||||
racing.pop()
|
||||
loser = which.pick(left, right)
|
||||
loser.cancel()
|
||||
finished.callback(which(result))
|
||||
|
||||
finished = Deferred()
|
||||
left.addBoth(got_result, Left)
|
||||
right.addBoth(got_result, Right)
|
||||
return finished
|
||||
|
||||
|
||||
@attr.s
|
||||
class Left(object):
|
||||
value = attr.ib()
|
||||
|
||||
@classmethod
|
||||
def pick(cls, left, right):
|
||||
return left
|
||||
|
||||
|
||||
@attr.s
|
||||
class Right(object):
|
||||
value = attr.ib()
|
||||
|
||||
@classmethod
|
||||
def pick(cls, left, right):
|
||||
return right
|
||||
|
||||
|
||||
@inline_callbacks
|
||||
def _test_streaming_logs(reactor, temp_dir, alice):
|
||||
cfg = read_config(join(temp_dir, "alice"), "portnum")
|
||||
node_url = cfg.get_config_from_file("node.url")
|
||||
api_auth_token = cfg.get_private_config("api_auth_token")
|
||||
|
||||
ws_url = node_url.replace("http://", "ws://")
|
||||
log_url = ws_url + "private/logs/v1"
|
||||
|
||||
client = yield _connect_client(reactor, api_auth_token, log_url)
|
||||
client.on_close = Deferred()
|
||||
client.on_message = Deferred()
|
||||
|
||||
result = yield _race(client.on_close, client.on_message)
|
||||
|
||||
assert result == Right("some payload")
|
||||
|
||||
|
||||
@pytest_twisted.inlineCallbacks
|
||||
def test_streaming_logs(reactor, temp_dir, alice):
|
||||
yield _test_streaming_logs(reactor, temp_dir, alice)
|
28
src/allmydata/test/web/matchers.py
Normal file
28
src/allmydata/test/web/matchers.py
Normal file
@ -0,0 +1,28 @@
|
||||
import attr
|
||||
|
||||
from testtools.matchers import Mismatch
|
||||
|
||||
@attr.s
|
||||
class _HasResponseCode(object):
|
||||
match_expected_code = attr.ib()
|
||||
|
||||
def match(self, response):
|
||||
actual_code = response.code
|
||||
mismatch = self.match_expected_code.match(actual_code)
|
||||
if mismatch is None:
|
||||
return None
|
||||
return Mismatch(
|
||||
u"Response {} code: {}".format(
|
||||
response,
|
||||
mismatch.describe(),
|
||||
),
|
||||
mismatch.get_details(),
|
||||
)
|
||||
|
||||
def has_response_code(match_expected_code):
|
||||
"""
|
||||
Match a Treq response with the given code.
|
||||
|
||||
:param int expected_code: The HTTP response code expected of the response.
|
||||
"""
|
||||
return _HasResponseCode(match_expected_code)
|
51
src/allmydata/test/web/test_logs.py
Normal file
51
src/allmydata/test/web/test_logs.py
Normal file
@ -0,0 +1,51 @@
|
||||
"""
|
||||
Tests for ``allmydata.web.logs``.
|
||||
"""
|
||||
|
||||
from __future__ import (
|
||||
print_function,
|
||||
unicode_literals,
|
||||
absolute_import,
|
||||
division,
|
||||
)
|
||||
|
||||
from testtools.matchers import (
|
||||
Always,
|
||||
)
|
||||
from testtools.twistedsupport import (
|
||||
succeeded,
|
||||
)
|
||||
|
||||
from treq.client import (
|
||||
HTTPClient,
|
||||
)
|
||||
from treq.testing import (
|
||||
RequestTraversalAgent,
|
||||
)
|
||||
|
||||
from ..common import (
|
||||
SyncTestCase,
|
||||
)
|
||||
|
||||
from ...web.logs import (
|
||||
create_log_resources,
|
||||
)
|
||||
|
||||
class StreamingEliotLogsTests(SyncTestCase):
|
||||
"""
|
||||
Tests for the log streaming resources created by ``create_log_resources``.
|
||||
"""
|
||||
def setUp(self):
|
||||
self.resource = create_log_resources()
|
||||
self.agent = RequestTraversalAgent(self.resource)
|
||||
self.client = HTTPClient(self.agent)
|
||||
return super(StreamingEliotLogsTests, self).setUp()
|
||||
|
||||
def test_v1(self):
|
||||
"""
|
||||
There is a resource at *logs/v1*.
|
||||
"""
|
||||
self.assertThat(
|
||||
self.client.head(b"http:///logs/v1"),
|
||||
succeeded(Always()),
|
||||
)
|
110
src/allmydata/test/web/test_private.py
Normal file
110
src/allmydata/test/web/test_private.py
Normal file
@ -0,0 +1,110 @@
|
||||
"""
|
||||
Tests for ``allmydata.web.private``.
|
||||
"""
|
||||
|
||||
from __future__ import (
|
||||
print_function,
|
||||
unicode_literals,
|
||||
absolute_import,
|
||||
division,
|
||||
)
|
||||
|
||||
from testtools.matchers import (
|
||||
Equals,
|
||||
)
|
||||
from testtools.twistedsupport import (
|
||||
succeeded,
|
||||
)
|
||||
|
||||
from twisted.web.http import (
|
||||
UNAUTHORIZED,
|
||||
NOT_FOUND,
|
||||
)
|
||||
from twisted.web.http_headers import (
|
||||
Headers,
|
||||
)
|
||||
|
||||
from treq.client import (
|
||||
HTTPClient,
|
||||
)
|
||||
from treq.testing import (
|
||||
RequestTraversalAgent,
|
||||
)
|
||||
|
||||
from ..common import (
|
||||
SyncTestCase,
|
||||
)
|
||||
|
||||
from ...web.private import (
|
||||
SCHEME,
|
||||
create_private_tree,
|
||||
)
|
||||
|
||||
from .matchers import (
|
||||
has_response_code,
|
||||
)
|
||||
|
||||
class PrivacyTests(SyncTestCase):
|
||||
"""
|
||||
Tests for the privacy features of the resources created by ``create_private_tree``.
|
||||
"""
|
||||
def setUp(self):
|
||||
self.token = u"abcdef"
|
||||
self.resource = create_private_tree(lambda: self.token)
|
||||
self.agent = RequestTraversalAgent(self.resource)
|
||||
self.client = HTTPClient(self.agent)
|
||||
return super(PrivacyTests, self).setUp()
|
||||
|
||||
def _authorization(self, scheme, value):
|
||||
return Headers({
|
||||
u"authorization": [u"{} {}".format(scheme, value)],
|
||||
})
|
||||
|
||||
def test_unauthorized(self):
|
||||
"""
|
||||
A request without an *Authorization* header receives an *Unauthorized* response.
|
||||
"""
|
||||
self.assertThat(
|
||||
self.client.head(b"http:///foo/bar"),
|
||||
succeeded(has_response_code(Equals(UNAUTHORIZED))),
|
||||
)
|
||||
|
||||
def test_wrong_scheme(self):
|
||||
"""
|
||||
A request with an *Authorization* header not containing the Tahoe-LAFS
|
||||
scheme receives an *Unauthorized* response.
|
||||
"""
|
||||
self.assertThat(
|
||||
self.client.head(
|
||||
b"http:///foo/bar",
|
||||
headers=self._authorization(u"basic", self.token),
|
||||
),
|
||||
succeeded(has_response_code(Equals(UNAUTHORIZED))),
|
||||
)
|
||||
|
||||
def test_wrong_token(self):
|
||||
"""
|
||||
A request with an *Authorization* header not containing the expected token
|
||||
receives an *Unauthorized* response.
|
||||
"""
|
||||
self.assertThat(
|
||||
self.client.head(
|
||||
b"http:///foo/bar",
|
||||
headers=self._authorization(SCHEME, u"foo bar"),
|
||||
),
|
||||
succeeded(has_response_code(Equals(UNAUTHORIZED))),
|
||||
)
|
||||
|
||||
def test_authorized(self):
|
||||
"""
|
||||
A request with an *Authorization* header containing the expected scheme
|
||||
and token does not receive an *Unauthorized* response.
|
||||
"""
|
||||
self.assertThat(
|
||||
self.client.head(
|
||||
b"http:///foo/bar",
|
||||
headers=self._authorization(SCHEME, self.token),
|
||||
),
|
||||
# It's a made up URL so we don't get a 200, either, but a 404.
|
||||
succeeded(has_response_code(Equals(NOT_FOUND))),
|
||||
)
|
@ -12,18 +12,12 @@ from autobahn.twisted.websocket import (
|
||||
WebSocketServerFactory,
|
||||
WebSocketServerProtocol,
|
||||
)
|
||||
from autobahn.websocket.types import ConnectionDeny
|
||||
|
||||
import eliot
|
||||
|
||||
from twisted.web.resource import (
|
||||
Resource,
|
||||
)
|
||||
|
||||
from allmydata.util.hashutil import (
|
||||
timing_safe_compare,
|
||||
)
|
||||
|
||||
|
||||
class TokenAuthenticatedWebSocketServerProtocol(WebSocketServerProtocol):
|
||||
"""
|
||||
@ -36,25 +30,12 @@ class TokenAuthenticatedWebSocketServerProtocol(WebSocketServerProtocol):
|
||||
"""
|
||||
WebSocket callback
|
||||
"""
|
||||
if b'authorization' in req.headers:
|
||||
auth = req.headers[b'authorization'].encode('ascii').split(b' ', 1)
|
||||
if len(auth) == 2:
|
||||
tag, token = auth
|
||||
if tag == b"tahoe-lafs":
|
||||
if timing_safe_compare(token, self.factory.tahoe_client.get_auth_token()):
|
||||
# we don't care what WebSocket sub-protocol is
|
||||
# negotiated, nor do we need to send headers to the
|
||||
# client, so we ask Autobahn to just allow this
|
||||
# connection with the defaults. We could return a
|
||||
# (headers, protocol) pair here instead if required.
|
||||
return None
|
||||
|
||||
# everything else -- i.e. no Authorization header, or it's
|
||||
# wrong -- means we deny the websocket connection
|
||||
raise ConnectionDeny(
|
||||
code=ConnectionDeny.NOT_ACCEPTABLE,
|
||||
reason=u"Invalid or missing token"
|
||||
)
|
||||
# we don't care what WebSocket sub-protocol is
|
||||
# negotiated, nor do we need to send headers to the
|
||||
# client, so we ask Autobahn to just allow this
|
||||
# connection with the defaults. We could return a
|
||||
# (headers, protocol) pair here instead if required.
|
||||
return None
|
||||
|
||||
def _received_eliot_log(self, message):
|
||||
"""
|
||||
@ -81,20 +62,13 @@ class TokenAuthenticatedWebSocketServerProtocol(WebSocketServerProtocol):
|
||||
pass
|
||||
|
||||
|
||||
def create_log_streaming_resource(client):
|
||||
"""
|
||||
Create a new resource that accepts WebSocket connections if they
|
||||
include a correct `Authorization: tahoe-lafs <api_auth_token>`
|
||||
header (where `api_auth_token` matches the private configuration
|
||||
value).
|
||||
"""
|
||||
def create_log_streaming_resource():
|
||||
factory = WebSocketServerFactory()
|
||||
factory.tahoe_client = client
|
||||
factory.protocol = TokenAuthenticatedWebSocketServerProtocol
|
||||
return WebSocketResource(factory)
|
||||
|
||||
|
||||
def create_log_resources(client):
|
||||
def create_log_resources():
|
||||
logs = Resource()
|
||||
logs.putChild(b"v1", create_log_streaming_resource(client))
|
||||
logs.putChild(b"v1", create_log_streaming_resource())
|
||||
return logs
|
||||
|
@ -6,15 +6,130 @@ from __future__ import (
|
||||
division,
|
||||
)
|
||||
|
||||
import attr
|
||||
|
||||
from zope.interface import (
|
||||
implementer,
|
||||
)
|
||||
|
||||
from twisted.python.failure import (
|
||||
Failure,
|
||||
)
|
||||
from twisted.internet.defer import (
|
||||
succeed,
|
||||
fail,
|
||||
)
|
||||
from twisted.cred.credentials import (
|
||||
ICredentials,
|
||||
)
|
||||
from twisted.cred.portal import (
|
||||
IRealm,
|
||||
Portal,
|
||||
)
|
||||
from twisted.cred.checkers import (
|
||||
ANONYMOUS,
|
||||
)
|
||||
from twisted.cred.error import (
|
||||
UnauthorizedLogin,
|
||||
)
|
||||
from twisted.web.iweb import (
|
||||
ICredentialFactory,
|
||||
)
|
||||
from twisted.web.resource import (
|
||||
IResource,
|
||||
Resource,
|
||||
)
|
||||
from twisted.web.guard import (
|
||||
HTTPAuthSessionWrapper,
|
||||
)
|
||||
|
||||
from ..util.hashutil import (
|
||||
timing_safe_compare,
|
||||
)
|
||||
|
||||
from .logs import (
|
||||
create_log_resources,
|
||||
)
|
||||
|
||||
def create_private_tree(client):
|
||||
SCHEME = b"tahoe-lafs"
|
||||
|
||||
class IToken(ICredentials):
|
||||
def check(auth_token):
|
||||
pass
|
||||
|
||||
|
||||
@implementer(IToken)
|
||||
@attr.s
|
||||
class Token(object):
|
||||
proposed_token = attr.ib(type=bytes)
|
||||
|
||||
def equals(self, valid_token):
|
||||
return timing_safe_compare(
|
||||
valid_token.encode("ascii"),
|
||||
self.proposed_token,
|
||||
)
|
||||
|
||||
|
||||
@attr.s
|
||||
class TokenChecker(object):
|
||||
get_auth_token = attr.ib()
|
||||
|
||||
credentialInterfaces = [IToken]
|
||||
|
||||
def requestAvatarId(self, credentials):
|
||||
if credentials.equals(self.get_auth_token()):
|
||||
return succeed(ANONYMOUS)
|
||||
return fail(Failure(UnauthorizedLogin()))
|
||||
|
||||
|
||||
@implementer(ICredentialFactory)
|
||||
@attr.s
|
||||
class TokenCredentialFactory(object):
|
||||
scheme = SCHEME
|
||||
authentication_realm = b"tahoe-lafs"
|
||||
|
||||
def getChallenge(self, request):
|
||||
return {b"realm": self.authentication_realm}
|
||||
|
||||
def decode(self, response, request):
|
||||
return Token(response)
|
||||
|
||||
|
||||
@implementer(IRealm)
|
||||
@attr.s
|
||||
class PrivateRealm(object):
|
||||
_root = attr.ib()
|
||||
|
||||
def _logout(self):
|
||||
pass
|
||||
|
||||
def requestAvatar(self, avatarId, mind, *interfaces):
|
||||
if IResource in interfaces:
|
||||
return (IResource, self._root, self._logout)
|
||||
raise NotImplementedError(
|
||||
"PrivateRealm supports IResource not {}".format(interfaces),
|
||||
)
|
||||
|
||||
|
||||
def _create_vulnerable_tree():
|
||||
private = Resource()
|
||||
private.putChild(b"logs", create_log_resources(client))
|
||||
private.putChild(b"logs", create_log_resources())
|
||||
return private
|
||||
|
||||
|
||||
def _create_private_tree(get_auth_token, vulnerable):
|
||||
realm = PrivateRealm(vulnerable)
|
||||
portal = Portal(realm, [TokenChecker(get_auth_token)])
|
||||
return HTTPAuthSessionWrapper(portal, [TokenCredentialFactory()])
|
||||
|
||||
|
||||
def create_private_tree(get_auth_token):
|
||||
"""
|
||||
Create a new resource tree that only allows requests if they include a
|
||||
correct `Authorization: tahoe-lafs <api_auth_token>` header (where
|
||||
`api_auth_token` matches the private configuration value).
|
||||
"""
|
||||
return _create_private_tree(
|
||||
get_auth_token,
|
||||
_create_vulnerable_tree(),
|
||||
)
|
||||
|
@ -173,7 +173,7 @@ class Root(MultiFormatPage):
|
||||
# Handler for everything beneath "/private", an area of the resource
|
||||
# hierarchy which is only accessible with the private per-node API
|
||||
# auth token.
|
||||
self.child_private = create_private_tree(client)
|
||||
self.child_private = create_private_tree(client.get_auth_token)
|
||||
|
||||
self.child_file = FileHandler(client)
|
||||
self.child_named = FileHandler(client)
|
||||
|
Loading…
Reference in New Issue
Block a user