mirror of
https://github.com/tahoe-lafs/tahoe-lafs.git
synced 2025-04-14 06:06:40 +00:00
hashtree: fix O(N**2) behavior, to improve fatal alacrity problems in a 10GB file (#670). Also improve docstring.
This commit is contained in:
parent
438bc67548
commit
466014f66f
@ -162,6 +162,15 @@ class CompleteBinaryTreeMixin:
|
||||
def get_leaf(self, leafnum):
|
||||
return self[self.first_leaf_num + leafnum]
|
||||
|
||||
def depth_of(self, i):
|
||||
"""Return the depth or level of the given node. Level 0 contains node
|
||||
Level 1 contains nodes 1 and 2. Level 2 contains nodes 3,4,5,6."""
|
||||
depth = 0
|
||||
while i != 0:
|
||||
depth += 1
|
||||
i = self.parent(i)
|
||||
return depth
|
||||
|
||||
def empty_leaf_hash(i):
|
||||
return tagged_hash('Merkle tree empty leaf', "%d" % i)
|
||||
def pair_hash(a, b):
|
||||
@ -337,27 +346,30 @@ class IncompleteHashTree(CompleteBinaryTreeMixin, list):
|
||||
from 0 (the root of the tree) to 2*num_leaves-2 (the right-most
|
||||
leaf). leaf[i] is the same as hash[num_leaves-1+i].
|
||||
|
||||
The best way to use me is to obtain the root hash from some 'good'
|
||||
channel, and use the 'bad' channel to obtain data block 0 and the
|
||||
The best way to use me is to start by obtaining the root hash from
|
||||
some 'good' channel and populate me with it:
|
||||
|
||||
iht = IncompleteHashTree(numleaves)
|
||||
roothash = trusted_channel.get_roothash()
|
||||
iht.set_hashes(hashes={0: roothash})
|
||||
|
||||
Then use the 'bad' channel to obtain data block 0 and the
|
||||
corresponding hash chain (a dict with the same hashes that
|
||||
needed_hashes(0) tells you, e.g. {0:h0, 2:h2, 4:h4, 8:h8} when
|
||||
len(L)=8). Hash the data block to create leaf0, then feed everything
|
||||
into set_hashes() and see if it raises an exception or not::
|
||||
|
||||
iht = IncompleteHashTree(numleaves)
|
||||
roothash = trusted_channel.get_roothash()
|
||||
otherhashes = untrusted_channel.get_hashes()
|
||||
# otherhashes.keys() should == iht.needed_hashes(leaves=[0])
|
||||
datablock0 = untrusted_channel.get_data(0)
|
||||
leaf0 = HASH(datablock0)
|
||||
# HASH() is probably hashutil.tagged_hash(tag, datablock0)
|
||||
hashes = otherhashes.copy()
|
||||
hashes[0] = roothash # from 'good' channel
|
||||
iht.set_hashes(hashes, leaves={0: leaf0})
|
||||
otherhashes = untrusted_channel.get_hashes()
|
||||
# otherhashes.keys() should == iht.needed_hashes(leaves=[0])
|
||||
datablock0 = untrusted_channel.get_data(0)
|
||||
leaf0 = HASH(datablock0)
|
||||
# HASH() is probably hashutil.tagged_hash(tag, datablock0)
|
||||
iht.set_hashes(otherhashes, leaves={0: leaf0})
|
||||
|
||||
If the set_hashes() call doesn't raise an exception, the data block
|
||||
was valid. If it raises BadHashError, then either the data block was
|
||||
corrupted or one of the received hashes was corrupted.
|
||||
corrupted or one of the received hashes was corrupted. If it raises
|
||||
NotEnoughHashesError, then the otherhashes dictionary was incomplete.
|
||||
"""
|
||||
|
||||
assert isinstance(hashes, dict)
|
||||
@ -376,73 +388,87 @@ class IncompleteHashTree(CompleteBinaryTreeMixin, list):
|
||||
% (leafnum, hashnum))
|
||||
new_hashes[hashnum] = leafhash
|
||||
|
||||
added = set() # we'll remove these if the check fails
|
||||
remove_upon_failure = set() # we'll remove these if the check fails
|
||||
|
||||
# visualize this method in the following way:
|
||||
# A: start with the empty or partially-populated tree as shown in
|
||||
# the HashTree docstring
|
||||
# B: add all of our input hashes to the tree, filling in some of the
|
||||
# holes. Don't overwrite anything, but new values must equal the
|
||||
# existing ones. Mark everything that was added with a red dot
|
||||
# (meaning "not yet validated")
|
||||
# C: start with the lowest/deepest level. Pick any red-dotted node,
|
||||
# hash it with its sibling to compute the parent hash. Add the
|
||||
# parent to the tree just like in step B (if the parent already
|
||||
# exists, the values must be equal; if not, add our computed
|
||||
# value with a red dot). If we have no sibling, throw
|
||||
# NotEnoughHashesError, since we won't be able to validate this
|
||||
# node. Remove the red dot. If there was a red dot on our
|
||||
# sibling, remove it too.
|
||||
# D: finish all red-dotted nodes in one level before moving up to
|
||||
# the next.
|
||||
# E: if we hit NotEnoughHashesError or BadHashError before getting
|
||||
# to the root, discard every hash we've added.
|
||||
|
||||
try:
|
||||
num_levels = self.depth_of(len(self)-1)
|
||||
# hashes_to_check[level] is set(index). This holds the "red dots"
|
||||
# described above
|
||||
hashes_to_check = [set() for level in range(num_levels+1)]
|
||||
|
||||
# first we provisionally add all hashes to the tree, comparing
|
||||
# any duplicates
|
||||
for i in new_hashes:
|
||||
for i,h in new_hashes.iteritems():
|
||||
level = self.depth_of(i)
|
||||
hashes_to_check[level].add(i)
|
||||
|
||||
if self[i]:
|
||||
if self[i] != new_hashes[i]:
|
||||
msg = "new hash %s does not match existing hash %s at " % (base32.b2a(new_hashes[i]), base32.b2a(self[i]))
|
||||
msg += self._name_hash(i)
|
||||
raise BadHashError(msg)
|
||||
if self[i] != h:
|
||||
raise BadHashError("new hash %s does not match "
|
||||
"existing hash %s at %s"
|
||||
% (base32.b2a(h),
|
||||
base32.b2a(self[i]),
|
||||
self._name_hash(i)))
|
||||
else:
|
||||
self[i] = new_hashes[i]
|
||||
added.add(i)
|
||||
self[i] = h
|
||||
remove_upon_failure.add(i)
|
||||
|
||||
# then we start from the bottom and compute new parent hashes
|
||||
# upwards, comparing any that already exist. When this phase
|
||||
# ends, all nodes that have a sibling will also have a parent.
|
||||
for level in reversed(range(len(hashes_to_check))):
|
||||
this_level = hashes_to_check[level]
|
||||
while this_level:
|
||||
i = this_level.pop()
|
||||
if i == 0:
|
||||
# The root has no sibling. How lonely. TODO: consider
|
||||
# setting the root in our constructor, then throw
|
||||
# NotEnoughHashesError here, because if we've
|
||||
# generated the root from below, we don't have
|
||||
# anything to validate it against.
|
||||
continue
|
||||
siblingnum = self.sibling(i)
|
||||
if self[siblingnum] is None:
|
||||
# without a sibling, we can't compute a parent, and
|
||||
# we can't verify this node
|
||||
raise NotEnoughHashesError("unable to validate [%d]"%i)
|
||||
parentnum = self.parent(i)
|
||||
# make sure we know right from left
|
||||
leftnum, rightnum = sorted([i, siblingnum])
|
||||
new_parent_hash = pair_hash(self[leftnum], self[rightnum])
|
||||
if self[parentnum]:
|
||||
if self[parentnum] != new_parent_hash:
|
||||
raise BadHashError("h([%d]+[%d]) != h[%d]" %
|
||||
(leftnum, rightnum, parentnum))
|
||||
else:
|
||||
self[parentnum] = new_parent_hash
|
||||
remove_upon_failure.add(parentnum)
|
||||
parent_level = self.depth_of(parentnum)
|
||||
assert parent_level == level-1
|
||||
hashes_to_check[parent_level].add(parentnum)
|
||||
|
||||
hashes_to_check = list(new_hashes.keys())
|
||||
# leaf-most first means reverse sorted order
|
||||
while hashes_to_check:
|
||||
hashes_to_check.sort()
|
||||
i = hashes_to_check.pop(-1)
|
||||
if i == 0:
|
||||
# The root has no sibling. How lonely.
|
||||
continue
|
||||
if self[self.sibling(i)] is None:
|
||||
# without a sibling, we can't compute a parent
|
||||
continue
|
||||
parentnum = self.parent(i)
|
||||
# make sure we know right from left
|
||||
leftnum, rightnum = sorted([i, self.sibling(i)])
|
||||
new_parent_hash = pair_hash(self[leftnum], self[rightnum])
|
||||
if self[parentnum]:
|
||||
if self[parentnum] != new_parent_hash:
|
||||
raise BadHashError("h([%d]+[%d]) != h[%d]" %
|
||||
(leftnum, rightnum, parentnum))
|
||||
else:
|
||||
self[parentnum] = new_parent_hash
|
||||
added.add(parentnum)
|
||||
hashes_to_check.insert(0, parentnum)
|
||||
|
||||
# then we walk downwards from the top (root), and anything that
|
||||
# is reachable is validated. If any of the hashes that we've
|
||||
# added are unreachable, then they are unvalidated.
|
||||
|
||||
reachable = set()
|
||||
if self[0]:
|
||||
reachable.add(0)
|
||||
# TODO: this could be done more efficiently, by starting from
|
||||
# each element of new_hashes and walking upwards instead,
|
||||
# remembering a set of validated nodes so that the searches for
|
||||
# later new_hashes goes faster. This approach is O(n), whereas
|
||||
# O(ln(n)) should be feasible.
|
||||
for i in range(1, len(self)):
|
||||
if self[i] and self.parent(i) in reachable:
|
||||
reachable.add(i)
|
||||
|
||||
# were we unable to validate any of the new hashes?
|
||||
unvalidated = set(new_hashes.keys()) - reachable
|
||||
if unvalidated:
|
||||
those = ",".join([str(i) for i in sorted(unvalidated)])
|
||||
raise NotEnoughHashesError("unable to validate hashes %s"
|
||||
% those)
|
||||
# our sibling is now as valid as this node
|
||||
this_level.discard(siblingnum)
|
||||
# we're done!
|
||||
|
||||
except (BadHashError, NotEnoughHashesError):
|
||||
for i in added:
|
||||
for i in remove_upon_failure:
|
||||
self[i] = None
|
||||
raise
|
||||
|
@ -80,6 +80,46 @@ class Incomplete(unittest.TestCase):
|
||||
self.failUnlessEqual(ht.needed_hashes(5, False), set([11, 6, 1]))
|
||||
self.failUnlessEqual(ht.needed_hashes(5, True), set([12, 11, 6, 1]))
|
||||
|
||||
def test_depth_of(self):
|
||||
ht = hashtree.IncompleteHashTree(8)
|
||||
self.failUnlessEqual(ht.depth_of(0), 0)
|
||||
for i in [1,2]:
|
||||
self.failUnlessEqual(ht.depth_of(i), 1, "i=%d"%i)
|
||||
for i in [3,4,5,6]:
|
||||
self.failUnlessEqual(ht.depth_of(i), 2, "i=%d"%i)
|
||||
for i in [7,8,9,10,11,12,13,14]:
|
||||
self.failUnlessEqual(ht.depth_of(i), 3, "i=%d"%i)
|
||||
self.failUnlessRaises(IndexError, ht.depth_of, 15)
|
||||
|
||||
def test_large(self):
|
||||
# IncompleteHashTree.set_hashes() used to take O(N**2). This test is
|
||||
# meant to show that it now takes O(N) or maybe O(N*ln(N)). I wish
|
||||
# there were a good way to assert this (like counting VM operations
|
||||
# or something): the problem was inside list.sort(), so there's no
|
||||
# good way to instrument set_hashes() to count what we care about. On
|
||||
# my laptop, 10k leaves takes 1.1s in this fixed version, and 11.6s
|
||||
# in the old broken version. An 80k-leaf test (corresponding to a
|
||||
# 10GB file with a 128KiB segsize) 10s in the fixed version, and
|
||||
# several hours in the broken version, but 10s on my laptop (plus the
|
||||
# 20s of setup code) probably means 200s on our dapper buildslave,
|
||||
# which is painfully long for a unit test.
|
||||
self.do_test_speed(10000)
|
||||
|
||||
def do_test_speed(self, SIZE):
|
||||
# on my laptop, SIZE=80k (corresponding to a 10GB file with a 128KiB
|
||||
# segsize) takes:
|
||||
# 7s to build the (complete) HashTree
|
||||
# 13s to set up the dictionary
|
||||
# 10s to run set_hashes()
|
||||
ht = make_tree(SIZE)
|
||||
iht = hashtree.IncompleteHashTree(SIZE)
|
||||
|
||||
needed = set()
|
||||
for i in range(SIZE):
|
||||
needed.update(ht.needed_hashes(i, True))
|
||||
all = dict([ (i, ht[i]) for i in needed])
|
||||
iht.set_hashes(hashes=all)
|
||||
|
||||
def test_check(self):
|
||||
# first create a complete hash tree
|
||||
ht = make_tree(6)
|
||||
|
Loading…
x
Reference in New Issue
Block a user