PollMixin: add timeout= argument, rewrite to avoid tail-recursion problems

This commit is contained in:
Brian Warner 2008-02-04 20:35:07 -07:00
parent f4cbd5ca34
commit 3a5ba35215
2 changed files with 34 additions and 18 deletions

View File

@ -374,19 +374,25 @@ class PollMixinTests(unittest.TestCase):
def setUp(self): def setUp(self):
self.pm = testutil.PollMixin() self.pm = testutil.PollMixin()
def _check(self, d):
def fail_unless_arg_is_true(arg):
self.failUnless(arg is True, repr(arg))
d.addCallback(fail_unless_arg_is_true)
return d
def test_PollMixin_True(self): def test_PollMixin_True(self):
d = self.pm.poll(check_f=lambda : True, d = self.pm.poll(check_f=lambda : True,
pollinterval=0.1) pollinterval=0.1)
return self._check(d) return d
def test_PollMixin_False_then_True(self): def test_PollMixin_False_then_True(self):
i = iter([False, True]) i = iter([False, True])
d = self.pm.poll(check_f=i.next, d = self.pm.poll(check_f=i.next,
pollinterval=0.1) pollinterval=0.1)
return self._check(d) return d
def test_timeout(self):
d = self.pm.poll(check_f=lambda: False,
pollinterval=0.01,
timeout=1)
def _suc(res):
self.fail("poll should have failed, not returned %s" % (res,))
def _err(f):
f.trap(testutil.TimeoutError)
return None # success
d.addCallbacks(_suc, _err)
return d

View File

@ -1,6 +1,6 @@
import os, signal, time import os, signal, time
from twisted.internet import reactor, defer from twisted.internet import reactor, defer, task
from twisted.python import failure from twisted.python import failure
@ -30,22 +30,32 @@ class SignalMixin:
if self.sigchldHandler: if self.sigchldHandler:
signal.signal(signal.SIGCHLD, self.sigchldHandler) signal.signal(signal.SIGCHLD, self.sigchldHandler)
class TimeoutError(Exception):
pass
class PollMixin: class PollMixin:
def poll(self, check_f, pollinterval=0.01): def poll(self, check_f, pollinterval=0.01, timeout=None):
# Return a Deferred, then call check_f periodically until it returns # Return a Deferred, then call check_f periodically until it returns
# True, at which point the Deferred will fire.. If check_f raises an # True, at which point the Deferred will fire.. If check_f raises an
# exception, the Deferred will errback. # exception, the Deferred will errback. If the check_f does not
d = defer.maybeDeferred(self._poll, None, check_f, pollinterval) # indicate success within timeout= seconds, the Deferred will
# errback. If timeout=None, no timeout will be enforced.
cutoff = None
if timeout is not None:
cutoff = time.time() + timeout
stash = [] # ick. We have to pass the LoopingCall into itself
lc = task.LoopingCall(self._poll, check_f, stash, cutoff)
stash.append(lc)
d = lc.start(pollinterval)
return d return d
def _poll(self, res, check_f, pollinterval): def _poll(self, check_f, stash, cutoff):
if cutoff is not None and time.time() > cutoff:
raise TimeoutError()
lc = stash[0]
if check_f(): if check_f():
return True lc.stop()
d = defer.Deferred()
d.addCallback(self._poll, check_f, pollinterval)
reactor.callLater(pollinterval, d.callback, None)
return d
class ShouldFailMixin: class ShouldFailMixin: