mutable: handle bad hashes, improve test coverage, rearrange slightly to facilitate these

This commit is contained in:
Brian Warner 2007-11-13 23:08:15 -07:00
parent 05253dbe72
commit d6f2dbbac7
4 changed files with 183 additions and 73 deletions

View File

@ -31,10 +31,11 @@ class UncoordinatedWriteError(Exception):
class CorruptShareError(Exception): class CorruptShareError(Exception):
def __init__(self, peerid, shnum, reason): def __init__(self, peerid, shnum, reason):
self.args = (peerid, shnum, reason)
self.peerid = peerid self.peerid = peerid
self.shnum = shnum self.shnum = shnum
self.reason = reason self.reason = reason
def __repr__(self): def __str__(self):
short_peerid = idlib.nodeid_b2a(self.peerid)[:8] short_peerid = idlib.nodeid_b2a(self.peerid)[:8]
return "<CorruptShareError peerid=%s shnum[%d]: %s" % (short_peerid, return "<CorruptShareError peerid=%s shnum[%d]: %s" % (short_peerid,
self.shnum, self.shnum,
@ -104,6 +105,7 @@ def unpack_share(data):
chunk = share_hash_chain_s[i:i+hsize] chunk = share_hash_chain_s[i:i+hsize]
(hid, h) = struct.unpack(share_hash_format, chunk) (hid, h) = struct.unpack(share_hash_format, chunk)
share_hash_chain.append( (hid, h) ) share_hash_chain.append( (hid, h) )
share_hash_chain = dict(share_hash_chain)
block_hash_tree_s = data[o['block_hash_tree']:o['share_data']] block_hash_tree_s = data[o['block_hash_tree']:o['share_data']]
assert len(block_hash_tree_s) % 32 == 0, len(block_hash_tree_s) assert len(block_hash_tree_s) % 32 == 0, len(block_hash_tree_s)
block_hash_tree = [] block_hash_tree = []
@ -167,6 +169,32 @@ def pack_offsets(verification_key_length, signature_length,
offsets['enc_privkey'], offsets['enc_privkey'],
offsets['EOF']) offsets['EOF'])
def pack_share(prefix, verification_key, signature,
share_hash_chain, block_hash_tree,
share_data, encprivkey):
share_hash_chain_s = "".join([struct.pack(">H32s", i, share_hash_chain[i])
for i in sorted(share_hash_chain.keys())])
for h in block_hash_tree:
assert len(h) == 32
block_hash_tree_s = "".join(block_hash_tree)
offsets = pack_offsets(len(verification_key),
len(signature),
len(share_hash_chain_s),
len(block_hash_tree_s),
len(share_data),
len(encprivkey))
final_share = "".join([prefix,
offsets,
verification_key,
signature,
share_hash_chain_s,
block_hash_tree_s,
share_data,
encprivkey])
return final_share
class Retrieve: class Retrieve:
def __init__(self, filenode): def __init__(self, filenode):
self._node = filenode self._node = filenode
@ -224,6 +252,8 @@ class Retrieve:
# 7: if we discover corrupt shares during the reconstruction process, # 7: if we discover corrupt shares during the reconstruction process,
# remove that share from the sharemap. and start step#6 again. # remove that share from the sharemap. and start step#6 again.
self.log("starting retrieval")
initial_query_count = 5 initial_query_count = 5
self._read_size = 2000 self._read_size = 2000
@ -245,11 +275,12 @@ class Retrieve:
# continuing through the last byte of sharedata. # continuing through the last byte of sharedata.
self._valid_versions = {} self._valid_versions = {}
# self._valid_shares is a set (peerid,data) tuples. Each time we # self._valid_shares is a dict mapping (peerid,data) tuples to
# examine the hash chains inside a share and validate them against a # validated sharedata strings. Each time we examine the hash chains
# signed root_hash, we add the share to self._valid_shares . We use # inside a share and validate them against a signed root_hash, we add
# this to avoid re-checking the hashes over and over again. # the share to self._valid_shares . We use this to avoid re-checking
self._valid_shares = set() # the hashes over and over again.
self._valid_shares = {}
self._done_deferred = defer.Deferred() self._done_deferred = defer.Deferred()
@ -332,6 +363,8 @@ class Retrieve:
for shnum,datav in datavs.items(): for shnum,datav in datavs.items():
data = datav[0] data = datav[0]
self.log("_got_results: got shnum #%d from peerid %s"
% (shnum, idlib.shortnodeid_b2a(peerid)))
(seqnum, root_hash, IV, k, N, segsize, datalength, (seqnum, root_hash, IV, k, N, segsize, datalength,
pubkey_s, signature, prefix) = unpack_prefix_and_signature(data) pubkey_s, signature, prefix) = unpack_prefix_and_signature(data)
@ -339,7 +372,7 @@ class Retrieve:
fingerprint = hashutil.ssk_pubkey_fingerprint_hash(pubkey_s) fingerprint = hashutil.ssk_pubkey_fingerprint_hash(pubkey_s)
if fingerprint != self._node._fingerprint: if fingerprint != self._node._fingerprint:
# bad share # bad share
raise CorruptShareError(peerid, raise CorruptShareError(peerid, shnum,
"pubkey doesn't match fingerprint") "pubkey doesn't match fingerprint")
self._pubkey = self._deserialize_pubkey(pubkey_s) self._pubkey = self._deserialize_pubkey(pubkey_s)
self._node._populate_pubkey(self._pubkey) self._node._populate_pubkey(self._pubkey)
@ -349,7 +382,7 @@ class Retrieve:
# it's a new pair. Verify the signature. # it's a new pair. Verify the signature.
valid = self._pubkey.verify(prefix, signature) valid = self._pubkey.verify(prefix, signature)
if not valid: if not valid:
raise CorruptShareError(peerid, raise CorruptShareError(peerid, shnum,
"signature is invalid") "signature is invalid")
# ok, it's a valid verinfo. Add it to the list of validated # ok, it's a valid verinfo. Add it to the list of validated
# versions. # versions.
@ -372,6 +405,7 @@ class Retrieve:
# there's enough data present. If not, raise NeedMoreDataError, # there's enough data present. If not, raise NeedMoreDataError,
# which will trigger a re-fetch. # which will trigger a re-fetch.
_ignored = unpack_share(data) _ignored = unpack_share(data)
self.log(" found enough data to add share contents")
self._valid_versions[verinfo][1].add(shnum, (peerid, data)) self._valid_versions[verinfo][1].add(shnum, (peerid, data))
@ -391,9 +425,11 @@ class Retrieve:
self._bad_peerids.add(peerid) self._bad_peerids.add(peerid)
short_sid = idlib.b2a(self._storage_index)[:6] short_sid = idlib.b2a(self._storage_index)[:6]
if f.check(CorruptShareError): if f.check(CorruptShareError):
self.log("WEIRD: bad share for %s: %s" % (short_sid, f)) self.log("WEIRD: bad share for %s: %s %s" % (short_sid, f,
f.value))
else: else:
self.log("WEIRD: other error for %s: %s" % (short_sid, f)) self.log("WEIRD: other error for %s: %s %s" % (short_sid, f,
f.value))
def _check_for_done(self, res): def _check_for_done(self, res):
if not self._running: if not self._running:
@ -422,13 +458,17 @@ class Retrieve:
def _problem(f): def _problem(f):
self._last_failure = f self._last_failure = f
if f.check(CorruptShareError): if f.check(CorruptShareError):
# log(WEIRD) self.log("WEIRD: saw corrupt share, rescheduling")
# _attempt_decode is responsible for removing the bad # _attempt_decode is responsible for removing the bad
# share, so we can just try again # share, so we can just try again
eventually(self._check_for_done) eventually(self._check_for_done, None)
return return
return f return f
d.addCallbacks(self._done, _problem) d.addCallbacks(self._done, _problem)
# TODO: create an errback-routing mechanism to make sure that
# weird coding errors will cause the retrieval to fail rather
# than hanging forever. Any otherwise-unhandled exceptions
# should follow this path.
return return
# we don't have enough shares yet. Should we send out more queries? # we don't have enough shares yet. Should we send out more queries?
@ -478,13 +518,26 @@ class Retrieve:
# sharemap is a dict which maps shnum to [(peerid,data)..] sets. # sharemap is a dict which maps shnum to [(peerid,data)..] sets.
(seqnum, root_hash, IV, segsize, datalength) = verinfo (seqnum, root_hash, IV, segsize, datalength) = verinfo
assert len(sharemap) >= self._required_shares, len(sharemap)
shares_s = []
for shnum in sorted(sharemap.keys()):
for shareinfo in sharemap[shnum]:
shares_s.append("#%d" % shnum)
shares_s = ",".join(shares_s)
self.log("_attempt_decode: version %d-%s, shares: %s" %
(seqnum, idlib.b2a(root_hash)[:4], shares_s))
# first, validate each share that we haven't validated yet. We use # first, validate each share that we haven't validated yet. We use
# self._valid_shares to remember which ones we've already checked. # self._valid_shares to remember which ones we've already checked.
shares = {} shares = {}
for shnum, shareinfos in sharemap.items(): for shnum, shareinfos in sharemap.items():
assert len(shareinfos) > 0
for shareinfo in shareinfos: for shareinfo in shareinfos:
# have we already validated the hashes on this share?
if shareinfo not in self._valid_shares: if shareinfo not in self._valid_shares:
# nope: must check the hashes and extract the actual data
(peerid,data) = shareinfo (peerid,data) = shareinfo
try: try:
# The (seqnum+root_hash+IV) tuple for this share was # The (seqnum+root_hash+IV) tuple for this share was
@ -500,24 +553,36 @@ class Retrieve:
# validate the prefix on all shares) from using # validate the prefix on all shares) from using
# anything else in the share. # anything else in the share.
validator = self._validate_share_and_extract_data validator = self._validate_share_and_extract_data
sharedata = validator(root_hash, shnum, data) sharedata = validator(peerid, root_hash, shnum, data)
assert isinstance(sharedata, str) assert isinstance(sharedata, str)
except CorruptShareError, e: except CorruptShareError, e:
self.log("WEIRD: share was corrupt: %s" % e) self.log("WEIRD: share was corrupt: %s" % e)
sharemap[shnum].discard(shareinfo) sharemap[shnum].discard(shareinfo)
if not sharemap[shnum]:
# remove the key so the test in _check_for_done
# can accurately decide that we don't have enough
# shares to try again right now.
del sharemap[shnum]
# If there are enough remaining shares, # If there are enough remaining shares,
# _check_for_done() will try again # _check_for_done() will try again
raise raise
self._valid_shares.add(shareinfo) # share is valid: remember it so we won't need to check
shares[shnum] = sharedata # (or extract) it again
# at this point, all shares in the sharemap are valid, and they're self._valid_shares[shareinfo] = sharedata
# all for the same seqnum+root_hash version, so it's now down to
# doing FEC and decrypt. # the share is now in _valid_shares, so just copy over the
# sharedata
shares[shnum] = self._valid_shares[shareinfo]
# now that the big loop is done, all shares in the sharemap are
# valid, and they're all for the same seqnum+root_hash version, so
# it's now down to doing FEC and decrypt.
assert len(shares) >= self._required_shares, len(shares)
d = defer.maybeDeferred(self._decode, shares, segsize, datalength) d = defer.maybeDeferred(self._decode, shares, segsize, datalength)
d.addCallback(self._decrypt, IV, seqnum, root_hash) d.addCallback(self._decrypt, IV, seqnum, root_hash)
return d return d
def _validate_share_and_extract_data(self, root_hash, shnum, data): def _validate_share_and_extract_data(self, peerid, root_hash, shnum, data):
# 'data' is the whole SMDF share # 'data' is the whole SMDF share
self.log("_validate_share_and_extract_data[%d]" % shnum) self.log("_validate_share_and_extract_data[%d]" % shnum)
assert data[0] == "\x00" assert data[0] == "\x00"
@ -531,7 +596,7 @@ class Retrieve:
leaves = [hashutil.block_hash(share_data)] leaves = [hashutil.block_hash(share_data)]
t = hashtree.HashTree(leaves) t = hashtree.HashTree(leaves)
if list(t) != block_hash_tree: if list(t) != block_hash_tree:
raise CorruptShareError("block hash tree failure") raise CorruptShareError(peerid, shnum, "block hash tree failure")
share_hash_leaf = t[0] share_hash_leaf = t[0]
# t2 = hashtree.IncompleteHashTree() # t2 = hashtree.IncompleteHashTree()
# TODO: use shnum, share_hash_leaf, share_hash_chain to compare against # TODO: use shnum, share_hash_leaf, share_hash_chain to compare against
@ -553,6 +618,7 @@ class Retrieve:
shareids.append(shareid) shareids.append(shareid)
shares.append(share) shares.append(share)
assert len(shareids) >= self._required_shares, len(shareids)
# zfec really doesn't want extra shares # zfec really doesn't want extra shares
shareids = shareids[:self._required_shares] shareids = shareids[:self._required_shares]
shares = shares[:self._required_shares] shares = shares[:self._required_shares]
@ -650,7 +716,7 @@ class Publish:
# 4a: may need to run recovery algorithm # 4a: may need to run recovery algorithm
# 5: when enough responses are back, we're done # 5: when enough responses are back, we're done
self.log("got enough peers") self.log("starting publish, data is %r" % (newdata,))
self._storage_index = self._node.get_storage_index() self._storage_index = self._node.get_storage_index()
self._writekey = self._node.get_writekey() self._writekey = self._node.get_writekey()
@ -747,9 +813,12 @@ class Publish:
def _got_query_results(self, datavs, peerid, permutedid, def _got_query_results(self, datavs, peerid, permutedid,
reachable_peers, current_share_peers): reachable_peers, current_share_peers):
self.log("_got_query_results")
assert isinstance(datavs, dict) assert isinstance(datavs, dict)
reachable_peers[peerid] = permutedid reachable_peers[peerid] = permutedid
for shnum, datav in datavs.items(): for shnum, datav in datavs.items():
self.log(" peer has shnum %d" % shnum)
assert len(datav) == 1 assert len(datav) == 1
data = datav[0] data = datav[0]
# We want (seqnum, root_hash, IV) from all servers to know what # We want (seqnum, root_hash, IV) from all servers to know what
@ -999,29 +1068,14 @@ class Publish:
final_shares = {} final_shares = {}
for shnum in range(total_shares): for shnum in range(total_shares):
shc = share_hash_chain[shnum] final_share = pack_share(prefix,
share_hash_chain_s = "".join([struct.pack(">H32s", i, shc[i])
for i in sorted(shc.keys())])
bht = block_hash_trees[shnum]
for h in bht:
assert len(h) == 32
block_hash_tree_s = "".join(bht)
share_data = all_shares[shnum]
offsets = pack_offsets(len(verification_key),
len(signature),
len(share_hash_chain_s),
len(block_hash_tree_s),
len(share_data),
len(encprivkey))
final_shares[shnum] = "".join([prefix,
offsets,
verification_key, verification_key,
signature, signature,
share_hash_chain_s, share_hash_chain[shnum],
block_hash_tree_s, block_hash_trees[shnum],
share_data, all_shares[shnum],
encprivkey]) encprivkey)
final_shares[shnum] = final_share
return (seqnum, root_hash, final_shares, target_info) return (seqnum, root_hash, final_shares, target_info)

View File

@ -179,7 +179,8 @@ def dump_SDMF_share(offset, length, config, out, err):
print >>out, " total_shares: %d" % N print >>out, " total_shares: %d" % N
print >>out, " segsize: %d" % segsize print >>out, " segsize: %d" % segsize
print >>out, " datalen: %d" % datalen print >>out, " datalen: %d" % datalen
share_hash_ids = ",".join([str(hid) for (hid,hash) in share_hash_chain]) share_hash_ids = ",".join(sorted([str(hid)
for hid in share_hash_chain.keys()]))
print >>out, " share_hash_chain: %s" % share_hash_ids print >>out, " share_hash_chain: %s" % share_hash_ids
print >>out, " block_hash_tree: %d nodes" % len(block_hash_tree) print >>out, " block_hash_tree: %d nodes" % len(block_hash_tree)

View File

@ -264,13 +264,12 @@ class Publish(unittest.TestCase):
k, N, segsize, datalen) k, N, segsize, datalen)
self.failUnlessEqual(signature, self.failUnlessEqual(signature,
FakePrivKey(0).sign(sig_material)) FakePrivKey(0).sign(sig_material))
self.failUnless(isinstance(share_hash_chain, list)) self.failUnless(isinstance(share_hash_chain, dict))
self.failUnlessEqual(len(share_hash_chain), 4) # ln2(10)++ self.failUnlessEqual(len(share_hash_chain), 4) # ln2(10)++
for i in share_hash_chain: for shnum,share_hash in share_hash_chain.items():
self.failUnless(isinstance(i, tuple)) self.failUnless(isinstance(shnum, int))
self.failUnless(isinstance(i[0], int)) self.failUnless(isinstance(share_hash, str))
self.failUnless(isinstance(i[1], str)) self.failUnlessEqual(len(share_hash), 32)
self.failUnlessEqual(len(i[1]), 32)
self.failUnless(isinstance(block_hash_tree, list)) self.failUnless(isinstance(block_hash_tree, list))
self.failUnlessEqual(len(block_hash_tree), 1) # very small tree self.failUnlessEqual(len(block_hash_tree), 1) # very small tree
self.failUnlessEqual(IV, "IV"*8) self.failUnlessEqual(IV, "IV"*8)

View File

@ -6,7 +6,7 @@ from twisted.trial import unittest
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from twisted.internet import threads # CLI tests use deferToThread from twisted.internet import threads # CLI tests use deferToThread
from twisted.application import service from twisted.application import service
from allmydata import client, uri, download, upload from allmydata import client, uri, download, upload, storage, mutable
from allmydata.introducer import IntroducerNode from allmydata.introducer import IntroducerNode
from allmydata.util import deferredutil, fileutil, idlib, mathutil, testutil from allmydata.util import deferredutil, fileutil, idlib, mathutil, testutil
from allmydata.scripts import runner from allmydata.scripts import runner
@ -237,6 +237,72 @@ class SystemTest(testutil.SignalMixin, unittest.TestCase):
return d return d
test_upload_and_download.timeout = 4800 test_upload_and_download.timeout = 4800
def _find_shares(self, basedir):
shares = []
for (dirpath, dirnames, filenames) in os.walk(basedir):
if "storage" not in dirpath:
continue
if not filenames:
continue
pieces = dirpath.split(os.sep)
if pieces[-3] == "storage" and pieces[-2] == "shares":
# we're sitting in .../storage/shares/$SINDEX , and there
# are sharefiles here
assert pieces[-4].startswith("client")
client_num = int(pieces[-4][-1])
storage_index_s = pieces[-1]
storage_index = idlib.a2b(storage_index_s)
for sharename in filenames:
shnum = int(sharename)
filename = os.path.join(dirpath, sharename)
data = (client_num, storage_index, filename, shnum)
shares.append(data)
if not shares:
self.fail("unable to find any share files in %s" % basedir)
return shares
def _corrupt_mutable_share(self, filename, which):
msf = storage.MutableShareFile(filename)
datav = msf.readv([ (0, 1000000) ])
final_share = datav[0]
assert len(final_share) < 1000000 # ought to be truncated
pieces = mutable.unpack_share(final_share)
(seqnum, root_hash, IV, k, N, segsize, datalen,
verification_key, signature, share_hash_chain, block_hash_tree,
share_data, enc_privkey) = pieces
if which == "seqnum":
seqnum = seqnum + 15
elif which == "R":
root_hash = self.flip_bit(root_hash)
elif which == "IV":
IV = self.flip_bit(IV)
elif which == "segsize":
segsize = segsize + 15
elif which == "pubkey":
verification_key = self.flip_bit(verification_key)
elif which == "signature":
signature = self.flip_bit(signature)
elif which == "share_hash_chain":
nodenum = share_hash_chain.keys()[0]
share_hash_chain[nodenum] = self.flip_bit(share_hash_chain[nodenum])
elif which == "block_hash_tree":
block_hash_tree[-1] = self.flip_bit(block_hash_tree[-1])
elif which == "share_data":
share_data = self.flip_bit(share_data)
elif which == "encprivkey":
enc_privkey = self.flip_bit(enc_privkey)
prefix = mutable.pack_prefix(seqnum, root_hash, IV, k, N,
segsize, datalen)
final_share = mutable.pack_share(prefix,
verification_key,
signature,
share_hash_chain,
block_hash_tree,
share_data,
enc_privkey)
msf.writev( [(0, final_share)], None)
def test_mutable(self): def test_mutable(self):
self.basedir = "system/SystemTest/test_mutable" self.basedir = "system/SystemTest/test_mutable"
@ -260,22 +326,8 @@ class SystemTest(testutil.SignalMixin, unittest.TestCase):
def _test_debug(res): def _test_debug(res):
# find a share. It is important to run this while there is only # find a share. It is important to run this while there is only
# one slot in the grid. # one slot in the grid.
for (dirpath, dirnames, filenames) in os.walk(self.basedir): shares = self._find_shares(self.basedir)
if "storage" not in dirpath: (client_num, storage_index, filename, shnum) = shares[0]
continue
if not filenames:
continue
pieces = dirpath.split(os.sep)
if pieces[-3] == "storage" and pieces[-2] == "shares":
# we're sitting in .../storage/shares/$SINDEX , and there
# are sharefiles here
assert pieces[-4].startswith("client")
client_num = int(pieces[-4][-1])
filename = os.path.join(dirpath, filenames[0])
break
else:
self.fail("unable to find any share files in %s"
% self.basedir)
log.msg("test_system.SystemTest.test_mutable._test_debug using %s" log.msg("test_system.SystemTest.test_mutable._test_debug using %s"
% filename) % filename)
log.msg(" for clients[%d]" % client_num) log.msg(" for clients[%d]" % client_num)
@ -367,6 +419,7 @@ class SystemTest(testutil.SignalMixin, unittest.TestCase):
uri = self._mutable_node_1.get_uri() uri = self._mutable_node_1.get_uri()
newnode1 = self.clients[2].create_node_from_uri(uri) newnode1 = self.clients[2].create_node_from_uri(uri)
newnode2 = self.clients[3].create_node_from_uri(uri) newnode2 = self.clients[3].create_node_from_uri(uri)
self._newnode3 = self.clients[3].create_node_from_uri(uri)
log.msg("starting replace2") log.msg("starting replace2")
d1 = newnode1.replace(NEWERDATA, wait_for_numpeers=self.numclients) d1 = newnode1.replace(NEWERDATA, wait_for_numpeers=self.numclients)
d1.addCallback(lambda res: newnode2.download_to_data()) d1.addCallback(lambda res: newnode2.download_to_data())
@ -376,13 +429,13 @@ class SystemTest(testutil.SignalMixin, unittest.TestCase):
def _check_download_5(res): def _check_download_5(res):
log.msg("finished replace2") log.msg("finished replace2")
self.failUnlessEqual(res, NEWERDATA) self.failUnlessEqual(res, NEWERDATA)
# Make sure we can create empty files -- this can screw up the # make sure we can create empty files, this usually screws up the
# segsize math. # segsize math
d1 = self.clients[2].create_mutable_file("", wait_for_numpeers=self.numclients) d1 = self.clients[2].create_mutable_file("")
d1.addCallback(lambda newnode: newnode.download_to_data()) d1.addCallback(lambda newnode: newnode.download_to_data())
d1.addCallback(lambda res: self.failUnlessEqual("", res)) d1.addCallback(lambda res: self.failUnlessEqual("", res))
return d1 return d1
d.addCallback(_check_download_5) d.addCallback(_check_empty_file)
d.addCallback(lambda res: self.clients[0].create_empty_dirnode(wait_for_numpeers=self.numclients)) d.addCallback(lambda res: self.clients[0].create_empty_dirnode(wait_for_numpeers=self.numclients))
def _created_dirnode(dnode): def _created_dirnode(dnode):
@ -394,6 +447,9 @@ class SystemTest(testutil.SignalMixin, unittest.TestCase):
d1.addCallback(lambda res: dnode.set_node("see recursive", dnode, wait_for_numpeers=self.numclients)) d1.addCallback(lambda res: dnode.set_node("see recursive", dnode, wait_for_numpeers=self.numclients))
d1.addCallback(lambda res: dnode.has_child("see recursive")) d1.addCallback(lambda res: dnode.has_child("see recursive"))
d1.addCallback(lambda answer: self.failUnlessEqual(answer, True)) d1.addCallback(lambda answer: self.failUnlessEqual(answer, True))
d1.addCallback(lambda res: dnode.build_manifest())
d1.addCallback(lambda manifest:
self.failUnlessEqual(len(manifest), 1))
return d1 return d1
d.addCallback(_created_dirnode) d.addCallback(_created_dirnode)