From 98c8e25709579c17c6b37be181dee059cc2a016d Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Fri, 26 Sep 2008 09:57:54 -0700 Subject: [PATCH] netstring: add required_trailer= argument --- src/allmydata/test/test_netstring.py | 4 ++++ src/allmydata/util/netstring.py | 17 ++++++++++++----- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/src/allmydata/test/test_netstring.py b/src/allmydata/test/test_netstring.py index 3923a0e19..5c8199a6f 100644 --- a/src/allmydata/test/test_netstring.py +++ b/src/allmydata/test/test_netstring.py @@ -12,6 +12,10 @@ class Netstring(unittest.TestCase): self.failUnlessRaises(ValueError, split_netstring, a, 3) self.failUnlessRaises(ValueError, split_netstring, a+" extra", 2) self.failUnlessRaises(ValueError, split_netstring, a+" extra", 2, False) + self.failUnlessEqual(split_netstring(a+"++", 2, required_trailer="++"), + ("hello", "world")) + self.failUnlessRaises(ValueError, + split_netstring, a+"+", 2, required_trailer="not") def test_extra(self): a = netstring("hello") diff --git a/src/allmydata/util/netstring.py b/src/allmydata/util/netstring.py index 70a14e01b..a1fe8cb98 100644 --- a/src/allmydata/util/netstring.py +++ b/src/allmydata/util/netstring.py @@ -4,11 +4,18 @@ def netstring(s): assert isinstance(s, str), s # no unicode here return "%d:%s," % (len(s), s,) -def split_netstring(data, numstrings, allow_leftover=False): +def split_netstring(data, numstrings, + allow_leftover=False, + required_trailer=""): """like string.split(), but extracts netstrings. If allow_leftover=False, - returns numstrings elements, and throws ValueError if there was leftover - data. If allow_leftover=True, returns numstrings+1 elements, in which the - last element is the leftover data (possibly an empty string)""" + I return numstrings elements, and throw ValueError if there was leftover + data that does not exactly equal 'required_trailer'. If + allow_leftover=True, required_trailer must be empty, and I return + numstrings+1 elements, in which the last element is the leftover data + (possibly an empty string)""" + + assert not (allow_leftover and required_trailer) + elements = [] assert numstrings >= 0 while data: @@ -25,7 +32,7 @@ def split_netstring(data, numstrings, allow_leftover=False): raise ValueError("ran out of netstrings") if allow_leftover: return tuple(elements + [data]) - if data: + if data != required_trailer: raise ValueError("leftover data in netstrings") return tuple(elements)