tahoe-lafs/src/allmydata/download.py

420 lines
15 KiB
Python
Raw Normal View History

import os, random, sha
from zope.interface import implements
from twisted.python import log
from twisted.internet import defer
from twisted.application import service
from allmydata.util import idlib, mathutil, hashutil
from allmydata.util.assertutil import _assert
2007-04-12 20:13:25 +00:00
from allmydata import codec, hashtree
from allmydata.Crypto.Cipher import AES
2007-01-17 04:29:59 +00:00
from allmydata.uri import unpack_uri
from allmydata.interfaces import IDownloadTarget, IDownloader
class NotEnoughPeersError(Exception):
pass
class HaveAllPeersError(Exception):
# we use this to jump out of the loop
pass
class Output:
def __init__(self, downloadable, key):
self.downloadable = downloadable
self._decryptor = AES.new(key=key, mode=AES.MODE_CTR,
counterstart="\x00"*16)
self._verifierid_hasher = sha.new(netstring("allmydata_v1_verifierid"))
self._fileid_hasher = sha.new(netstring("allmydata_v1_fileid"))
self.length = 0
def open(self):
self.downloadable.open()
def write(self, crypttext):
self.length += len(crypttext)
self._verifierid_hasher.update(crypttext)
plaintext = self._decryptor.decrypt(crypttext)
self._fileid_hasher.update(plaintext)
self.downloadable.write(plaintext)
def close(self):
self.verifierid = self._verifierid_hasher.digest()
self.fileid = self._fileid_hasher.digest()
self.downloadable.close()
def finish(self):
return self.downloadable.finish()
class ValidatedBucket:
def __init__(self, sharenum, bucket,
share_hash_tree, roothash,
num_blocks):
self.sharenum = sharenum
self.bucket = bucket
self._share_hash = None # None means not validated yet
self.share_hash_tree = share_hash_tree
self._roothash = roothash
2007-04-12 20:13:25 +00:00
self.block_hash_tree = hashtree.IncompleteHashTree(num_blocks)
def get_block(self, blocknum):
# the first time we use this bucket, we need to fetch enough elements
# of the share hash tree to validate it from our share hash up to the
# hashroot.
if not self._share_hash:
d1 = self.bucket.callRemote('get_share_hashes')
else:
d1 = defer.succeed(None)
# we might need to grab some elements of our block hash tree, to
# validate the requested block up to the share hash
needed = self.block_hash_tree.needed_hashes(blocknum)
if needed:
# TODO: get fewer hashes, callRemote('get_block_hashes', needed)
d2 = self.bucket.callRemote('get_block_hashes')
else:
d2 = defer.succeed([])
d3 = self.bucket.callRemote('get_block', blocknum)
d = defer.gatherResults([d1, d2, d3])
d.addCallback(self._got_data, blocknum)
return d
def _got_data(self, res, blocknum):
sharehashes, blockhashes, blockdata = res
if not self._share_hash:
sh = dict(sharehashes)
sh[0] = self._roothash # always use our own root, from the URI
if self.share_hash_tree.get_leaf_index(self.sharenum) not in sh:
raise hashutil.NotEnoughHashesError
self.share_hash_tree.set_hashes(sh)
self._share_hash = self.share_hash_tree.get_leaf(self.sharenum)
blockhash = hashutil.tagged_hash("encoded subshare", blockdata)
# we always validate the blockhash
bh = dict(enumerate(blockhashes))
bh[0] = self._share_hash # replace blockhash root with validated value
self.block_hash_tree.set_hashes(bh, {blocknum: blockhash})
# If we made it here, the block is good. If the hash trees didn't
# like what they saw, they would have raised a BadHashError, causing
# our caller to see a Failure and thus ignore this block (as well as
# dropping this bucket).
return blockdata
class BlockDownloader:
def __init__(self, vbucket, blocknum, parent):
self.vbucket = vbucket
self.blocknum = blocknum
self.parent = parent
def start(self, segnum):
d = self.vbucket.get_block(segnum)
d.addCallbacks(self._hold_block, self._got_block_error)
return d
def _hold_block(self, data):
self.parent.hold_block(self.blocknum, data)
def _got_block_error(self, f):
log.msg("BlockDownloader[%d] got error: %s" % (self.blocknum, f))
self.parent.bucket_failed(self.blocknum, self.vbucket)
class SegmentDownloader:
def __init__(self, parent, segmentnumber, needed_shares):
self.parent = parent
self.segmentnumber = segmentnumber
self.needed_blocks = needed_shares
self.blocks = {} # k: blocknum, v: data
def start(self):
return self._download()
def _download(self):
d = self._try()
def _done(res):
if len(self.blocks) >= self.needed_blocks:
# we only need self.needed_blocks blocks
# we want to get the smallest blockids, because they are
# more likely to be fast "primary blocks"
blockids = sorted(self.blocks.keys())[:self.needed_blocks]
blocks = []
for blocknum in blockids:
blocks.append(self.blocks[blocknum])
return (blocks, blockids)
else:
return self._download()
d.addCallback(_done)
return d
def _try(self):
# fill our set of active buckets, maybe raising NotEnoughPeersError
active_buckets = self.parent._activate_enough_buckets()
# Now we have enough buckets, in self.parent.active_buckets.
# in test cases, bd.start might mutate active_buckets right away, so
# we need to put off calling start() until we've iterated all the way
# through it.
downloaders = []
for blocknum, vbucket in active_buckets.iteritems():
bd = BlockDownloader(vbucket, blocknum, self)
downloaders.append(bd)
l = [bd.start(self.segmentnumber) for bd in downloaders]
return defer.DeferredList(l, fireOnOneErrback=True)
def hold_block(self, blocknum, data):
self.blocks[blocknum] = data
def bucket_failed(self, shnum, vbucket):
del self.parent.active_buckets[shnum]
s = self.parent._share_buckets[shnum]
# s is a set of ValidatedBucket instances
s.remove(vbucket)
# ... which might now be empty
if not s:
# there are no more buckets which can provide this share, so
# remove the key. This may prompt us to use a different share.
del self.parent._share_buckets[shnum]
class FileDownloader:
debug = False
def __init__(self, client, uri, downloadable):
self._client = client
self._downloadable = downloadable
(codec_name, codec_params, tail_codec_params, verifierid, roothash, needed_shares, total_shares, size, segment_size) = unpack_uri(uri)
assert isinstance(verifierid, str)
assert len(verifierid) == 20
self._verifierid = verifierid
self._roothash = roothash
self._codec = codec.get_decoder_by_name(codec_name)
self._codec.set_serialized_params(codec_params)
self._tail_codec = codec.get_decoder_by_name(codec_name)
self._tail_codec.set_serialized_params(tail_codec_params)
self._total_segments = mathutil.div_ceil(size, segment_size)
self._current_segnum = 0
self._segment_size = segment_size
self._size = size
self._num_needed_shares = self._codec.get_needed_shares()
key = "\x00" * 16
self._output = Output(downloadable, key)
2007-04-12 20:13:25 +00:00
self._share_hashtree = hashtree.IncompleteHashTree(total_shares)
self._share_hashtree.set_hashes({0: roothash})
self.active_buckets = {} # k: shnum, v: bucket
self._share_buckets = {} # k: shnum, v: set of buckets
def start(self):
log.msg("starting download [%s]" % (idlib.b2a(self._verifierid),))
if self.debug:
print "starting download"
# first step: who should we download from?
d = defer.maybeDeferred(self._get_all_shareholders)
d.addCallback(self._got_all_shareholders)
# once we know that, we can download blocks from them
d.addCallback(self._download_all_segments)
d.addCallback(self._done)
return d
def _get_all_shareholders(self):
dl = []
for (permutedpeerid, peerid, connection) in self._client.get_permuted_peers(self._verifierid):
d = connection.callRemote("get_service", "storageserver")
d.addCallback(lambda ss: ss.callRemote("get_buckets",
self._verifierid))
d.addCallbacks(self._got_response, self._got_error,
callbackArgs=(connection,))
dl.append(d)
return defer.DeferredList(dl)
def _got_response(self, buckets, connection):
_assert(isinstance(buckets, dict), buckets) # soon foolscap will check this for us with its DictOf schema constraint
for sharenum, bucket in buckets.iteritems():
self.add_share_bucket(sharenum, bucket)
def add_share_bucket(self, sharenum, bucket):
vbucket = ValidatedBucket(sharenum, bucket,
self._share_hashtree,
self._roothash,
self._total_segments)
self._share_buckets.setdefault(sharenum, set()).add(vbucket)
def _got_error(self, f):
self._client.log("Somebody failed. -- %s" % (f,))
def _got_all_shareholders(self, res):
if len(self._share_buckets) < self._num_needed_shares:
raise NotEnoughPeersError
for s in self._share_buckets.values():
for vb in s:
assert isinstance(vb, ValidatedBucket), \
"vb is %s but should be a ValidatedBucket" % (vb,)
def _activate_enough_buckets(self):
"""either return a mapping from shnum to a ValidatedBucket that can
provide data for that share, or raise NotEnoughPeersError"""
while len(self.active_buckets) < self._num_needed_shares:
# need some more
handled_shnums = set(self.active_buckets.keys())
available_shnums = set(self._share_buckets.keys())
potential_shnums = list(available_shnums - handled_shnums)
if not potential_shnums:
raise NotEnoughPeersError
# choose a random share
shnum = random.choice(potential_shnums)
# and a random bucket that will provide it
validated_bucket = random.choice(list(self._share_buckets[shnum]))
self.active_buckets[shnum] = validated_bucket
return self.active_buckets
def _download_all_segments(self, res):
# the promise: upon entry to this function, self._share_buckets
# contains enough buckets to complete the download, and some extra
# ones to tolerate some buckets dropping out or having errors.
# self._share_buckets is a dictionary that maps from shnum to a set
# of ValidatedBuckets, which themselves are wrappers around
# RIBucketReader references.
self.active_buckets = {} # k: shnum, v: ValidatedBucket instance
self._output.open()
d = defer.succeed(None)
for segnum in range(self._total_segments-1):
d.addCallback(self._download_segment, segnum)
d.addCallback(self._download_tail_segment, self._total_segments-1)
return d
def _download_segment(self, res, segnum):
segmentdler = SegmentDownloader(self, segnum, self._num_needed_shares)
d = segmentdler.start()
d.addCallback(lambda (shares, shareids):
self._codec.decode(shares, shareids))
def _done(res):
for buf in res:
self._output.write(buf)
d.addCallback(_done)
return d
def _download_tail_segment(self, res, segnum):
segmentdler = SegmentDownloader(self, segnum, self._num_needed_shares)
d = segmentdler.start()
d.addCallback(lambda (shares, shareids):
self._tail_codec.decode(shares, shareids))
def _done(res):
# trim off any padding added by the upload side
data = ''.join(res)
tail_size = self._size % self._segment_size
self._output.write(data[:tail_size])
d.addCallback(_done)
return d
def _done(self, res):
self._output.close()
#print "VERIFIERID: %s" % idlib.b2a(self._output.verifierid)
#print "FILEID: %s" % idlib.b2a(self._output.fileid)
#assert self._verifierid == self._output.verifierid
#assert self._fileid = self._output.fileid
_assert(self._output.length == self._size,
got=self._output.length, expected=self._size)
return self._output.finish()
def netstring(s):
return "%d:%s," % (len(s), s)
class FileName:
implements(IDownloadTarget)
def __init__(self, filename):
self._filename = filename
def open(self):
self.f = open(self._filename, "wb")
return self.f
def write(self, data):
self.f.write(data)
def close(self):
self.f.close()
def fail(self):
self.f.close()
os.unlink(self._filename)
def register_canceller(self, cb):
pass # we won't use it
def finish(self):
pass
class Data:
implements(IDownloadTarget)
def __init__(self):
self._data = []
def open(self):
pass
def write(self, data):
self._data.append(data)
def close(self):
self.data = "".join(self._data)
del self._data
def fail(self):
del self._data
def register_canceller(self, cb):
pass # we won't use it
def finish(self):
return self.data
class FileHandle:
implements(IDownloadTarget)
def __init__(self, filehandle):
self._filehandle = filehandle
def open(self):
pass
def write(self, data):
self._filehandle.write(data)
def close(self):
# the originator of the filehandle reserves the right to close it
pass
def fail(self):
pass
def register_canceller(self, cb):
pass
def finish(self):
pass
class Downloader(service.MultiService):
"""I am a service that allows file downloading.
"""
implements(IDownloader)
name = "downloader"
debug = False
def download(self, uri, t):
assert self.parent
assert self.running
t = IDownloadTarget(t)
assert t.write
assert t.close
dl = FileDownloader(self.parent, uri, t)
if self.debug:
dl.debug = True
d = dl.start()
return d
# utility functions
def download_to_data(self, uri):
return self.download(uri, Data())
def download_to_filename(self, uri, filename):
return self.download(uri, FileName(filename))
def download_to_filehandle(self, uri, filehandle):
return self.download(uri, FileHandle(filehandle))