mutable: handle bad hashes, improve test coverage, rearrange slightly to facilitate these

This commit is contained in:
Brian Warner 2007-11-13 23:08:15 -07:00
parent 05253dbe72
commit d6f2dbbac7
4 changed files with 183 additions and 73 deletions

View File

@ -31,10 +31,11 @@ class UncoordinatedWriteError(Exception):
class CorruptShareError(Exception):
def __init__(self, peerid, shnum, reason):
self.args = (peerid, shnum, reason)
self.peerid = peerid
self.shnum = shnum
self.reason = reason
def __repr__(self):
def __str__(self):
short_peerid = idlib.nodeid_b2a(self.peerid)[:8]
return "<CorruptShareError peerid=%s shnum[%d]: %s" % (short_peerid,
self.shnum,
@ -104,6 +105,7 @@ def unpack_share(data):
chunk = share_hash_chain_s[i:i+hsize]
(hid, h) = struct.unpack(share_hash_format, chunk)
share_hash_chain.append( (hid, h) )
share_hash_chain = dict(share_hash_chain)
block_hash_tree_s = data[o['block_hash_tree']:o['share_data']]
assert len(block_hash_tree_s) % 32 == 0, len(block_hash_tree_s)
block_hash_tree = []
@ -167,6 +169,32 @@ def pack_offsets(verification_key_length, signature_length,
offsets['enc_privkey'],
offsets['EOF'])
def pack_share(prefix, verification_key, signature,
share_hash_chain, block_hash_tree,
share_data, encprivkey):
share_hash_chain_s = "".join([struct.pack(">H32s", i, share_hash_chain[i])
for i in sorted(share_hash_chain.keys())])
for h in block_hash_tree:
assert len(h) == 32
block_hash_tree_s = "".join(block_hash_tree)
offsets = pack_offsets(len(verification_key),
len(signature),
len(share_hash_chain_s),
len(block_hash_tree_s),
len(share_data),
len(encprivkey))
final_share = "".join([prefix,
offsets,
verification_key,
signature,
share_hash_chain_s,
block_hash_tree_s,
share_data,
encprivkey])
return final_share
class Retrieve:
def __init__(self, filenode):
self._node = filenode
@ -224,6 +252,8 @@ class Retrieve:
# 7: if we discover corrupt shares during the reconstruction process,
# remove that share from the sharemap. and start step#6 again.
self.log("starting retrieval")
initial_query_count = 5
self._read_size = 2000
@ -245,11 +275,12 @@ class Retrieve:
# continuing through the last byte of sharedata.
self._valid_versions = {}
# self._valid_shares is a set (peerid,data) tuples. Each time we
# examine the hash chains inside a share and validate them against a
# signed root_hash, we add the share to self._valid_shares . We use
# this to avoid re-checking the hashes over and over again.
self._valid_shares = set()
# self._valid_shares is a dict mapping (peerid,data) tuples to
# validated sharedata strings. Each time we examine the hash chains
# inside a share and validate them against a signed root_hash, we add
# the share to self._valid_shares . We use this to avoid re-checking
# the hashes over and over again.
self._valid_shares = {}
self._done_deferred = defer.Deferred()
@ -332,6 +363,8 @@ class Retrieve:
for shnum,datav in datavs.items():
data = datav[0]
self.log("_got_results: got shnum #%d from peerid %s"
% (shnum, idlib.shortnodeid_b2a(peerid)))
(seqnum, root_hash, IV, k, N, segsize, datalength,
pubkey_s, signature, prefix) = unpack_prefix_and_signature(data)
@ -339,7 +372,7 @@ class Retrieve:
fingerprint = hashutil.ssk_pubkey_fingerprint_hash(pubkey_s)
if fingerprint != self._node._fingerprint:
# bad share
raise CorruptShareError(peerid,
raise CorruptShareError(peerid, shnum,
"pubkey doesn't match fingerprint")
self._pubkey = self._deserialize_pubkey(pubkey_s)
self._node._populate_pubkey(self._pubkey)
@ -349,7 +382,7 @@ class Retrieve:
# it's a new pair. Verify the signature.
valid = self._pubkey.verify(prefix, signature)
if not valid:
raise CorruptShareError(peerid,
raise CorruptShareError(peerid, shnum,
"signature is invalid")
# ok, it's a valid verinfo. Add it to the list of validated
# versions.
@ -372,6 +405,7 @@ class Retrieve:
# there's enough data present. If not, raise NeedMoreDataError,
# which will trigger a re-fetch.
_ignored = unpack_share(data)
self.log(" found enough data to add share contents")
self._valid_versions[verinfo][1].add(shnum, (peerid, data))
@ -391,9 +425,11 @@ class Retrieve:
self._bad_peerids.add(peerid)
short_sid = idlib.b2a(self._storage_index)[:6]
if f.check(CorruptShareError):
self.log("WEIRD: bad share for %s: %s" % (short_sid, f))
self.log("WEIRD: bad share for %s: %s %s" % (short_sid, f,
f.value))
else:
self.log("WEIRD: other error for %s: %s" % (short_sid, f))
self.log("WEIRD: other error for %s: %s %s" % (short_sid, f,
f.value))
def _check_for_done(self, res):
if not self._running:
@ -422,13 +458,17 @@ class Retrieve:
def _problem(f):
self._last_failure = f
if f.check(CorruptShareError):
# log(WEIRD)
self.log("WEIRD: saw corrupt share, rescheduling")
# _attempt_decode is responsible for removing the bad
# share, so we can just try again
eventually(self._check_for_done)
eventually(self._check_for_done, None)
return
return f
d.addCallbacks(self._done, _problem)
# TODO: create an errback-routing mechanism to make sure that
# weird coding errors will cause the retrieval to fail rather
# than hanging forever. Any otherwise-unhandled exceptions
# should follow this path.
return
# we don't have enough shares yet. Should we send out more queries?
@ -478,13 +518,26 @@ class Retrieve:
# sharemap is a dict which maps shnum to [(peerid,data)..] sets.
(seqnum, root_hash, IV, segsize, datalength) = verinfo
assert len(sharemap) >= self._required_shares, len(sharemap)
shares_s = []
for shnum in sorted(sharemap.keys()):
for shareinfo in sharemap[shnum]:
shares_s.append("#%d" % shnum)
shares_s = ",".join(shares_s)
self.log("_attempt_decode: version %d-%s, shares: %s" %
(seqnum, idlib.b2a(root_hash)[:4], shares_s))
# first, validate each share that we haven't validated yet. We use
# self._valid_shares to remember which ones we've already checked.
shares = {}
for shnum, shareinfos in sharemap.items():
assert len(shareinfos) > 0
for shareinfo in shareinfos:
# have we already validated the hashes on this share?
if shareinfo not in self._valid_shares:
# nope: must check the hashes and extract the actual data
(peerid,data) = shareinfo
try:
# The (seqnum+root_hash+IV) tuple for this share was
@ -500,24 +553,36 @@ class Retrieve:
# validate the prefix on all shares) from using
# anything else in the share.
validator = self._validate_share_and_extract_data
sharedata = validator(root_hash, shnum, data)
sharedata = validator(peerid, root_hash, shnum, data)
assert isinstance(sharedata, str)
except CorruptShareError, e:
self.log("WEIRD: share was corrupt: %s" % e)
sharemap[shnum].discard(shareinfo)
if not sharemap[shnum]:
# remove the key so the test in _check_for_done
# can accurately decide that we don't have enough
# shares to try again right now.
del sharemap[shnum]
# If there are enough remaining shares,
# _check_for_done() will try again
raise
self._valid_shares.add(shareinfo)
shares[shnum] = sharedata
# at this point, all shares in the sharemap are valid, and they're
# all for the same seqnum+root_hash version, so it's now down to
# doing FEC and decrypt.
# share is valid: remember it so we won't need to check
# (or extract) it again
self._valid_shares[shareinfo] = sharedata
# the share is now in _valid_shares, so just copy over the
# sharedata
shares[shnum] = self._valid_shares[shareinfo]
# now that the big loop is done, all shares in the sharemap are
# valid, and they're all for the same seqnum+root_hash version, so
# it's now down to doing FEC and decrypt.
assert len(shares) >= self._required_shares, len(shares)
d = defer.maybeDeferred(self._decode, shares, segsize, datalength)
d.addCallback(self._decrypt, IV, seqnum, root_hash)
return d
def _validate_share_and_extract_data(self, root_hash, shnum, data):
def _validate_share_and_extract_data(self, peerid, root_hash, shnum, data):
# 'data' is the whole SMDF share
self.log("_validate_share_and_extract_data[%d]" % shnum)
assert data[0] == "\x00"
@ -531,7 +596,7 @@ class Retrieve:
leaves = [hashutil.block_hash(share_data)]
t = hashtree.HashTree(leaves)
if list(t) != block_hash_tree:
raise CorruptShareError("block hash tree failure")
raise CorruptShareError(peerid, shnum, "block hash tree failure")
share_hash_leaf = t[0]
# t2 = hashtree.IncompleteHashTree()
# TODO: use shnum, share_hash_leaf, share_hash_chain to compare against
@ -553,6 +618,7 @@ class Retrieve:
shareids.append(shareid)
shares.append(share)
assert len(shareids) >= self._required_shares, len(shareids)
# zfec really doesn't want extra shares
shareids = shareids[:self._required_shares]
shares = shares[:self._required_shares]
@ -650,7 +716,7 @@ class Publish:
# 4a: may need to run recovery algorithm
# 5: when enough responses are back, we're done
self.log("got enough peers")
self.log("starting publish, data is %r" % (newdata,))
self._storage_index = self._node.get_storage_index()
self._writekey = self._node.get_writekey()
@ -747,9 +813,12 @@ class Publish:
def _got_query_results(self, datavs, peerid, permutedid,
reachable_peers, current_share_peers):
self.log("_got_query_results")
assert isinstance(datavs, dict)
reachable_peers[peerid] = permutedid
for shnum, datav in datavs.items():
self.log(" peer has shnum %d" % shnum)
assert len(datav) == 1
data = datav[0]
# We want (seqnum, root_hash, IV) from all servers to know what
@ -999,29 +1068,14 @@ class Publish:
final_shares = {}
for shnum in range(total_shares):
shc = share_hash_chain[shnum]
share_hash_chain_s = "".join([struct.pack(">H32s", i, shc[i])
for i in sorted(shc.keys())])
bht = block_hash_trees[shnum]
for h in bht:
assert len(h) == 32
block_hash_tree_s = "".join(bht)
share_data = all_shares[shnum]
offsets = pack_offsets(len(verification_key),
len(signature),
len(share_hash_chain_s),
len(block_hash_tree_s),
len(share_data),
len(encprivkey))
final_shares[shnum] = "".join([prefix,
offsets,
final_share = pack_share(prefix,
verification_key,
signature,
share_hash_chain_s,
block_hash_tree_s,
share_data,
encprivkey])
share_hash_chain[shnum],
block_hash_trees[shnum],
all_shares[shnum],
encprivkey)
final_shares[shnum] = final_share
return (seqnum, root_hash, final_shares, target_info)

View File

@ -179,7 +179,8 @@ def dump_SDMF_share(offset, length, config, out, err):
print >>out, " total_shares: %d" % N
print >>out, " segsize: %d" % segsize
print >>out, " datalen: %d" % datalen
share_hash_ids = ",".join([str(hid) for (hid,hash) in share_hash_chain])
share_hash_ids = ",".join(sorted([str(hid)
for hid in share_hash_chain.keys()]))
print >>out, " share_hash_chain: %s" % share_hash_ids
print >>out, " block_hash_tree: %d nodes" % len(block_hash_tree)

View File

@ -264,13 +264,12 @@ class Publish(unittest.TestCase):
k, N, segsize, datalen)
self.failUnlessEqual(signature,
FakePrivKey(0).sign(sig_material))
self.failUnless(isinstance(share_hash_chain, list))
self.failUnless(isinstance(share_hash_chain, dict))
self.failUnlessEqual(len(share_hash_chain), 4) # ln2(10)++
for i in share_hash_chain:
self.failUnless(isinstance(i, tuple))
self.failUnless(isinstance(i[0], int))
self.failUnless(isinstance(i[1], str))
self.failUnlessEqual(len(i[1]), 32)
for shnum,share_hash in share_hash_chain.items():
self.failUnless(isinstance(shnum, int))
self.failUnless(isinstance(share_hash, str))
self.failUnlessEqual(len(share_hash), 32)
self.failUnless(isinstance(block_hash_tree, list))
self.failUnlessEqual(len(block_hash_tree), 1) # very small tree
self.failUnlessEqual(IV, "IV"*8)

View File

@ -6,7 +6,7 @@ from twisted.trial import unittest
from twisted.internet import defer, reactor
from twisted.internet import threads # CLI tests use deferToThread
from twisted.application import service
from allmydata import client, uri, download, upload
from allmydata import client, uri, download, upload, storage, mutable
from allmydata.introducer import IntroducerNode
from allmydata.util import deferredutil, fileutil, idlib, mathutil, testutil
from allmydata.scripts import runner
@ -237,6 +237,72 @@ class SystemTest(testutil.SignalMixin, unittest.TestCase):
return d
test_upload_and_download.timeout = 4800
def _find_shares(self, basedir):
shares = []
for (dirpath, dirnames, filenames) in os.walk(basedir):
if "storage" not in dirpath:
continue
if not filenames:
continue
pieces = dirpath.split(os.sep)
if pieces[-3] == "storage" and pieces[-2] == "shares":
# we're sitting in .../storage/shares/$SINDEX , and there
# are sharefiles here
assert pieces[-4].startswith("client")
client_num = int(pieces[-4][-1])
storage_index_s = pieces[-1]
storage_index = idlib.a2b(storage_index_s)
for sharename in filenames:
shnum = int(sharename)
filename = os.path.join(dirpath, sharename)
data = (client_num, storage_index, filename, shnum)
shares.append(data)
if not shares:
self.fail("unable to find any share files in %s" % basedir)
return shares
def _corrupt_mutable_share(self, filename, which):
msf = storage.MutableShareFile(filename)
datav = msf.readv([ (0, 1000000) ])
final_share = datav[0]
assert len(final_share) < 1000000 # ought to be truncated
pieces = mutable.unpack_share(final_share)
(seqnum, root_hash, IV, k, N, segsize, datalen,
verification_key, signature, share_hash_chain, block_hash_tree,
share_data, enc_privkey) = pieces
if which == "seqnum":
seqnum = seqnum + 15
elif which == "R":
root_hash = self.flip_bit(root_hash)
elif which == "IV":
IV = self.flip_bit(IV)
elif which == "segsize":
segsize = segsize + 15
elif which == "pubkey":
verification_key = self.flip_bit(verification_key)
elif which == "signature":
signature = self.flip_bit(signature)
elif which == "share_hash_chain":
nodenum = share_hash_chain.keys()[0]
share_hash_chain[nodenum] = self.flip_bit(share_hash_chain[nodenum])
elif which == "block_hash_tree":
block_hash_tree[-1] = self.flip_bit(block_hash_tree[-1])
elif which == "share_data":
share_data = self.flip_bit(share_data)
elif which == "encprivkey":
enc_privkey = self.flip_bit(enc_privkey)
prefix = mutable.pack_prefix(seqnum, root_hash, IV, k, N,
segsize, datalen)
final_share = mutable.pack_share(prefix,
verification_key,
signature,
share_hash_chain,
block_hash_tree,
share_data,
enc_privkey)
msf.writev( [(0, final_share)], None)
def test_mutable(self):
self.basedir = "system/SystemTest/test_mutable"
@ -260,22 +326,8 @@ class SystemTest(testutil.SignalMixin, unittest.TestCase):
def _test_debug(res):
# find a share. It is important to run this while there is only
# one slot in the grid.
for (dirpath, dirnames, filenames) in os.walk(self.basedir):
if "storage" not in dirpath:
continue
if not filenames:
continue
pieces = dirpath.split(os.sep)
if pieces[-3] == "storage" and pieces[-2] == "shares":
# we're sitting in .../storage/shares/$SINDEX , and there
# are sharefiles here
assert pieces[-4].startswith("client")
client_num = int(pieces[-4][-1])
filename = os.path.join(dirpath, filenames[0])
break
else:
self.fail("unable to find any share files in %s"
% self.basedir)
shares = self._find_shares(self.basedir)
(client_num, storage_index, filename, shnum) = shares[0]
log.msg("test_system.SystemTest.test_mutable._test_debug using %s"
% filename)
log.msg(" for clients[%d]" % client_num)
@ -367,6 +419,7 @@ class SystemTest(testutil.SignalMixin, unittest.TestCase):
uri = self._mutable_node_1.get_uri()
newnode1 = self.clients[2].create_node_from_uri(uri)
newnode2 = self.clients[3].create_node_from_uri(uri)
self._newnode3 = self.clients[3].create_node_from_uri(uri)
log.msg("starting replace2")
d1 = newnode1.replace(NEWERDATA, wait_for_numpeers=self.numclients)
d1.addCallback(lambda res: newnode2.download_to_data())
@ -376,13 +429,13 @@ class SystemTest(testutil.SignalMixin, unittest.TestCase):
def _check_download_5(res):
log.msg("finished replace2")
self.failUnlessEqual(res, NEWERDATA)
# Make sure we can create empty files -- this can screw up the
# segsize math.
d1 = self.clients[2].create_mutable_file("", wait_for_numpeers=self.numclients)
# make sure we can create empty files, this usually screws up the
# segsize math
d1 = self.clients[2].create_mutable_file("")
d1.addCallback(lambda newnode: newnode.download_to_data())
d1.addCallback(lambda res: self.failUnlessEqual("", res))
return d1
d.addCallback(_check_download_5)
d.addCallback(_check_empty_file)
d.addCallback(lambda res: self.clients[0].create_empty_dirnode(wait_for_numpeers=self.numclients))
def _created_dirnode(dnode):
@ -394,6 +447,9 @@ class SystemTest(testutil.SignalMixin, unittest.TestCase):
d1.addCallback(lambda res: dnode.set_node("see recursive", dnode, wait_for_numpeers=self.numclients))
d1.addCallback(lambda res: dnode.has_child("see recursive"))
d1.addCallback(lambda answer: self.failUnlessEqual(answer, True))
d1.addCallback(lambda res: dnode.build_manifest())
d1.addCallback(lambda manifest:
self.failUnlessEqual(len(manifest), 1))
return d1
d.addCallback(_created_dirnode)