diff --git a/src/allmydata/test/test_iputil.py b/src/allmydata/test/test_iputil.py index e4e524fa7..1963ac2d4 100644 --- a/src/allmydata/test/test_iputil.py +++ b/src/allmydata/test/test_iputil.py @@ -1,14 +1,22 @@ +#!/usr/bin/env python + +from allmydata.util import iputil, testutil from twisted.trial import unittest -from allmydata.util import iputil +import re, sys + +DOTTED_QUAD_RE=re.compile("^[0-9]+\.[0-9]+\.[0-9]+\.[0-9]+$") class ListAddresses(unittest.TestCase): + def test_get_local_ip_for(self): + addr = iputil.get_local_ip_for('127.0.0.1') + self.failUnless(DOTTED_QUAD_RE.match(addr)) + def test_list_async(self): d = iputil.get_local_addresses_async() def _check(addresses): self.failUnless(len(addresses) >= 1) # always have localhost - self.failUnless("127.0.0.1" in addresses, - "didn't see 127.0.0.1 in %s" % (addresses,)) + self.failUnless("127.0.0.1" in addresses, addresses) d.addCallbacks(_check) return d - + test_list_async.timeout=2 diff --git a/src/allmydata/util/iputil.py b/src/allmydata/util/iputil.py index f1180bff2..82752cd26 100644 --- a/src/allmydata/util/iputil.py +++ b/src/allmydata/util/iputil.py @@ -8,6 +8,7 @@ import re, socket, sys # from Twisted from twisted.internet import defer +from twisted.python import log from twisted.internet import reactor from twisted.internet.protocol import DatagramProtocol from twisted.internet.utils import getProcessOutput @@ -23,6 +24,9 @@ def get_local_addresses_async(target='A.ROOT-SERVERS.NET'): pass the address of a host that you are actually trying to be reachable to. """ + if sys.platform == "cygwin": + return _cygwin_hack(target) + addresses = set() addresses.add(get_local_ip_for(target)) @@ -37,13 +41,10 @@ def get_local_addresses_async(target='A.ROOT-SERVERS.NET'): def get_local_ip_for(target): """Find out what our IP address is for use by a given target. - Returns a string that holds the IP address which could be used by - 'target' to connect to us. It might work for them, it might not. + @returns: the IP address as a dotted-quad string which could be used by + 'target' to connect to us. It might work for them, it might not """ - try: - target_ipaddr = socket.gethostbyname(target) - except socket.gaierror: - return "127.0.0.1" + target_ipaddr = socket.gethostbyname(target) udpprot = DatagramProtocol() port = reactor.listenUDP(0, udpprot) udpprot.transport.connect(target_ipaddr, 7) @@ -91,12 +92,20 @@ _netbsd_re = re.compile('^\s+inet (?P
\d+\.\d+\.\d+\.\d+)\s.+$', flags=r # Irix 6.5 _irix_path = '/usr/etc/ifconfig' -_irix_args = ('-a',) # Solaris 2.x _sunos_path = '/usr/sbin/ifconfig' -_sunos_args = ('-a',) +# k: platform string as provided in the value of _platform_map +# v: tuple of (path_to_tool, args, regex,) +_tool_map = { + "linux": (_linux_path, (), _linux_re,), + "win32": (_win32_path, _win32_args, _win32_re,), + "cygwin": (_win32_path, _win32_args, _win32_re,), + "bsd": (_netbsd_path, _netbsd_args, _netbsd_re,), + "irix": (_irix_path, _netbsd_args, _netbsd_re,), + "sunos": (_sunos_path, _netbsd_args, _netbsd_re,), + } def _find_addresses_via_config(): # originally by Greg Smith, hacked by Zooko to conform to Brian's API @@ -104,29 +113,22 @@ def _find_addresses_via_config(): if not platform: raise UnsupportedPlatformError(sys.platform) - if platform in ('win32', 'cygwin',): - l = [] - for executable in which(_win32_path): - l.append(_query(executable, _win32_re, _win32_args)) - dl = defer.DeferredList(l) - def _gather_results(res): - addresses = set() - for r in res: - if r[0]: - addresses.update(r[1]) - return addresses - dl.addCallback(_gather_results) - return dl - elif platform == 'linux': - return _query(_linux_path, _linux_re) - elif platform == 'bsd': - return _query(_netbsd_path, _netbsd_re, _netbsd_args) - elif platform == 'irix' : - return _query(_irix_path, _netbsd_re, _irix_args) - elif platform == 'sunos': - return _query(_sunos_path, _netbsd_re, _sunos_args) + (pathtotool, args, regex,) = _tool_map[platform] + + l = [] + for executable in which(pathtotool): + l.append(_query(executable, args, regex)) + dl = defer.DeferredList(l) + def _gather_results(res): + addresses = set() + for r in res: + if r[0]: + addresses.update(r[1]) + return addresses + dl.addCallback(_gather_results) + return dl -def _query(path, regex, args=()): +def _query(path, args, regex): d = getProcessOutput(path, args) def _parse(output): addresses = set() @@ -140,3 +142,15 @@ def _query(path, regex, args=()): return addresses d.addCallback(_parse) return d + +def _cygwin_hack(target): + res = set() + for h in [target, "localhost", "127.0.0.1",]: + try: + res.add(get_local_ip_for(h)) + except socket.gaierror: + pass + + return defer.succeed(res) + +