diff --git a/src/allmydata/test/cli/wormholetesting.py b/src/allmydata/test/cli/wormholetesting.py index b60980bff..715e82236 100644 --- a/src/allmydata/test/cli/wormholetesting.py +++ b/src/allmydata/test/cli/wormholetesting.py @@ -32,7 +32,7 @@ For example:: from __future__ import annotations -from typing import Iterator +from typing import Iterator, Optional, Sequence from collections.abc import Awaitable from inspect import getargspec from itertools import count @@ -44,6 +44,11 @@ from wormhole._interfaces import IWormhole from wormhole.wormhole import create from zope.interface import implementer +WormholeCode = str +WormholeMessage = bytes +AppId = str +RelayURL = str +ApplicationKey = tuple[RelayURL, AppId] @define class MemoryWormholeServer(object): @@ -56,8 +61,8 @@ class MemoryWormholeServer(object): :ivar _waiters: Observers waiting for a wormhole to be created for a specific application id and relay URL combination. """ - _apps: dict[tuple[str, str], _WormholeApp] = field(default=Factory(dict)) - _waiters: dict[tuple[str, str], Deferred] = field(default=Factory(dict)) + _apps: dict[ApplicationKey, _WormholeApp] = field(default=Factory(dict)) + _waiters: dict[ApplicationKey, Deferred] = field(default=Factory(dict)) def create( self, @@ -89,7 +94,7 @@ class MemoryWormholeServer(object): self._waiters.pop(key).callback(wormhole) return wormhole - def _view(self, key: tuple[str, str]) -> _WormholeServerView: + def _view(self, key: ApplicationKey) -> _WormholeServerView: """ Created a view onto this server's state that is limited by a certain appid/relay_url pair. @@ -108,7 +113,7 @@ class TestingHelper(object): """ _server: MemoryWormholeServer - async def wait_for_wormhole(self, appid: str, relay_url: str) -> IWormhole: + async def wait_for_wormhole(self, appid: AppId, relay_url: RelayURL) -> IWormhole: """ Wait for a wormhole to appear at a specific location. @@ -120,7 +125,7 @@ class TestingHelper(object): :return: The first wormhole to be created which matches the given parameters. """ - key = relay_url, appid + key = (relay_url, appid) if key in self._server._waiters: raise ValueError(f"There is already a waiter for {key}") d = Deferred() @@ -152,11 +157,11 @@ class _WormholeApp(object): Represent a collection of wormholes that belong to the same appid/relay_url scope. """ - wormholes: dict = field(default=Factory(dict)) - _waiting: dict = field(default=Factory(dict)) + wormholes: dict[WormholeCode, IWormhole] = field(default=Factory(dict)) + _waiting: dict[WormholeCode, Sequence[Deferred]] = field(default=Factory(dict)) _counter: Iterator[int] = field(default=Factory(count)) - def allocate_code(self, wormhole, code): + def allocate_code(self, wormhole: IWormhole, code: Optional[WormholeCode]) -> WormholeCode: """ Allocate a new code for the given wormhole. @@ -179,7 +184,7 @@ class _WormholeApp(object): return code - def wait_for_wormhole(self, code: str) -> Awaitable[_MemoryWormhole]: + def wait_for_wormhole(self, code: WormholeCode) -> Awaitable[_MemoryWormhole]: """ Return a ``Deferred`` which fires with the next wormhole to be associated with the given code. This is used to let the first end of a wormhole @@ -197,9 +202,9 @@ class _WormholeServerView(object): wormholes. """ _server: MemoryWormholeServer - _key: tuple[str, str] + _key: ApplicationKey - def allocate_code(self, wormhole: _MemoryWormhole, code: str) -> str: + def allocate_code(self, wormhole: _MemoryWormhole, code: Optional[WormholeCode]) -> WormholeCode: """ Allocate a new code for the given wormhole in the scope associated with this view. @@ -207,7 +212,7 @@ class _WormholeServerView(object): app = self._server._apps.setdefault(self._key, _WormholeApp()) return app.allocate_code(wormhole, code) - def wormhole_by_code(self, code, exclude): + def wormhole_by_code(self, code: WormholeCode, exclude: object) -> Deferred[IWormhole]: """ Retrieve all wormholes previously associated with a code. """ @@ -228,46 +233,42 @@ class _MemoryWormhole(object): """ _view: _WormholeServerView - _code: str = None + _code: Optional[WormholeCode] = None _payload: DeferredQueue = field(default=Factory(DeferredQueue)) _waiting_for_code: list[Deferred] = field(default=Factory(list)) - _allocated: bool = False - def allocate_code(self): + def allocate_code(self) -> None: if self._code is not None: raise ValueError( "allocate_code used with a wormhole which already has a code" ) - self._allocated = True self._code = self._view.allocate_code(self, None) waiters = self._waiting_for_code self._waiting_for_code = None for d in waiters: d.callback(self._code) - def set_code(self, code): + def set_code(self, code: WormholeCode) -> None: if self._code is None: self._code = code self._view.allocate_code(self, code) else: raise ValueError("set_code used with a wormhole which already has a code") - def when_code(self): + def when_code(self) -> Deferred[WormholeCode]: if self._code is None: d = Deferred() self._waiting_for_code.append(d) return d return succeed(self._code) - get_code = when_code - def get_welcome(self): return succeed("welcome") - def send_message(self, payload): + def send_message(self, payload: WormholeMessage) -> None: self._payload.put(payload) - def when_received(self): + def when_received(self) -> Deferred[WormholeMessage]: if self._code is None: raise ValueError( "This implementation requires set_code or allocate_code " @@ -284,11 +285,11 @@ class _MemoryWormhole(object): get_message = when_received - def close(self): + def close(self) -> None: pass # 0.9.2 compatibility - def get_code(self): + def get_code(self) -> Deferred[WormholeCode]: if self._code is None: self.allocate_code() return self.when_code()