From 053109b28bef7184df24446fc6915da80c544a51 Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Thu, 7 Jun 2007 19:32:29 -0700 Subject: [PATCH] add tests for bad/inconsistent plaintext/crypttext merkle tree hashes --- src/allmydata/download.py | 47 ++++++-- src/allmydata/test/test_encode.py | 176 +++++++++++++++++++++++++++--- 2 files changed, 197 insertions(+), 26 deletions(-) diff --git a/src/allmydata/download.py b/src/allmydata/download.py index f4c6fbee3..79bd2e721 100644 --- a/src/allmydata/download.py +++ b/src/allmydata/download.py @@ -277,6 +277,13 @@ class FileDownloader: self._thingA_data = None + self._fetch_failures = {"thingA": 0, + "plaintext_hashroot": 0, + "plaintext_hashtree": 0, + "crypttext_hashroot": 0, + "crypttext_hashtree": 0, + } + def start(self): log.msg("starting download [%s]" % idlib.b2a(self._storage_index)) @@ -345,6 +352,7 @@ class FileDownloader: def _validate(proposal, bucket): h = hashtree.thingA_hash(proposal) if h != self._thingA_hash: + self._fetch_failures["thingA"] += 1 msg = ("The copy of thingA we received from %s was bad" % bucket) raise BadThingAHashValue(msg) @@ -357,14 +365,17 @@ class FileDownloader: def _obtain_validated_thing(self, ignored, sources, name, methname, args, validatorfunc): if not sources: - raise NotEnoughPeersError("ran out of peers while fetching %s" % - name) + raise NotEnoughPeersError("started with zero peers while fetching " + "%s" % name) bucket = sources[0] sources = sources[1:] d = bucket.callRemote(methname, *args) d.addCallback(validatorfunc, bucket) def _bad(f): log.msg("%s from vbucket %s failed: %s" % (name, bucket, f)) # WEIRD + if not sources: + raise NotEnoughPeersError("ran out of peers, last error was %s" + % (f,)) # try again with a different one return self._obtain_validated_thing(None, sources, name, methname, args, validatorfunc) @@ -402,10 +413,20 @@ class FileDownloader: def _get_plaintext_hashtrees(self): def _validate_plaintext_hashtree(proposal, bucket): if proposal[0] != self._thingA_data['plaintext_root_hash']: + self._fetch_failures["plaintext_hashroot"] += 1 msg = ("The copy of the plaintext_root_hash we received from" " %s was bad" % bucket) raise BadPlaintextHashValue(msg) - self._plaintext_hashes = proposal + pt_hashtree = hashtree.IncompleteHashTree(self._total_segments) + pt_hashes = dict(list(enumerate(proposal))) + try: + pt_hashtree.set_hashes(pt_hashes) + except hashtree.BadHashError: + # the hashes they gave us were not self-consistent, even + # though the root matched what we saw in the thingA block + self._fetch_failures["plaintext_hashtree"] += 1 + raise + self._plaintext_hashtree = pt_hashtree d = self._obtain_validated_thing(None, self._thingA_sources, "plaintext_hashes", @@ -416,10 +437,19 @@ class FileDownloader: def _get_crypttext_hashtrees(self, res): def _validate_crypttext_hashtree(proposal, bucket): if proposal[0] != self._thingA_data['crypttext_root_hash']: + self._fetch_failures["crypttext_hashroot"] += 1 msg = ("The copy of the crypttext_root_hash we received from" " %s was bad" % bucket) raise BadCrypttextHashValue(msg) - self._crypttext_hashes = proposal + ct_hashtree = hashtree.IncompleteHashTree(self._total_segments) + ct_hashes = dict(list(enumerate(proposal))) + try: + ct_hashtree.set_hashes(ct_hashes) + except hashtree.BadHashError: + self._fetch_failures["crypttext_hashtree"] += 1 + raise + ct_hashtree.set_hashes(ct_hashes) + self._crypttext_hashtree = ct_hashtree d = self._obtain_validated_thing(None, self._thingA_sources, "crypttext_hashes", @@ -428,13 +458,8 @@ class FileDownloader: return d def _setup_hashtrees(self, res): - plaintext_hashtree = hashtree.IncompleteHashTree(self._total_segments) - plaintext_hashes = dict(list(enumerate(self._plaintext_hashes))) - plaintext_hashtree.set_hashes(plaintext_hashes) - crypttext_hashtree = hashtree.IncompleteHashTree(self._total_segments) - crypttext_hashes = dict(list(enumerate(self._crypttext_hashes))) - crypttext_hashtree.set_hashes(crypttext_hashes) - self._output.setup_hashtrees(plaintext_hashtree, crypttext_hashtree) + self._output.setup_hashtrees(self._plaintext_hashtree, + self._crypttext_hashtree) def _create_validated_buckets(self, ignored=None): diff --git a/src/allmydata/test/test_encode.py b/src/allmydata/test/test_encode.py index d108b9bfa..b39751aaa 100644 --- a/src/allmydata/test/test_encode.py +++ b/src/allmydata/test/test_encode.py @@ -96,7 +96,7 @@ class FakeBucketWriter: assert not self.closed self.closed = True - def flip_bit(self, good): + def flip_bit(self, good): # flips the last bit return good[:-1] + chr(ord(good[-1]) ^ 0x01) def get_block(self, blocknum): @@ -106,17 +106,20 @@ class FakeBucketWriter: return self.blocks[blocknum] def get_plaintext_hashes(self): - if self.mode == "bad plaintexthash": - hashes = self.plaintext_hashes[:] + hashes = self.plaintext_hashes[:] + if self.mode == "bad plaintext hashroot": + hashes[0] = self.flip_bit(hashes[0]) + if self.mode == "bad plaintext hash": hashes[1] = self.flip_bit(hashes[1]) - return hashes - return self.plaintext_hashes + return hashes + def get_crypttext_hashes(self): - if self.mode == "bad crypttexthash": - hashes = self.crypttext_hashes[:] + hashes = self.crypttext_hashes[:] + if self.mode == "bad crypttext hashroot": + hashes[0] = self.flip_bit(hashes[0]) + if self.mode == "bad crypttext hash": hashes[1] = self.flip_bit(hashes[1]) - return hashes - return self.crypttext_hashes + return hashes def get_block_hashes(self): if self.mode == "bad blockhash": @@ -136,6 +139,11 @@ class FakeBucketWriter: return [] return self.share_hashes + def get_thingA(self): + if self.mode == "bad thingA": + return self.flip_bit(self.thingA) + return self.thingA + def make_data(length): data = "happy happy joy joy" * 100 @@ -250,6 +258,7 @@ class Roundtrip(unittest.TestCase): datalen=76, max_segment_size=25, bucket_modes={}, + recover_mode="recover", ): if AVAILABLE_SHARES is None: AVAILABLE_SHARES = k_and_happy_and_n[2] @@ -257,10 +266,16 @@ class Roundtrip(unittest.TestCase): d = self.send(k_and_happy_and_n, AVAILABLE_SHARES, max_segment_size, bucket_modes, data) # that fires with (thingA_hash, e, shareholders) - d.addCallback(self.recover, AVAILABLE_SHARES) + if recover_mode == "recover": + d.addCallback(self.recover, AVAILABLE_SHARES) + elif recover_mode == "thingA": + d.addCallback(self.recover_with_thingA, AVAILABLE_SHARES) + else: + raise RuntimeError, "unknown recover_mode '%s'" % recover_mode # that fires with newdata - def _downloaded(newdata): + def _downloaded((newdata, fd)): self.failUnless(newdata == data) + return fd d.addCallback(_downloaded) return d @@ -305,8 +320,17 @@ class Roundtrip(unittest.TestCase): client = None target = download.Data() fd = download.FileDownloader(client, URI, target) - fd.check_verifierid = False - fd.check_fileid = False + + # we manually cycle the FileDownloader through a number of steps that + # would normally be sequenced by a Deferred chain in + # FileDownloader.start(), to give us more control over the process. + # In particular, by bypassing _get_all_shareholders, we skip + # permuted-peerlist selection. + for shnum, bucket in shareholders.items(): + if shnum < AVAILABLE_SHARES and bucket.closed: + fd.add_share_bucket(shnum, bucket) + fd._got_all_shareholders(None) + # grab a copy of thingA from one of the shareholders thingA = shareholders[0].thingA thingA_data = bencode.bdecode(thingA) @@ -321,13 +345,64 @@ class Roundtrip(unittest.TestCase): 'total_shares': e.num_shares, } fd._got_thingA(thingA_data) + # we skip _get_hashtrees here, and the lack of hashtree attributes + # will cause the download.Output object to skip the + # plaintext/crypttext merkle tree checks. We instruct the downloader + # to skip the full-file checks as well. + fd.check_verifierid = False + fd.check_fileid = False + + fd._create_validated_buckets(None) + d = fd._download_all_segments(None) + d.addCallback(fd._done) + def _done(newdata): + return (newdata, fd) + d.addCallback(_done) + return d + + def recover_with_thingA(self, (thingA_hash, e, shareholders), + AVAILABLE_SHARES): + URI = pack_uri(storage_index="S" * 20, + key=e.key, + thingA_hash=thingA_hash, + needed_shares=e.required_shares, + total_shares=e.num_shares, + size=e.file_size) + client = None + target = download.Data() + fd = download.FileDownloader(client, URI, target) + + # we manually cycle the FileDownloader through a number of steps that + # would normally be sequenced by a Deferred chain in + # FileDownloader.start(), to give us more control over the process. + # In particular, by bypassing _get_all_shareholders, we skip + # permuted-peerlist selection. for shnum, bucket in shareholders.items(): if shnum < AVAILABLE_SHARES and bucket.closed: fd.add_share_bucket(shnum, bucket) fd._got_all_shareholders(None) - fd._create_validated_buckets(None) - d = fd._download_all_segments(None) + + # ask shareholders for thingA as usual, validating the responses. + # Arrange for shareholders[0] to be the first, so we can selectively + # corrupt the data it returns. + fd._thingA_sources = shareholders.values() + fd._thingA_sources.remove(shareholders[0]) + fd._thingA_sources.insert(0, shareholders[0]) + # the thingA block contains plaintext/crypttext hash trees, but does + # not have a fileid or verifierid, so we have to disable those checks + fd.check_verifierid = False + fd.check_fileid = False + + d = fd._obtain_thingA(None) + d.addCallback(fd._got_thingA) + + d.addCallback(fd._get_hashtrees) + d.addCallback(fd._create_validated_buckets) + d.addCallback(fd._download_all_segments) d.addCallback(fd._done) + def _done(newdata): + return (newdata, fd) + d.addCallback(_done) return d def test_not_enough_shares(self): @@ -419,6 +494,77 @@ class Roundtrip(unittest.TestCase): for i in range(6, 10)]) return self.send_and_recover((4,8,10), bucket_modes=modemap) + def assertFetchFailureIn(self, fd, where): + expected = {"thingA": 0, + "plaintext_hashroot": 0, + "plaintext_hashtree": 0, + "crypttext_hashroot": 0, + "crypttext_hashtree": 0, + } + if where is not None: + expected[where] += 1 + self.failUnlessEqual(fd._fetch_failures, expected) + + def test_good_thingA(self): + # exercise recover_mode="thingA", just to make sure the test works + modemap = dict([(i, "good") for i in range(1)] + + [(i, "good") for i in range(1, 10)]) + d = self.send_and_recover((4,8,10), bucket_modes=modemap, + recover_mode="thingA") + d.addCallback(self.assertFetchFailureIn, None) + return d + + def test_bad_thingA(self): + # the first server has a bad thingA block, so we will fail over to a + # different server. + modemap = dict([(i, "bad thingA") for i in range(1)] + + [(i, "good") for i in range(1, 10)]) + d = self.send_and_recover((4,8,10), bucket_modes=modemap, + recover_mode="thingA") + d.addCallback(self.assertFetchFailureIn, "thingA") + return d + + def test_bad_plaintext_hashroot(self): + # the first server has a bad plaintext hashroot, so we will fail over + # to a different server. + modemap = dict([(i, "bad plaintext hashroot") for i in range(1)] + + [(i, "good") for i in range(1, 10)]) + d = self.send_and_recover((4,8,10), bucket_modes=modemap, + recover_mode="thingA") + d.addCallback(self.assertFetchFailureIn, "plaintext_hashroot") + return d + + def test_bad_crypttext_hashroot(self): + # the first server has a bad crypttext hashroot, so we will fail + # over to a different server. + modemap = dict([(i, "bad crypttext hashroot") for i in range(1)] + + [(i, "good") for i in range(1, 10)]) + d = self.send_and_recover((4,8,10), bucket_modes=modemap, + recover_mode="thingA") + d.addCallback(self.assertFetchFailureIn, "crypttext_hashroot") + return d + + def test_bad_plaintext_hashes(self): + # the first server has a bad plaintext hash block, so we will fail + # over to a different server. + modemap = dict([(i, "bad plaintext hash") for i in range(1)] + + [(i, "good") for i in range(1, 10)]) + d = self.send_and_recover((4,8,10), bucket_modes=modemap, + recover_mode="thingA") + d.addCallback(self.assertFetchFailureIn, "plaintext_hashtree") + return d + + def test_bad_crypttext_hashes(self): + # the first server has a bad crypttext hash block, so we will fail + # over to a different server. + modemap = dict([(i, "bad crypttext hash") for i in range(1)] + + [(i, "good") for i in range(1, 10)]) + d = self.send_and_recover((4,8,10), bucket_modes=modemap, + recover_mode="thingA") + d.addCallback(self.assertFetchFailureIn, "crypttext_hashtree") + return d + + def test_bad_sharehashes_failure(self): # the first 7 servers have bad block hashes, so the sharehash tree # will not validate, and the download will fail