rrefutil: add trap_remote utility and friends

This commit is contained in:
Brian Warner 2009-02-27 00:55:24 -07:00
parent 8251572e01
commit 1b3e635936
2 changed files with 105 additions and 9 deletions

View File

@ -5,12 +5,13 @@ import os, time
from StringIO import StringIO
from twisted.trial import unittest
from twisted.internet import defer, reactor
from twisted.python import failure
from twisted.python.failure import Failure
from allmydata.util import base32, idlib, humanreadable, mathutil, hashutil
from allmydata.util import assertutil, fileutil, deferredutil, abbreviate
from allmydata.util import limiter, time_format, pollmixin, cachedir
from allmydata.util import statistics, dictutil
from allmydata.util import statistics, dictutil, rrefutil
from allmydata.util.rrefutil import ServerFailure
class Base32(unittest.TestCase):
def test_b2a_matches_Pythons(self):
@ -535,7 +536,7 @@ class DeferredUtilTests(unittest.TestCase):
self.failUnlessEqual(good, [])
self.failUnlessEqual(len(bad), 1)
f = bad[0]
self.failUnless(isinstance(f, failure.Failure))
self.failUnless(isinstance(f, Failure))
self.failUnless(f.check(ValueError))
class HashUtilTests(unittest.TestCase):
@ -1195,3 +1196,79 @@ class DictUtil(unittest.TestCase):
self.failUnlessEqual(x, "b")
self.failUnlessEqual(d.items(), [("c", 1), ("a", 3)])
class FakeRemoteReference:
def callRemote(self, methname, *args, **kwargs):
return defer.maybeDeferred(self.oops)
def oops(self):
raise IndexError("remote missing key")
class RemoteFailures(unittest.TestCase):
def test_check(self):
try:
raise IndexError("local missing key")
except IndexError:
localf = Failure()
self.failUnlessEqual(localf.check(IndexError, KeyError), IndexError)
self.failUnlessEqual(localf.check(ValueError, KeyError), None)
self.failUnlessEqual(localf.check(ServerFailure), None)
frr = FakeRemoteReference()
wrr = rrefutil.WrappedRemoteReference(frr)
d = wrr.callRemote("oops")
def _check(f):
self.failUnlessEqual(f.check(IndexError, KeyError), None)
self.failUnlessEqual(f.check(ServerFailure, KeyError),
ServerFailure)
d.addErrback(_check)
return d
def test_is_remote(self):
try:
raise IndexError("local missing key")
except IndexError:
localf = Failure()
self.failIf(rrefutil.is_remote(localf))
self.failUnless(rrefutil.is_local(localf))
frr = FakeRemoteReference()
wrr = rrefutil.WrappedRemoteReference(frr)
d = wrr.callRemote("oops")
def _check(f):
self.failUnless(rrefutil.is_remote(f))
self.failIf(rrefutil.is_local(f))
d.addErrback(_check)
return d
def test_trap(self):
try:
raise IndexError("local missing key")
except IndexError:
localf = Failure()
self.failUnlessRaises(Failure, localf.trap, ValueError, KeyError)
self.failUnlessRaises(Failure, localf.trap, ServerFailure)
self.failUnlessEqual(localf.trap(IndexError, KeyError), IndexError)
self.failUnlessEqual(rrefutil.trap_local(localf, IndexError, KeyError),
IndexError)
self.failUnlessRaises(Failure,
rrefutil.trap_remote, localf, ValueError, KeyError)
frr = FakeRemoteReference()
wrr = rrefutil.WrappedRemoteReference(frr)
d = wrr.callRemote("oops")
def _check(f):
self.failUnlessRaises(Failure,
f.trap, ValueError, KeyError)
self.failUnlessRaises(Failure,
f.trap, IndexError)
self.failUnlessEqual(f.trap(ServerFailure), ServerFailure)
self.failUnlessRaises(Failure,
rrefutil.trap_remote, f, ValueError, KeyError)
self.failUnlessEqual(rrefutil.trap_remote(f, IndexError, KeyError),
IndexError)
self.failUnlessRaises(Failure,
rrefutil.trap_local, f, ValueError, KeyError)
self.failUnlessRaises(Failure,
rrefutil.trap_local, f, IndexError)
d.addErrback(_check)
return d

View File

@ -3,9 +3,9 @@ import exceptions
from foolscap.tokens import Violation
class ServerFailure(exceptions.Exception):
# If the server returns a Failure instead of the normal response to a protocol, then this
# exception will be raised, with the Failure that the server returned as its .remote_failure
# attribute.
# If the server returns a Failure instead of the normal response to a
# protocol, then this exception will be raised, with the Failure that the
# server returned as its .remote_failure attribute.
def __init__(self, remote_failure):
self.remote_failure = remote_failure
def __repr__(self):
@ -13,11 +13,30 @@ class ServerFailure(exceptions.Exception):
def __str__(self):
return str(self.remote_failure)
def is_remote(f):
if isinstance(f.value, ServerFailure):
return True
return False
def is_local(f):
return not is_remote(f)
def trap_remote(f, *errorTypes):
if is_remote(f):
return f.value.remote_failure.trap(*errorTypes)
raise f
def trap_local(f, *errorTypes):
if is_local(f):
return f.trap(*errorTypes)
raise f
def _wrap_server_failure(f):
raise ServerFailure(f)
class WrappedRemoteReference(object):
"""I intercept any errback from the server and wrap it in a ServerFailure."""
"""I intercept any errback from the server and wrap it in a
ServerFailure."""
def __init__(self, original):
self.rref = original
@ -37,8 +56,8 @@ class WrappedRemoteReference(object):
return self.rref.dontNotifyOnDisconnect(*args, **kwargs)
class VersionedRemoteReference(WrappedRemoteReference):
"""I wrap a RemoteReference, and add a .version attribute. I also intercept any errback from
the server and wrap it in a ServerFailure."""
"""I wrap a RemoteReference, and add a .version attribute. I also
intercept any errback from the server and wrap it in a ServerFailure."""
def __init__(self, original, version):
WrappedRemoteReference.__init__(self, original)