rrefutil: add check_remote utility function

This commit is contained in:
Brian Warner 2009-02-27 00:59:57 -07:00
parent 1b3e635936
commit 8c3013c4f7
2 changed files with 23 additions and 0 deletions

View File

@ -1204,13 +1204,21 @@ class FakeRemoteReference:
class RemoteFailures(unittest.TestCase): class RemoteFailures(unittest.TestCase):
def test_check(self): def test_check(self):
check_local = rrefutil.check_local
check_remote = rrefutil.check_remote
try: try:
raise IndexError("local missing key") raise IndexError("local missing key")
except IndexError: except IndexError:
localf = Failure() localf = Failure()
self.failUnlessEqual(localf.check(IndexError, KeyError), IndexError) self.failUnlessEqual(localf.check(IndexError, KeyError), IndexError)
self.failUnlessEqual(localf.check(ValueError, KeyError), None) self.failUnlessEqual(localf.check(ValueError, KeyError), None)
self.failUnlessEqual(localf.check(ServerFailure), None) self.failUnlessEqual(localf.check(ServerFailure), None)
self.failUnlessEqual(check_local(localf, IndexError, KeyError),
IndexError)
self.failUnlessEqual(check_local(localf, ValueError, KeyError), None)
self.failUnlessEqual(check_remote(localf, IndexError, KeyError), None)
self.failUnlessEqual(check_remote(localf, ValueError, KeyError), None)
frr = FakeRemoteReference() frr = FakeRemoteReference()
wrr = rrefutil.WrappedRemoteReference(frr) wrr = rrefutil.WrappedRemoteReference(frr)
@ -1219,6 +1227,11 @@ class RemoteFailures(unittest.TestCase):
self.failUnlessEqual(f.check(IndexError, KeyError), None) self.failUnlessEqual(f.check(IndexError, KeyError), None)
self.failUnlessEqual(f.check(ServerFailure, KeyError), self.failUnlessEqual(f.check(ServerFailure, KeyError),
ServerFailure) ServerFailure)
self.failUnlessEqual(check_remote(f, IndexError, KeyError),
IndexError)
self.failUnlessEqual(check_remote(f, ValueError, KeyError), None)
self.failUnlessEqual(check_local(f, IndexError, KeyError), None)
self.failUnlessEqual(check_local(f, ValueError, KeyError), None)
d.addErrback(_check) d.addErrback(_check)
return d return d

View File

@ -21,6 +21,16 @@ def is_remote(f):
def is_local(f): def is_local(f):
return not is_remote(f) return not is_remote(f)
def check_remote(f, *errorTypes):
if is_remote(f):
return f.value.remote_failure.check(*errorTypes)
return None
def check_local(f, *errorTypes):
if is_local(f):
return f.check(*errorTypes)
return None
def trap_remote(f, *errorTypes): def trap_remote(f, *errorTypes):
if is_remote(f): if is_remote(f):
return f.value.remote_failure.trap(*errorTypes) return f.value.remote_failure.trap(*errorTypes)