diff --git a/src/allmydata/test/cli/test_invite.py b/src/allmydata/test/cli/test_invite.py index 50f446ae2..5b4871944 100644 --- a/src/allmydata/test/cli/test_invite.py +++ b/src/allmydata/test/cli/test_invite.py @@ -4,8 +4,9 @@ Tests for ``tahoe invite``. import json import os +from functools import partial from os.path import join -from typing import Optional, Sequence +from typing import Awaitable, Callable, Optional, Sequence, TypeVar from twisted.internet import defer from twisted.trial import unittest @@ -16,7 +17,58 @@ from ...util.jsonbytes import dumps_bytes from ..common_util import run_cli from ..no_network import GridTestMixin from .common import CLITestMixin -from .wormholetesting import MemoryWormholeServer, memory_server +from .wormholetesting import IWormhole, MemoryWormholeServer, memory_server + + +async def open_wormhole() -> tuple[Callable, IWormhole, str]: + """ + Create a new in-memory wormhole server, open one end of a wormhole, and + return it and related info. + + :return: A three-tuple allowing use of the wormhole. The first element is + a callable like ``run_cli`` but which will run commands so that they + use the in-memory wormhole server instead of a real one. The second + element is the open wormhole. The third element is the wormhole's + code. + """ + server = MemoryWormholeServer() + options = runner.Options() + options.wormhole = server + reactor = object() + + wormhole = server.create( + "tahoe-lafs.org/invite", + "ws://wormhole.tahoe-lafs.org:4000/v1", + reactor, + ) + code = await wormhole.get_code() + + return (partial(run_cli, options=options), wormhole, code) + + +def send_messages(wormhole: IWormhole, messages: list[dict]) -> None: + """ + Send a list of message through a wormhole. + """ + for msg in messages: + wormhole.send_message(dumps_bytes(msg)) + + +A = TypeVar("A") +B = TypeVar("B") + +def concurrently( + client: Callable[[], Awaitable[A]], + server: Callable[[], Awaitable[B]], +) -> defer.Deferred[tuple[A, B]]: + """ + Run two asynchronous functions concurrently and asynchronously return a + tuple of both their results. + """ + return defer.gatherResults([ + defer.Deferred.fromCoroutine(client()), + defer.Deferred.fromCoroutine(server()), + ]) class Join(GridTestMixin, CLITestMixin, unittest.TestCase): @@ -33,18 +85,8 @@ class Join(GridTestMixin, CLITestMixin, unittest.TestCase): successfully join after an invite """ node_dir = self.mktemp() - server = MemoryWormholeServer() - options = runner.Options() - options.wormhole = server - reactor = object() - - wormhole = server.create( - "tahoe-lafs.org/invite", - "ws://wormhole.tahoe-lafs.org:4000/v1", - reactor, - ) - code = yield wormhole.get_code() - messages = [ + run_cli, wormhole, code = yield defer.Deferred.fromCoroutine(open_wormhole()) + send_messages(wormhole, [ {u"abilities": {u"server-v1": {}}}, { u"shares-needed": 1, @@ -53,15 +95,12 @@ class Join(GridTestMixin, CLITestMixin, unittest.TestCase): u"nickname": u"somethinghopefullyunique", u"introducer": u"pb://foo", }, - ] - for msg in messages: - wormhole.send_message(dumps_bytes(msg)) + ]) rc, out, err = yield run_cli( "create-client", "--join", code, node_dir, - options=options, ) self.assertEqual(0, rc) @@ -86,18 +125,8 @@ class Join(GridTestMixin, CLITestMixin, unittest.TestCase): Server sends JSON with unknown/illegal key """ node_dir = self.mktemp() - server = MemoryWormholeServer() - options = runner.Options() - options.wormhole = server - reactor = object() - - wormhole = server.create( - "tahoe-lafs.org/invite", - "ws://wormhole.tahoe-lafs.org:4000/v1", - reactor, - ) - code = yield wormhole.get_code() - messages = [ + run_cli, wormhole, code = yield defer.Deferred.fromCoroutine(open_wormhole()) + send_messages(wormhole, [ {u"abilities": {u"server-v1": {}}}, { u"shares-needed": 1, @@ -107,15 +136,12 @@ class Join(GridTestMixin, CLITestMixin, unittest.TestCase): u"introducer": u"pb://foo", u"something-else": u"not allowed", }, - ] - for msg in messages: - wormhole.send_message(dumps_bytes(msg)) + ]) rc, out, err = yield run_cli( "create-client", "--join", code, node_dir, - options=options, ) # should still succeed -- just ignores the not-whitelisted @@ -137,7 +163,7 @@ class Invite(GridTestMixin, CLITestMixin, unittest.TestCase): intro_dir, ) - async def _invite_success(self, extra_args: Sequence[bytes] = (), tahoe_config: Optional[byte] = None) -> str: + async def _invite_success(self, extra_args: Sequence[bytes] = (), tahoe_config: Optional[bytes] = None) -> str: """ Exercise an expected-success case of ``tahoe invite``. @@ -217,10 +243,7 @@ class Invite(GridTestMixin, CLITestMixin, unittest.TestCase): ) return invite - invite, _ = await defer.gatherResults(map( - defer.Deferred.fromCoroutine, - [client(), server()], - )) + invite, _ = await concurrently(client, server) return invite @@ -340,10 +363,7 @@ shares.total = 6 # Send some surprising client abilities. other_end.send_message(dumps_bytes({u"abilities": {u"client-v9000": {}}})) - yield defer.gatherResults(map( - defer.Deferred.fromCoroutine, - [client(), server()], - )) + yield concurrently(client, server) @defer.inlineCallbacks @@ -398,10 +418,7 @@ shares.total = 6 # Send a no-abilities message through to the server. other_end.send_message(dumps_bytes({})) - yield defer.gatherResults(map( - defer.Deferred.fromCoroutine, - [client(), server()], - )) + yield concurrently(client, server) @defer.inlineCallbacks @@ -416,18 +433,8 @@ shares.total = 6 with open(join(priv_dir, "introducer.furl"), "w") as f: f.write("pb://fooblam\n") - wormhole_server = MemoryWormholeServer() - options = runner.Options() - options.wormhole = wormhole_server - reactor = object() - - wormhole = wormhole_server.create( - "tahoe-lafs.org/invite", - "ws://wormhole.tahoe-lafs.org:4000/v1", - reactor, - ) - code = yield wormhole.get_code() - messages = [ + run_cli, wormhole, code = yield defer.Deferred.fromCoroutine(open_wormhole()) + send_messages(wormhole, [ {u"abilities": {u"server-v9000": {}}}, { "shares-needed": "1", @@ -436,15 +443,12 @@ shares.total = 6 "nickname": "foo", "introducer": "pb://fooblam", }, - ] - for msg in messages: - wormhole.send_message(dumps_bytes(msg)) + ]) rc, out, err = yield run_cli( "create-client", "--join", code, "foo", - options=options, ) self.assertNotEqual(rc, 0) self.assertIn("Expected 'server-v1' in server abilities", out + err) @@ -461,18 +465,8 @@ shares.total = 6 with open(join(priv_dir, "introducer.furl"), "w") as f: f.write("pb://fooblam\n") - server = MemoryWormholeServer() - options = runner.Options() - options.wormhole = server - reactor = object() - - wormhole = server.create( - "tahoe-lafs.org/invite", - "ws://wormhole.tahoe-lafs.org:4000/v1", - reactor, - ) - code = yield wormhole.get_code() - messages = [ + run_cli, wormhole, code = yield defer.Deferred.fromCoroutine(open_wormhole()) + send_messages(wormhole, [ {}, { "shares-needed": "1", @@ -481,15 +475,12 @@ shares.total = 6 "nickname": "bar", "introducer": "pb://fooblam", }, - ] - for msg in messages: - wormhole.send_message(dumps_bytes(msg)) + ]) rc, out, err = yield run_cli( "create-client", "--join", code, "bar", - options=options, ) self.assertNotEqual(rc, 0) self.assertIn("Expected 'abilities' in server introduction", out + err)