mirror of
https://github.com/tahoe-lafs/tahoe-lafs.git
synced 2024-12-20 05:28:04 +00:00
Merge branch 'sdmf-partial-2'
This cleans up the mutable-retrieve code a bit, and should fix some corner cases where an offset/size combinations that reads the last byte of the file (but not the first) could cause an assert to fire, making the download hang. Should address ticket:2459 and ticket:2462.
This commit is contained in:
commit
8f615c8551
@ -7,8 +7,10 @@ from twisted.python import failure
|
||||
from twisted.internet.interfaces import IPushProducer, IConsumer
|
||||
from foolscap.api import eventually, fireEventually, DeadReferenceError, \
|
||||
RemoteException
|
||||
|
||||
from allmydata.interfaces import IRetrieveStatus, NotEnoughSharesError, \
|
||||
DownloadStopped, MDMF_VERSION, SDMF_VERSION
|
||||
from allmydata.util.assertutil import _assert, precondition
|
||||
from allmydata.util import hashutil, log, mathutil, deferredutil
|
||||
from allmydata.util.dictutil import DictOfSets
|
||||
from allmydata import hashtree, codec
|
||||
@ -115,6 +117,10 @@ class Retrieve:
|
||||
self.servermap = servermap
|
||||
assert self._node.get_pubkey()
|
||||
self.verinfo = verinfo
|
||||
# TODO: make it possible to use self.verinfo.datalength instead
|
||||
(seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
|
||||
offsets_tuple) = self.verinfo
|
||||
self._data_length = datalength
|
||||
# during repair, we may be called upon to grab the private key, since
|
||||
# it wasn't picked up during a verify=False checker run, and we'll
|
||||
# need it for repair to generate a new version.
|
||||
@ -145,8 +151,6 @@ class Retrieve:
|
||||
self._status.set_helper(False)
|
||||
self._status.set_progress(0.0)
|
||||
self._status.set_active(True)
|
||||
(seqnum, root_hash, IV, segsize, datalength, k, N, prefix,
|
||||
offsets_tuple) = self.verinfo
|
||||
self._status.set_size(datalength)
|
||||
self._status.set_encoding(k, N)
|
||||
self.readers = {}
|
||||
@ -230,21 +234,37 @@ class Retrieve:
|
||||
|
||||
|
||||
def download(self, consumer=None, offset=0, size=None):
|
||||
assert IConsumer.providedBy(consumer) or self._verify
|
||||
|
||||
precondition(self._verify or IConsumer.providedBy(consumer))
|
||||
if size is None:
|
||||
size = self._data_length - offset
|
||||
if self._verify:
|
||||
_assert(size == self._data_length, (size, self._data_length))
|
||||
self.log("starting download")
|
||||
self._done_deferred = defer.Deferred()
|
||||
if consumer:
|
||||
self._consumer = consumer
|
||||
# we provide IPushProducer, so streaming=True, per
|
||||
# IConsumer.
|
||||
# we provide IPushProducer, so streaming=True, per IConsumer.
|
||||
self._consumer.registerProducer(self, streaming=True)
|
||||
self._started = time.time()
|
||||
self._started_fetching = time.time()
|
||||
if size == 0:
|
||||
# short-circuit the rest of the process
|
||||
self._done()
|
||||
else:
|
||||
self._start_download(consumer, offset, size)
|
||||
return self._done_deferred
|
||||
|
||||
def _start_download(self, consumer, offset, size):
|
||||
precondition((0 <= offset < self._data_length)
|
||||
and (size > 0)
|
||||
and (offset+size <= self._data_length),
|
||||
(offset, size, self._data_length))
|
||||
|
||||
self._done_deferred = defer.Deferred()
|
||||
self._offset = offset
|
||||
self._read_length = size
|
||||
self._setup_encoding_parameters()
|
||||
self._setup_download()
|
||||
self.log("starting download")
|
||||
self._started_fetching = time.time()
|
||||
|
||||
# The download process beyond this is a state machine.
|
||||
# _add_active_servers will select the servers that we want to use
|
||||
# for the download, and then attempt to start downloading. After
|
||||
@ -254,7 +274,6 @@ class Retrieve:
|
||||
# will errback. Otherwise, it will eventually callback with the
|
||||
# contents of the mutable file.
|
||||
self.loop()
|
||||
return self._done_deferred
|
||||
|
||||
def loop(self):
|
||||
d = fireEventually(None) # avoid #237 recursion limit problem
|
||||
@ -265,7 +284,6 @@ class Retrieve:
|
||||
d.addErrback(self._error)
|
||||
|
||||
def _setup_download(self):
|
||||
self._started = time.time()
|
||||
self._status.set_status("Retrieving Shares")
|
||||
|
||||
# how many shares do we need?
|
||||
@ -325,6 +343,8 @@ class Retrieve:
|
||||
"""
|
||||
# We don't need the block hash trees in this case.
|
||||
self._block_hash_trees = None
|
||||
self._offset = 0
|
||||
self._read_length = self._data_length
|
||||
self._setup_encoding_parameters()
|
||||
|
||||
# _decode_blocks() expects the output of a gatherResults that
|
||||
@ -352,7 +372,7 @@ class Retrieve:
|
||||
self._required_shares = k
|
||||
self._total_shares = n
|
||||
self._segment_size = segsize
|
||||
self._data_length = datalength
|
||||
#self._data_length = datalength # set during __init__()
|
||||
|
||||
if not IV:
|
||||
self._version = MDMF_VERSION
|
||||
@ -408,34 +428,29 @@ class Retrieve:
|
||||
# offset we were given.
|
||||
start = self._offset // self._segment_size
|
||||
|
||||
assert start < self._num_segments
|
||||
_assert(start <= self._num_segments,
|
||||
start=start, num_segments=self._num_segments,
|
||||
offset=self._offset, segment_size=self._segment_size)
|
||||
self._start_segment = start
|
||||
self.log("got start segment: %d" % self._start_segment)
|
||||
else:
|
||||
self._start_segment = 0
|
||||
|
||||
# We might want to read only part of the file, and need to figure out
|
||||
# where to stop reading. Our end segment is the last segment
|
||||
# containing part of the segment that we were asked to read.
|
||||
_assert(self._read_length > 0, self._read_length)
|
||||
end_data = self._offset + self._read_length
|
||||
|
||||
# If self._read_length is None, then we want to read the whole
|
||||
# file. Otherwise, we want to read only part of the file, and
|
||||
# need to figure out where to stop reading.
|
||||
if self._read_length is not None:
|
||||
# our end segment is the last segment containing part of the
|
||||
# segment that we were asked to read.
|
||||
self.log("got read length %d" % self._read_length)
|
||||
if self._read_length != 0:
|
||||
end_data = self._offset + self._read_length
|
||||
|
||||
# We don't actually need to read the byte at end_data,
|
||||
# but the one before it.
|
||||
end = (end_data - 1) // self._segment_size
|
||||
|
||||
assert end < self._num_segments
|
||||
self._last_segment = end
|
||||
else:
|
||||
self._last_segment = self._start_segment
|
||||
self.log("got end segment: %d" % self._last_segment)
|
||||
else:
|
||||
self._last_segment = self._num_segments - 1
|
||||
# We don't actually need to read the byte at end_data, but the one
|
||||
# before it.
|
||||
end = (end_data - 1) // self._segment_size
|
||||
_assert(0 <= end < self._num_segments,
|
||||
end=end, num_segments=self._num_segments,
|
||||
end_data=end_data, offset=self._offset,
|
||||
read_length=self._read_length, segment_size=self._segment_size)
|
||||
self._last_segment = end
|
||||
self.log("got end segment: %d" % self._last_segment)
|
||||
|
||||
self._current_segment = self._start_segment
|
||||
|
||||
@ -568,6 +583,7 @@ class Retrieve:
|
||||
I download, validate, decode, decrypt, and assemble the segment
|
||||
that this Retrieve is currently responsible for downloading.
|
||||
"""
|
||||
|
||||
if self._current_segment > self._last_segment:
|
||||
# No more segments to download, we're done.
|
||||
self.log("got plaintext, done")
|
||||
@ -654,33 +670,27 @@ class Retrieve:
|
||||
target that is handling the file download.
|
||||
"""
|
||||
self.log("got plaintext for segment %d" % self._current_segment)
|
||||
if self._current_segment == self._start_segment:
|
||||
# We're on the first segment. It's possible that we want
|
||||
# only some part of the end of this segment, and that we
|
||||
# just downloaded the whole thing to get that part. If so,
|
||||
# we need to account for that and give the reader just the
|
||||
# data that they want.
|
||||
n = self._offset % self._segment_size
|
||||
self.log("stripping %d bytes off of the first segment" % n)
|
||||
self.log("original segment length: %d" % len(segment))
|
||||
segment = segment[n:]
|
||||
self.log("new segment length: %d" % len(segment))
|
||||
|
||||
if self._current_segment == self._last_segment and self._read_length is not None:
|
||||
# We're on the last segment. It's possible that we only want
|
||||
# part of the beginning of this segment, and that we
|
||||
# downloaded the whole thing anyway. Make sure to give the
|
||||
# caller only the portion of the segment that they want to
|
||||
# receive.
|
||||
extra = self._read_length
|
||||
if self._start_segment != self._last_segment:
|
||||
extra -= self._segment_size - \
|
||||
(self._offset % self._segment_size)
|
||||
extra %= self._segment_size
|
||||
self.log("original segment length: %d" % len(segment))
|
||||
segment = segment[:extra]
|
||||
self.log("new segment length: %d" % len(segment))
|
||||
self.log("only taking %d bytes of the last segment" % extra)
|
||||
if self._read_length == 0:
|
||||
self.log("on first+last segment, size=0, using 0 bytes")
|
||||
segment = b""
|
||||
|
||||
if self._current_segment == self._last_segment:
|
||||
# trim off the tail
|
||||
wanted = (self._offset + self._read_length) % self._segment_size
|
||||
if wanted != 0:
|
||||
self.log("on the last segment: using first %d bytes" % wanted)
|
||||
segment = segment[:wanted]
|
||||
else:
|
||||
self.log("on the last segment: using all %d bytes" %
|
||||
len(segment))
|
||||
|
||||
if self._current_segment == self._start_segment:
|
||||
# Trim off the head, if offset != 0. This should also work if
|
||||
# start==last, because we trim the tail first.
|
||||
skip = self._offset % self._segment_size
|
||||
self.log("on the first segment: skipping first %d bytes" % skip)
|
||||
segment = segment[skip:]
|
||||
|
||||
if not self._verify:
|
||||
self._consumer.write(segment)
|
||||
|
@ -17,7 +17,7 @@ from allmydata.interfaces import IRepairResults, ICheckAndRepairResults, \
|
||||
from allmydata.monitor import Monitor
|
||||
from allmydata.test.common import ShouldFailMixin
|
||||
from allmydata.test.no_network import GridTestMixin
|
||||
from foolscap.api import eventually, fireEventually
|
||||
from foolscap.api import eventually, fireEventually, flushEventualQueue
|
||||
from foolscap.logging import log
|
||||
from allmydata.storage_client import StorageFarmBroker
|
||||
from allmydata.storage.common import storage_index_to_dir
|
||||
@ -916,11 +916,13 @@ class PublishMixin:
|
||||
d.addCallback(_created)
|
||||
return d
|
||||
|
||||
def publish_mdmf(self):
|
||||
def publish_mdmf(self, data=None):
|
||||
# like publish_one, except that the result is guaranteed to be
|
||||
# an MDMF file.
|
||||
# self.CONTENTS should have more than one segment.
|
||||
self.CONTENTS = "This is an MDMF file" * 100000
|
||||
if data is None:
|
||||
data = "This is an MDMF file" * 100000
|
||||
self.CONTENTS = data
|
||||
self.uploadable = MutableData(self.CONTENTS)
|
||||
self._storage = FakeStorage()
|
||||
self._nodemaker = make_nodemaker(self._storage)
|
||||
@ -933,10 +935,12 @@ class PublishMixin:
|
||||
return d
|
||||
|
||||
|
||||
def publish_sdmf(self):
|
||||
def publish_sdmf(self, data=None):
|
||||
# like publish_one, except that the result is guaranteed to be
|
||||
# an SDMF file
|
||||
self.CONTENTS = "This is an SDMF file" * 1000
|
||||
if data is None:
|
||||
data = "This is an SDMF file" * 1000
|
||||
self.CONTENTS = data
|
||||
self.uploadable = MutableData(self.CONTENTS)
|
||||
self._storage = FakeStorage()
|
||||
self._nodemaker = make_nodemaker(self._storage)
|
||||
@ -948,20 +952,6 @@ class PublishMixin:
|
||||
d.addCallback(_created)
|
||||
return d
|
||||
|
||||
def publish_empty_sdmf(self):
|
||||
self.CONTENTS = ""
|
||||
self.uploadable = MutableData(self.CONTENTS)
|
||||
self._storage = FakeStorage()
|
||||
self._nodemaker = make_nodemaker(self._storage, keysize=None)
|
||||
self._storage_broker = self._nodemaker.storage_broker
|
||||
d = self._nodemaker.create_mutable_file(self.uploadable,
|
||||
version=SDMF_VERSION)
|
||||
def _created(node):
|
||||
self._fn = node
|
||||
self._fn2 = self._nodemaker.create_from_cap(node.get_uri())
|
||||
d.addCallback(_created)
|
||||
return d
|
||||
|
||||
|
||||
def publish_multiple(self, version=0):
|
||||
self.CONTENTS = ["Contents 0",
|
||||
@ -1903,6 +1893,19 @@ class Checker(unittest.TestCase, CheckerMixin, PublishMixin):
|
||||
"test_verify_mdmf_bad_encprivkey_uncheckable")
|
||||
return d
|
||||
|
||||
def test_verify_sdmf_empty(self):
|
||||
d = self.publish_sdmf("")
|
||||
d.addCallback(lambda ignored: self._fn.check(Monitor(), verify=True))
|
||||
d.addCallback(self.check_good, "test_verify_sdmf")
|
||||
d.addCallback(flushEventualQueue)
|
||||
return d
|
||||
|
||||
def test_verify_mdmf_empty(self):
|
||||
d = self.publish_mdmf("")
|
||||
d.addCallback(lambda ignored: self._fn.check(Monitor(), verify=True))
|
||||
d.addCallback(self.check_good, "test_verify_mdmf")
|
||||
d.addCallback(flushEventualQueue)
|
||||
return d
|
||||
|
||||
class Repair(unittest.TestCase, PublishMixin, ShouldFailMixin):
|
||||
|
||||
@ -2155,7 +2158,7 @@ class Repair(unittest.TestCase, PublishMixin, ShouldFailMixin):
|
||||
# In the buggy version, the check that precedes the retrieve+publish
|
||||
# cycle uses MODE_READ, instead of MODE_REPAIR, and fails to get the
|
||||
# privkey that repair needs.
|
||||
d = self.publish_empty_sdmf()
|
||||
d = self.publish_sdmf("")
|
||||
def _delete_one_share(ign):
|
||||
shares = self._storage._peers
|
||||
for peerid in shares:
|
||||
@ -3068,11 +3071,13 @@ class Version(GridTestMixin, unittest.TestCase, testutil.ShouldFailMixin, \
|
||||
self.c = self.g.clients[0]
|
||||
self.nm = self.c.nodemaker
|
||||
self.data = "test data" * 100000 # about 900 KiB; MDMF
|
||||
self.small_data = "test data" * 10 # about 90 B; SDMF
|
||||
self.small_data = "test data" * 10 # 90 B; SDMF
|
||||
|
||||
|
||||
def do_upload_mdmf(self):
|
||||
d = self.nm.create_mutable_file(MutableData(self.data),
|
||||
def do_upload_mdmf(self, data=None):
|
||||
if data is None:
|
||||
data = self.data
|
||||
d = self.nm.create_mutable_file(MutableData(data),
|
||||
version=MDMF_VERSION)
|
||||
def _then(n):
|
||||
assert isinstance(n, MutableFileNode)
|
||||
@ -3082,8 +3087,10 @@ class Version(GridTestMixin, unittest.TestCase, testutil.ShouldFailMixin, \
|
||||
d.addCallback(_then)
|
||||
return d
|
||||
|
||||
def do_upload_sdmf(self):
|
||||
d = self.nm.create_mutable_file(MutableData(self.small_data))
|
||||
def do_upload_sdmf(self, data=None):
|
||||
if data is None:
|
||||
data = self.small_data
|
||||
d = self.nm.create_mutable_file(MutableData(data))
|
||||
def _then(n):
|
||||
assert isinstance(n, MutableFileNode)
|
||||
assert n._protocol_version == SDMF_VERSION
|
||||
@ -3359,56 +3366,127 @@ class Version(GridTestMixin, unittest.TestCase, testutil.ShouldFailMixin, \
|
||||
return d
|
||||
|
||||
|
||||
def test_partial_read(self):
|
||||
d = self.do_upload_mdmf()
|
||||
d.addCallback(lambda ign: self.mdmf_node.get_best_readable_version())
|
||||
modes = [("start_on_segment_boundary",
|
||||
mathutil.next_multiple(128 * 1024, 3), 50),
|
||||
("ending_one_byte_after_segment_boundary",
|
||||
mathutil.next_multiple(128 * 1024, 3)-50, 51),
|
||||
("zero_length_at_start", 0, 0),
|
||||
("zero_length_in_middle", 50, 0),
|
||||
("zero_length_at_segment_boundary",
|
||||
mathutil.next_multiple(128 * 1024, 3), 0),
|
||||
]
|
||||
def _test_partial_read(self, node, expected, modes, step):
|
||||
d = node.get_best_readable_version()
|
||||
for (name, offset, length) in modes:
|
||||
d.addCallback(self._do_partial_read, name, offset, length)
|
||||
# then read only a few bytes at a time, and see that the results are
|
||||
# what we expect.
|
||||
d.addCallback(self._do_partial_read, name, expected, offset, length)
|
||||
# then read the whole thing, but only a few bytes at a time, and see
|
||||
# that the results are what we expect.
|
||||
def _read_data(version):
|
||||
c = consumer.MemoryConsumer()
|
||||
d2 = defer.succeed(None)
|
||||
for i in xrange(0, len(self.data), 10000):
|
||||
d2.addCallback(lambda ignored, i=i: version.read(c, i, 10000))
|
||||
for i in xrange(0, len(expected), step):
|
||||
d2.addCallback(lambda ignored, i=i: version.read(c, i, step))
|
||||
d2.addCallback(lambda ignored:
|
||||
self.failUnlessEqual(self.data, "".join(c.chunks)))
|
||||
self.failUnlessEqual(expected, "".join(c.chunks)))
|
||||
return d2
|
||||
d.addCallback(_read_data)
|
||||
return d
|
||||
def _do_partial_read(self, version, name, offset, length):
|
||||
|
||||
def _do_partial_read(self, version, name, expected, offset, length):
|
||||
c = consumer.MemoryConsumer()
|
||||
d = version.read(c, offset, length)
|
||||
expected = self.data[offset:offset+length]
|
||||
if length is None:
|
||||
expected_range = expected[offset:]
|
||||
else:
|
||||
expected_range = expected[offset:offset+length]
|
||||
d.addCallback(lambda ignored: "".join(c.chunks))
|
||||
def _check(results):
|
||||
if results != expected:
|
||||
print
|
||||
if results != expected_range:
|
||||
print "read([%d]+%s) got %d bytes, not %d" % \
|
||||
(offset, length, len(results), len(expected_range))
|
||||
print "got: %s ... %s" % (results[:20], results[-20:])
|
||||
print "exp: %s ... %s" % (expected[:20], expected[-20:])
|
||||
self.fail("results[%s] != expected" % name)
|
||||
print "exp: %s ... %s" % (expected_range[:20], expected_range[-20:])
|
||||
self.fail("results[%s] != expected_range" % name)
|
||||
return version # daisy-chained to next call
|
||||
d.addCallback(_check)
|
||||
return d
|
||||
|
||||
def test_partial_read_mdmf_0(self):
|
||||
data = ""
|
||||
d = self.do_upload_mdmf(data=data)
|
||||
modes = [("all1", 0,0),
|
||||
("all2", 0,None),
|
||||
]
|
||||
d.addCallback(self._test_partial_read, data, modes, 1)
|
||||
return d
|
||||
|
||||
def test_partial_read_mdmf_large(self):
|
||||
segment_boundary = mathutil.next_multiple(128 * 1024, 3)
|
||||
modes = [("start_on_segment_boundary", segment_boundary, 50),
|
||||
("ending_one_byte_after_segment_boundary", segment_boundary-50, 51),
|
||||
("zero_length_at_start", 0, 0),
|
||||
("zero_length_in_middle", 50, 0),
|
||||
("zero_length_at_segment_boundary", segment_boundary, 0),
|
||||
("complete_file1", 0, len(self.data)),
|
||||
("complete_file2", 0, None),
|
||||
]
|
||||
d = self.do_upload_mdmf()
|
||||
d.addCallback(self._test_partial_read, self.data, modes, 10000)
|
||||
return d
|
||||
|
||||
def test_partial_read_sdmf_0(self):
|
||||
data = ""
|
||||
modes = [("all1", 0,0),
|
||||
("all2", 0,None),
|
||||
]
|
||||
d = self.do_upload_sdmf(data=data)
|
||||
d.addCallback(self._test_partial_read, data, modes, 1)
|
||||
return d
|
||||
|
||||
def test_partial_read_sdmf_2(self):
|
||||
data = "hi"
|
||||
modes = [("one_byte", 0, 1),
|
||||
("last_byte", 1, 1),
|
||||
("last_byte2", 1, None),
|
||||
("complete_file", 0, 2),
|
||||
("complete_file2", 0, None),
|
||||
]
|
||||
d = self.do_upload_sdmf(data=data)
|
||||
d.addCallback(self._test_partial_read, data, modes, 1)
|
||||
return d
|
||||
|
||||
def test_partial_read_sdmf_90(self):
|
||||
modes = [("start_at_middle", 50, 40),
|
||||
("start_at_middle2", 50, None),
|
||||
("zero_length_at_start", 0, 0),
|
||||
("zero_length_in_middle", 50, 0),
|
||||
("zero_length_at_end", 90, 0),
|
||||
("complete_file1", 0, None),
|
||||
("complete_file2", 0, 90),
|
||||
]
|
||||
d = self.do_upload_sdmf()
|
||||
d.addCallback(self._test_partial_read, self.small_data, modes, 10)
|
||||
return d
|
||||
|
||||
def test_partial_read_sdmf_100(self):
|
||||
data = "test data "*10
|
||||
modes = [("start_at_middle", 50, 50),
|
||||
("start_at_middle2", 50, None),
|
||||
("zero_length_at_start", 0, 0),
|
||||
("zero_length_in_middle", 50, 0),
|
||||
("complete_file1", 0, 100),
|
||||
("complete_file2", 0, None),
|
||||
]
|
||||
d = self.do_upload_sdmf(data=data)
|
||||
d.addCallback(self._test_partial_read, data, modes, 10)
|
||||
return d
|
||||
|
||||
|
||||
def _test_read_and_download(self, node, expected):
|
||||
d = node.get_best_readable_version()
|
||||
def _read_data(version):
|
||||
c = consumer.MemoryConsumer()
|
||||
c2 = consumer.MemoryConsumer()
|
||||
d2 = defer.succeed(None)
|
||||
d2.addCallback(lambda ignored: version.read(c))
|
||||
d2.addCallback(lambda ignored:
|
||||
self.failUnlessEqual(expected, "".join(c.chunks)))
|
||||
|
||||
d2.addCallback(lambda ignored: version.read(c2, offset=0,
|
||||
size=len(expected)))
|
||||
d2.addCallback(lambda ignored:
|
||||
self.failUnlessEqual(expected, "".join(c2.chunks)))
|
||||
return d2
|
||||
d.addCallback(_read_data)
|
||||
d.addCallback(lambda ignored: node.download_best_version())
|
||||
@ -3441,7 +3519,7 @@ class Update(GridTestMixin, unittest.TestCase, testutil.ShouldFailMixin):
|
||||
self.c = self.g.clients[0]
|
||||
self.nm = self.c.nodemaker
|
||||
self.data = "testdata " * 100000 # about 900 KiB; MDMF
|
||||
self.small_data = "test data" * 10 # about 90 B; SDMF
|
||||
self.small_data = "test data" * 10 # 90 B; SDMF
|
||||
|
||||
|
||||
def do_upload_sdmf(self):
|
||||
|
Loading…
Reference in New Issue
Block a user