diff --git a/src/allmydata/storage_client.py b/src/allmydata/storage_client.py index 63859a368..58aa749de 100644 --- a/src/allmydata/storage_client.py +++ b/src/allmydata/storage_client.py @@ -65,6 +65,8 @@ from allmydata.util.assertutil import precondition from allmydata.util.observer import ObserverList from allmydata.util.rrefutil import add_version_to_remote_reference from allmydata.util.hashutil import permute_server_hash +from allmydata.util.dictutil import BytesKeyDict + # who is responsible for de-duplication? # both? @@ -92,7 +94,7 @@ class StorageClientConfig(object): decreasing preference. See the *[client]peers.preferred* documentation for details. - :ivar dict[unicode, dict[bytes, bytes]] storage_plugins: A mapping from + :ivar dict[unicode, dict[unicode, unicode]] storage_plugins: A mapping from names of ``IFoolscapStoragePlugin`` configured in *tahoe.cfg* to the respective configuration. """ @@ -107,24 +109,24 @@ class StorageClientConfig(object): :param _Config config: The loaded Tahoe-LAFS node configuration. """ - ps = config.get_config("client", "peers.preferred", b"").split(b",") - preferred_peers = tuple([p.strip() for p in ps if p != b""]) + ps = config.get_config("client", "peers.preferred", "").split(",") + preferred_peers = tuple([p.strip() for p in ps if p != ""]) enabled_storage_plugins = ( name.strip() for name in config.get_config( - b"client", - b"storage.plugins", - b"", - ).decode("utf-8").split(u",") + "client", + "storage.plugins", + "", + ).split(u",") if name.strip() ) storage_plugins = {} for plugin_name in enabled_storage_plugins: try: - plugin_config = config.items(b"storageclient.plugins." + plugin_name) + plugin_config = config.items("storageclient.plugins." + plugin_name) except NoSectionError: plugin_config = [] storage_plugins[plugin_name] = dict(plugin_config) @@ -173,7 +175,7 @@ class StorageFarmBroker(service.MultiService): # storage servers that we've heard about. Each descriptor manages its # own Reconnector, and will give us a RemoteReference when we ask # them for it. - self.servers = {} + self.servers = BytesKeyDict() self._static_server_ids = set() # ignore announcements for these self.introducer_client = None self._threshold_listeners = [] # tuples of (threshold, Deferred) @@ -198,8 +200,10 @@ class StorageFarmBroker(service.MultiService): # written tests will still fail if a surprising exception # arrives here but they might be harder to debug without this # information. - pass + raise else: + if isinstance(server_id, unicode): + server_id = server_id.encode("utf-8") self._static_server_ids.add(server_id) self.servers[server_id] = storage_server storage_server.setServiceParent(self) @@ -555,7 +559,7 @@ class _FoolscapStorage(object): if isinstance(seed, unicode): seed = seed.encode("utf-8") ps = base32.a2b(seed) - elif re.search(r'^v0-[0-9a-zA-Z]{52}$', server_id): + elif re.search(br'^v0-[0-9a-zA-Z]{52}$', server_id): ps = base32.a2b(server_id[3:]) else: log.msg("unable to parse serverid '%(server_id)s as pubkey, " diff --git a/src/allmydata/test/test_storage_client.py b/src/allmydata/test/test_storage_client.py index 344c03e29..11784adfe 100644 --- a/src/allmydata/test/test_storage_client.py +++ b/src/allmydata/test/test_storage_client.py @@ -504,14 +504,14 @@ class TestStorageFarmBroker(unittest.TestCase): def test_static_servers(self): broker = make_broker() - key_s = 'v0-1234-1' + key_s = b'v0-1234-1' servers_yaml = """\ storage: v0-1234-1: ann: anonymous-storage-FURL: {furl} permutation-seed-base32: aaaaaaaaaaaaaaaaaaaaaaaa -""".format(furl=SOME_FURL) +""".format(furl=SOME_FURL.decode("utf-8")) servers = yamlutil.safe_load(servers_yaml) permseed = base32.a2b(b"aaaaaaaaaaaaaaaaaaaaaaaa") broker.set_static_servers(servers["storage"]) @@ -527,7 +527,7 @@ storage: ann2 = { "service-name": "storage", - "anonymous-storage-FURL": "pb://{}@nowhere/fake2".format(base32.b2a(str(1))), + "anonymous-storage-FURL": "pb://{}@nowhere/fake2".format(base32.b2a(b"1")), "permutation-seed-base32": "bbbbbbbbbbbbbbbbbbbbbbbb", } broker._got_announcement(key_s, ann2) @@ -538,7 +538,7 @@ storage: def test_static_permutation_seed_pubkey(self): broker = make_broker() server_id = b"v0-4uazse3xb6uu5qpkb7tel2bm6bpea4jhuigdhqcuvvse7hugtsia" - k = "4uazse3xb6uu5qpkb7tel2bm6bpea4jhuigdhqcuvvse7hugtsia" + k = b"4uazse3xb6uu5qpkb7tel2bm6bpea4jhuigdhqcuvvse7hugtsia" ann = { "anonymous-storage-FURL": SOME_FURL, } @@ -549,7 +549,7 @@ storage: def test_static_permutation_seed_explicit(self): broker = make_broker() server_id = b"v0-4uazse3xb6uu5qpkb7tel2bm6bpea4jhuigdhqcuvvse7hugtsia" - k = "w5gl5igiexhwmftwzhai5jy2jixn7yx7" + k = b"w5gl5igiexhwmftwzhai5jy2jixn7yx7" ann = { "anonymous-storage-FURL": SOME_FURL, "permutation-seed-base32": k, diff --git a/src/allmydata/util/dictutil.py b/src/allmydata/util/dictutil.py index 3ace8fca4..5fc85fbd1 100644 --- a/src/allmydata/util/dictutil.py +++ b/src/allmydata/util/dictutil.py @@ -14,6 +14,7 @@ if PY2: # subclassing dict, so we'd end up exposing Python 3 dict APIs to lots of # code that doesn't support it. from builtins import filter, map, zip, ascii, chr, hex, input, next, oct, open, pow, round, super, bytes, list, object, range, str, max, min # noqa: F401 +from six import ensure_str class DictOfSets(dict): @@ -76,3 +77,34 @@ class AuxValueDict(dict): have an auxvalue.""" super(AuxValueDict, self).__setitem__(key, value) self.auxilliary[key] = auxilliary + + +class _TypedKeyDict(dict): + """Dictionary that enforces key type. + + Doesn't override everything, but probably good enough to catch most + problems. + + Subclass and override KEY_TYPE. + """ + + KEY_TYPE = object + + +def _make_enforcing_override(K, method_name): + def f(self, key, *args, **kwargs): + assert isinstance(key, self.KEY_TYPE) + return getattr(dict, method_name)(self, key, *args, **kwargs) + f.__name__ = ensure_str(method_name) + setattr(K, method_name, f) + +for _method_name in ["__setitem__", "__getitem__", "setdefault", "get", + "__delitem__"]: + _make_enforcing_override(_TypedKeyDict, _method_name) +del _method_name + + +class BytesKeyDict(_TypedKeyDict): + """Keys should be bytes.""" + + KEY_TYPE = bytes