test_mutable: test all hash-failure cases except a corrupted encrypted private key

This commit is contained in:
Brian Warner 2008-03-10 17:46:52 -07:00
parent 647734cd3b
commit be5a6147b4
2 changed files with 180 additions and 18 deletions

View File

@ -49,11 +49,8 @@ SIGNED_PREFIX = ">BQ32s16s BBQQ" # this is covered by the signature
HEADER = ">BQ32s16s BBQQ LLLLQQ" # includes offsets HEADER = ">BQ32s16s BBQQ LLLLQQ" # includes offsets
HEADER_LENGTH = struct.calcsize(HEADER) HEADER_LENGTH = struct.calcsize(HEADER)
def unpack_prefix_and_signature(data): def unpack_header(data):
assert len(data) >= HEADER_LENGTH
o = {} o = {}
prefix = data[:struct.calcsize(SIGNED_PREFIX)]
(version, (version,
seqnum, seqnum,
root_hash, root_hash,
@ -65,6 +62,18 @@ def unpack_prefix_and_signature(data):
o['share_data'], o['share_data'],
o['enc_privkey'], o['enc_privkey'],
o['EOF']) = struct.unpack(HEADER, data[:HEADER_LENGTH]) o['EOF']) = struct.unpack(HEADER, data[:HEADER_LENGTH])
return (version, seqnum, root_hash, IV, k, N, segsize, datalen, o)
def unpack_prefix_and_signature(data):
assert len(data) >= HEADER_LENGTH
prefix = data[:struct.calcsize(SIGNED_PREFIX)]
(version,
seqnum,
root_hash,
IV,
k, N, segsize, datalen,
o) = unpack_header(data)
assert version == 0 assert version == 0
if len(data) < o['share_hash_chain']: if len(data) < o['share_hash_chain']:
@ -535,7 +544,7 @@ class Retrieve:
self._pubkey = self._deserialize_pubkey(pubkey_s) self._pubkey = self._deserialize_pubkey(pubkey_s)
self._node._populate_pubkey(self._pubkey) self._node._populate_pubkey(self._pubkey)
verinfo = (seqnum, root_hash, IV, segsize, datalength) verinfo = (seqnum, root_hash, IV, segsize, datalength) #, k, N)
self._status.sharemap[peerid].add(verinfo) self._status.sharemap[peerid].add(verinfo)
if verinfo not in self._valid_versions: if verinfo not in self._valid_versions:
@ -694,12 +703,12 @@ class Retrieve:
# arbitrary, really I want this to be something like # arbitrary, really I want this to be something like
# k - max(known_version_sharecounts) + some extra # k - max(known_version_sharecounts) + some extra
break break
new_search_distance = max(max(peer_indicies),
self._status.get_search_distance())
self._status.set_search_distance(new_search_distance)
if new_query_peers: if new_query_peers:
self.log("sending %d new queries (read %d bytes)" % self.log("sending %d new queries (read %d bytes)" %
(len(new_query_peers), self._read_size), level=log.UNUSUAL) (len(new_query_peers), self._read_size), level=log.UNUSUAL)
new_search_distance = max(max(peer_indicies),
self._status.get_search_distance())
self._status.set_search_distance(new_search_distance)
for (peerid, ss) in new_query_peers: for (peerid, ss) in new_query_peers:
self._do_query(ss, peerid, self._storage_index, self._read_size) self._do_query(ss, peerid, self._storage_index, self._read_size)
# we'll retrigger when those queries come back # we'll retrigger when those queries come back
@ -802,7 +811,8 @@ class Retrieve:
try: try:
t2.set_hashes(hashes=share_hash_chain, t2.set_hashes(hashes=share_hash_chain,
leaves={shnum: share_hash_leaf}) leaves={shnum: share_hash_leaf})
except (hashtree.BadHashError, hashtree.NotEnoughHashesError), e: except (hashtree.BadHashError, hashtree.NotEnoughHashesError,
IndexError), e:
msg = "corrupt hashes: %s" % (e,) msg = "corrupt hashes: %s" % (e,)
raise CorruptShareError(peerid, shnum, msg) raise CorruptShareError(peerid, shnum, msg)
self.log(" data valid! len=%d" % len(share_data)) self.log(" data valid! len=%d" % len(share_data))
@ -864,16 +874,18 @@ class Retrieve:
self._node._populate_root_hash(root_hash) self._node._populate_root_hash(root_hash)
return plaintext return plaintext
def _done(self, contents): def _done(self, res):
# res is either the new contents, or a Failure
self.log("DONE") self.log("DONE")
self._running = False self._running = False
self._status.set_active(False) self._status.set_active(False)
self._status.set_status("Done") self._status.set_status("Done")
self._status.set_progress(1.0) self._status.set_progress(1.0)
self._status.set_size(len(contents)) if isinstance(res, str):
self._status.set_size(len(res))
elapsed = time.time() - self._started elapsed = time.time() - self._started
self._status.timings["total"] = elapsed self._status.timings["total"] = elapsed
eventually(self._done_deferred.callback, contents) eventually(self._done_deferred.callback, res)
def get_status(self): def get_status(self):
return self._status return self._status

View File

@ -1,5 +1,5 @@
import itertools, struct import itertools, struct, re
from cStringIO import StringIO from cStringIO import StringIO
from twisted.trial import unittest from twisted.trial import unittest
from twisted.internet import defer from twisted.internet import defer
@ -52,6 +52,10 @@ class FakeStorage:
# tests to examine and manipulate the published shares. It also lets us # tests to examine and manipulate the published shares. It also lets us
# control the order in which read queries are answered, to exercise more # control the order in which read queries are answered, to exercise more
# of the error-handling code in mutable.Retrieve . # of the error-handling code in mutable.Retrieve .
#
# Note that we ignore the storage index: this FakeStorage instance can
# only be used for a single storage index.
def __init__(self): def __init__(self):
self._peers = {} self._peers = {}
@ -177,6 +181,12 @@ class FakePubKey:
def serialize(self): def serialize(self):
return "PUBKEY-%d" % self.count return "PUBKEY-%d" % self.count
def verify(self, msg, signature): def verify(self, msg, signature):
if signature[:5] != "SIGN(":
return False
if signature[5:-1] != msg:
return False
if signature[-1] != ")":
return False
return True return True
class FakePrivKey: class FakePrivKey:
@ -433,6 +443,14 @@ class FakeRetrieve(mutable.Retrieve):
vector.append(shares[shnum][offset:offset+length]) vector.append(shares[shnum][offset:offset+length])
return defer.succeed(response) return defer.succeed(response)
def _deserialize_pubkey(self, pubkey_s):
mo = re.search(r"^PUBKEY-(\d+)$", pubkey_s)
if not mo:
raise RuntimeError("mangled pubkey")
count = mo.group(1)
return FakePubKey(int(count))
class Roundtrip(unittest.TestCase): class Roundtrip(unittest.TestCase):
def setup_for_publish(self, num_peers): def setup_for_publish(self, num_peers):
@ -444,16 +462,15 @@ class Roundtrip(unittest.TestCase):
fn.create("") fn.create("")
p = FakePublish(fn) p = FakePublish(fn)
p._storage = s p._storage = s
return c, fn, p r = FakeRetrieve(fn)
r._storage = s
return c, s, fn, p, r
def test_basic(self): def test_basic(self):
c, fn, p = self.setup_for_publish(20) c, s, fn, p, r = self.setup_for_publish(20)
contents = "New contents go here" contents = "New contents go here"
d = p.publish(contents) d = p.publish(contents)
def _published(res): def _published(res):
# TODO: examine peers and check on their shares
r = FakeRetrieve(fn)
r._storage = p._storage
return r.retrieve() return r.retrieve()
d.addCallback(_published) d.addCallback(_published)
def _retrieved(new_contents): def _retrieved(new_contents):
@ -461,3 +478,136 @@ class Roundtrip(unittest.TestCase):
d.addCallback(_retrieved) d.addCallback(_retrieved)
return d return d
def flip_bit(self, original, byte_offset):
return (original[:byte_offset] +
chr(ord(original[byte_offset]) ^ 0x01) +
original[byte_offset+1:])
def shouldFail(self, expected_failure, which, substring,
callable, *args, **kwargs):
assert substring is None or isinstance(substring, str)
d = defer.maybeDeferred(callable, *args, **kwargs)
def done(res):
if isinstance(res, failure.Failure):
res.trap(expected_failure)
if substring:
self.failUnless(substring in str(res),
"substring '%s' not in '%s'"
% (substring, str(res)))
else:
self.fail("%s was supposed to raise %s, not get '%s'" %
(which, expected_failure, res))
d.addBoth(done)
return d
def _corrupt_all(self, offset, substring, refetch_pubkey=False,
should_succeed=False):
c, s, fn, p, r = self.setup_for_publish(20)
contents = "New contents go here"
d = p.publish(contents)
def _published(res):
if refetch_pubkey:
# clear the pubkey, to force a fetch
r._pubkey = None
for peerid in s._peers:
shares = s._peers[peerid]
for shnum in shares:
data = shares[shnum]
(version,
seqnum,
root_hash,
IV,
k, N, segsize, datalen,
o) = mutable.unpack_header(data)
if isinstance(offset, tuple):
offset1, offset2 = offset
else:
offset1 = offset
offset2 = 0
if offset1 == "pubkey":
real_offset = 107
elif offset1 in o:
real_offset = o[offset1]
else:
real_offset = offset1
real_offset = int(real_offset) + offset2
assert isinstance(real_offset, int), offset
shares[shnum] = self.flip_bit(data, real_offset)
d.addCallback(_published)
if should_succeed:
d.addCallback(lambda res: r.retrieve())
else:
d.addCallback(lambda res:
self.shouldFail(NotEnoughPeersError,
"_corrupt_all(offset=%s)" % (offset,),
substring,
r.retrieve))
return d
def test_corrupt_all_verbyte(self):
# when the version byte is not 0, we hit an assertion error in
# unpack_share().
return self._corrupt_all(0, "AssertionError")
def test_corrupt_all_seqnum(self):
# a corrupt sequence number will trigger a bad signature
return self._corrupt_all(1, "signature is invalid")
def test_corrupt_all_R(self):
# a corrupt root hash will trigger a bad signature
return self._corrupt_all(9, "signature is invalid")
def test_corrupt_all_IV(self):
# a corrupt salt/IV will trigger a bad signature
return self._corrupt_all(41, "signature is invalid")
def test_corrupt_all_k(self):
# a corrupt 'k' will trigger a bad signature
return self._corrupt_all(57, "signature is invalid")
def test_corrupt_all_N(self):
# a corrupt 'N' will trigger a bad signature
return self._corrupt_all(58, "signature is invalid")
def test_corrupt_all_segsize(self):
# a corrupt segsize will trigger a bad signature
return self._corrupt_all(59, "signature is invalid")
def test_corrupt_all_datalen(self):
# a corrupt data length will trigger a bad signature
return self._corrupt_all(67, "signature is invalid")
def test_corrupt_all_pubkey(self):
# a corrupt pubkey won't match the URI's fingerprint
return self._corrupt_all("pubkey", "pubkey doesn't match fingerprint",
refetch_pubkey=True)
def test_corrupt_all_sig(self):
# a corrupt signature is a bad one
# the signature runs from about [543:799], depending upon the length
# of the pubkey
return self._corrupt_all("signature", "signature is invalid",
refetch_pubkey=True)
def test_corrupt_all_share_hash_chain_number(self):
# a corrupt share hash chain entry will show up as a bad hash. If we
# mangle the first byte, that will look like a bad hash number,
# causing an IndexError
return self._corrupt_all("share_hash_chain", "corrupt hashes")
def test_corrupt_all_share_hash_chain_hash(self):
# a corrupt share hash chain entry will show up as a bad hash. If we
# mangle a few bytes in, that will look like a bad hash.
return self._corrupt_all(("share_hash_chain",4), "corrupt hashes")
def test_corrupt_all_block_hash_tree(self):
return self._corrupt_all("block_hash_tree", "block hash tree failure")
def test_corrupt_all_block(self):
return self._corrupt_all("share_data", "block hash tree failure")
def test_corrupt_all_encprivkey(self):
# a corrupted privkey won't even be noticed by the reader
return self._corrupt_all("enc_privkey", None, should_succeed=True)