hashtree: fix O(N**2) behavior, to improve fatal alacrity problems in a 10GB file (#670). Also improve docstring.

This commit is contained in:
Brian Warner 2009-03-31 13:21:27 -07:00
parent 438bc67548
commit 466014f66f
2 changed files with 137 additions and 71 deletions

View File

@ -162,6 +162,15 @@ class CompleteBinaryTreeMixin:
def get_leaf(self, leafnum):
return self[self.first_leaf_num + leafnum]
def depth_of(self, i):
"""Return the depth or level of the given node. Level 0 contains node
Level 1 contains nodes 1 and 2. Level 2 contains nodes 3,4,5,6."""
depth = 0
while i != 0:
depth += 1
i = self.parent(i)
return depth
def empty_leaf_hash(i):
return tagged_hash('Merkle tree empty leaf', "%d" % i)
def pair_hash(a, b):
@ -337,27 +346,30 @@ class IncompleteHashTree(CompleteBinaryTreeMixin, list):
from 0 (the root of the tree) to 2*num_leaves-2 (the right-most
leaf). leaf[i] is the same as hash[num_leaves-1+i].
The best way to use me is to obtain the root hash from some 'good'
channel, and use the 'bad' channel to obtain data block 0 and the
The best way to use me is to start by obtaining the root hash from
some 'good' channel and populate me with it:
iht = IncompleteHashTree(numleaves)
roothash = trusted_channel.get_roothash()
iht.set_hashes(hashes={0: roothash})
Then use the 'bad' channel to obtain data block 0 and the
corresponding hash chain (a dict with the same hashes that
needed_hashes(0) tells you, e.g. {0:h0, 2:h2, 4:h4, 8:h8} when
len(L)=8). Hash the data block to create leaf0, then feed everything
into set_hashes() and see if it raises an exception or not::
iht = IncompleteHashTree(numleaves)
roothash = trusted_channel.get_roothash()
otherhashes = untrusted_channel.get_hashes()
# otherhashes.keys() should == iht.needed_hashes(leaves=[0])
datablock0 = untrusted_channel.get_data(0)
leaf0 = HASH(datablock0)
# HASH() is probably hashutil.tagged_hash(tag, datablock0)
hashes = otherhashes.copy()
hashes[0] = roothash # from 'good' channel
iht.set_hashes(hashes, leaves={0: leaf0})
otherhashes = untrusted_channel.get_hashes()
# otherhashes.keys() should == iht.needed_hashes(leaves=[0])
datablock0 = untrusted_channel.get_data(0)
leaf0 = HASH(datablock0)
# HASH() is probably hashutil.tagged_hash(tag, datablock0)
iht.set_hashes(otherhashes, leaves={0: leaf0})
If the set_hashes() call doesn't raise an exception, the data block
was valid. If it raises BadHashError, then either the data block was
corrupted or one of the received hashes was corrupted.
corrupted or one of the received hashes was corrupted. If it raises
NotEnoughHashesError, then the otherhashes dictionary was incomplete.
"""
assert isinstance(hashes, dict)
@ -376,73 +388,87 @@ class IncompleteHashTree(CompleteBinaryTreeMixin, list):
% (leafnum, hashnum))
new_hashes[hashnum] = leafhash
added = set() # we'll remove these if the check fails
remove_upon_failure = set() # we'll remove these if the check fails
# visualize this method in the following way:
# A: start with the empty or partially-populated tree as shown in
# the HashTree docstring
# B: add all of our input hashes to the tree, filling in some of the
# holes. Don't overwrite anything, but new values must equal the
# existing ones. Mark everything that was added with a red dot
# (meaning "not yet validated")
# C: start with the lowest/deepest level. Pick any red-dotted node,
# hash it with its sibling to compute the parent hash. Add the
# parent to the tree just like in step B (if the parent already
# exists, the values must be equal; if not, add our computed
# value with a red dot). If we have no sibling, throw
# NotEnoughHashesError, since we won't be able to validate this
# node. Remove the red dot. If there was a red dot on our
# sibling, remove it too.
# D: finish all red-dotted nodes in one level before moving up to
# the next.
# E: if we hit NotEnoughHashesError or BadHashError before getting
# to the root, discard every hash we've added.
try:
num_levels = self.depth_of(len(self)-1)
# hashes_to_check[level] is set(index). This holds the "red dots"
# described above
hashes_to_check = [set() for level in range(num_levels+1)]
# first we provisionally add all hashes to the tree, comparing
# any duplicates
for i in new_hashes:
for i,h in new_hashes.iteritems():
level = self.depth_of(i)
hashes_to_check[level].add(i)
if self[i]:
if self[i] != new_hashes[i]:
msg = "new hash %s does not match existing hash %s at " % (base32.b2a(new_hashes[i]), base32.b2a(self[i]))
msg += self._name_hash(i)
raise BadHashError(msg)
if self[i] != h:
raise BadHashError("new hash %s does not match "
"existing hash %s at %s"
% (base32.b2a(h),
base32.b2a(self[i]),
self._name_hash(i)))
else:
self[i] = new_hashes[i]
added.add(i)
self[i] = h
remove_upon_failure.add(i)
# then we start from the bottom and compute new parent hashes
# upwards, comparing any that already exist. When this phase
# ends, all nodes that have a sibling will also have a parent.
for level in reversed(range(len(hashes_to_check))):
this_level = hashes_to_check[level]
while this_level:
i = this_level.pop()
if i == 0:
# The root has no sibling. How lonely. TODO: consider
# setting the root in our constructor, then throw
# NotEnoughHashesError here, because if we've
# generated the root from below, we don't have
# anything to validate it against.
continue
siblingnum = self.sibling(i)
if self[siblingnum] is None:
# without a sibling, we can't compute a parent, and
# we can't verify this node
raise NotEnoughHashesError("unable to validate [%d]"%i)
parentnum = self.parent(i)
# make sure we know right from left
leftnum, rightnum = sorted([i, siblingnum])
new_parent_hash = pair_hash(self[leftnum], self[rightnum])
if self[parentnum]:
if self[parentnum] != new_parent_hash:
raise BadHashError("h([%d]+[%d]) != h[%d]" %
(leftnum, rightnum, parentnum))
else:
self[parentnum] = new_parent_hash
remove_upon_failure.add(parentnum)
parent_level = self.depth_of(parentnum)
assert parent_level == level-1
hashes_to_check[parent_level].add(parentnum)
hashes_to_check = list(new_hashes.keys())
# leaf-most first means reverse sorted order
while hashes_to_check:
hashes_to_check.sort()
i = hashes_to_check.pop(-1)
if i == 0:
# The root has no sibling. How lonely.
continue
if self[self.sibling(i)] is None:
# without a sibling, we can't compute a parent
continue
parentnum = self.parent(i)
# make sure we know right from left
leftnum, rightnum = sorted([i, self.sibling(i)])
new_parent_hash = pair_hash(self[leftnum], self[rightnum])
if self[parentnum]:
if self[parentnum] != new_parent_hash:
raise BadHashError("h([%d]+[%d]) != h[%d]" %
(leftnum, rightnum, parentnum))
else:
self[parentnum] = new_parent_hash
added.add(parentnum)
hashes_to_check.insert(0, parentnum)
# then we walk downwards from the top (root), and anything that
# is reachable is validated. If any of the hashes that we've
# added are unreachable, then they are unvalidated.
reachable = set()
if self[0]:
reachable.add(0)
# TODO: this could be done more efficiently, by starting from
# each element of new_hashes and walking upwards instead,
# remembering a set of validated nodes so that the searches for
# later new_hashes goes faster. This approach is O(n), whereas
# O(ln(n)) should be feasible.
for i in range(1, len(self)):
if self[i] and self.parent(i) in reachable:
reachable.add(i)
# were we unable to validate any of the new hashes?
unvalidated = set(new_hashes.keys()) - reachable
if unvalidated:
those = ",".join([str(i) for i in sorted(unvalidated)])
raise NotEnoughHashesError("unable to validate hashes %s"
% those)
# our sibling is now as valid as this node
this_level.discard(siblingnum)
# we're done!
except (BadHashError, NotEnoughHashesError):
for i in added:
for i in remove_upon_failure:
self[i] = None
raise

View File

@ -80,6 +80,46 @@ class Incomplete(unittest.TestCase):
self.failUnlessEqual(ht.needed_hashes(5, False), set([11, 6, 1]))
self.failUnlessEqual(ht.needed_hashes(5, True), set([12, 11, 6, 1]))
def test_depth_of(self):
ht = hashtree.IncompleteHashTree(8)
self.failUnlessEqual(ht.depth_of(0), 0)
for i in [1,2]:
self.failUnlessEqual(ht.depth_of(i), 1, "i=%d"%i)
for i in [3,4,5,6]:
self.failUnlessEqual(ht.depth_of(i), 2, "i=%d"%i)
for i in [7,8,9,10,11,12,13,14]:
self.failUnlessEqual(ht.depth_of(i), 3, "i=%d"%i)
self.failUnlessRaises(IndexError, ht.depth_of, 15)
def test_large(self):
# IncompleteHashTree.set_hashes() used to take O(N**2). This test is
# meant to show that it now takes O(N) or maybe O(N*ln(N)). I wish
# there were a good way to assert this (like counting VM operations
# or something): the problem was inside list.sort(), so there's no
# good way to instrument set_hashes() to count what we care about. On
# my laptop, 10k leaves takes 1.1s in this fixed version, and 11.6s
# in the old broken version. An 80k-leaf test (corresponding to a
# 10GB file with a 128KiB segsize) 10s in the fixed version, and
# several hours in the broken version, but 10s on my laptop (plus the
# 20s of setup code) probably means 200s on our dapper buildslave,
# which is painfully long for a unit test.
self.do_test_speed(10000)
def do_test_speed(self, SIZE):
# on my laptop, SIZE=80k (corresponding to a 10GB file with a 128KiB
# segsize) takes:
# 7s to build the (complete) HashTree
# 13s to set up the dictionary
# 10s to run set_hashes()
ht = make_tree(SIZE)
iht = hashtree.IncompleteHashTree(SIZE)
needed = set()
for i in range(SIZE):
needed.update(ht.needed_hashes(i, True))
all = dict([ (i, ht[i]) for i in needed])
iht.set_hashes(hashes=all)
def test_check(self):
# first create a complete hash tree
ht = make_tree(6)