mirror of
https://github.com/tahoe-lafs/tahoe-lafs.git
synced 2024-12-24 07:06:41 +00:00
add new test for doing an encode/decode round trip, and make it almost work
This commit is contained in:
parent
2593ce42c3
commit
234b2f354e
@ -7,6 +7,7 @@ from twisted.application import service
|
||||
|
||||
from allmydata.util import idlib, bencode, mathutil
|
||||
from allmydata.util.deferredutil import DeferredListShouldSucceed
|
||||
from allmydata.util.assertutil import _assert
|
||||
from allmydata import codec
|
||||
from allmydata.Crypto.Cipher import AES
|
||||
from allmydata.uri import unpack_uri
|
||||
@ -27,13 +28,24 @@ class Output:
|
||||
counterstart="\x00"*16)
|
||||
self._verifierid_hasher = sha.new(netstring("allmydata_v1_verifierid"))
|
||||
self._fileid_hasher = sha.new(netstring("allmydata_v1_fileid"))
|
||||
self.length = 0
|
||||
|
||||
def open(self):
|
||||
self.downloadable.open()
|
||||
|
||||
def write(self, crypttext):
|
||||
self.length += len(crypttext)
|
||||
self._verifierid_hasher.update(crypttext)
|
||||
plaintext = self._decryptor.decrypt(crypttext)
|
||||
self._fileid_hasher.update(plaintext)
|
||||
self.downloadable.write(plaintext)
|
||||
def finish(self):
|
||||
|
||||
def close(self):
|
||||
self.verifierid = self._verifierid_hasher.digest()
|
||||
self.fileid = self._fileid_hasher.digest()
|
||||
self.downloadable.close()
|
||||
|
||||
def finish(self):
|
||||
return self.downloadable.finish()
|
||||
|
||||
class BlockDownloader:
|
||||
@ -51,10 +63,12 @@ class BlockDownloader:
|
||||
self.parent.hold_block(self.blocknum, data)
|
||||
|
||||
def _got_block_error(self, f):
|
||||
log.msg("BlockDownloader[%d] got error: %s" % (self.blocknum, f))
|
||||
self.parent.bucket_failed(self.blocknum, self.bucket)
|
||||
|
||||
class SegmentDownloader:
|
||||
def __init__(self, segmentnumber, needed_shares):
|
||||
def __init__(self, parent, segmentnumber, needed_shares):
|
||||
self.parent = parent
|
||||
self.segmentnumber = segmentnumber
|
||||
self.needed_blocks = needed_shares
|
||||
self.blocks = {} # k: blocknum, v: data
|
||||
@ -66,7 +80,14 @@ class SegmentDownloader:
|
||||
d = self._try()
|
||||
def _done(res):
|
||||
if len(self.blocks) >= self.needed_blocks:
|
||||
return self.blocks
|
||||
# we only need self.needed_blocks blocks
|
||||
# we want to get the smallest blockids, because they are
|
||||
# more likely to be fast "primary blocks"
|
||||
blockids = sorted(self.blocks.keys())[:self.needed_blocks]
|
||||
blocks = []
|
||||
for blocknum in blockids:
|
||||
blocks.append(self.blocks[blocknum])
|
||||
return (blocks, blockids)
|
||||
else:
|
||||
return self._download()
|
||||
d.addCallback(_done)
|
||||
@ -79,14 +100,19 @@ class SegmentDownloader:
|
||||
if not otherblocknums:
|
||||
raise NotEnoughPeersError
|
||||
blocknum = random.choice(otherblocknums)
|
||||
self.parent.active_buckets[blocknum] = random.choice(self.parent._share_buckets[blocknum])
|
||||
bucket = random.choice(list(self.parent._share_buckets[blocknum]))
|
||||
self.parent.active_buckets[blocknum] = bucket
|
||||
|
||||
# Now we have enough buckets, in self.parent.active_buckets.
|
||||
l = []
|
||||
|
||||
# in test cases, bd.start might mutate active_buckets right away, so
|
||||
# we need to put off calling start() until we've iterated all the way
|
||||
# through it
|
||||
downloaders = []
|
||||
for blocknum, bucket in self.parent.active_buckets.iteritems():
|
||||
bd = BlockDownloader(bucket, blocknum, self)
|
||||
d = bd.start(self.segmentnumber)
|
||||
l.append(d)
|
||||
downloaders.append(bd)
|
||||
l = [bd.start(self.segmentnumber) for bd in downloaders]
|
||||
return defer.DeferredList(l)
|
||||
|
||||
def hold_block(self, blocknum, data):
|
||||
@ -115,7 +141,11 @@ class FileDownloader:
|
||||
self._total_segments = mathutil.div_ceil(size, segment_size)
|
||||
self._current_segnum = 0
|
||||
self._segment_size = segment_size
|
||||
self._needed_shares = self._decoder.get_needed_shares()
|
||||
self._size = size
|
||||
self._num_needed_shares = self._decoder.get_needed_shares()
|
||||
|
||||
key = "\x00" * 16
|
||||
self._output = Output(downloadable, key)
|
||||
|
||||
# future:
|
||||
# self._share_hash_tree = ??
|
||||
@ -134,9 +164,6 @@ class FileDownloader:
|
||||
self.active_buckets = {} # k: shnum, v: bucket
|
||||
self._share_buckets = {} # k: shnum, v: set of buckets
|
||||
|
||||
key = "\x00" * 16
|
||||
self._output = Output(self._downloadable, key)
|
||||
|
||||
d = defer.maybeDeferred(self._get_all_shareholders)
|
||||
d.addCallback(self._got_all_shareholders)
|
||||
d.addCallback(self._download_all_segments)
|
||||
@ -160,11 +187,12 @@ class FileDownloader:
|
||||
self._client.log("Somebody failed. -- %s" % (f,))
|
||||
|
||||
def _got_all_shareholders(self, res):
|
||||
if len(self._share_buckets) < self._needed_shares:
|
||||
if len(self._share_buckets) < self._num_needed_shares:
|
||||
raise NotEnoughPeersError
|
||||
|
||||
self.active_buckets = {}
|
||||
|
||||
self._output.open()
|
||||
|
||||
def _download_all_segments(self):
|
||||
d = self._download_segment(self._current_segnum)
|
||||
def _done(res):
|
||||
@ -175,74 +203,33 @@ class FileDownloader:
|
||||
return d
|
||||
|
||||
def _download_segment(self, segnum):
|
||||
segmentdler = SegmentDownloader(segnum, self._needed_shares)
|
||||
segmentdler = SegmentDownloader(self, segnum, self._num_needed_shares)
|
||||
d = segmentdler.start()
|
||||
d.addCallback(self._decoder.decode)
|
||||
d.addCallback(lambda (shares, shareids):
|
||||
self._decoder.decode(shares, shareids))
|
||||
def _done(res):
|
||||
self._current_segnum += 1
|
||||
if self._current_segnum == self._total_segments:
|
||||
data = ''.join(res)
|
||||
padsize = mathutil.pad_size(self._size, self._segment_size)
|
||||
data = data[:-padsize]
|
||||
self.output.write(data)
|
||||
self._output.write(data)
|
||||
else:
|
||||
for buf in res:
|
||||
self.output.write(buf)
|
||||
self._output.write(buf)
|
||||
d.addCallback(_done)
|
||||
return d
|
||||
|
||||
def _done(self, res):
|
||||
self._output.close()
|
||||
#print "VERIFIERID: %s" % idlib.b2a(self._output.verifierid)
|
||||
#print "FILEID: %s" % idlib.b2a(self._output.fileid)
|
||||
#assert self._verifierid == self._output.verifierid
|
||||
#assert self._fileid = self._output.fileid
|
||||
_assert(self._output.length == self._size,
|
||||
got=self._output.length, expected=self._size)
|
||||
return self._output.finish()
|
||||
|
||||
def _write_data(self, data):
|
||||
self._verifierid_hasher.update(data)
|
||||
|
||||
|
||||
|
||||
# old stuff
|
||||
def _got_all_peers(self, res):
|
||||
all_buckets = []
|
||||
for peerid, buckets in self.landlords:
|
||||
all_buckets.extend(buckets)
|
||||
# TODO: try to avoid pulling multiple shares from the same peer
|
||||
all_buckets = all_buckets[:self.needed_shares]
|
||||
# retrieve all shares
|
||||
dl = []
|
||||
shares = []
|
||||
shareids = []
|
||||
for (bucket_num, bucket) in all_buckets:
|
||||
d0 = bucket.callRemote("get_metadata")
|
||||
d1 = bucket.callRemote("read")
|
||||
d2 = DeferredListShouldSucceed([d0, d1])
|
||||
def _got(res):
|
||||
shareid_s, sharedata = res
|
||||
shareid = bencode.bdecode(shareid_s)
|
||||
shares.append(sharedata)
|
||||
shareids.append(shareid)
|
||||
d2.addCallback(_got)
|
||||
dl.append(d2)
|
||||
d = DeferredListShouldSucceed(dl)
|
||||
|
||||
d.addCallback(lambda res: self._decoder.decode(shares, shareids))
|
||||
|
||||
def _write(decoded_shares):
|
||||
data = "".join(decoded_shares)
|
||||
self._target.open()
|
||||
hasher = sha.new(netstring("allmydata_v1_verifierid"))
|
||||
hasher.update(data)
|
||||
vid = hasher.digest()
|
||||
assert self._verifierid == vid, "%s != %s" % (idlib.b2a(self._verifierid), idlib.b2a(vid))
|
||||
self._target.write(data)
|
||||
d.addCallback(_write)
|
||||
|
||||
def _done(res):
|
||||
self._target.close()
|
||||
return self._target.finish()
|
||||
def _fail(res):
|
||||
self._target.fail()
|
||||
return res
|
||||
d.addCallbacks(_done, _fail)
|
||||
return d
|
||||
|
||||
def netstring(s):
|
||||
return "%d:%s," % (len(s), s)
|
||||
|
@ -2,7 +2,8 @@
|
||||
|
||||
from twisted.trial import unittest
|
||||
from twisted.internet import defer
|
||||
from allmydata import encode_new
|
||||
from allmydata import encode_new, download
|
||||
from allmydata.uri import pack_uri
|
||||
from cStringIO import StringIO
|
||||
|
||||
class MyEncoder(encode_new.Encoder):
|
||||
@ -24,8 +25,8 @@ class Encode(unittest.TestCase):
|
||||
class FakePeer:
|
||||
def __init__(self):
|
||||
self.blocks = {}
|
||||
self.blockhashes = None
|
||||
self.sharehashes = None
|
||||
self.block_hashes = None
|
||||
self.share_hashes = None
|
||||
self.closed = False
|
||||
|
||||
def callRemote(self, methname, *args, **kwargs):
|
||||
@ -41,19 +42,29 @@ class FakePeer:
|
||||
|
||||
def put_block_hashes(self, blockhashes):
|
||||
assert not self.closed
|
||||
assert self.blockhashes is None
|
||||
self.blockhashes = blockhashes
|
||||
assert self.block_hashes is None
|
||||
self.block_hashes = blockhashes
|
||||
|
||||
def put_share_hashes(self, sharehashes):
|
||||
assert not self.closed
|
||||
assert self.sharehashes is None
|
||||
self.sharehashes = sharehashes
|
||||
assert self.share_hashes is None
|
||||
self.share_hashes = sharehashes
|
||||
|
||||
def close(self):
|
||||
assert not self.closed
|
||||
self.closed = True
|
||||
|
||||
|
||||
def get_block(self, blocknum):
|
||||
assert isinstance(blocknum, int)
|
||||
return self.blocks[blocknum]
|
||||
|
||||
def get_block_hashes(self):
|
||||
return self.block_hashes
|
||||
def get_share_hashes(self):
|
||||
return self.share_hashes
|
||||
|
||||
|
||||
class UpDown(unittest.TestCase):
|
||||
def test_send(self):
|
||||
e = encode_new.Encoder()
|
||||
@ -79,16 +90,60 @@ class UpDown(unittest.TestCase):
|
||||
for i,peer in enumerate(all_shareholders):
|
||||
self.failUnless(peer.closed)
|
||||
self.failUnlessEqual(len(peer.blocks), NUM_SEGMENTS)
|
||||
#self.failUnlessEqual(len(peer.blockhashes), NUM_SEGMENTS)
|
||||
#self.failUnlessEqual(len(peer.block_hashes), NUM_SEGMENTS)
|
||||
# that isn't true: each peer gets a full tree, so it's more
|
||||
# like 2n-1 but with rounding to a power of two
|
||||
for h in peer.blockhashes:
|
||||
for h in peer.block_hashes:
|
||||
self.failUnlessEqual(len(h), 32)
|
||||
#self.failUnlessEqual(len(peer.sharehashes), NUM_SHARES)
|
||||
#self.failUnlessEqual(len(peer.share_hashes), NUM_SHARES)
|
||||
# that isn't true: each peer only gets the chain they need
|
||||
for (hashnum, h) in peer.sharehashes:
|
||||
for (hashnum, h) in peer.share_hashes:
|
||||
self.failUnless(isinstance(hashnum, int))
|
||||
self.failUnlessEqual(len(h), 32)
|
||||
d.addCallback(_check)
|
||||
|
||||
return d
|
||||
|
||||
def test_send_and_recover(self):
|
||||
e = encode_new.Encoder()
|
||||
data = "happy happy joy joy" * 4
|
||||
e.setup(StringIO(data))
|
||||
NUM_SHARES = 100
|
||||
assert e.num_shares == NUM_SHARES # else we'll be completely confused
|
||||
e.segment_size = 25 # force use of multiple segments
|
||||
e.setup_codec() # need to rebuild the codec for that change
|
||||
NUM_SEGMENTS = 4
|
||||
assert (NUM_SEGMENTS-1)*e.segment_size < len(data) <= NUM_SEGMENTS*e.segment_size
|
||||
shareholders = {}
|
||||
all_shareholders = []
|
||||
for shnum in range(NUM_SHARES):
|
||||
peer = FakePeer()
|
||||
shareholders[shnum] = peer
|
||||
all_shareholders.append(peer)
|
||||
e.set_shareholders(shareholders)
|
||||
d = e.start()
|
||||
def _uploaded(roothash):
|
||||
URI = pack_uri(e._codec.get_encoder_type(),
|
||||
e._codec.get_serialized_params(),
|
||||
"V" * 20,
|
||||
roothash,
|
||||
e.required_shares,
|
||||
e.num_shares,
|
||||
e.file_size,
|
||||
e.segment_size)
|
||||
client = None
|
||||
target = download.Data()
|
||||
fd = download.FileDownloader(client, URI, target)
|
||||
fd._share_buckets = {}
|
||||
for shnum in range(NUM_SHARES):
|
||||
fd._share_buckets[shnum] = set([all_shareholders[shnum]])
|
||||
fd._got_all_shareholders(None)
|
||||
d2 = fd._download_all_segments()
|
||||
d2.addCallback(fd._done)
|
||||
return d2
|
||||
d.addCallback(_uploaded)
|
||||
def _downloaded(newdata):
|
||||
self.failUnless(newdata == data)
|
||||
d.addCallback(_downloaded)
|
||||
|
||||
return d
|
||||
|
Loading…
Reference in New Issue
Block a user