diff --git a/src/allmydata/test/no_network.py b/src/allmydata/test/no_network.py index 495553a83..27c9e3cdb 100644 --- a/src/allmydata/test/no_network.py +++ b/src/allmydata/test/no_network.py @@ -26,13 +26,16 @@ if PY2: from past.builtins import unicode import os +from base64 import b32encode +from functools import ( + partial, +) from zope.interface import implementer from twisted.application import service from twisted.internet import defer from twisted.python.failure import Failure from twisted.web.error import Error from foolscap.api import Referenceable, fireEventually, RemoteException -from base64 import b32encode import treq from allmydata.util.assertutil import _assert @@ -59,14 +62,24 @@ class IntentionalError(Exception): class Marker(object): pass +fireNow = partial(defer.succeed, None) + class LocalWrapper(object): - def __init__(self, original): + def __init__(self, original, fireEventually=fireEventually): + """ + :param Callable[[], Deferred[None]] fireEventually: Get a Deferred + that will fire at some point. This is used to control when + ``callRemote`` calls the remote method. The default value allows + the reactor to iterate before the call happens. Use ``fireNow`` + to call the remote method synchronously. + """ self.original = original self.broken = False self.hung_until = None self.post_call_notifier = None self.disconnectors = {} self.counter_by_methname = {} + self._fireEventually = fireEventually def _clear_counters(self): self.counter_by_methname = {} @@ -82,7 +95,7 @@ class LocalWrapper(object): # selected return values. def wrap(a): if isinstance(a, Referenceable): - return LocalWrapper(a) + return self._wrap(a) else: return a args = tuple([wrap(a) for a in args]) @@ -110,7 +123,7 @@ class LocalWrapper(object): return d2 return _really_call() - d = fireEventually() + d = self._fireEventually() d.addCallback(lambda res: _call()) def _wrap_exception(f): return Failure(RemoteException(f)) @@ -124,10 +137,10 @@ class LocalWrapper(object): if methname == "allocate_buckets": (alreadygot, allocated) = res for shnum in allocated: - allocated[shnum] = LocalWrapper(allocated[shnum]) + allocated[shnum] = self._wrap(allocated[shnum]) if methname == "get_buckets": for shnum in res: - res[shnum] = LocalWrapper(res[shnum]) + res[shnum] = self._wrap(res[shnum]) return res d.addCallback(_return_membrane) if self.post_call_notifier: @@ -141,6 +154,10 @@ class LocalWrapper(object): def dontNotifyOnDisconnect(self, marker): del self.disconnectors[marker] + def _wrap(self, value): + return LocalWrapper(value, self._fireEventually) + + def wrap_storage_server(original): # Much of the upload/download code uses rref.version (which normally # comes from rrefutil.add_version_to_remote_reference). To avoid using a