Tests and additional check for typed key dicts.

This commit is contained in:
Itamar Turner-Trauring 2020-11-04 13:36:08 -05:00
parent dc818757b6
commit 0a6321cc9a
2 changed files with 68 additions and 1 deletions

View File

@ -12,6 +12,8 @@ from future.utils import PY2
if PY2:
from builtins import filter, map, zip, ascii, chr, hex, input, next, oct, open, pow, round, super, bytes, dict, list, object, range, str, max, min # noqa: F401
from unittest import skipIf
from twisted.trial import unittest
from allmydata.util import dictutil
@ -88,3 +90,59 @@ class DictUtil(unittest.TestCase):
self.failUnlessEqual(sorted(d.keys()), ["one","two"])
self.failUnlessEqual(d["one"], 1)
self.failUnlessEqual(d.get_aux("one"), None)
class TypedKeyDict(unittest.TestCase):
"""Tests for dictionaries that limit keys."""
@skipIf(PY2, "Python 2 doesn't have issues mixing bytes and unicode.")
def setUp(self):
pass
def test_bytes(self):
"""BytesKeyDict is limited to just byte keys."""
self.assertRaises(TypeError, dictutil.BytesKeyDict, {u"hello": 123})
d = dictutil.BytesKeyDict({b"123": 200})
with self.assertRaises(TypeError):
d[u"hello"] = "blah"
with self.assertRaises(TypeError):
d[u"hello"]
with self.assertRaises(TypeError):
del d[u"hello"]
with self.assertRaises(TypeError):
d.setdefault(u"hello", "123")
with self.assertRaises(TypeError):
d.get(u"xcd")
# Byte keys are fine:
self.assertEqual(d, {b"123": 200})
d[b"456"] = 400
self.assertEqual(d[b"456"], 400)
del d[b"456"]
self.assertEqual(d.get(b"456", 50), 50)
self.assertEqual(d.setdefault(b"456", 300), 300)
self.assertEqual(d[b"456"], 300)
def test_unicode(self):
"""UnicodeKeyDict is limited to just byte keys."""
self.assertRaises(TypeError, dictutil.UnicodeKeyDict, {b"hello": 123})
d = dictutil.UnicodeKeyDict({u"123": 200})
with self.assertRaises(TypeError):
d[b"hello"] = "blah"
with self.assertRaises(TypeError):
d[b"hello"]
with self.assertRaises(TypeError):
del d[b"hello"]
with self.assertRaises(TypeError):
d.setdefault(b"hello", "123")
with self.assertRaises(TypeError):
d.get(b"xcd")
# Byte keys are fine:
self.assertEqual(d, {u"123": 200})
d[u"456"] = 400
self.assertEqual(d[u"456"], 400)
del d[u"456"]
self.assertEqual(d.get(u"456", 50), 50)
self.assertEqual(d.setdefault(u"456", 300), 300)
self.assertEqual(d[u"456"], 300)

View File

@ -90,10 +90,19 @@ class _TypedKeyDict(dict):
KEY_TYPE = object
def __init__(self, *args, **kwargs):
dict.__init__(self, *args, **kwargs)
for key in self:
if not isinstance(key, self.KEY_TYPE):
raise TypeError("{} must be of type {}".format(
repr(key), self.KEY_TYPE))
def _make_enforcing_override(K, method_name):
def f(self, key, *args, **kwargs):
assert isinstance(key, self.KEY_TYPE)
if not isinstance(key, self.KEY_TYPE):
raise TypeError("{} must be of type {}".format(
repr(key), self.KEY_TYPE))
return getattr(dict, method_name)(self, key, *args, **kwargs)
f.__name__ = ensure_str(method_name)
setattr(K, method_name, f)