test_encode.py: even more testing of merkle trees, getting fairly comprehensive now

This commit is contained in:
Brian Warner 2007-06-07 21:24:39 -07:00
parent 053109b28b
commit cabba59fe7
2 changed files with 102 additions and 99 deletions

View File

@ -122,6 +122,8 @@ class Encoder(object):
data['size'] = self.file_size data['size'] = self.file_size
data['segment_size'] = self.segment_size data['segment_size'] = self.segment_size
data['num_segments'] = mathutil.div_ceil(self.file_size,
self.segment_size)
data['needed_shares'] = self.required_shares data['needed_shares'] = self.required_shares
data['total_shares'] = self.num_shares data['total_shares'] = self.num_shares

View File

@ -1,14 +1,18 @@
#! /usr/bin/env python
from twisted.trial import unittest from twisted.trial import unittest
from twisted.internet import defer from twisted.internet import defer
from twisted.python.failure import Failure from twisted.python.failure import Failure
from foolscap import eventual from foolscap import eventual
from allmydata import encode, download from allmydata import encode, download, hashtree
from allmydata.util import bencode from allmydata.util import hashutil
from allmydata.uri import pack_uri from allmydata.uri import pack_uri
from allmydata.Crypto.Cipher import AES
import sha
from cStringIO import StringIO from cStringIO import StringIO
def netstring(s):
return "%d:%s," % (len(s), s)
class FakePeer: class FakePeer:
def __init__(self, mode="good"): def __init__(self, mode="good"):
self.ss = FakeStorageServer(mode) self.ss = FakeStorageServer(mode)
@ -44,6 +48,9 @@ class FakeStorageServer:
class LostPeerError(Exception): class LostPeerError(Exception):
pass pass
def flip_bit(good): # flips the last bit
return good[:-1] + chr(ord(good[-1]) ^ 0x01)
class FakeBucketWriter: class FakeBucketWriter:
# these are used for both reading and writing # these are used for both reading and writing
def __init__(self, mode="good"): def __init__(self, mode="good"):
@ -96,41 +103,38 @@ class FakeBucketWriter:
assert not self.closed assert not self.closed
self.closed = True self.closed = True
def flip_bit(self, good): # flips the last bit
return good[:-1] + chr(ord(good[-1]) ^ 0x01)
def get_block(self, blocknum): def get_block(self, blocknum):
assert isinstance(blocknum, (int, long)) assert isinstance(blocknum, (int, long))
if self.mode == "bad block": if self.mode == "bad block":
return self.flip_bit(self.blocks[blocknum]) return flip_bit(self.blocks[blocknum])
return self.blocks[blocknum] return self.blocks[blocknum]
def get_plaintext_hashes(self): def get_plaintext_hashes(self):
hashes = self.plaintext_hashes[:] hashes = self.plaintext_hashes[:]
if self.mode == "bad plaintext hashroot": if self.mode == "bad plaintext hashroot":
hashes[0] = self.flip_bit(hashes[0]) hashes[0] = flip_bit(hashes[0])
if self.mode == "bad plaintext hash": if self.mode == "bad plaintext hash":
hashes[1] = self.flip_bit(hashes[1]) hashes[1] = flip_bit(hashes[1])
return hashes return hashes
def get_crypttext_hashes(self): def get_crypttext_hashes(self):
hashes = self.crypttext_hashes[:] hashes = self.crypttext_hashes[:]
if self.mode == "bad crypttext hashroot": if self.mode == "bad crypttext hashroot":
hashes[0] = self.flip_bit(hashes[0]) hashes[0] = flip_bit(hashes[0])
if self.mode == "bad crypttext hash": if self.mode == "bad crypttext hash":
hashes[1] = self.flip_bit(hashes[1]) hashes[1] = flip_bit(hashes[1])
return hashes return hashes
def get_block_hashes(self): def get_block_hashes(self):
if self.mode == "bad blockhash": if self.mode == "bad blockhash":
hashes = self.block_hashes[:] hashes = self.block_hashes[:]
hashes[1] = self.flip_bit(hashes[1]) hashes[1] = flip_bit(hashes[1])
return hashes return hashes
return self.block_hashes return self.block_hashes
def get_share_hashes(self): def get_share_hashes(self):
if self.mode == "bad sharehash": if self.mode == "bad sharehash":
hashes = self.share_hashes[:] hashes = self.share_hashes[:]
hashes[1] = (hashes[1][0], self.flip_bit(hashes[1][1])) hashes[1] = (hashes[1][0], flip_bit(hashes[1][1]))
return hashes return hashes
if self.mode == "missing sharehash": if self.mode == "missing sharehash":
# one sneaky attack would be to pretend we don't know our own # one sneaky attack would be to pretend we don't know our own
@ -141,7 +145,7 @@ class FakeBucketWriter:
def get_thingA(self): def get_thingA(self):
if self.mode == "bad thingA": if self.mode == "bad thingA":
return self.flip_bit(self.thingA) return flip_bit(self.thingA)
return self.thingA return self.thingA
@ -266,12 +270,7 @@ class Roundtrip(unittest.TestCase):
d = self.send(k_and_happy_and_n, AVAILABLE_SHARES, d = self.send(k_and_happy_and_n, AVAILABLE_SHARES,
max_segment_size, bucket_modes, data) max_segment_size, bucket_modes, data)
# that fires with (thingA_hash, e, shareholders) # that fires with (thingA_hash, e, shareholders)
if recover_mode == "recover": d.addCallback(self.recover, AVAILABLE_SHARES, recover_mode)
d.addCallback(self.recover, AVAILABLE_SHARES)
elif recover_mode == "thingA":
d.addCallback(self.recover_with_thingA, AVAILABLE_SHARES)
else:
raise RuntimeError, "unknown recover_mode '%s'" % recover_mode
# that fires with newdata # that fires with newdata
def _downloaded((newdata, fd)): def _downloaded((newdata, fd)):
self.failUnless(newdata == data) self.failUnless(newdata == data)
@ -301,8 +300,15 @@ class Roundtrip(unittest.TestCase):
peer = FakeBucketWriter(mode) peer = FakeBucketWriter(mode)
shareholders[shnum] = peer shareholders[shnum] = peer
e.set_shareholders(shareholders) e.set_shareholders(shareholders)
e.set_thingA_data({'verifierid': "V" * 20, fileid_hasher = sha.new(netstring("allmydata_fileid_v1"))
'fileid': "F" * 20, fileid_hasher.update(data)
cryptor = AES.new(key=nonkey, mode=AES.MODE_CTR,
counterstart="\x00"*16)
verifierid_hasher = sha.new(netstring("allmydata_verifierid_v1"))
verifierid_hasher.update(cryptor.encrypt(data))
e.set_thingA_data({'verifierid': verifierid_hasher.digest(),
'fileid': fileid_hasher.digest(),
}) })
d = e.start() d = e.start()
def _sent(thingA_hash): def _sent(thingA_hash):
@ -310,9 +316,14 @@ class Roundtrip(unittest.TestCase):
d.addCallback(_sent) d.addCallback(_sent)
return d return d
def recover(self, (thingA_hash, e, shareholders), AVAILABLE_SHARES): def recover(self, (thingA_hash, e, shareholders), AVAILABLE_SHARES,
recover_mode):
key = e.key
if "corrupt_key" in recover_mode:
key = flip_bit(key)
URI = pack_uri(storage_index="S" * 20, URI = pack_uri(storage_index="S" * 20,
key=e.key, key=key,
thingA_hash=thingA_hash, thingA_hash=thingA_hash,
needed_shares=e.required_shares, needed_shares=e.required_shares,
total_shares=e.num_shares, total_shares=e.num_shares,
@ -331,72 +342,39 @@ class Roundtrip(unittest.TestCase):
fd.add_share_bucket(shnum, bucket) fd.add_share_bucket(shnum, bucket)
fd._got_all_shareholders(None) fd._got_all_shareholders(None)
# grab a copy of thingA from one of the shareholders # Make it possible to obtain thingA from the shareholders. Arrange
thingA = shareholders[0].thingA # for shareholders[0] to be the first, so we can selectively corrupt
thingA_data = bencode.bdecode(thingA) # the data it returns.
NOTthingA = {'codec_name': e._codec.get_encoder_type(),
'codec_params': e._codec.get_serialized_params(),
'tail_codec_params': e._tail_codec.get_serialized_params(),
'verifierid': "V" * 20,
'fileid': "F" * 20,
#'share_root_hash': roothash,
'segment_size': e.segment_size,
'needed_shares': e.required_shares,
'total_shares': e.num_shares,
}
fd._got_thingA(thingA_data)
# we skip _get_hashtrees here, and the lack of hashtree attributes
# will cause the download.Output object to skip the
# plaintext/crypttext merkle tree checks. We instruct the downloader
# to skip the full-file checks as well.
fd.check_verifierid = False
fd.check_fileid = False
fd._create_validated_buckets(None)
d = fd._download_all_segments(None)
d.addCallback(fd._done)
def _done(newdata):
return (newdata, fd)
d.addCallback(_done)
return d
def recover_with_thingA(self, (thingA_hash, e, shareholders),
AVAILABLE_SHARES):
URI = pack_uri(storage_index="S" * 20,
key=e.key,
thingA_hash=thingA_hash,
needed_shares=e.required_shares,
total_shares=e.num_shares,
size=e.file_size)
client = None
target = download.Data()
fd = download.FileDownloader(client, URI, target)
# we manually cycle the FileDownloader through a number of steps that
# would normally be sequenced by a Deferred chain in
# FileDownloader.start(), to give us more control over the process.
# In particular, by bypassing _get_all_shareholders, we skip
# permuted-peerlist selection.
for shnum, bucket in shareholders.items():
if shnum < AVAILABLE_SHARES and bucket.closed:
fd.add_share_bucket(shnum, bucket)
fd._got_all_shareholders(None)
# ask shareholders for thingA as usual, validating the responses.
# Arrange for shareholders[0] to be the first, so we can selectively
# corrupt the data it returns.
fd._thingA_sources = shareholders.values() fd._thingA_sources = shareholders.values()
fd._thingA_sources.remove(shareholders[0]) fd._thingA_sources.remove(shareholders[0])
fd._thingA_sources.insert(0, shareholders[0]) fd._thingA_sources.insert(0, shareholders[0])
# the thingA block contains plaintext/crypttext hash trees, but does
# not have a fileid or verifierid, so we have to disable those checks
fd.check_verifierid = False
fd.check_fileid = False
d = fd._obtain_thingA(None) d = defer.succeed(None)
# have the FileDownloader retrieve a copy of thingA itself
d.addCallback(fd._obtain_thingA)
if "corrupt_crypttext_hashes" in recover_mode:
# replace everybody's crypttext hash trees with a different one
# (computed over a different file), then modify our thingA to
# reflect the new crypttext hash tree root
def _corrupt_crypttext_hashes(thingA):
assert isinstance(thingA, dict)
assert 'crypttext_root_hash' in thingA
badhash = hashutil.tagged_hash("bogus", "data")
bad_crypttext_hashes = [badhash] * thingA['num_segments']
badtree = hashtree.HashTree(bad_crypttext_hashes)
for bucket in shareholders.values():
bucket.crypttext_hashes = list(badtree)
thingA['crypttext_root_hash'] = badtree[0]
return thingA
d.addCallback(_corrupt_crypttext_hashes)
d.addCallback(fd._got_thingA) d.addCallback(fd._got_thingA)
# also have the FileDownloader ask for hash trees
d.addCallback(fd._get_hashtrees) d.addCallback(fd._get_hashtrees)
d.addCallback(fd._create_validated_buckets) d.addCallback(fd._create_validated_buckets)
d.addCallback(fd._download_all_segments) d.addCallback(fd._download_all_segments)
d.addCallback(fd._done) d.addCallback(fd._done)
@ -505,12 +483,11 @@ class Roundtrip(unittest.TestCase):
expected[where] += 1 expected[where] += 1
self.failUnlessEqual(fd._fetch_failures, expected) self.failUnlessEqual(fd._fetch_failures, expected)
def test_good_thingA(self): def test_good(self):
# exercise recover_mode="thingA", just to make sure the test works # just to make sure the test harness works when we aren't
modemap = dict([(i, "good") for i in range(1)] + # intentionally causing failures
[(i, "good") for i in range(1, 10)]) modemap = dict([(i, "good") for i in range(0, 10)])
d = self.send_and_recover((4,8,10), bucket_modes=modemap, d = self.send_and_recover((4,8,10), bucket_modes=modemap)
recover_mode="thingA")
d.addCallback(self.assertFetchFailureIn, None) d.addCallback(self.assertFetchFailureIn, None)
return d return d
@ -519,8 +496,7 @@ class Roundtrip(unittest.TestCase):
# different server. # different server.
modemap = dict([(i, "bad thingA") for i in range(1)] + modemap = dict([(i, "bad thingA") for i in range(1)] +
[(i, "good") for i in range(1, 10)]) [(i, "good") for i in range(1, 10)])
d = self.send_and_recover((4,8,10), bucket_modes=modemap, d = self.send_and_recover((4,8,10), bucket_modes=modemap)
recover_mode="thingA")
d.addCallback(self.assertFetchFailureIn, "thingA") d.addCallback(self.assertFetchFailureIn, "thingA")
return d return d
@ -529,8 +505,7 @@ class Roundtrip(unittest.TestCase):
# to a different server. # to a different server.
modemap = dict([(i, "bad plaintext hashroot") for i in range(1)] + modemap = dict([(i, "bad plaintext hashroot") for i in range(1)] +
[(i, "good") for i in range(1, 10)]) [(i, "good") for i in range(1, 10)])
d = self.send_and_recover((4,8,10), bucket_modes=modemap, d = self.send_and_recover((4,8,10), bucket_modes=modemap)
recover_mode="thingA")
d.addCallback(self.assertFetchFailureIn, "plaintext_hashroot") d.addCallback(self.assertFetchFailureIn, "plaintext_hashroot")
return d return d
@ -539,8 +514,7 @@ class Roundtrip(unittest.TestCase):
# over to a different server. # over to a different server.
modemap = dict([(i, "bad crypttext hashroot") for i in range(1)] + modemap = dict([(i, "bad crypttext hashroot") for i in range(1)] +
[(i, "good") for i in range(1, 10)]) [(i, "good") for i in range(1, 10)])
d = self.send_and_recover((4,8,10), bucket_modes=modemap, d = self.send_and_recover((4,8,10), bucket_modes=modemap)
recover_mode="thingA")
d.addCallback(self.assertFetchFailureIn, "crypttext_hashroot") d.addCallback(self.assertFetchFailureIn, "crypttext_hashroot")
return d return d
@ -549,8 +523,7 @@ class Roundtrip(unittest.TestCase):
# over to a different server. # over to a different server.
modemap = dict([(i, "bad plaintext hash") for i in range(1)] + modemap = dict([(i, "bad plaintext hash") for i in range(1)] +
[(i, "good") for i in range(1, 10)]) [(i, "good") for i in range(1, 10)])
d = self.send_and_recover((4,8,10), bucket_modes=modemap, d = self.send_and_recover((4,8,10), bucket_modes=modemap)
recover_mode="thingA")
d.addCallback(self.assertFetchFailureIn, "plaintext_hashtree") d.addCallback(self.assertFetchFailureIn, "plaintext_hashtree")
return d return d
@ -559,11 +532,39 @@ class Roundtrip(unittest.TestCase):
# over to a different server. # over to a different server.
modemap = dict([(i, "bad crypttext hash") for i in range(1)] + modemap = dict([(i, "bad crypttext hash") for i in range(1)] +
[(i, "good") for i in range(1, 10)]) [(i, "good") for i in range(1, 10)])
d = self.send_and_recover((4,8,10), bucket_modes=modemap, d = self.send_and_recover((4,8,10), bucket_modes=modemap)
recover_mode="thingA")
d.addCallback(self.assertFetchFailureIn, "crypttext_hashtree") d.addCallback(self.assertFetchFailureIn, "crypttext_hashtree")
return d return d
def test_bad_crypttext_hashes_failure(self):
# to test that the crypttext merkle tree is really being applied, we
# sneak into the download process and corrupt two things: we replace
# everybody's crypttext hashtree with a bad version (computed over
# bogus data), and we modify the supposedly-validated thingA block to
# match the new crypttext hashtree root. The download process should
# notice that the crypttext coming out of FEC doesn't match the tree,
# and fail.
modemap = dict([(i, "good") for i in range(0, 10)])
d = self.send_and_recover((4,8,10), bucket_modes=modemap,
recover_mode=("corrupt_crypttext_hashes"))
def _done(res):
self.failUnless(isinstance(res, Failure))
self.failUnless(res.check(hashtree.BadHashError), res)
d.addBoth(_done)
return d
def test_bad_plaintext(self):
# faking a decryption failure is easier: just corrupt the key
modemap = dict([(i, "good") for i in range(0, 10)])
d = self.send_and_recover((4,8,10), bucket_modes=modemap,
recover_mode=("corrupt_key"))
def _done(res):
self.failUnless(isinstance(res, Failure))
self.failUnless(res.check(hashtree.BadHashError))
d.addBoth(_done)
return d
def test_bad_sharehashes_failure(self): def test_bad_sharehashes_failure(self):
# the first 7 servers have bad block hashes, so the sharehash tree # the first 7 servers have bad block hashes, so the sharehash tree