add tests for bad/inconsistent plaintext/crypttext merkle tree hashes

This commit is contained in:
Brian Warner 2007-06-07 19:32:29 -07:00
parent 4f001bedb3
commit 053109b28b
2 changed files with 197 additions and 26 deletions

View File

@ -277,6 +277,13 @@ class FileDownloader:
self._thingA_data = None self._thingA_data = None
self._fetch_failures = {"thingA": 0,
"plaintext_hashroot": 0,
"plaintext_hashtree": 0,
"crypttext_hashroot": 0,
"crypttext_hashtree": 0,
}
def start(self): def start(self):
log.msg("starting download [%s]" % idlib.b2a(self._storage_index)) log.msg("starting download [%s]" % idlib.b2a(self._storage_index))
@ -345,6 +352,7 @@ class FileDownloader:
def _validate(proposal, bucket): def _validate(proposal, bucket):
h = hashtree.thingA_hash(proposal) h = hashtree.thingA_hash(proposal)
if h != self._thingA_hash: if h != self._thingA_hash:
self._fetch_failures["thingA"] += 1
msg = ("The copy of thingA we received from %s was bad" % msg = ("The copy of thingA we received from %s was bad" %
bucket) bucket)
raise BadThingAHashValue(msg) raise BadThingAHashValue(msg)
@ -357,14 +365,17 @@ class FileDownloader:
def _obtain_validated_thing(self, ignored, sources, name, methname, args, def _obtain_validated_thing(self, ignored, sources, name, methname, args,
validatorfunc): validatorfunc):
if not sources: if not sources:
raise NotEnoughPeersError("ran out of peers while fetching %s" % raise NotEnoughPeersError("started with zero peers while fetching "
name) "%s" % name)
bucket = sources[0] bucket = sources[0]
sources = sources[1:] sources = sources[1:]
d = bucket.callRemote(methname, *args) d = bucket.callRemote(methname, *args)
d.addCallback(validatorfunc, bucket) d.addCallback(validatorfunc, bucket)
def _bad(f): def _bad(f):
log.msg("%s from vbucket %s failed: %s" % (name, bucket, f)) # WEIRD log.msg("%s from vbucket %s failed: %s" % (name, bucket, f)) # WEIRD
if not sources:
raise NotEnoughPeersError("ran out of peers, last error was %s"
% (f,))
# try again with a different one # try again with a different one
return self._obtain_validated_thing(None, sources, name, return self._obtain_validated_thing(None, sources, name,
methname, args, validatorfunc) methname, args, validatorfunc)
@ -402,10 +413,20 @@ class FileDownloader:
def _get_plaintext_hashtrees(self): def _get_plaintext_hashtrees(self):
def _validate_plaintext_hashtree(proposal, bucket): def _validate_plaintext_hashtree(proposal, bucket):
if proposal[0] != self._thingA_data['plaintext_root_hash']: if proposal[0] != self._thingA_data['plaintext_root_hash']:
self._fetch_failures["plaintext_hashroot"] += 1
msg = ("The copy of the plaintext_root_hash we received from" msg = ("The copy of the plaintext_root_hash we received from"
" %s was bad" % bucket) " %s was bad" % bucket)
raise BadPlaintextHashValue(msg) raise BadPlaintextHashValue(msg)
self._plaintext_hashes = proposal pt_hashtree = hashtree.IncompleteHashTree(self._total_segments)
pt_hashes = dict(list(enumerate(proposal)))
try:
pt_hashtree.set_hashes(pt_hashes)
except hashtree.BadHashError:
# the hashes they gave us were not self-consistent, even
# though the root matched what we saw in the thingA block
self._fetch_failures["plaintext_hashtree"] += 1
raise
self._plaintext_hashtree = pt_hashtree
d = self._obtain_validated_thing(None, d = self._obtain_validated_thing(None,
self._thingA_sources, self._thingA_sources,
"plaintext_hashes", "plaintext_hashes",
@ -416,10 +437,19 @@ class FileDownloader:
def _get_crypttext_hashtrees(self, res): def _get_crypttext_hashtrees(self, res):
def _validate_crypttext_hashtree(proposal, bucket): def _validate_crypttext_hashtree(proposal, bucket):
if proposal[0] != self._thingA_data['crypttext_root_hash']: if proposal[0] != self._thingA_data['crypttext_root_hash']:
self._fetch_failures["crypttext_hashroot"] += 1
msg = ("The copy of the crypttext_root_hash we received from" msg = ("The copy of the crypttext_root_hash we received from"
" %s was bad" % bucket) " %s was bad" % bucket)
raise BadCrypttextHashValue(msg) raise BadCrypttextHashValue(msg)
self._crypttext_hashes = proposal ct_hashtree = hashtree.IncompleteHashTree(self._total_segments)
ct_hashes = dict(list(enumerate(proposal)))
try:
ct_hashtree.set_hashes(ct_hashes)
except hashtree.BadHashError:
self._fetch_failures["crypttext_hashtree"] += 1
raise
ct_hashtree.set_hashes(ct_hashes)
self._crypttext_hashtree = ct_hashtree
d = self._obtain_validated_thing(None, d = self._obtain_validated_thing(None,
self._thingA_sources, self._thingA_sources,
"crypttext_hashes", "crypttext_hashes",
@ -428,13 +458,8 @@ class FileDownloader:
return d return d
def _setup_hashtrees(self, res): def _setup_hashtrees(self, res):
plaintext_hashtree = hashtree.IncompleteHashTree(self._total_segments) self._output.setup_hashtrees(self._plaintext_hashtree,
plaintext_hashes = dict(list(enumerate(self._plaintext_hashes))) self._crypttext_hashtree)
plaintext_hashtree.set_hashes(plaintext_hashes)
crypttext_hashtree = hashtree.IncompleteHashTree(self._total_segments)
crypttext_hashes = dict(list(enumerate(self._crypttext_hashes)))
crypttext_hashtree.set_hashes(crypttext_hashes)
self._output.setup_hashtrees(plaintext_hashtree, crypttext_hashtree)
def _create_validated_buckets(self, ignored=None): def _create_validated_buckets(self, ignored=None):

View File

@ -96,7 +96,7 @@ class FakeBucketWriter:
assert not self.closed assert not self.closed
self.closed = True self.closed = True
def flip_bit(self, good): def flip_bit(self, good): # flips the last bit
return good[:-1] + chr(ord(good[-1]) ^ 0x01) return good[:-1] + chr(ord(good[-1]) ^ 0x01)
def get_block(self, blocknum): def get_block(self, blocknum):
@ -106,17 +106,20 @@ class FakeBucketWriter:
return self.blocks[blocknum] return self.blocks[blocknum]
def get_plaintext_hashes(self): def get_plaintext_hashes(self):
if self.mode == "bad plaintexthash": hashes = self.plaintext_hashes[:]
hashes = self.plaintext_hashes[:] if self.mode == "bad plaintext hashroot":
hashes[0] = self.flip_bit(hashes[0])
if self.mode == "bad plaintext hash":
hashes[1] = self.flip_bit(hashes[1]) hashes[1] = self.flip_bit(hashes[1])
return hashes return hashes
return self.plaintext_hashes
def get_crypttext_hashes(self): def get_crypttext_hashes(self):
if self.mode == "bad crypttexthash": hashes = self.crypttext_hashes[:]
hashes = self.crypttext_hashes[:] if self.mode == "bad crypttext hashroot":
hashes[0] = self.flip_bit(hashes[0])
if self.mode == "bad crypttext hash":
hashes[1] = self.flip_bit(hashes[1]) hashes[1] = self.flip_bit(hashes[1])
return hashes return hashes
return self.crypttext_hashes
def get_block_hashes(self): def get_block_hashes(self):
if self.mode == "bad blockhash": if self.mode == "bad blockhash":
@ -136,6 +139,11 @@ class FakeBucketWriter:
return [] return []
return self.share_hashes return self.share_hashes
def get_thingA(self):
if self.mode == "bad thingA":
return self.flip_bit(self.thingA)
return self.thingA
def make_data(length): def make_data(length):
data = "happy happy joy joy" * 100 data = "happy happy joy joy" * 100
@ -250,6 +258,7 @@ class Roundtrip(unittest.TestCase):
datalen=76, datalen=76,
max_segment_size=25, max_segment_size=25,
bucket_modes={}, bucket_modes={},
recover_mode="recover",
): ):
if AVAILABLE_SHARES is None: if AVAILABLE_SHARES is None:
AVAILABLE_SHARES = k_and_happy_and_n[2] AVAILABLE_SHARES = k_and_happy_and_n[2]
@ -257,10 +266,16 @@ class Roundtrip(unittest.TestCase):
d = self.send(k_and_happy_and_n, AVAILABLE_SHARES, d = self.send(k_and_happy_and_n, AVAILABLE_SHARES,
max_segment_size, bucket_modes, data) max_segment_size, bucket_modes, data)
# that fires with (thingA_hash, e, shareholders) # that fires with (thingA_hash, e, shareholders)
d.addCallback(self.recover, AVAILABLE_SHARES) if recover_mode == "recover":
d.addCallback(self.recover, AVAILABLE_SHARES)
elif recover_mode == "thingA":
d.addCallback(self.recover_with_thingA, AVAILABLE_SHARES)
else:
raise RuntimeError, "unknown recover_mode '%s'" % recover_mode
# that fires with newdata # that fires with newdata
def _downloaded(newdata): def _downloaded((newdata, fd)):
self.failUnless(newdata == data) self.failUnless(newdata == data)
return fd
d.addCallback(_downloaded) d.addCallback(_downloaded)
return d return d
@ -305,8 +320,17 @@ class Roundtrip(unittest.TestCase):
client = None client = None
target = download.Data() target = download.Data()
fd = download.FileDownloader(client, URI, target) fd = download.FileDownloader(client, URI, target)
fd.check_verifierid = False
fd.check_fileid = False # we manually cycle the FileDownloader through a number of steps that
# would normally be sequenced by a Deferred chain in
# FileDownloader.start(), to give us more control over the process.
# In particular, by bypassing _get_all_shareholders, we skip
# permuted-peerlist selection.
for shnum, bucket in shareholders.items():
if shnum < AVAILABLE_SHARES and bucket.closed:
fd.add_share_bucket(shnum, bucket)
fd._got_all_shareholders(None)
# grab a copy of thingA from one of the shareholders # grab a copy of thingA from one of the shareholders
thingA = shareholders[0].thingA thingA = shareholders[0].thingA
thingA_data = bencode.bdecode(thingA) thingA_data = bencode.bdecode(thingA)
@ -321,13 +345,64 @@ class Roundtrip(unittest.TestCase):
'total_shares': e.num_shares, 'total_shares': e.num_shares,
} }
fd._got_thingA(thingA_data) fd._got_thingA(thingA_data)
# we skip _get_hashtrees here, and the lack of hashtree attributes
# will cause the download.Output object to skip the
# plaintext/crypttext merkle tree checks. We instruct the downloader
# to skip the full-file checks as well.
fd.check_verifierid = False
fd.check_fileid = False
fd._create_validated_buckets(None)
d = fd._download_all_segments(None)
d.addCallback(fd._done)
def _done(newdata):
return (newdata, fd)
d.addCallback(_done)
return d
def recover_with_thingA(self, (thingA_hash, e, shareholders),
AVAILABLE_SHARES):
URI = pack_uri(storage_index="S" * 20,
key=e.key,
thingA_hash=thingA_hash,
needed_shares=e.required_shares,
total_shares=e.num_shares,
size=e.file_size)
client = None
target = download.Data()
fd = download.FileDownloader(client, URI, target)
# we manually cycle the FileDownloader through a number of steps that
# would normally be sequenced by a Deferred chain in
# FileDownloader.start(), to give us more control over the process.
# In particular, by bypassing _get_all_shareholders, we skip
# permuted-peerlist selection.
for shnum, bucket in shareholders.items(): for shnum, bucket in shareholders.items():
if shnum < AVAILABLE_SHARES and bucket.closed: if shnum < AVAILABLE_SHARES and bucket.closed:
fd.add_share_bucket(shnum, bucket) fd.add_share_bucket(shnum, bucket)
fd._got_all_shareholders(None) fd._got_all_shareholders(None)
fd._create_validated_buckets(None)
d = fd._download_all_segments(None) # ask shareholders for thingA as usual, validating the responses.
# Arrange for shareholders[0] to be the first, so we can selectively
# corrupt the data it returns.
fd._thingA_sources = shareholders.values()
fd._thingA_sources.remove(shareholders[0])
fd._thingA_sources.insert(0, shareholders[0])
# the thingA block contains plaintext/crypttext hash trees, but does
# not have a fileid or verifierid, so we have to disable those checks
fd.check_verifierid = False
fd.check_fileid = False
d = fd._obtain_thingA(None)
d.addCallback(fd._got_thingA)
d.addCallback(fd._get_hashtrees)
d.addCallback(fd._create_validated_buckets)
d.addCallback(fd._download_all_segments)
d.addCallback(fd._done) d.addCallback(fd._done)
def _done(newdata):
return (newdata, fd)
d.addCallback(_done)
return d return d
def test_not_enough_shares(self): def test_not_enough_shares(self):
@ -419,6 +494,77 @@ class Roundtrip(unittest.TestCase):
for i in range(6, 10)]) for i in range(6, 10)])
return self.send_and_recover((4,8,10), bucket_modes=modemap) return self.send_and_recover((4,8,10), bucket_modes=modemap)
def assertFetchFailureIn(self, fd, where):
expected = {"thingA": 0,
"plaintext_hashroot": 0,
"plaintext_hashtree": 0,
"crypttext_hashroot": 0,
"crypttext_hashtree": 0,
}
if where is not None:
expected[where] += 1
self.failUnlessEqual(fd._fetch_failures, expected)
def test_good_thingA(self):
# exercise recover_mode="thingA", just to make sure the test works
modemap = dict([(i, "good") for i in range(1)] +
[(i, "good") for i in range(1, 10)])
d = self.send_and_recover((4,8,10), bucket_modes=modemap,
recover_mode="thingA")
d.addCallback(self.assertFetchFailureIn, None)
return d
def test_bad_thingA(self):
# the first server has a bad thingA block, so we will fail over to a
# different server.
modemap = dict([(i, "bad thingA") for i in range(1)] +
[(i, "good") for i in range(1, 10)])
d = self.send_and_recover((4,8,10), bucket_modes=modemap,
recover_mode="thingA")
d.addCallback(self.assertFetchFailureIn, "thingA")
return d
def test_bad_plaintext_hashroot(self):
# the first server has a bad plaintext hashroot, so we will fail over
# to a different server.
modemap = dict([(i, "bad plaintext hashroot") for i in range(1)] +
[(i, "good") for i in range(1, 10)])
d = self.send_and_recover((4,8,10), bucket_modes=modemap,
recover_mode="thingA")
d.addCallback(self.assertFetchFailureIn, "plaintext_hashroot")
return d
def test_bad_crypttext_hashroot(self):
# the first server has a bad crypttext hashroot, so we will fail
# over to a different server.
modemap = dict([(i, "bad crypttext hashroot") for i in range(1)] +
[(i, "good") for i in range(1, 10)])
d = self.send_and_recover((4,8,10), bucket_modes=modemap,
recover_mode="thingA")
d.addCallback(self.assertFetchFailureIn, "crypttext_hashroot")
return d
def test_bad_plaintext_hashes(self):
# the first server has a bad plaintext hash block, so we will fail
# over to a different server.
modemap = dict([(i, "bad plaintext hash") for i in range(1)] +
[(i, "good") for i in range(1, 10)])
d = self.send_and_recover((4,8,10), bucket_modes=modemap,
recover_mode="thingA")
d.addCallback(self.assertFetchFailureIn, "plaintext_hashtree")
return d
def test_bad_crypttext_hashes(self):
# the first server has a bad crypttext hash block, so we will fail
# over to a different server.
modemap = dict([(i, "bad crypttext hash") for i in range(1)] +
[(i, "good") for i in range(1, 10)])
d = self.send_and_recover((4,8,10), bucket_modes=modemap,
recover_mode="thingA")
d.addCallback(self.assertFetchFailureIn, "crypttext_hashtree")
return d
def test_bad_sharehashes_failure(self): def test_bad_sharehashes_failure(self):
# the first 7 servers have bad block hashes, so the sharehash tree # the first 7 servers have bad block hashes, so the sharehash tree
# will not validate, and the download will fail # will not validate, and the download will fail