directories: keep track of your position as you decode netstring after netstring from an input buffer instead of copying the trailing part

This makes decoding linear in the number of netstrings instead of O(N^2).
This commit is contained in:
Zooko O'Whielacronx 2009-07-04 19:51:09 -07:00
parent 4206a2c1c7
commit efafcfb91a
3 changed files with 39 additions and 43 deletions

View File

@ -204,15 +204,16 @@ class NewDirectoryNode:
# rocap, rwcap, metadata), in which the name,rocap,metadata are in # rocap, rwcap, metadata), in which the name,rocap,metadata are in
# cleartext. The 'name' is UTF-8 encoded. The rwcap is formatted as: # cleartext. The 'name' is UTF-8 encoded. The rwcap is formatted as:
# pack("16ss32s", iv, AES(H(writekey+iv), plaintextrwcap), mac) # pack("16ss32s", iv, AES(H(writekey+iv), plaintextrwcap), mac)
assert isinstance(data, str) assert isinstance(data, str), (repr(data), type(data))
# an empty directory is serialized as an empty string # an empty directory is serialized as an empty string
if data == "": if data == "":
return {} return {}
writeable = not self.is_readonly() writeable = not self.is_readonly()
children = {} children = {}
while len(data) > 0: position = 0
entry, data = split_netstring(data, 1, True) while position < len(data):
name, rocap, rwcapdata, metadata_s = split_netstring(entry, 4) entries, position = split_netstring(data, 1, position)
(name, rocap, rwcapdata, metadata_s), subpos = split_netstring(entries[0], 4)
name = name.decode("utf-8") name = name.decode("utf-8")
rwcap = None rwcap = None
if writeable: if writeable:

View File

@ -5,37 +5,32 @@ from allmydata.util.netstring import netstring, split_netstring
class Netstring(unittest.TestCase): class Netstring(unittest.TestCase):
def test_split(self): def test_split(self):
a = netstring("hello") + netstring("world") a = netstring("hello") + netstring("world")
self.failUnlessEqual(split_netstring(a, 2), ("hello", "world")) self.failUnlessEqual(split_netstring(a, 2), (["hello", "world"], len(a)))
self.failUnlessEqual(split_netstring(a, 2, False), ("hello", "world")) self.failUnlessEqual(split_netstring(a, 2, required_trailer=""), (["hello", "world"], len(a)))
self.failUnlessEqual(split_netstring(a, 2, True),
("hello", "world", ""))
self.failUnlessRaises(ValueError, split_netstring, a, 3) self.failUnlessRaises(ValueError, split_netstring, a, 3)
self.failUnlessRaises(ValueError, split_netstring, a+" extra", 2) self.failUnlessRaises(ValueError, split_netstring, a+" extra", 2, required_trailer="")
self.failUnlessRaises(ValueError, split_netstring, a+" extra", 2, False) self.failUnlessEqual(split_netstring(a+" extra", 2), (["hello", "world"], len(a)))
self.failUnlessEqual(split_netstring(a+"++", 2, required_trailer="++"), self.failUnlessEqual(split_netstring(a+"++", 2, required_trailer="++"),
("hello", "world")) (["hello", "world"], len(a)+2))
self.failUnlessRaises(ValueError, self.failUnlessRaises(ValueError,
split_netstring, a+"+", 2, required_trailer="not") split_netstring, a+"+", 2, required_trailer="not")
def test_extra(self): def test_extra(self):
a = netstring("hello") a = netstring("hello")
self.failUnlessEqual(split_netstring(a, 1, True), ("hello", "")) self.failUnlessEqual(split_netstring(a, 1), (["hello"], len(a)))
b = netstring("hello") + "extra stuff" b = netstring("hello") + "extra stuff"
self.failUnlessEqual(split_netstring(b, 1, True), self.failUnlessEqual(split_netstring(b, 1),
("hello", "extra stuff")) (["hello"], len(a)))
def test_nested(self): def test_nested(self):
a = netstring("hello") + netstring("world") + "extra stuff" a = netstring("hello") + netstring("world") + "extra stuff"
b = netstring("a") + netstring("is") + netstring(a) + netstring(".") b = netstring("a") + netstring("is") + netstring(a) + netstring(".")
top = split_netstring(b, 4) (top, pos) = split_netstring(b, 4)
self.failUnlessEqual(len(top), 4) self.failUnlessEqual(len(top), 4)
self.failUnlessEqual(top[0], "a") self.failUnlessEqual(top[0], "a")
self.failUnlessEqual(top[1], "is") self.failUnlessEqual(top[1], "is")
self.failUnlessEqual(top[2], a) self.failUnlessEqual(top[2], a)
self.failUnlessEqual(top[3], ".") self.failUnlessEqual(top[3], ".")
self.failUnlessRaises(ValueError, split_netstring, a, 2) self.failUnlessRaises(ValueError, split_netstring, a, 2, required_trailer="")
self.failUnlessRaises(ValueError, split_netstring, a, 2, False) bottom = split_netstring(a, 2)
bottom = split_netstring(a, 2, True) self.failUnlessEqual(bottom, (["hello", "world"], len(netstring("hello")+netstring("world"))))
self.failUnlessEqual(bottom, ("hello", "world", "extra stuff"))

View File

@ -5,34 +5,34 @@ def netstring(s):
return "%d:%s," % (len(s), s,) return "%d:%s," % (len(s), s,)
def split_netstring(data, numstrings, def split_netstring(data, numstrings,
allow_leftover=False, position=0,
required_trailer=""): required_trailer=None):
"""like string.split(), but extracts netstrings. If allow_leftover=False, """like string.split(), but extracts netstrings. Ignore all bytes of data
I return numstrings elements, and throw ValueError if there was leftover before the 'position' byte. Return a tuple of (list of elements (numstrings
data that does not exactly equal 'required_trailer'. If in length), new position index). The new position index points to the first
allow_leftover=True, required_trailer must be empty, and I return byte which was not consumed (the 'required_trailer', if any, counts as
numstrings+1 elements, in which the last element is the leftover data consumed). If 'required_trailer' is not None, throw ValueError if leftover
(possibly an empty string)""" data does not exactly equal 'required_trailer'."""
assert not (allow_leftover and required_trailer)
assert type(position) in (int, long), (repr(position), type(position))
elements = [] elements = []
assert numstrings >= 0 assert numstrings >= 0
while data: while position < len(data):
colon = data.index(":") colon = data.index(":", position)
length = int(data[:colon]) length = int(data[position:colon])
string = data[colon+1:colon+1+length] string = data[colon+1:colon+1+length]
assert len(string) == length assert len(string) == length, (len(string), length)
elements.append(string) elements.append(string)
assert data[colon+1+length] == "," position = colon+1+length
data = data[colon+1+length+1:] assert data[position] == ",", position
position += 1
if len(elements) == numstrings: if len(elements) == numstrings:
break break
if len(elements) < numstrings: if len(elements) < numstrings:
raise ValueError("ran out of netstrings") raise ValueError("ran out of netstrings")
if allow_leftover: if required_trailer is not None:
return tuple(elements + [data]) if ((len(data) - position) != len(required_trailer)) or (data[position:] != required_trailer):
if data != required_trailer:
raise ValueError("leftover data in netstrings") raise ValueError("leftover data in netstrings")
return tuple(elements) return (elements, position + len(required_trailer))
else:
return (elements, position)