diff --git a/src/allmydata/test/test_util.py b/src/allmydata/test/test_util.py index 284bf5e5f..b59e47a26 100644 --- a/src/allmydata/test/test_util.py +++ b/src/allmydata/test/test_util.py @@ -374,19 +374,25 @@ class PollMixinTests(unittest.TestCase): def setUp(self): self.pm = testutil.PollMixin() - def _check(self, d): - def fail_unless_arg_is_true(arg): - self.failUnless(arg is True, repr(arg)) - d.addCallback(fail_unless_arg_is_true) - return d - def test_PollMixin_True(self): d = self.pm.poll(check_f=lambda : True, pollinterval=0.1) - return self._check(d) + return d def test_PollMixin_False_then_True(self): i = iter([False, True]) d = self.pm.poll(check_f=i.next, pollinterval=0.1) - return self._check(d) + return d + + def test_timeout(self): + d = self.pm.poll(check_f=lambda: False, + pollinterval=0.01, + timeout=1) + def _suc(res): + self.fail("poll should have failed, not returned %s" % (res,)) + def _err(f): + f.trap(testutil.TimeoutError) + return None # success + d.addCallbacks(_suc, _err) + return d diff --git a/src/allmydata/util/testutil.py b/src/allmydata/util/testutil.py index 867cf00c1..13e72911a 100644 --- a/src/allmydata/util/testutil.py +++ b/src/allmydata/util/testutil.py @@ -1,6 +1,6 @@ import os, signal, time -from twisted.internet import reactor, defer +from twisted.internet import reactor, defer, task from twisted.python import failure @@ -30,22 +30,32 @@ class SignalMixin: if self.sigchldHandler: signal.signal(signal.SIGCHLD, self.sigchldHandler) +class TimeoutError(Exception): + pass + class PollMixin: - def poll(self, check_f, pollinterval=0.01): + def poll(self, check_f, pollinterval=0.01, timeout=None): # Return a Deferred, then call check_f periodically until it returns # True, at which point the Deferred will fire.. If check_f raises an - # exception, the Deferred will errback. - d = defer.maybeDeferred(self._poll, None, check_f, pollinterval) + # exception, the Deferred will errback. If the check_f does not + # indicate success within timeout= seconds, the Deferred will + # errback. If timeout=None, no timeout will be enforced. + cutoff = None + if timeout is not None: + cutoff = time.time() + timeout + stash = [] # ick. We have to pass the LoopingCall into itself + lc = task.LoopingCall(self._poll, check_f, stash, cutoff) + stash.append(lc) + d = lc.start(pollinterval) return d - def _poll(self, res, check_f, pollinterval): + def _poll(self, check_f, stash, cutoff): + if cutoff is not None and time.time() > cutoff: + raise TimeoutError() + lc = stash[0] if check_f(): - return True - d = defer.Deferred() - d.addCallback(self._poll, check_f, pollinterval) - reactor.callLater(pollinterval, d.callback, None) - return d + lc.stop() class ShouldFailMixin: