mirror of
https://github.com/tahoe-lafs/tahoe-lafs.git
synced 2025-01-13 08:19:45 +00:00
820 lines
32 KiB
Python
820 lines
32 KiB
Python
|
|
import os, random
|
|
from zope.interface import implements
|
|
from twisted.internet import defer
|
|
from twisted.internet.interfaces import IPushProducer, IConsumer
|
|
from twisted.application import service
|
|
from foolscap.eventual import eventually
|
|
|
|
from allmydata.util import idlib, mathutil, hashutil, log
|
|
from allmydata.util.assertutil import _assert
|
|
from allmydata import codec, hashtree, storage, uri
|
|
from allmydata.interfaces import IDownloadTarget, IDownloader, IFileURI
|
|
from allmydata.encode import NotEnoughPeersError
|
|
from pycryptopp.cipher.aes import AES
|
|
|
|
class HaveAllPeersError(Exception):
|
|
# we use this to jump out of the loop
|
|
pass
|
|
|
|
class BadURIExtensionHashValue(Exception):
|
|
pass
|
|
class BadPlaintextHashValue(Exception):
|
|
pass
|
|
class BadCrypttextHashValue(Exception):
|
|
pass
|
|
|
|
class DownloadStopped(Exception):
|
|
pass
|
|
|
|
class Output:
|
|
def __init__(self, downloadable, key, total_length, log_parent):
|
|
self.downloadable = downloadable
|
|
self._decryptor = AES(key)
|
|
self._crypttext_hasher = hashutil.crypttext_hasher()
|
|
self._plaintext_hasher = hashutil.plaintext_hasher()
|
|
self.length = 0
|
|
self.total_length = total_length
|
|
self._segment_number = 0
|
|
self._plaintext_hash_tree = None
|
|
self._crypttext_hash_tree = None
|
|
self._opened = False
|
|
self._log_parent = log_parent
|
|
|
|
def log(self, *args, **kwargs):
|
|
if "parent" not in kwargs:
|
|
kwargs["parent"] = self._log_parent
|
|
if "facility" not in kwargs:
|
|
kwargs["facility"] = "download.output"
|
|
return log.msg(*args, **kwargs)
|
|
|
|
def setup_hashtrees(self, plaintext_hashtree, crypttext_hashtree):
|
|
self._plaintext_hash_tree = plaintext_hashtree
|
|
self._crypttext_hash_tree = crypttext_hashtree
|
|
|
|
def write_segment(self, crypttext):
|
|
self.length += len(crypttext)
|
|
|
|
# memory footprint: 'crypttext' is the only segment_size usage
|
|
# outstanding. While we decrypt it into 'plaintext', we hit
|
|
# 2*segment_size.
|
|
self._crypttext_hasher.update(crypttext)
|
|
if self._crypttext_hash_tree:
|
|
ch = hashutil.crypttext_segment_hasher()
|
|
ch.update(crypttext)
|
|
crypttext_leaves = {self._segment_number: ch.digest()}
|
|
self.log(format="crypttext leaf hash (%(bytes)sB) [%(segnum)d] is %(hash)s",
|
|
bytes=len(crypttext),
|
|
segnum=self._segment_number, hash=idlib.b2a(ch.digest()),
|
|
level=log.NOISY)
|
|
self._crypttext_hash_tree.set_hashes(leaves=crypttext_leaves)
|
|
|
|
plaintext = self._decryptor.process(crypttext)
|
|
del crypttext
|
|
|
|
# now we're back down to 1*segment_size.
|
|
|
|
self._plaintext_hasher.update(plaintext)
|
|
if self._plaintext_hash_tree:
|
|
ph = hashutil.plaintext_segment_hasher()
|
|
ph.update(plaintext)
|
|
plaintext_leaves = {self._segment_number: ph.digest()}
|
|
self.log(format="plaintext leaf hash (%(bytes)sB) [%(segnum)d] is %(hash)s",
|
|
bytes=len(plaintext),
|
|
segnum=self._segment_number, hash=idlib.b2a(ph.digest()),
|
|
level=log.NOISY)
|
|
self._plaintext_hash_tree.set_hashes(leaves=plaintext_leaves)
|
|
|
|
self._segment_number += 1
|
|
# We're still at 1*segment_size. The Downloadable is responsible for
|
|
# any memory usage beyond this.
|
|
if not self._opened:
|
|
self._opened = True
|
|
self.downloadable.open(self.total_length)
|
|
self.downloadable.write(plaintext)
|
|
|
|
def fail(self, why):
|
|
# this is really unusual, and deserves maximum forensics
|
|
self.log("download failed!", failure=why, level=log.SCARY)
|
|
self.downloadable.fail(why)
|
|
|
|
def close(self):
|
|
self.crypttext_hash = self._crypttext_hasher.digest()
|
|
self.plaintext_hash = self._plaintext_hasher.digest()
|
|
self.log("download finished, closing IDownloadable", level=log.NOISY)
|
|
self.downloadable.close()
|
|
|
|
def finish(self):
|
|
return self.downloadable.finish()
|
|
|
|
class ValidatedBucket:
|
|
"""I am a front-end for a remote storage bucket, responsible for
|
|
retrieving and validating data from that bucket.
|
|
|
|
My get_block() method is used by BlockDownloaders.
|
|
"""
|
|
|
|
def __init__(self, sharenum, bucket,
|
|
share_hash_tree, roothash,
|
|
num_blocks):
|
|
self.sharenum = sharenum
|
|
self.bucket = bucket
|
|
self._share_hash = None # None means not validated yet
|
|
self.share_hash_tree = share_hash_tree
|
|
self._roothash = roothash
|
|
self.block_hash_tree = hashtree.IncompleteHashTree(num_blocks)
|
|
self.started = False
|
|
|
|
def get_block(self, blocknum):
|
|
if not self.started:
|
|
d = self.bucket.start()
|
|
def _started(res):
|
|
self.started = True
|
|
return self.get_block(blocknum)
|
|
d.addCallback(_started)
|
|
return d
|
|
|
|
# the first time we use this bucket, we need to fetch enough elements
|
|
# of the share hash tree to validate it from our share hash up to the
|
|
# hashroot.
|
|
if not self._share_hash:
|
|
d1 = self.bucket.get_share_hashes()
|
|
else:
|
|
d1 = defer.succeed([])
|
|
|
|
# we might need to grab some elements of our block hash tree, to
|
|
# validate the requested block up to the share hash
|
|
needed = self.block_hash_tree.needed_hashes(blocknum)
|
|
if needed:
|
|
# TODO: get fewer hashes, use get_block_hashes(needed)
|
|
d2 = self.bucket.get_block_hashes()
|
|
else:
|
|
d2 = defer.succeed([])
|
|
|
|
d3 = self.bucket.get_block(blocknum)
|
|
|
|
d = defer.gatherResults([d1, d2, d3])
|
|
d.addCallback(self._got_data, blocknum)
|
|
return d
|
|
|
|
def _got_data(self, res, blocknum):
|
|
sharehashes, blockhashes, blockdata = res
|
|
blockhash = None # to make logging it safe
|
|
|
|
try:
|
|
if not self._share_hash:
|
|
sh = dict(sharehashes)
|
|
sh[0] = self._roothash # always use our own root, from the URI
|
|
sht = self.share_hash_tree
|
|
if sht.get_leaf_index(self.sharenum) not in sh:
|
|
raise hashtree.NotEnoughHashesError
|
|
sht.set_hashes(sh)
|
|
self._share_hash = sht.get_leaf(self.sharenum)
|
|
|
|
blockhash = hashutil.block_hash(blockdata)
|
|
#log.msg("checking block_hash(shareid=%d, blocknum=%d) len=%d "
|
|
# "%r .. %r: %s" %
|
|
# (self.sharenum, blocknum, len(blockdata),
|
|
# blockdata[:50], blockdata[-50:], idlib.b2a(blockhash)))
|
|
|
|
# we always validate the blockhash
|
|
bh = dict(enumerate(blockhashes))
|
|
# replace blockhash root with validated value
|
|
bh[0] = self._share_hash
|
|
self.block_hash_tree.set_hashes(bh, {blocknum: blockhash})
|
|
|
|
except (hashtree.BadHashError, hashtree.NotEnoughHashesError):
|
|
# log.WEIRD: indicates undetected disk/network error, or more
|
|
# likely a programming error
|
|
log.msg("hash failure in block=%d, shnum=%d on %s" %
|
|
(blocknum, self.sharenum, self.bucket))
|
|
if self._share_hash:
|
|
log.msg(""" failure occurred when checking the block_hash_tree.
|
|
This suggests that either the block data was bad, or that the
|
|
block hashes we received along with it were bad.""")
|
|
else:
|
|
log.msg(""" the failure probably occurred when checking the
|
|
share_hash_tree, which suggests that the share hashes we
|
|
received from the remote peer were bad.""")
|
|
log.msg(" have self._share_hash: %s" % bool(self._share_hash))
|
|
log.msg(" block length: %d" % len(blockdata))
|
|
log.msg(" block hash: %s" % idlib.b2a_or_none(blockhash))
|
|
if len(blockdata) < 100:
|
|
log.msg(" block data: %r" % (blockdata,))
|
|
else:
|
|
log.msg(" block data start/end: %r .. %r" %
|
|
(blockdata[:50], blockdata[-50:]))
|
|
log.msg(" root hash: %s" % idlib.b2a(self._roothash))
|
|
log.msg(" share hash tree:\n" + self.share_hash_tree.dump())
|
|
log.msg(" block hash tree:\n" + self.block_hash_tree.dump())
|
|
lines = []
|
|
for i,h in sorted(sharehashes):
|
|
lines.append("%3d: %s" % (i, idlib.b2a_or_none(h)))
|
|
log.msg(" sharehashes:\n" + "\n".join(lines) + "\n")
|
|
lines = []
|
|
for i,h in enumerate(blockhashes):
|
|
lines.append("%3d: %s" % (i, idlib.b2a_or_none(h)))
|
|
log.msg(" blockhashes:\n" + "\n".join(lines) + "\n")
|
|
raise
|
|
|
|
# If we made it here, the block is good. If the hash trees didn't
|
|
# like what they saw, they would have raised a BadHashError, causing
|
|
# our caller to see a Failure and thus ignore this block (as well as
|
|
# dropping this bucket).
|
|
return blockdata
|
|
|
|
|
|
|
|
class BlockDownloader:
|
|
"""I am responsible for downloading a single block (from a single bucket)
|
|
for a single segment.
|
|
|
|
I am a child of the SegmentDownloader.
|
|
"""
|
|
|
|
def __init__(self, vbucket, blocknum, parent):
|
|
self.vbucket = vbucket
|
|
self.blocknum = blocknum
|
|
self.parent = parent
|
|
self._log_number = self.parent.log("starting block %d" % blocknum)
|
|
|
|
def log(self, msg, parent=None):
|
|
if parent is None:
|
|
parent = self._log_number
|
|
return self.parent.log(msg, parent=parent)
|
|
|
|
def start(self, segnum):
|
|
lognum = self.log("get_block(segnum=%d)" % segnum)
|
|
d = self.vbucket.get_block(segnum)
|
|
d.addCallbacks(self._hold_block, self._got_block_error,
|
|
callbackArgs=(lognum,), errbackArgs=(lognum,))
|
|
return d
|
|
|
|
def _hold_block(self, data, lognum):
|
|
self.log("got block", parent=lognum)
|
|
self.parent.hold_block(self.blocknum, data)
|
|
|
|
def _got_block_error(self, f, lognum):
|
|
self.log("BlockDownloader[%d] got error: %s" % (self.blocknum, f),
|
|
parent=lognum)
|
|
self.parent.bucket_failed(self.vbucket)
|
|
|
|
class SegmentDownloader:
|
|
"""I am responsible for downloading all the blocks for a single segment
|
|
of data.
|
|
|
|
I am a child of the FileDownloader.
|
|
"""
|
|
|
|
def __init__(self, parent, segmentnumber, needed_shares):
|
|
self.parent = parent
|
|
self.segmentnumber = segmentnumber
|
|
self.needed_blocks = needed_shares
|
|
self.blocks = {} # k: blocknum, v: data
|
|
self._log_number = self.parent.log("starting segment %d" %
|
|
segmentnumber)
|
|
|
|
def log(self, msg, parent=None):
|
|
if parent is None:
|
|
parent = self._log_number
|
|
return self.parent.log(msg, parent=parent)
|
|
|
|
def start(self):
|
|
return self._download()
|
|
|
|
def _download(self):
|
|
d = self._try()
|
|
def _done(res):
|
|
if len(self.blocks) >= self.needed_blocks:
|
|
# we only need self.needed_blocks blocks
|
|
# we want to get the smallest blockids, because they are
|
|
# more likely to be fast "primary blocks"
|
|
blockids = sorted(self.blocks.keys())[:self.needed_blocks]
|
|
blocks = []
|
|
for blocknum in blockids:
|
|
blocks.append(self.blocks[blocknum])
|
|
return (blocks, blockids)
|
|
else:
|
|
return self._download()
|
|
d.addCallback(_done)
|
|
return d
|
|
|
|
def _try(self):
|
|
# fill our set of active buckets, maybe raising NotEnoughPeersError
|
|
active_buckets = self.parent._activate_enough_buckets()
|
|
# Now we have enough buckets, in self.parent.active_buckets.
|
|
|
|
# in test cases, bd.start might mutate active_buckets right away, so
|
|
# we need to put off calling start() until we've iterated all the way
|
|
# through it.
|
|
downloaders = []
|
|
for blocknum, vbucket in active_buckets.iteritems():
|
|
bd = BlockDownloader(vbucket, blocknum, self)
|
|
downloaders.append(bd)
|
|
l = [bd.start(self.segmentnumber) for bd in downloaders]
|
|
return defer.DeferredList(l, fireOnOneErrback=True)
|
|
|
|
def hold_block(self, blocknum, data):
|
|
self.blocks[blocknum] = data
|
|
|
|
def bucket_failed(self, vbucket):
|
|
self.parent.bucket_failed(vbucket)
|
|
|
|
class FileDownloader:
|
|
implements(IPushProducer)
|
|
check_crypttext_hash = True
|
|
check_plaintext_hash = True
|
|
|
|
def __init__(self, client, u, downloadable):
|
|
self._client = client
|
|
|
|
u = IFileURI(u)
|
|
self._storage_index = u.storage_index
|
|
self._uri_extension_hash = u.uri_extension_hash
|
|
self._total_shares = u.total_shares
|
|
self._size = u.size
|
|
self._num_needed_shares = u.needed_shares
|
|
|
|
self.init_logging()
|
|
|
|
if IConsumer.providedBy(downloadable):
|
|
downloadable.registerProducer(self, True)
|
|
self._downloadable = downloadable
|
|
self._output = Output(downloadable, u.key, self._size, self._log_number)
|
|
self._paused = False
|
|
self._stopped = False
|
|
|
|
self.active_buckets = {} # k: shnum, v: bucket
|
|
self._share_buckets = [] # list of (sharenum, bucket) tuples
|
|
self._share_vbuckets = {} # k: shnum, v: set of ValidatedBuckets
|
|
self._uri_extension_sources = []
|
|
|
|
self._uri_extension_data = None
|
|
|
|
self._fetch_failures = {"uri_extension": 0,
|
|
"plaintext_hashroot": 0,
|
|
"plaintext_hashtree": 0,
|
|
"crypttext_hashroot": 0,
|
|
"crypttext_hashtree": 0,
|
|
}
|
|
|
|
def init_logging(self):
|
|
self._log_prefix = prefix = idlib.b2a(self._storage_index)[:6]
|
|
num = self._client.log(format="FileDownloader(%(si)s): starting",
|
|
si=idlib.b2a(self._storage_index))
|
|
self._log_number = num
|
|
|
|
def log(self, *args, **kwargs):
|
|
if "parent" not in kwargs:
|
|
kwargs["parent"] = self._log_number
|
|
if "facility" not in kwargs:
|
|
kwargs["facility"] = "tahoe.download"
|
|
return log.msg(*args, **kwargs)
|
|
|
|
def pauseProducing(self):
|
|
if self._paused:
|
|
return
|
|
self._paused = defer.Deferred()
|
|
|
|
def resumeProducing(self):
|
|
if self._paused:
|
|
p = self._paused
|
|
self._paused = None
|
|
eventually(p.callback, None)
|
|
|
|
def stopProducing(self):
|
|
self.log("Download.stopProducing")
|
|
self._stopped = True
|
|
|
|
def start(self):
|
|
self.log("starting download")
|
|
|
|
# first step: who should we download from?
|
|
d = defer.maybeDeferred(self._get_all_shareholders)
|
|
d.addCallback(self._got_all_shareholders)
|
|
# now get the uri_extension block from somebody and validate it
|
|
d.addCallback(self._obtain_uri_extension)
|
|
d.addCallback(self._got_uri_extension)
|
|
d.addCallback(self._get_hashtrees)
|
|
d.addCallback(self._create_validated_buckets)
|
|
# once we know that, we can download blocks from everybody
|
|
d.addCallback(self._download_all_segments)
|
|
def _finished(res):
|
|
if IConsumer.providedBy(self._downloadable):
|
|
self._downloadable.unregisterProducer()
|
|
return res
|
|
d.addBoth(_finished)
|
|
def _failed(why):
|
|
self._output.fail(why)
|
|
return why
|
|
d.addErrback(_failed)
|
|
d.addCallback(self._done)
|
|
return d
|
|
|
|
def _get_all_shareholders(self):
|
|
dl = []
|
|
for (peerid,ss) in self._client.get_permuted_peers("storage",
|
|
self._storage_index):
|
|
d = ss.callRemote("get_buckets", self._storage_index)
|
|
d.addCallbacks(self._got_response, self._got_error)
|
|
dl.append(d)
|
|
return defer.DeferredList(dl)
|
|
|
|
def _got_response(self, buckets):
|
|
for sharenum, bucket in buckets.iteritems():
|
|
b = storage.ReadBucketProxy(bucket)
|
|
self.add_share_bucket(sharenum, b)
|
|
self._uri_extension_sources.append(b)
|
|
|
|
def add_share_bucket(self, sharenum, bucket):
|
|
# this is split out for the benefit of test_encode.py
|
|
self._share_buckets.append( (sharenum, bucket) )
|
|
|
|
def _got_error(self, f):
|
|
self._client.log("Somebody failed. -- %s" % (f,))
|
|
|
|
def bucket_failed(self, vbucket):
|
|
shnum = vbucket.sharenum
|
|
del self.active_buckets[shnum]
|
|
s = self._share_vbuckets[shnum]
|
|
# s is a set of ValidatedBucket instances
|
|
s.remove(vbucket)
|
|
# ... which might now be empty
|
|
if not s:
|
|
# there are no more buckets which can provide this share, so
|
|
# remove the key. This may prompt us to use a different share.
|
|
del self._share_vbuckets[shnum]
|
|
|
|
def _got_all_shareholders(self, res):
|
|
if len(self._share_buckets) < self._num_needed_shares:
|
|
raise NotEnoughPeersError
|
|
#for s in self._share_vbuckets.values():
|
|
# for vb in s:
|
|
# assert isinstance(vb, ValidatedBucket), \
|
|
# "vb is %s but should be a ValidatedBucket" % (vb,)
|
|
|
|
def _unpack_uri_extension_data(self, data):
|
|
return uri.unpack_extension(data)
|
|
|
|
def _obtain_uri_extension(self, ignored):
|
|
# all shareholders are supposed to have a copy of uri_extension, and
|
|
# all are supposed to be identical. We compute the hash of the data
|
|
# that comes back, and compare it against the version in our URI. If
|
|
# they don't match, ignore their data and try someone else.
|
|
def _validate(proposal, bucket):
|
|
h = hashutil.uri_extension_hash(proposal)
|
|
if h != self._uri_extension_hash:
|
|
self._fetch_failures["uri_extension"] += 1
|
|
msg = ("The copy of uri_extension we received from "
|
|
"%s was bad" % bucket)
|
|
raise BadURIExtensionHashValue(msg)
|
|
return self._unpack_uri_extension_data(proposal)
|
|
return self._obtain_validated_thing(None,
|
|
self._uri_extension_sources,
|
|
"uri_extension",
|
|
"get_uri_extension", (), _validate)
|
|
|
|
def _obtain_validated_thing(self, ignored, sources, name, methname, args,
|
|
validatorfunc):
|
|
if not sources:
|
|
raise NotEnoughPeersError("started with zero peers while fetching "
|
|
"%s" % name)
|
|
bucket = sources[0]
|
|
sources = sources[1:]
|
|
#d = bucket.callRemote(methname, *args)
|
|
d = bucket.startIfNecessary()
|
|
d.addCallback(lambda res: getattr(bucket, methname)(*args))
|
|
d.addCallback(validatorfunc, bucket)
|
|
def _bad(f):
|
|
self.log("WEIRD: %s from vbucket %s failed: %s" % (name, bucket, f))
|
|
if not sources:
|
|
raise NotEnoughPeersError("ran out of peers, last error was %s"
|
|
% (f,))
|
|
# try again with a different one
|
|
return self._obtain_validated_thing(None, sources, name,
|
|
methname, args, validatorfunc)
|
|
d.addErrback(_bad)
|
|
return d
|
|
|
|
def _got_uri_extension(self, uri_extension_data):
|
|
d = self._uri_extension_data = uri_extension_data
|
|
|
|
self._codec = codec.get_decoder_by_name(d['codec_name'])
|
|
self._codec.set_serialized_params(d['codec_params'])
|
|
self._tail_codec = codec.get_decoder_by_name(d['codec_name'])
|
|
self._tail_codec.set_serialized_params(d['tail_codec_params'])
|
|
|
|
crypttext_hash = d['crypttext_hash']
|
|
assert isinstance(crypttext_hash, str)
|
|
assert len(crypttext_hash) == 32
|
|
self._crypttext_hash = crypttext_hash
|
|
self._plaintext_hash = d['plaintext_hash']
|
|
self._roothash = d['share_root_hash']
|
|
|
|
self._segment_size = segment_size = d['segment_size']
|
|
self._total_segments = mathutil.div_ceil(self._size, segment_size)
|
|
self._current_segnum = 0
|
|
|
|
self._share_hashtree = hashtree.IncompleteHashTree(d['total_shares'])
|
|
self._share_hashtree.set_hashes({0: self._roothash})
|
|
|
|
def _get_hashtrees(self, res):
|
|
d = self._get_plaintext_hashtrees()
|
|
d.addCallback(self._get_crypttext_hashtrees)
|
|
d.addCallback(self._setup_hashtrees)
|
|
return d
|
|
|
|
def _get_plaintext_hashtrees(self):
|
|
def _validate_plaintext_hashtree(proposal, bucket):
|
|
if proposal[0] != self._uri_extension_data['plaintext_root_hash']:
|
|
self._fetch_failures["plaintext_hashroot"] += 1
|
|
msg = ("The copy of the plaintext_root_hash we received from"
|
|
" %s was bad" % bucket)
|
|
raise BadPlaintextHashValue(msg)
|
|
pt_hashtree = hashtree.IncompleteHashTree(self._total_segments)
|
|
pt_hashes = dict(list(enumerate(proposal)))
|
|
try:
|
|
pt_hashtree.set_hashes(pt_hashes)
|
|
except hashtree.BadHashError:
|
|
# the hashes they gave us were not self-consistent, even
|
|
# though the root matched what we saw in the uri_extension
|
|
# block
|
|
self._fetch_failures["plaintext_hashtree"] += 1
|
|
raise
|
|
self._plaintext_hashtree = pt_hashtree
|
|
d = self._obtain_validated_thing(None,
|
|
self._uri_extension_sources,
|
|
"plaintext_hashes",
|
|
"get_plaintext_hashes", (),
|
|
_validate_plaintext_hashtree)
|
|
return d
|
|
|
|
def _get_crypttext_hashtrees(self, res):
|
|
def _validate_crypttext_hashtree(proposal, bucket):
|
|
if proposal[0] != self._uri_extension_data['crypttext_root_hash']:
|
|
self._fetch_failures["crypttext_hashroot"] += 1
|
|
msg = ("The copy of the crypttext_root_hash we received from"
|
|
" %s was bad" % bucket)
|
|
raise BadCrypttextHashValue(msg)
|
|
ct_hashtree = hashtree.IncompleteHashTree(self._total_segments)
|
|
ct_hashes = dict(list(enumerate(proposal)))
|
|
try:
|
|
ct_hashtree.set_hashes(ct_hashes)
|
|
except hashtree.BadHashError:
|
|
self._fetch_failures["crypttext_hashtree"] += 1
|
|
raise
|
|
ct_hashtree.set_hashes(ct_hashes)
|
|
self._crypttext_hashtree = ct_hashtree
|
|
d = self._obtain_validated_thing(None,
|
|
self._uri_extension_sources,
|
|
"crypttext_hashes",
|
|
"get_crypttext_hashes", (),
|
|
_validate_crypttext_hashtree)
|
|
return d
|
|
|
|
def _setup_hashtrees(self, res):
|
|
self._output.setup_hashtrees(self._plaintext_hashtree,
|
|
self._crypttext_hashtree)
|
|
|
|
|
|
def _create_validated_buckets(self, ignored=None):
|
|
self._share_vbuckets = {}
|
|
for sharenum, bucket in self._share_buckets:
|
|
vbucket = ValidatedBucket(sharenum, bucket,
|
|
self._share_hashtree,
|
|
self._roothash,
|
|
self._total_segments)
|
|
s = self._share_vbuckets.setdefault(sharenum, set())
|
|
s.add(vbucket)
|
|
|
|
def _activate_enough_buckets(self):
|
|
"""either return a mapping from shnum to a ValidatedBucket that can
|
|
provide data for that share, or raise NotEnoughPeersError"""
|
|
|
|
while len(self.active_buckets) < self._num_needed_shares:
|
|
# need some more
|
|
handled_shnums = set(self.active_buckets.keys())
|
|
available_shnums = set(self._share_vbuckets.keys())
|
|
potential_shnums = list(available_shnums - handled_shnums)
|
|
if not potential_shnums:
|
|
raise NotEnoughPeersError
|
|
# choose a random share
|
|
shnum = random.choice(potential_shnums)
|
|
# and a random bucket that will provide it
|
|
validated_bucket = random.choice(list(self._share_vbuckets[shnum]))
|
|
self.active_buckets[shnum] = validated_bucket
|
|
return self.active_buckets
|
|
|
|
|
|
def _download_all_segments(self, res):
|
|
# the promise: upon entry to this function, self._share_vbuckets
|
|
# contains enough buckets to complete the download, and some extra
|
|
# ones to tolerate some buckets dropping out or having errors.
|
|
# self._share_vbuckets is a dictionary that maps from shnum to a set
|
|
# of ValidatedBuckets, which themselves are wrappers around
|
|
# RIBucketReader references.
|
|
self.active_buckets = {} # k: shnum, v: ValidatedBucket instance
|
|
|
|
d = defer.succeed(None)
|
|
for segnum in range(self._total_segments-1):
|
|
d.addCallback(self._download_segment, segnum)
|
|
# this pause, at the end of write, prevents pre-fetch from
|
|
# happening until the consumer is ready for more data.
|
|
d.addCallback(self._check_for_pause)
|
|
d.addCallback(self._download_tail_segment, self._total_segments-1)
|
|
return d
|
|
|
|
def _check_for_pause(self, res):
|
|
if self._paused:
|
|
d = defer.Deferred()
|
|
self._paused.addCallback(lambda ignored: d.callback(res))
|
|
return d
|
|
if self._stopped:
|
|
raise DownloadStopped("our Consumer called stopProducing()")
|
|
return res
|
|
|
|
def _download_segment(self, res, segnum):
|
|
self.log("downloading seg#%d of %d (%d%%)"
|
|
% (segnum, self._total_segments,
|
|
100.0 * segnum / self._total_segments))
|
|
# memory footprint: when the SegmentDownloader finishes pulling down
|
|
# all shares, we have 1*segment_size of usage.
|
|
segmentdler = SegmentDownloader(self, segnum, self._num_needed_shares)
|
|
d = segmentdler.start()
|
|
# pause before using more memory
|
|
d.addCallback(self._check_for_pause)
|
|
# while the codec does its job, we hit 2*segment_size
|
|
d.addCallback(lambda (shares, shareids):
|
|
self._codec.decode(shares, shareids))
|
|
# once the codec is done, we drop back to 1*segment_size, because
|
|
# 'shares' goes out of scope. The memory usage is all in the
|
|
# plaintext now, spread out into a bunch of tiny buffers.
|
|
|
|
# pause/check-for-stop just before writing, to honor stopProducing
|
|
d.addCallback(self._check_for_pause)
|
|
def _done(buffers):
|
|
# we start by joining all these buffers together into a single
|
|
# string. This makes Output.write easier, since it wants to hash
|
|
# data one segment at a time anyways, and doesn't impact our
|
|
# memory footprint since we're already peaking at 2*segment_size
|
|
# inside the codec a moment ago.
|
|
segment = "".join(buffers)
|
|
del buffers
|
|
# we're down to 1*segment_size right now, but write_segment()
|
|
# will decrypt a copy of the segment internally, which will push
|
|
# us up to 2*segment_size while it runs.
|
|
self._output.write_segment(segment)
|
|
d.addCallback(_done)
|
|
return d
|
|
|
|
def _download_tail_segment(self, res, segnum):
|
|
self.log("downloading seg#%d of %d (%d%%)"
|
|
% (segnum, self._total_segments,
|
|
100.0 * segnum / self._total_segments))
|
|
segmentdler = SegmentDownloader(self, segnum, self._num_needed_shares)
|
|
d = segmentdler.start()
|
|
# pause before using more memory
|
|
d.addCallback(self._check_for_pause)
|
|
d.addCallback(lambda (shares, shareids):
|
|
self._tail_codec.decode(shares, shareids))
|
|
# pause/check-for-stop just before writing, to honor stopProducing
|
|
d.addCallback(self._check_for_pause)
|
|
def _done(buffers):
|
|
# trim off any padding added by the upload side
|
|
segment = "".join(buffers)
|
|
del buffers
|
|
# we never send empty segments. If the data was an exact multiple
|
|
# of the segment size, the last segment will be full.
|
|
pad_size = mathutil.pad_size(self._size, self._segment_size)
|
|
tail_size = self._segment_size - pad_size
|
|
segment = segment[:tail_size]
|
|
self._output.write_segment(segment)
|
|
d.addCallback(_done)
|
|
return d
|
|
|
|
def _done(self, res):
|
|
self.log("download done")
|
|
self._output.close()
|
|
if self.check_crypttext_hash:
|
|
_assert(self._crypttext_hash == self._output.crypttext_hash,
|
|
"bad crypttext_hash: computed=%s, expected=%s" %
|
|
(idlib.b2a(self._output.crypttext_hash),
|
|
idlib.b2a(self._crypttext_hash)))
|
|
if self.check_plaintext_hash:
|
|
_assert(self._plaintext_hash == self._output.plaintext_hash,
|
|
"bad plaintext_hash: computed=%s, expected=%s" %
|
|
(idlib.b2a(self._output.plaintext_hash),
|
|
idlib.b2a(self._plaintext_hash)))
|
|
_assert(self._output.length == self._size,
|
|
got=self._output.length, expected=self._size)
|
|
return self._output.finish()
|
|
|
|
class LiteralDownloader:
|
|
def __init__(self, client, u, downloadable):
|
|
self._uri = IFileURI(u)
|
|
assert isinstance(self._uri, uri.LiteralFileURI)
|
|
self._downloadable = downloadable
|
|
|
|
def start(self):
|
|
data = self._uri.data
|
|
self._downloadable.open(len(data))
|
|
self._downloadable.write(data)
|
|
self._downloadable.close()
|
|
return defer.maybeDeferred(self._downloadable.finish)
|
|
|
|
|
|
class FileName:
|
|
implements(IDownloadTarget)
|
|
def __init__(self, filename):
|
|
self._filename = filename
|
|
self.f = None
|
|
def open(self, size):
|
|
self.f = open(self._filename, "wb")
|
|
return self.f
|
|
def write(self, data):
|
|
self.f.write(data)
|
|
def close(self):
|
|
if self.f:
|
|
self.f.close()
|
|
def fail(self, why):
|
|
if self.f:
|
|
self.f.close()
|
|
os.unlink(self._filename)
|
|
def register_canceller(self, cb):
|
|
pass # we won't use it
|
|
def finish(self):
|
|
pass
|
|
|
|
class Data:
|
|
implements(IDownloadTarget)
|
|
def __init__(self):
|
|
self._data = []
|
|
def open(self, size):
|
|
pass
|
|
def write(self, data):
|
|
self._data.append(data)
|
|
def close(self):
|
|
self.data = "".join(self._data)
|
|
del self._data
|
|
def fail(self, why):
|
|
del self._data
|
|
def register_canceller(self, cb):
|
|
pass # we won't use it
|
|
def finish(self):
|
|
return self.data
|
|
|
|
class FileHandle:
|
|
"""Use me to download data to a pre-defined filehandle-like object. I
|
|
will use the target's write() method. I will *not* close the filehandle:
|
|
I leave that up to the originator of the filehandle. The download process
|
|
will return the filehandle when it completes.
|
|
"""
|
|
implements(IDownloadTarget)
|
|
def __init__(self, filehandle):
|
|
self._filehandle = filehandle
|
|
def open(self, size):
|
|
pass
|
|
def write(self, data):
|
|
self._filehandle.write(data)
|
|
def close(self):
|
|
# the originator of the filehandle reserves the right to close it
|
|
pass
|
|
def fail(self, why):
|
|
pass
|
|
def register_canceller(self, cb):
|
|
pass
|
|
def finish(self):
|
|
return self._filehandle
|
|
|
|
class Downloader(service.MultiService):
|
|
"""I am a service that allows file downloading.
|
|
"""
|
|
implements(IDownloader)
|
|
name = "downloader"
|
|
|
|
def download(self, u, t):
|
|
assert self.parent
|
|
assert self.running
|
|
u = IFileURI(u)
|
|
t = IDownloadTarget(t)
|
|
assert t.write
|
|
assert t.close
|
|
if isinstance(u, uri.LiteralFileURI):
|
|
dl = LiteralDownloader(self.parent, u, t)
|
|
elif isinstance(u, uri.CHKFileURI):
|
|
dl = FileDownloader(self.parent, u, t)
|
|
else:
|
|
raise RuntimeError("I don't know how to download a %s" % u)
|
|
d = dl.start()
|
|
return d
|
|
|
|
# utility functions
|
|
def download_to_data(self, uri):
|
|
return self.download(uri, Data())
|
|
def download_to_filename(self, uri, filename):
|
|
return self.download(uri, FileName(filename))
|
|
def download_to_filehandle(self, uri, filehandle):
|
|
return self.download(uri, FileHandle(filehandle))
|
|
|
|
|