tahoe-lafs/src/allmydata/download.py

719 lines
29 KiB
Python

import os, random
from zope.interface import implements
from twisted.python import log
from twisted.internet import defer
from twisted.application import service
from allmydata.util import idlib, mathutil, hashutil
from allmydata.util.assertutil import _assert
from allmydata import codec, hashtree, storage, uri
from allmydata.Crypto.Cipher import AES
from allmydata.interfaces import IDownloadTarget, IDownloader, IFileURI
from allmydata.encode import NotEnoughPeersError
class HaveAllPeersError(Exception):
# we use this to jump out of the loop
pass
class BadURIExtensionHashValue(Exception):
pass
class BadPlaintextHashValue(Exception):
pass
class BadCrypttextHashValue(Exception):
pass
class Output:
def __init__(self, downloadable, key, total_length):
self.downloadable = downloadable
self._decryptor = AES.new(key=key, mode=AES.MODE_CTR,
counterstart="\x00"*16)
self._crypttext_hasher = hashutil.crypttext_hasher()
self._plaintext_hasher = hashutil.plaintext_hasher()
self.length = 0
self.total_length = total_length
self._segment_number = 0
self._plaintext_hash_tree = None
self._crypttext_hash_tree = None
self._opened = False
def setup_hashtrees(self, plaintext_hashtree, crypttext_hashtree):
self._plaintext_hash_tree = plaintext_hashtree
self._crypttext_hash_tree = crypttext_hashtree
def write_segment(self, crypttext):
self.length += len(crypttext)
# memory footprint: 'crypttext' is the only segment_size usage
# outstanding. While we decrypt it into 'plaintext', we hit
# 2*segment_size.
self._crypttext_hasher.update(crypttext)
if self._crypttext_hash_tree:
ch = hashutil.crypttext_segment_hasher()
ch.update(crypttext)
crypttext_leaves = {self._segment_number: ch.digest()}
self._crypttext_hash_tree.set_hashes(leaves=crypttext_leaves)
plaintext = self._decryptor.decrypt(crypttext)
del crypttext
# now we're back down to 1*segment_size.
self._plaintext_hasher.update(plaintext)
if self._plaintext_hash_tree:
ph = hashutil.plaintext_segment_hasher()
ph.update(plaintext)
plaintext_leaves = {self._segment_number: ph.digest()}
self._plaintext_hash_tree.set_hashes(leaves=plaintext_leaves)
self._segment_number += 1
# We're still at 1*segment_size. The Downloadable is responsible for
# any memory usage beyond this.
if not self._opened:
self._opened = True
self.downloadable.open(self.total_length)
self.downloadable.write(plaintext)
def fail(self, why):
log.msg("UNUSUAL: download failed: %s" % why)
self.downloadable.fail(why)
def close(self):
self.crypttext_hash = self._crypttext_hasher.digest()
self.plaintext_hash = self._plaintext_hasher.digest()
self.downloadable.close()
def finish(self):
return self.downloadable.finish()
class ValidatedBucket:
"""I am a front-end for a remote storage bucket, responsible for
retrieving and validating data from that bucket.
My get_block() method is used by BlockDownloaders.
"""
def __init__(self, sharenum, bucket,
share_hash_tree, roothash,
num_blocks):
self.sharenum = sharenum
self.bucket = bucket
self._share_hash = None # None means not validated yet
self.share_hash_tree = share_hash_tree
self._roothash = roothash
self.block_hash_tree = hashtree.IncompleteHashTree(num_blocks)
self.started = False
def get_block(self, blocknum):
if not self.started:
d = self.bucket.start()
def _started(res):
self.started = True
return self.get_block(blocknum)
d.addCallback(_started)
return d
# the first time we use this bucket, we need to fetch enough elements
# of the share hash tree to validate it from our share hash up to the
# hashroot.
if not self._share_hash:
d1 = self.bucket.get_share_hashes()
else:
d1 = defer.succeed([])
# we might need to grab some elements of our block hash tree, to
# validate the requested block up to the share hash
needed = self.block_hash_tree.needed_hashes(blocknum)
if needed:
# TODO: get fewer hashes, use get_block_hashes(needed)
d2 = self.bucket.get_block_hashes()
else:
d2 = defer.succeed([])
d3 = self.bucket.get_block(blocknum)
d = defer.gatherResults([d1, d2, d3])
d.addCallback(self._got_data, blocknum)
return d
def _got_data(self, res, blocknum):
sharehashes, blockhashes, blockdata = res
try:
if not self._share_hash:
sh = dict(sharehashes)
sh[0] = self._roothash # always use our own root, from the URI
sht = self.share_hash_tree
if sht.get_leaf_index(self.sharenum) not in sh:
raise hashtree.NotEnoughHashesError
sht.set_hashes(sh)
self._share_hash = sht.get_leaf(self.sharenum)
blockhash = hashutil.block_hash(blockdata)
#log.msg("checking block_hash(shareid=%d, blocknum=%d) len=%d "
# "%r .. %r: %s" %
# (self.sharenum, blocknum, len(blockdata),
# blockdata[:50], blockdata[-50:], idlib.b2a(blockhash)))
# we always validate the blockhash
bh = dict(enumerate(blockhashes))
# replace blockhash root with validated value
bh[0] = self._share_hash
self.block_hash_tree.set_hashes(bh, {blocknum: blockhash})
except (hashtree.BadHashError, hashtree.NotEnoughHashesError):
# log.WEIRD: indicates undetected disk/network error, or more
# likely a programming error
log.msg("hash failure in block=%d, shnum=%d on %s" %
(blocknum, self.sharenum, self.bucket))
if self._share_hash:
log.msg(""" failure occurred when checking the block_hash_tree.
This suggests that either the block data was bad, or that the
block hashes we received along with it were bad.""")
else:
log.msg(""" the failure probably occurred when checking the
share_hash_tree, which suggests that the share hashes we
received from the remote peer were bad.""")
log.msg(" have self._share_hash: %s" % bool(self._share_hash))
log.msg(" block length: %d" % len(blockdata))
log.msg(" block hash: %s" % idlib.b2a_or_none(blockhash)) # not safe
if len(blockdata) < 100:
log.msg(" block data: %r" % (blockdata,))
else:
log.msg(" block data start/end: %r .. %r" %
(blockdata[:50], blockdata[-50:]))
log.msg(" root hash: %s" % idlib.b2a(self._roothash))
log.msg(" share hash tree:\n" + self.share_hash_tree.dump())
log.msg(" block hash tree:\n" + self.block_hash_tree.dump())
lines = []
for i,h in sorted(sharehashes):
lines.append("%3d: %s" % (i, idlib.b2a_or_none(h)))
log.msg(" sharehashes:\n" + "\n".join(lines) + "\n")
lines = []
for i,h in enumerate(blockhashes):
lines.append("%3d: %s" % (i, idlib.b2a_or_none(h)))
log.msg(" blockhashes:\n" + "\n".join(lines) + "\n")
raise
# If we made it here, the block is good. If the hash trees didn't
# like what they saw, they would have raised a BadHashError, causing
# our caller to see a Failure and thus ignore this block (as well as
# dropping this bucket).
return blockdata
class BlockDownloader:
"""I am responsible for downloading a single block (from a single bucket)
for a single segment.
I am a child of the SegmentDownloader.
"""
def __init__(self, vbucket, blocknum, parent):
self.vbucket = vbucket
self.blocknum = blocknum
self.parent = parent
def start(self, segnum):
d = self.vbucket.get_block(segnum)
d.addCallbacks(self._hold_block, self._got_block_error)
return d
def _hold_block(self, data):
self.parent.hold_block(self.blocknum, data)
def _got_block_error(self, f):
log.msg("BlockDownloader[%d] got error: %s" % (self.blocknum, f))
self.parent.bucket_failed(self.vbucket)
class SegmentDownloader:
"""I am responsible for downloading all the blocks for a single segment
of data.
I am a child of the FileDownloader.
"""
def __init__(self, parent, segmentnumber, needed_shares):
self.parent = parent
self.segmentnumber = segmentnumber
self.needed_blocks = needed_shares
self.blocks = {} # k: blocknum, v: data
def start(self):
return self._download()
def _download(self):
d = self._try()
def _done(res):
if len(self.blocks) >= self.needed_blocks:
# we only need self.needed_blocks blocks
# we want to get the smallest blockids, because they are
# more likely to be fast "primary blocks"
blockids = sorted(self.blocks.keys())[:self.needed_blocks]
blocks = []
for blocknum in blockids:
blocks.append(self.blocks[blocknum])
return (blocks, blockids)
else:
return self._download()
d.addCallback(_done)
return d
def _try(self):
# fill our set of active buckets, maybe raising NotEnoughPeersError
active_buckets = self.parent._activate_enough_buckets()
# Now we have enough buckets, in self.parent.active_buckets.
# in test cases, bd.start might mutate active_buckets right away, so
# we need to put off calling start() until we've iterated all the way
# through it.
downloaders = []
for blocknum, vbucket in active_buckets.iteritems():
bd = BlockDownloader(vbucket, blocknum, self)
downloaders.append(bd)
l = [bd.start(self.segmentnumber) for bd in downloaders]
return defer.DeferredList(l, fireOnOneErrback=True)
def hold_block(self, blocknum, data):
self.blocks[blocknum] = data
def bucket_failed(self, vbucket):
self.parent.bucket_failed(vbucket)
class FileDownloader:
check_crypttext_hash = True
check_plaintext_hash = True
def __init__(self, client, u, downloadable):
self._client = client
u = IFileURI(u)
self._storage_index = u.storage_index
self._uri_extension_hash = u.uri_extension_hash
self._total_shares = u.total_shares
self._size = u.size
self._num_needed_shares = u.needed_shares
self._output = Output(downloadable, u.key, self._size)
self.active_buckets = {} # k: shnum, v: bucket
self._share_buckets = [] # list of (sharenum, bucket) tuples
self._share_vbuckets = {} # k: shnum, v: set of ValidatedBuckets
self._uri_extension_sources = []
self._uri_extension_data = None
self._fetch_failures = {"uri_extension": 0,
"plaintext_hashroot": 0,
"plaintext_hashtree": 0,
"crypttext_hashroot": 0,
"crypttext_hashtree": 0,
}
def start(self):
log.msg("starting download [%s]" % idlib.b2a(self._storage_index))
# first step: who should we download from?
d = defer.maybeDeferred(self._get_all_shareholders)
d.addCallback(self._got_all_shareholders)
# now get the uri_extension block from somebody and validate it
d.addCallback(self._obtain_uri_extension)
d.addCallback(self._got_uri_extension)
d.addCallback(self._get_hashtrees)
d.addCallback(self._create_validated_buckets)
# once we know that, we can download blocks from everybody
d.addCallback(self._download_all_segments)
def _failed(why):
self._output.fail(why)
return why
d.addErrback(_failed)
d.addCallback(self._done)
return d
def _get_all_shareholders(self):
dl = []
for (permutedpeerid, peerid, connection) in self._client.get_permuted_peers(self._storage_index):
d = connection.callRemote("get_service", "storageserver")
d.addCallback(lambda ss: ss.callRemote("get_buckets",
self._storage_index))
d.addCallbacks(self._got_response, self._got_error,
callbackArgs=(connection,))
dl.append(d)
return defer.DeferredList(dl)
def _got_response(self, buckets, connection):
_assert(isinstance(buckets, dict), buckets) # soon foolscap will check this for us with its DictOf schema constraint
for sharenum, bucket in buckets.iteritems():
b = storage.ReadBucketProxy(bucket)
self.add_share_bucket(sharenum, b)
self._uri_extension_sources.append(b)
def add_share_bucket(self, sharenum, bucket):
# this is split out for the benefit of test_encode.py
self._share_buckets.append( (sharenum, bucket) )
def _got_error(self, f):
self._client.log("Somebody failed. -- %s" % (f,))
def bucket_failed(self, vbucket):
shnum = vbucket.sharenum
del self.active_buckets[shnum]
s = self._share_vbuckets[shnum]
# s is a set of ValidatedBucket instances
s.remove(vbucket)
# ... which might now be empty
if not s:
# there are no more buckets which can provide this share, so
# remove the key. This may prompt us to use a different share.
del self._share_vbuckets[shnum]
def _got_all_shareholders(self, res):
if len(self._share_buckets) < self._num_needed_shares:
raise NotEnoughPeersError
#for s in self._share_vbuckets.values():
# for vb in s:
# assert isinstance(vb, ValidatedBucket), \
# "vb is %s but should be a ValidatedBucket" % (vb,)
def _unpack_uri_extension_data(self, data):
return uri.unpack_extension(data)
def _obtain_uri_extension(self, ignored):
# all shareholders are supposed to have a copy of uri_extension, and
# all are supposed to be identical. We compute the hash of the data
# that comes back, and compare it against the version in our URI. If
# they don't match, ignore their data and try someone else.
def _validate(proposal, bucket):
h = hashutil.uri_extension_hash(proposal)
if h != self._uri_extension_hash:
self._fetch_failures["uri_extension"] += 1
msg = ("The copy of uri_extension we received from "
"%s was bad" % bucket)
raise BadURIExtensionHashValue(msg)
return self._unpack_uri_extension_data(proposal)
return self._obtain_validated_thing(None,
self._uri_extension_sources,
"uri_extension",
"get_uri_extension", (), _validate)
def _obtain_validated_thing(self, ignored, sources, name, methname, args,
validatorfunc):
if not sources:
raise NotEnoughPeersError("started with zero peers while fetching "
"%s" % name)
bucket = sources[0]
sources = sources[1:]
#d = bucket.callRemote(methname, *args)
d = bucket.startIfNecessary()
d.addCallback(lambda res: getattr(bucket, methname)(*args))
d.addCallback(validatorfunc, bucket)
def _bad(f):
log.msg("%s from vbucket %s failed: %s" % (name, bucket, f)) # WEIRD
if not sources:
raise NotEnoughPeersError("ran out of peers, last error was %s"
% (f,))
# try again with a different one
return self._obtain_validated_thing(None, sources, name,
methname, args, validatorfunc)
d.addErrback(_bad)
return d
def _got_uri_extension(self, uri_extension_data):
d = self._uri_extension_data = uri_extension_data
self._codec = codec.get_decoder_by_name(d['codec_name'])
self._codec.set_serialized_params(d['codec_params'])
self._tail_codec = codec.get_decoder_by_name(d['codec_name'])
self._tail_codec.set_serialized_params(d['tail_codec_params'])
crypttext_hash = d['crypttext_hash']
assert isinstance(crypttext_hash, str)
assert len(crypttext_hash) == 32
self._crypttext_hash = crypttext_hash
self._plaintext_hash = d['plaintext_hash']
self._roothash = d['share_root_hash']
self._segment_size = segment_size = d['segment_size']
self._total_segments = mathutil.div_ceil(self._size, segment_size)
self._current_segnum = 0
self._share_hashtree = hashtree.IncompleteHashTree(d['total_shares'])
self._share_hashtree.set_hashes({0: self._roothash})
def _get_hashtrees(self, res):
d = self._get_plaintext_hashtrees()
d.addCallback(self._get_crypttext_hashtrees)
d.addCallback(self._setup_hashtrees)
return d
def _get_plaintext_hashtrees(self):
def _validate_plaintext_hashtree(proposal, bucket):
if proposal[0] != self._uri_extension_data['plaintext_root_hash']:
self._fetch_failures["plaintext_hashroot"] += 1
msg = ("The copy of the plaintext_root_hash we received from"
" %s was bad" % bucket)
raise BadPlaintextHashValue(msg)
pt_hashtree = hashtree.IncompleteHashTree(self._total_segments)
pt_hashes = dict(list(enumerate(proposal)))
try:
pt_hashtree.set_hashes(pt_hashes)
except hashtree.BadHashError:
# the hashes they gave us were not self-consistent, even
# though the root matched what we saw in the uri_extension
# block
self._fetch_failures["plaintext_hashtree"] += 1
raise
self._plaintext_hashtree = pt_hashtree
d = self._obtain_validated_thing(None,
self._uri_extension_sources,
"plaintext_hashes",
"get_plaintext_hashes", (),
_validate_plaintext_hashtree)
return d
def _get_crypttext_hashtrees(self, res):
def _validate_crypttext_hashtree(proposal, bucket):
if proposal[0] != self._uri_extension_data['crypttext_root_hash']:
self._fetch_failures["crypttext_hashroot"] += 1
msg = ("The copy of the crypttext_root_hash we received from"
" %s was bad" % bucket)
raise BadCrypttextHashValue(msg)
ct_hashtree = hashtree.IncompleteHashTree(self._total_segments)
ct_hashes = dict(list(enumerate(proposal)))
try:
ct_hashtree.set_hashes(ct_hashes)
except hashtree.BadHashError:
self._fetch_failures["crypttext_hashtree"] += 1
raise
ct_hashtree.set_hashes(ct_hashes)
self._crypttext_hashtree = ct_hashtree
d = self._obtain_validated_thing(None,
self._uri_extension_sources,
"crypttext_hashes",
"get_crypttext_hashes", (),
_validate_crypttext_hashtree)
return d
def _setup_hashtrees(self, res):
self._output.setup_hashtrees(self._plaintext_hashtree,
self._crypttext_hashtree)
def _create_validated_buckets(self, ignored=None):
self._share_vbuckets = {}
for sharenum, bucket in self._share_buckets:
vbucket = ValidatedBucket(sharenum, bucket,
self._share_hashtree,
self._roothash,
self._total_segments)
s = self._share_vbuckets.setdefault(sharenum, set())
s.add(vbucket)
def _activate_enough_buckets(self):
"""either return a mapping from shnum to a ValidatedBucket that can
provide data for that share, or raise NotEnoughPeersError"""
while len(self.active_buckets) < self._num_needed_shares:
# need some more
handled_shnums = set(self.active_buckets.keys())
available_shnums = set(self._share_vbuckets.keys())
potential_shnums = list(available_shnums - handled_shnums)
if not potential_shnums:
raise NotEnoughPeersError
# choose a random share
shnum = random.choice(potential_shnums)
# and a random bucket that will provide it
validated_bucket = random.choice(list(self._share_vbuckets[shnum]))
self.active_buckets[shnum] = validated_bucket
return self.active_buckets
def _download_all_segments(self, res):
# the promise: upon entry to this function, self._share_vbuckets
# contains enough buckets to complete the download, and some extra
# ones to tolerate some buckets dropping out or having errors.
# self._share_vbuckets is a dictionary that maps from shnum to a set
# of ValidatedBuckets, which themselves are wrappers around
# RIBucketReader references.
self.active_buckets = {} # k: shnum, v: ValidatedBucket instance
d = defer.succeed(None)
for segnum in range(self._total_segments-1):
d.addCallback(self._download_segment, segnum)
d.addCallback(self._download_tail_segment, self._total_segments-1)
return d
def _download_segment(self, res, segnum):
# memory footprint: when the SegmentDownloader finishes pulling down
# all shares, we have 1*segment_size of usage.
segmentdler = SegmentDownloader(self, segnum, self._num_needed_shares)
d = segmentdler.start()
# while the codec does its job, we hit 2*segment_size
d.addCallback(lambda (shares, shareids):
self._codec.decode(shares, shareids))
# once the codec is done, we drop back to 1*segment_size, because
# 'shares' goes out of scope. The memory usage is all in the
# plaintext now, spread out into a bunch of tiny buffers.
def _done(buffers):
# we start by joining all these buffers together into a single
# string. This makes Output.write easier, since it wants to hash
# data one segment at a time anyways, and doesn't impact our
# memory footprint since we're already peaking at 2*segment_size
# inside the codec a moment ago.
segment = "".join(buffers)
del buffers
# we're down to 1*segment_size right now, but write_segment()
# will decrypt a copy of the segment internally, which will push
# us up to 2*segment_size while it runs.
self._output.write_segment(segment)
d.addCallback(_done)
return d
def _download_tail_segment(self, res, segnum):
segmentdler = SegmentDownloader(self, segnum, self._num_needed_shares)
d = segmentdler.start()
d.addCallback(lambda (shares, shareids):
self._tail_codec.decode(shares, shareids))
def _done(buffers):
# trim off any padding added by the upload side
segment = "".join(buffers)
del buffers
# we never send empty segments. If the data was an exact multiple
# of the segment size, the last segment will be full.
pad_size = mathutil.pad_size(self._size, self._segment_size)
tail_size = self._segment_size - pad_size
segment = segment[:tail_size]
self._output.write_segment(segment)
d.addCallback(_done)
return d
def _done(self, res):
self._output.close()
log.msg("computed CRYPTTEXT_HASH: %s" %
idlib.b2a(self._output.crypttext_hash))
log.msg("computed PLAINTEXT_HASH: %s" %
idlib.b2a(self._output.plaintext_hash))
if self.check_crypttext_hash:
_assert(self._crypttext_hash == self._output.crypttext_hash,
"bad crypttext_hash: computed=%s, expected=%s" %
(idlib.b2a(self._output.crypttext_hash),
idlib.b2a(self._crypttext_hash)))
if self.check_plaintext_hash:
_assert(self._plaintext_hash == self._output.plaintext_hash,
"bad plaintext_hash: computed=%s, expected=%s" %
(idlib.b2a(self._output.plaintext_hash),
idlib.b2a(self._plaintext_hash)))
_assert(self._output.length == self._size,
got=self._output.length, expected=self._size)
return self._output.finish()
class LiteralDownloader:
def __init__(self, client, u, downloadable):
self._uri = IFileURI(u)
assert isinstance(self._uri, uri.LiteralFileURI)
self._downloadable = downloadable
def start(self):
data = self._uri.data
self._downloadable.open(len(data))
self._downloadable.write(data)
self._downloadable.close()
return defer.maybeDeferred(self._downloadable.finish)
class FileName:
implements(IDownloadTarget)
def __init__(self, filename):
self._filename = filename
self.f = None
def open(self, size):
self.f = open(self._filename, "wb")
return self.f
def write(self, data):
self.f.write(data)
def close(self):
if self.f:
self.f.close()
def fail(self, why):
if self.f:
self.f.close()
os.unlink(self._filename)
def register_canceller(self, cb):
pass # we won't use it
def finish(self):
pass
class Data:
implements(IDownloadTarget)
def __init__(self):
self._data = []
def open(self, size):
pass
def write(self, data):
self._data.append(data)
def close(self):
self.data = "".join(self._data)
del self._data
def fail(self, why):
del self._data
def register_canceller(self, cb):
pass # we won't use it
def finish(self):
return self.data
class FileHandle:
"""Use me to download data to a pre-defined filehandle-like object. I
will use the target's write() method. I will *not* close the filehandle:
I leave that up to the originator of the filehandle. The download process
will return the filehandle when it completes.
"""
implements(IDownloadTarget)
def __init__(self, filehandle):
self._filehandle = filehandle
def open(self, size):
pass
def write(self, data):
self._filehandle.write(data)
def close(self):
# the originator of the filehandle reserves the right to close it
pass
def fail(self, why):
pass
def register_canceller(self, cb):
pass
def finish(self):
return self._filehandle
class Downloader(service.MultiService):
"""I am a service that allows file downloading.
"""
implements(IDownloader)
name = "downloader"
def download(self, u, t):
assert self.parent
assert self.running
u = IFileURI(u)
t = IDownloadTarget(t)
assert t.write
assert t.close
if isinstance(u, uri.LiteralFileURI):
dl = LiteralDownloader(self.parent, u, t)
elif isinstance(u, uri.CHKFileURI):
dl = FileDownloader(self.parent, u, t)
else:
raise RuntimeError("I don't know how to download a %s" % u)
d = dl.start()
return d
# utility functions
def download_to_data(self, uri):
return self.download(uri, Data())
def download_to_filename(self, uri, filename):
return self.download(uri, FileName(filename))
def download_to_filehandle(self, uri, filehandle):
return self.download(uri, FileHandle(filehandle))