From d6f2dbbac7edbd09889606e84d8c49f7f040921e Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Tue, 13 Nov 2007 23:08:15 -0700 Subject: [PATCH] mutable: handle bad hashes, improve test coverage, rearrange slightly to facilitate these --- src/allmydata/mutable.py | 144 ++++++++++++++++++++--------- src/allmydata/scripts/debug.py | 3 +- src/allmydata/test/test_mutable.py | 11 +-- src/allmydata/test/test_system.py | 98 +++++++++++++++----- 4 files changed, 183 insertions(+), 73 deletions(-) diff --git a/src/allmydata/mutable.py b/src/allmydata/mutable.py index fadcea5a4..ba6b68e4a 100644 --- a/src/allmydata/mutable.py +++ b/src/allmydata/mutable.py @@ -31,10 +31,11 @@ class UncoordinatedWriteError(Exception): class CorruptShareError(Exception): def __init__(self, peerid, shnum, reason): + self.args = (peerid, shnum, reason) self.peerid = peerid self.shnum = shnum self.reason = reason - def __repr__(self): + def __str__(self): short_peerid = idlib.nodeid_b2a(self.peerid)[:8] return "H32s", i, share_hash_chain[i]) + for i in sorted(share_hash_chain.keys())]) + for h in block_hash_tree: + assert len(h) == 32 + block_hash_tree_s = "".join(block_hash_tree) + + offsets = pack_offsets(len(verification_key), + len(signature), + len(share_hash_chain_s), + len(block_hash_tree_s), + len(share_data), + len(encprivkey)) + final_share = "".join([prefix, + offsets, + verification_key, + signature, + share_hash_chain_s, + block_hash_tree_s, + share_data, + encprivkey]) + return final_share + + class Retrieve: def __init__(self, filenode): self._node = filenode @@ -224,6 +252,8 @@ class Retrieve: # 7: if we discover corrupt shares during the reconstruction process, # remove that share from the sharemap. and start step#6 again. + self.log("starting retrieval") + initial_query_count = 5 self._read_size = 2000 @@ -245,11 +275,12 @@ class Retrieve: # continuing through the last byte of sharedata. self._valid_versions = {} - # self._valid_shares is a set (peerid,data) tuples. Each time we - # examine the hash chains inside a share and validate them against a - # signed root_hash, we add the share to self._valid_shares . We use - # this to avoid re-checking the hashes over and over again. - self._valid_shares = set() + # self._valid_shares is a dict mapping (peerid,data) tuples to + # validated sharedata strings. Each time we examine the hash chains + # inside a share and validate them against a signed root_hash, we add + # the share to self._valid_shares . We use this to avoid re-checking + # the hashes over and over again. + self._valid_shares = {} self._done_deferred = defer.Deferred() @@ -332,6 +363,8 @@ class Retrieve: for shnum,datav in datavs.items(): data = datav[0] + self.log("_got_results: got shnum #%d from peerid %s" + % (shnum, idlib.shortnodeid_b2a(peerid))) (seqnum, root_hash, IV, k, N, segsize, datalength, pubkey_s, signature, prefix) = unpack_prefix_and_signature(data) @@ -339,7 +372,7 @@ class Retrieve: fingerprint = hashutil.ssk_pubkey_fingerprint_hash(pubkey_s) if fingerprint != self._node._fingerprint: # bad share - raise CorruptShareError(peerid, + raise CorruptShareError(peerid, shnum, "pubkey doesn't match fingerprint") self._pubkey = self._deserialize_pubkey(pubkey_s) self._node._populate_pubkey(self._pubkey) @@ -349,11 +382,11 @@ class Retrieve: # it's a new pair. Verify the signature. valid = self._pubkey.verify(prefix, signature) if not valid: - raise CorruptShareError(peerid, + raise CorruptShareError(peerid, shnum, "signature is invalid") # ok, it's a valid verinfo. Add it to the list of validated # versions. - self.log("found valid version %d-%s from %s-sh%d: %d-%d/%d/%d" + self.log(" found valid version %d-%s from %s-sh%d: %d-%d/%d/%d" % (seqnum, idlib.b2a(root_hash)[:4], idlib.shortnodeid_b2a(peerid), shnum, k, N, segsize, datalength)) @@ -372,6 +405,7 @@ class Retrieve: # there's enough data present. If not, raise NeedMoreDataError, # which will trigger a re-fetch. _ignored = unpack_share(data) + self.log(" found enough data to add share contents") self._valid_versions[verinfo][1].add(shnum, (peerid, data)) @@ -391,9 +425,11 @@ class Retrieve: self._bad_peerids.add(peerid) short_sid = idlib.b2a(self._storage_index)[:6] if f.check(CorruptShareError): - self.log("WEIRD: bad share for %s: %s" % (short_sid, f)) + self.log("WEIRD: bad share for %s: %s %s" % (short_sid, f, + f.value)) else: - self.log("WEIRD: other error for %s: %s" % (short_sid, f)) + self.log("WEIRD: other error for %s: %s %s" % (short_sid, f, + f.value)) def _check_for_done(self, res): if not self._running: @@ -422,13 +458,17 @@ class Retrieve: def _problem(f): self._last_failure = f if f.check(CorruptShareError): - # log(WEIRD) + self.log("WEIRD: saw corrupt share, rescheduling") # _attempt_decode is responsible for removing the bad # share, so we can just try again - eventually(self._check_for_done) + eventually(self._check_for_done, None) return return f d.addCallbacks(self._done, _problem) + # TODO: create an errback-routing mechanism to make sure that + # weird coding errors will cause the retrieval to fail rather + # than hanging forever. Any otherwise-unhandled exceptions + # should follow this path. return # we don't have enough shares yet. Should we send out more queries? @@ -478,13 +518,26 @@ class Retrieve: # sharemap is a dict which maps shnum to [(peerid,data)..] sets. (seqnum, root_hash, IV, segsize, datalength) = verinfo + assert len(sharemap) >= self._required_shares, len(sharemap) + + shares_s = [] + for shnum in sorted(sharemap.keys()): + for shareinfo in sharemap[shnum]: + shares_s.append("#%d" % shnum) + shares_s = ",".join(shares_s) + self.log("_attempt_decode: version %d-%s, shares: %s" % + (seqnum, idlib.b2a(root_hash)[:4], shares_s)) + # first, validate each share that we haven't validated yet. We use # self._valid_shares to remember which ones we've already checked. shares = {} for shnum, shareinfos in sharemap.items(): + assert len(shareinfos) > 0 for shareinfo in shareinfos: + # have we already validated the hashes on this share? if shareinfo not in self._valid_shares: + # nope: must check the hashes and extract the actual data (peerid,data) = shareinfo try: # The (seqnum+root_hash+IV) tuple for this share was @@ -500,24 +553,36 @@ class Retrieve: # validate the prefix on all shares) from using # anything else in the share. validator = self._validate_share_and_extract_data - sharedata = validator(root_hash, shnum, data) + sharedata = validator(peerid, root_hash, shnum, data) assert isinstance(sharedata, str) except CorruptShareError, e: self.log("WEIRD: share was corrupt: %s" % e) sharemap[shnum].discard(shareinfo) + if not sharemap[shnum]: + # remove the key so the test in _check_for_done + # can accurately decide that we don't have enough + # shares to try again right now. + del sharemap[shnum] # If there are enough remaining shares, # _check_for_done() will try again raise - self._valid_shares.add(shareinfo) - shares[shnum] = sharedata - # at this point, all shares in the sharemap are valid, and they're - # all for the same seqnum+root_hash version, so it's now down to - # doing FEC and decrypt. + # share is valid: remember it so we won't need to check + # (or extract) it again + self._valid_shares[shareinfo] = sharedata + + # the share is now in _valid_shares, so just copy over the + # sharedata + shares[shnum] = self._valid_shares[shareinfo] + + # now that the big loop is done, all shares in the sharemap are + # valid, and they're all for the same seqnum+root_hash version, so + # it's now down to doing FEC and decrypt. + assert len(shares) >= self._required_shares, len(shares) d = defer.maybeDeferred(self._decode, shares, segsize, datalength) d.addCallback(self._decrypt, IV, seqnum, root_hash) return d - def _validate_share_and_extract_data(self, root_hash, shnum, data): + def _validate_share_and_extract_data(self, peerid, root_hash, shnum, data): # 'data' is the whole SMDF share self.log("_validate_share_and_extract_data[%d]" % shnum) assert data[0] == "\x00" @@ -531,7 +596,7 @@ class Retrieve: leaves = [hashutil.block_hash(share_data)] t = hashtree.HashTree(leaves) if list(t) != block_hash_tree: - raise CorruptShareError("block hash tree failure") + raise CorruptShareError(peerid, shnum, "block hash tree failure") share_hash_leaf = t[0] # t2 = hashtree.IncompleteHashTree() # TODO: use shnum, share_hash_leaf, share_hash_chain to compare against @@ -553,6 +618,7 @@ class Retrieve: shareids.append(shareid) shares.append(share) + assert len(shareids) >= self._required_shares, len(shareids) # zfec really doesn't want extra shares shareids = shareids[:self._required_shares] shares = shares[:self._required_shares] @@ -650,7 +716,7 @@ class Publish: # 4a: may need to run recovery algorithm # 5: when enough responses are back, we're done - self.log("got enough peers") + self.log("starting publish, data is %r" % (newdata,)) self._storage_index = self._node.get_storage_index() self._writekey = self._node.get_writekey() @@ -747,9 +813,12 @@ class Publish: def _got_query_results(self, datavs, peerid, permutedid, reachable_peers, current_share_peers): + self.log("_got_query_results") + assert isinstance(datavs, dict) reachable_peers[peerid] = permutedid for shnum, datav in datavs.items(): + self.log(" peer has shnum %d" % shnum) assert len(datav) == 1 data = datav[0] # We want (seqnum, root_hash, IV) from all servers to know what @@ -999,29 +1068,14 @@ class Publish: final_shares = {} for shnum in range(total_shares): - shc = share_hash_chain[shnum] - share_hash_chain_s = "".join([struct.pack(">H32s", i, shc[i]) - for i in sorted(shc.keys())]) - bht = block_hash_trees[shnum] - for h in bht: - assert len(h) == 32 - block_hash_tree_s = "".join(bht) - share_data = all_shares[shnum] - offsets = pack_offsets(len(verification_key), - len(signature), - len(share_hash_chain_s), - len(block_hash_tree_s), - len(share_data), - len(encprivkey)) - - final_shares[shnum] = "".join([prefix, - offsets, - verification_key, - signature, - share_hash_chain_s, - block_hash_tree_s, - share_data, - encprivkey]) + final_share = pack_share(prefix, + verification_key, + signature, + share_hash_chain[shnum], + block_hash_trees[shnum], + all_shares[shnum], + encprivkey) + final_shares[shnum] = final_share return (seqnum, root_hash, final_shares, target_info) diff --git a/src/allmydata/scripts/debug.py b/src/allmydata/scripts/debug.py index 9c34d91a0..aa3027993 100644 --- a/src/allmydata/scripts/debug.py +++ b/src/allmydata/scripts/debug.py @@ -179,7 +179,8 @@ def dump_SDMF_share(offset, length, config, out, err): print >>out, " total_shares: %d" % N print >>out, " segsize: %d" % segsize print >>out, " datalen: %d" % datalen - share_hash_ids = ",".join([str(hid) for (hid,hash) in share_hash_chain]) + share_hash_ids = ",".join(sorted([str(hid) + for hid in share_hash_chain.keys()])) print >>out, " share_hash_chain: %s" % share_hash_ids print >>out, " block_hash_tree: %d nodes" % len(block_hash_tree) diff --git a/src/allmydata/test/test_mutable.py b/src/allmydata/test/test_mutable.py index dccdaa0dc..ca4e7ae6b 100644 --- a/src/allmydata/test/test_mutable.py +++ b/src/allmydata/test/test_mutable.py @@ -264,13 +264,12 @@ class Publish(unittest.TestCase): k, N, segsize, datalen) self.failUnlessEqual(signature, FakePrivKey(0).sign(sig_material)) - self.failUnless(isinstance(share_hash_chain, list)) + self.failUnless(isinstance(share_hash_chain, dict)) self.failUnlessEqual(len(share_hash_chain), 4) # ln2(10)++ - for i in share_hash_chain: - self.failUnless(isinstance(i, tuple)) - self.failUnless(isinstance(i[0], int)) - self.failUnless(isinstance(i[1], str)) - self.failUnlessEqual(len(i[1]), 32) + for shnum,share_hash in share_hash_chain.items(): + self.failUnless(isinstance(shnum, int)) + self.failUnless(isinstance(share_hash, str)) + self.failUnlessEqual(len(share_hash), 32) self.failUnless(isinstance(block_hash_tree, list)) self.failUnlessEqual(len(block_hash_tree), 1) # very small tree self.failUnlessEqual(IV, "IV"*8) diff --git a/src/allmydata/test/test_system.py b/src/allmydata/test/test_system.py index 7f56c4118..d404f9ba7 100644 --- a/src/allmydata/test/test_system.py +++ b/src/allmydata/test/test_system.py @@ -6,7 +6,7 @@ from twisted.trial import unittest from twisted.internet import defer, reactor from twisted.internet import threads # CLI tests use deferToThread from twisted.application import service -from allmydata import client, uri, download, upload +from allmydata import client, uri, download, upload, storage, mutable from allmydata.introducer import IntroducerNode from allmydata.util import deferredutil, fileutil, idlib, mathutil, testutil from allmydata.scripts import runner @@ -237,6 +237,72 @@ class SystemTest(testutil.SignalMixin, unittest.TestCase): return d test_upload_and_download.timeout = 4800 + def _find_shares(self, basedir): + shares = [] + for (dirpath, dirnames, filenames) in os.walk(basedir): + if "storage" not in dirpath: + continue + if not filenames: + continue + pieces = dirpath.split(os.sep) + if pieces[-3] == "storage" and pieces[-2] == "shares": + # we're sitting in .../storage/shares/$SINDEX , and there + # are sharefiles here + assert pieces[-4].startswith("client") + client_num = int(pieces[-4][-1]) + storage_index_s = pieces[-1] + storage_index = idlib.a2b(storage_index_s) + for sharename in filenames: + shnum = int(sharename) + filename = os.path.join(dirpath, sharename) + data = (client_num, storage_index, filename, shnum) + shares.append(data) + if not shares: + self.fail("unable to find any share files in %s" % basedir) + return shares + + def _corrupt_mutable_share(self, filename, which): + msf = storage.MutableShareFile(filename) + datav = msf.readv([ (0, 1000000) ]) + final_share = datav[0] + assert len(final_share) < 1000000 # ought to be truncated + pieces = mutable.unpack_share(final_share) + (seqnum, root_hash, IV, k, N, segsize, datalen, + verification_key, signature, share_hash_chain, block_hash_tree, + share_data, enc_privkey) = pieces + + if which == "seqnum": + seqnum = seqnum + 15 + elif which == "R": + root_hash = self.flip_bit(root_hash) + elif which == "IV": + IV = self.flip_bit(IV) + elif which == "segsize": + segsize = segsize + 15 + elif which == "pubkey": + verification_key = self.flip_bit(verification_key) + elif which == "signature": + signature = self.flip_bit(signature) + elif which == "share_hash_chain": + nodenum = share_hash_chain.keys()[0] + share_hash_chain[nodenum] = self.flip_bit(share_hash_chain[nodenum]) + elif which == "block_hash_tree": + block_hash_tree[-1] = self.flip_bit(block_hash_tree[-1]) + elif which == "share_data": + share_data = self.flip_bit(share_data) + elif which == "encprivkey": + enc_privkey = self.flip_bit(enc_privkey) + + prefix = mutable.pack_prefix(seqnum, root_hash, IV, k, N, + segsize, datalen) + final_share = mutable.pack_share(prefix, + verification_key, + signature, + share_hash_chain, + block_hash_tree, + share_data, + enc_privkey) + msf.writev( [(0, final_share)], None) def test_mutable(self): self.basedir = "system/SystemTest/test_mutable" @@ -260,22 +326,8 @@ class SystemTest(testutil.SignalMixin, unittest.TestCase): def _test_debug(res): # find a share. It is important to run this while there is only # one slot in the grid. - for (dirpath, dirnames, filenames) in os.walk(self.basedir): - if "storage" not in dirpath: - continue - if not filenames: - continue - pieces = dirpath.split(os.sep) - if pieces[-3] == "storage" and pieces[-2] == "shares": - # we're sitting in .../storage/shares/$SINDEX , and there - # are sharefiles here - assert pieces[-4].startswith("client") - client_num = int(pieces[-4][-1]) - filename = os.path.join(dirpath, filenames[0]) - break - else: - self.fail("unable to find any share files in %s" - % self.basedir) + shares = self._find_shares(self.basedir) + (client_num, storage_index, filename, shnum) = shares[0] log.msg("test_system.SystemTest.test_mutable._test_debug using %s" % filename) log.msg(" for clients[%d]" % client_num) @@ -367,6 +419,7 @@ class SystemTest(testutil.SignalMixin, unittest.TestCase): uri = self._mutable_node_1.get_uri() newnode1 = self.clients[2].create_node_from_uri(uri) newnode2 = self.clients[3].create_node_from_uri(uri) + self._newnode3 = self.clients[3].create_node_from_uri(uri) log.msg("starting replace2") d1 = newnode1.replace(NEWERDATA, wait_for_numpeers=self.numclients) d1.addCallback(lambda res: newnode2.download_to_data()) @@ -376,13 +429,13 @@ class SystemTest(testutil.SignalMixin, unittest.TestCase): def _check_download_5(res): log.msg("finished replace2") self.failUnlessEqual(res, NEWERDATA) - # Make sure we can create empty files -- this can screw up the - # segsize math. - d1 = self.clients[2].create_mutable_file("", wait_for_numpeers=self.numclients) + # make sure we can create empty files, this usually screws up the + # segsize math + d1 = self.clients[2].create_mutable_file("") d1.addCallback(lambda newnode: newnode.download_to_data()) d1.addCallback(lambda res: self.failUnlessEqual("", res)) return d1 - d.addCallback(_check_download_5) + d.addCallback(_check_empty_file) d.addCallback(lambda res: self.clients[0].create_empty_dirnode(wait_for_numpeers=self.numclients)) def _created_dirnode(dnode): @@ -394,6 +447,9 @@ class SystemTest(testutil.SignalMixin, unittest.TestCase): d1.addCallback(lambda res: dnode.set_node("see recursive", dnode, wait_for_numpeers=self.numclients)) d1.addCallback(lambda res: dnode.has_child("see recursive")) d1.addCallback(lambda answer: self.failUnlessEqual(answer, True)) + d1.addCallback(lambda res: dnode.build_manifest()) + d1.addCallback(lambda manifest: + self.failUnlessEqual(len(manifest), 1)) return d1 d.addCallback(_created_dirnode)