Improve SFTP error handling and remove use of IFinishableConsumer. fixes #1525

Signed-off-by: David-Sarah Hopwood <david-sarah@jacaranda.org>
This commit is contained in:
David-Sarah Hopwood 2013-03-19 05:37:02 +00:00
parent 48a2989ee1
commit 50c6562901
3 changed files with 176 additions and 66 deletions

View File

@ -22,7 +22,7 @@ from twisted.python.failure import Failure
from twisted.internet.interfaces import ITransport from twisted.internet.interfaces import ITransport
from twisted.internet import defer from twisted.internet import defer
from twisted.internet.interfaces import IFinishableConsumer from twisted.internet.interfaces import IConsumer
from foolscap.api import eventually from foolscap.api import eventually
from allmydata.util import deferredutil from allmydata.util import deferredutil
@ -294,7 +294,7 @@ def _direntry_for(filenode_or_parent, childname, filenode=None):
class OverwriteableFileConsumer(PrefixingLogMixin): class OverwriteableFileConsumer(PrefixingLogMixin):
implements(IFinishableConsumer) implements(IConsumer)
"""I act both as a consumer for the download of the original file contents, and as a """I act both as a consumer for the download of the original file contents, and as a
wrapper for a temporary file that records the downloaded data and any overwrites. wrapper for a temporary file that records the downloaded data and any overwrites.
I use a priority queue to keep track of which regions of the file have been overwritten I use a priority queue to keep track of which regions of the file have been overwritten
@ -321,12 +321,9 @@ class OverwriteableFileConsumer(PrefixingLogMixin):
self.milestones = [] # empty heap of (offset, d) self.milestones = [] # empty heap of (offset, d)
self.overwrites = [] # empty heap of (start, end) self.overwrites = [] # empty heap of (start, end)
self.is_closed = False self.is_closed = False
self.done = self.when_reached(download_size) # adds a milestone
self.is_done = False self.done = defer.Deferred()
def _signal_done(ign): self.done_status = None # None -> not complete, Failure -> download failed, str -> download succeeded
if noisy: self.log("DONE", level=NOISY)
self.is_done = True
self.done.addCallback(_signal_done)
self.producer = None self.producer = None
def get_file(self): def get_file(self):
@ -349,7 +346,7 @@ class OverwriteableFileConsumer(PrefixingLogMixin):
self.download_size = size self.download_size = size
if self.downloaded >= self.download_size: if self.downloaded >= self.download_size:
self.finish() self.download_done("size changed")
def registerProducer(self, p, streaming): def registerProducer(self, p, streaming):
if noisy: self.log(".registerProducer(%r, streaming=%r)" % (p, streaming), level=NOISY) if noisy: self.log(".registerProducer(%r, streaming=%r)" % (p, streaming), level=NOISY)
@ -362,7 +359,7 @@ class OverwriteableFileConsumer(PrefixingLogMixin):
p.resumeProducing() p.resumeProducing()
else: else:
def _iterate(): def _iterate():
if not self.is_done: if self.done_status is None:
p.resumeProducing() p.resumeProducing()
eventually(_iterate) eventually(_iterate)
_iterate() _iterate()
@ -429,13 +426,17 @@ class OverwriteableFileConsumer(PrefixingLogMixin):
return return
if noisy: self.log("MILESTONE %r %r" % (next, d), level=NOISY) if noisy: self.log("MILESTONE %r %r" % (next, d), level=NOISY)
heapq.heappop(self.milestones) heapq.heappop(self.milestones)
eventually(d.callback, None) eventually_callback(d)("reached")
if milestone >= self.download_size: if milestone >= self.download_size:
self.finish() self.download_done("reached download size")
def overwrite(self, offset, data): def overwrite(self, offset, data):
if noisy: self.log(".overwrite(%r, <data of length %r>)" % (offset, len(data)), level=NOISY) if noisy: self.log(".overwrite(%r, <data of length %r>)" % (offset, len(data)), level=NOISY)
if self.is_closed:
self.log("overwrite called on a closed OverwriteableFileConsumer", level=WEIRD)
raise SFTPError(FX_BAD_MESSAGE, "cannot write to a closed file handle")
if offset > self.current_size: if offset > self.current_size:
# Normally writing at an offset beyond the current end-of-file # Normally writing at an offset beyond the current end-of-file
# would leave a hole that appears filled with zeroes. However, an # would leave a hole that appears filled with zeroes. However, an
@ -463,6 +464,9 @@ class OverwriteableFileConsumer(PrefixingLogMixin):
The caller must perform no more overwrites until the Deferred has fired.""" The caller must perform no more overwrites until the Deferred has fired."""
if noisy: self.log(".read(%r, %r), current_size = %r" % (offset, length, self.current_size), level=NOISY) if noisy: self.log(".read(%r, %r), current_size = %r" % (offset, length, self.current_size), level=NOISY)
if self.is_closed:
self.log("read called on a closed OverwriteableFileConsumer", level=WEIRD)
raise SFTPError(FX_BAD_MESSAGE, "cannot read from a closed file handle")
# Note that the overwrite method is synchronous. When a write request is processed # Note that the overwrite method is synchronous. When a write request is processed
# (e.g. a writeChunk request on the async queue of GeneralSFTPFile), overwrite will # (e.g. a writeChunk request on the async queue of GeneralSFTPFile), overwrite will
@ -478,48 +482,68 @@ class OverwriteableFileConsumer(PrefixingLogMixin):
if noisy: self.log("truncating read to %r bytes" % (length,), level=NOISY) if noisy: self.log("truncating read to %r bytes" % (length,), level=NOISY)
needed = min(offset + length, self.download_size) needed = min(offset + length, self.download_size)
d = self.when_reached(needed)
def _reached(ign): # If we fail to reach the needed number of bytes, the read request will fail.
d = self.when_reached_or_failed(needed)
def _reached_in_read(res):
# It is not necessarily the case that self.downloaded >= needed, because # It is not necessarily the case that self.downloaded >= needed, because
# the file might have been truncated (thus truncating the download) and # the file might have been truncated (thus truncating the download) and
# then extended. # then extended.
_assert(self.current_size >= offset + length, _assert(self.current_size >= offset + length,
current_size=self.current_size, offset=offset, length=length) current_size=self.current_size, offset=offset, length=length)
if noisy: self.log("self.f = %r" % (self.f,), level=NOISY) if noisy: self.log("_reached_in_read(%r), self.f = %r" % (res, self.f,), level=NOISY)
self.f.seek(offset) self.f.seek(offset)
return self.f.read(length) return self.f.read(length)
d.addCallback(_reached) d.addCallback(_reached_in_read)
return d return d
def when_reached(self, index): def when_reached_or_failed(self, index):
if noisy: self.log(".when_reached(%r)" % (index,), level=NOISY) if noisy: self.log(".when_reached_or_failed(%r)" % (index,), level=NOISY)
if index <= self.downloaded: # already reached def _reached(res):
if noisy: self.log("already reached %r" % (index,), level=NOISY) if noisy: self.log("reached %r with result %r" % (index, res), level=NOISY)
return defer.succeed(None) return res
if self.done_status is not None:
return defer.execute(_reached, self.done_status)
if index <= self.downloaded: # already reached successfully
if noisy: self.log("already reached %r successfully" % (index,), level=NOISY)
return defer.succeed("already reached successfully")
d = defer.Deferred() d = defer.Deferred()
def _reached(ign):
if noisy: self.log("reached %r" % (index,), level=NOISY)
return ign
d.addCallback(_reached) d.addCallback(_reached)
heapq.heappush(self.milestones, (index, d)) heapq.heappush(self.milestones, (index, d))
return d return d
def when_done(self): def when_done(self):
return self.done d = defer.Deferred()
self.done.addCallback(lambda ign: eventually_callback(d)(self.done_status))
return d
def finish(self): def download_done(self, res):
"""Called by the producer when it has finished producing, or when we have _assert(isinstance(res, (str, Failure)), res=res)
received enough bytes, or as a result of a close. Defined by IFinishableConsumer.""" # Only the first call to download_done counts, but we log subsequent calls
# (multiple calls are normal).
if self.done_status is not None:
self.log("IGNORING extra call to download_done with result %r; previous result was %r"
% (res, self.done_status), level=OPERATIONAL)
return
self.log("DONE with result %r" % (res,), level=OPERATIONAL)
# We avoid errbacking self.done so that we are not left with an 'Unhandled error in Deferred'
# in case when_done() is never called. Instead we stash the failure in self.done_status,
# from where the callback added in when_done() can retrieve it.
self.done_status = res
eventually_callback(self.done)(None)
while len(self.milestones) > 0: while len(self.milestones) > 0:
(next, d) = self.milestones[0] (next, d) = self.milestones[0]
if noisy: self.log("MILESTONE FINISH %r %r" % (next, d), level=NOISY) if noisy: self.log("MILESTONE FINISH %r %r %r" % (next, d, res), level=NOISY)
heapq.heappop(self.milestones) heapq.heappop(self.milestones)
# The callback means that the milestone has been reached if # The callback means that the milestone has been reached if
# it is ever going to be. Note that the file may have been # it is ever going to be. Note that the file may have been
# truncated to before the milestone. # truncated to before the milestone.
eventually(d.callback, None) eventually_callback(d)(res)
def close(self): def close(self):
if not self.is_closed: if not self.is_closed:
@ -528,10 +552,14 @@ class OverwriteableFileConsumer(PrefixingLogMixin):
self.f.close() self.f.close()
except Exception, e: except Exception, e:
self.log("suppressed %r from close of temporary file %r" % (e, self.f), level=WEIRD) self.log("suppressed %r from close of temporary file %r" % (e, self.f), level=WEIRD)
self.finish() self.download_done("closed")
return self.done_status
def unregisterProducer(self): def unregisterProducer(self):
pass # This will happen just before our client calls download_done, which will tell
# us the outcome of the download; we don't know the outcome at this point.
self.producer = None
self.log("producer unregistered", level=NOISY)
SIZE_THRESHOLD = 1000 SIZE_THRESHOLD = 1000
@ -577,9 +605,9 @@ class ShortReadOnlySFTPFile(PrefixingLogMixin):
# i.e. we respond with an EOF error iff offset is already at EOF. # i.e. we respond with an EOF error iff offset is already at EOF.
if offset >= len(data): if offset >= len(data):
eventually(d.errback, SFTPError(FX_EOF, "read at or past end of file")) eventually_errback(d)(Failure(SFTPError(FX_EOF, "read at or past end of file")))
else: else:
eventually(d.callback, data[offset:offset+length]) # truncated if offset+length > len(data) eventually_callback(d)(data[offset:offset+length]) # truncated if offset+length > len(data)
return data return data
self.async.addCallbacks(_read, eventually_errback(d)) self.async.addCallbacks(_read, eventually_errback(d))
d.addBoth(_convert_error, request) d.addBoth(_convert_error, request)
@ -672,7 +700,7 @@ class GeneralSFTPFile(PrefixingLogMixin):
if (self.flags & FXF_TRUNC) or not filenode: if (self.flags & FXF_TRUNC) or not filenode:
# We're either truncating or creating the file, so we don't need the old contents. # We're either truncating or creating the file, so we don't need the old contents.
self.consumer = OverwriteableFileConsumer(0, tempfile_maker) self.consumer = OverwriteableFileConsumer(0, tempfile_maker)
self.consumer.finish() self.consumer.download_done("download not needed")
else: else:
self.async.addCallback(lambda ignored: filenode.get_best_readable_version()) self.async.addCallback(lambda ignored: filenode.get_best_readable_version())
@ -683,10 +711,16 @@ class GeneralSFTPFile(PrefixingLogMixin):
self.consumer = OverwriteableFileConsumer(download_size, tempfile_maker) self.consumer = OverwriteableFileConsumer(download_size, tempfile_maker)
version.read(self.consumer, 0, None) d = version.read(self.consumer, 0, None)
def _finished(res):
if not isinstance(res, Failure):
res = "download finished"
self.consumer.download_done(res)
d.addBoth(_finished)
# It is correct to drop d here.
self.async.addCallback(_read) self.async.addCallback(_read)
eventually(self.async.callback, None) eventually_callback(self.async)(None)
if noisy: self.log("open done", level=NOISY) if noisy: self.log("open done", level=NOISY)
return self return self
@ -739,7 +773,7 @@ class GeneralSFTPFile(PrefixingLogMixin):
def _read(ign): def _read(ign):
if noisy: self.log("_read in readChunk(%r, %r)" % (offset, length), level=NOISY) if noisy: self.log("_read in readChunk(%r, %r)" % (offset, length), level=NOISY)
d2 = self.consumer.read(offset, length) d2 = self.consumer.read(offset, length)
d2.addCallbacks(eventually_callback(d), eventually_errback(d)) d2.addBoth(eventually_callback(d))
# It is correct to drop d2 here. # It is correct to drop d2 here.
return None return None
self.async.addCallbacks(_read, eventually_errback(d)) self.async.addCallbacks(_read, eventually_errback(d))
@ -783,6 +817,24 @@ class GeneralSFTPFile(PrefixingLogMixin):
# don't addErrback to self.async, just allow subsequent async ops to fail. # don't addErrback to self.async, just allow subsequent async ops to fail.
return defer.succeed(None) return defer.succeed(None)
def _do_close(self, res, d=None):
if noisy: self.log("_do_close(%r)" % (res,), level=NOISY)
status = None
if self.consumer:
status = self.consumer.close()
# We must close_notify before re-firing self.async.
if self.close_notify:
self.close_notify(self.userpath, self.parent, self.childname, self)
if not isinstance(res, Failure) and isinstance(status, Failure):
res = status
if d:
eventually_callback(d)(res)
elif isinstance(res, Failure):
self.log("suppressing %r" % (res,), level=OPERATIONAL)
def close(self): def close(self):
request = ".close()" request = ".close()"
self.log(request, level=OPERATIONAL) self.log(request, level=OPERATIONAL)
@ -794,10 +846,14 @@ class GeneralSFTPFile(PrefixingLogMixin):
self.closed = True self.closed = True
if not (self.flags & (FXF_WRITE | FXF_CREAT)): if not (self.flags & (FXF_WRITE | FXF_CREAT)):
def _readonly_close(): # We never fail a close of a handle opened only for reading, even if the file
if self.consumer: # failed to download. (We could not do so deterministically, because it would
self.consumer.close() # depend on whether we reached the point of failure before abandoning the
return defer.execute(_readonly_close) # download.) Any reads that depended on file content that could not be downloaded
# will have failed. It is important that we don't close the consumer until
# previous read operations have completed.
self.async.addBoth(self._do_close)
return defer.succeed(None)
# We must capture the abandoned, parent, and childname variables synchronously # We must capture the abandoned, parent, and childname variables synchronously
# at the close call. This is needed by the correctness arguments in the comments # at the close call. This is needed by the correctness arguments in the comments
@ -811,20 +867,11 @@ class GeneralSFTPFile(PrefixingLogMixin):
# it is correct to optimize out the commit if it is False at the close call. # it is correct to optimize out the commit if it is False at the close call.
has_changed = self.has_changed has_changed = self.has_changed
def _committed(res): def _commit(ign):
if noisy: self.log("_committed(%r)" % (res,), level=NOISY)
self.consumer.close()
# We must close_notify before re-firing self.async.
if self.close_notify:
self.close_notify(self.userpath, self.parent, self.childname, self)
return res
def _close(ign):
d2 = self.consumer.when_done() d2 = self.consumer.when_done()
if self.filenode and self.filenode.is_mutable(): if self.filenode and self.filenode.is_mutable():
self.log("update mutable file %r childname=%r metadata=%r" % (self.filenode, childname, self.metadata), level=OPERATIONAL) self.log("update mutable file %r childname=%r metadata=%r"
% (self.filenode, childname, self.metadata), level=OPERATIONAL)
if self.metadata.get('no-write', False) and not self.filenode.is_readonly(): if self.metadata.get('no-write', False) and not self.filenode.is_readonly():
_assert(parent and childname, parent=parent, childname=childname, metadata=self.metadata) _assert(parent and childname, parent=parent, childname=childname, metadata=self.metadata)
d2.addCallback(lambda ign: parent.set_metadata_for(childname, self.metadata)) d2.addCallback(lambda ign: parent.set_metadata_for(childname, self.metadata))
@ -836,22 +883,19 @@ class GeneralSFTPFile(PrefixingLogMixin):
u = FileHandle(self.consumer.get_file(), self.convergence) u = FileHandle(self.consumer.get_file(), self.convergence)
return parent.add_file(childname, u, metadata=self.metadata) return parent.add_file(childname, u, metadata=self.metadata)
d2.addCallback(_add_file) d2.addCallback(_add_file)
d2.addBoth(_committed)
return d2 return d2
d = defer.Deferred()
# If the file has been abandoned, we don't want the close operation to get "stuck", # If the file has been abandoned, we don't want the close operation to get "stuck",
# even if self.async fails to re-fire. Doing the close independently of self.async # even if self.async fails to re-fire. Completing the close independently of self.async
# in that case ensures that dropping an ssh connection is sufficient to abandon # in that case should ensure that dropping an ssh connection is sufficient to abandon
# any heisenfiles that were not explicitly closed in that connection. # any heisenfiles that were not explicitly closed in that connection.
if abandoned or not has_changed: if abandoned or not has_changed:
d.addCallback(_committed) d = defer.succeed(None)
self.async.addBoth(self._do_close)
else: else:
self.async.addCallback(_close) d = defer.Deferred()
self.async.addCallback(_commit)
self.async.addCallbacks(eventually_callback(d), eventually_errback(d)) self.async.addBoth(self._do_close, d)
d.addBoth(_convert_error, request) d.addBoth(_convert_error, request)
return d return d
@ -873,7 +917,7 @@ class GeneralSFTPFile(PrefixingLogMixin):
# self.filenode might be None, but that's ok. # self.filenode might be None, but that's ok.
attrs = _populate_attrs(self.filenode, self.metadata, size=self.consumer.get_current_size()) attrs = _populate_attrs(self.filenode, self.metadata, size=self.consumer.get_current_size())
eventually(d.callback, attrs) eventually_callback(d)(attrs)
return None return None
self.async.addCallbacks(_get, eventually_errback(d)) self.async.addCallbacks(_get, eventually_errback(d))
d.addBoth(_convert_error, request) d.addBoth(_convert_error, request)
@ -911,7 +955,7 @@ class GeneralSFTPFile(PrefixingLogMixin):
# TODO: should we refuse to truncate a file opened with FXF_APPEND? # TODO: should we refuse to truncate a file opened with FXF_APPEND?
# <http://allmydata.org/trac/tahoe-lafs/ticket/1037#comment:20> # <http://allmydata.org/trac/tahoe-lafs/ticket/1037#comment:20>
self.consumer.set_current_size(size) self.consumer.set_current_size(size)
eventually(d.callback, None) eventually_callback(d)(None)
return None return None
self.async.addCallbacks(_set, eventually_errback(d)) self.async.addCallbacks(_set, eventually_errback(d))
d.addBoth(_convert_error, request) d.addBoth(_convert_error, request)

View File

@ -326,6 +326,13 @@ class NoNetworkGrid(service.MultiService):
ss.hung_until.callback(None) ss.hung_until.callback(None)
ss.hung_until = None ss.hung_until = None
def nuke_from_orbit(self):
""" Empty all share directories in this grid. It's the only way to be sure ;-) """
for server in self.servers_by_number.values():
for prefixdir in os.listdir(server.sharedir):
if prefixdir != 'incoming':
fileutil.rm_dir(os.path.join(server.sharedir, prefixdir))
class GridTestMixin: class GridTestMixin:
def setUp(self): def setUp(self):

View File

@ -36,6 +36,10 @@ from allmydata.test.common_util import ReallyEqualMixin
timeout = 240 timeout = 240
#defer.setDebugging(True)
#from twisted.internet import base
#base.DelayedCall.debug = True
class Handler(GridTestMixin, ShouldFailMixin, ReallyEqualMixin, unittest.TestCase): class Handler(GridTestMixin, ShouldFailMixin, ReallyEqualMixin, unittest.TestCase):
"""This is a no-network unit test of the SFTPUserHandler and the abstractions it uses.""" """This is a no-network unit test of the SFTPUserHandler and the abstractions it uses."""
@ -519,6 +523,43 @@ class Handler(GridTestMixin, ShouldFailMixin, ReallyEqualMixin, unittest.TestCas
return d2 return d2
d.addCallback(_read_short) d.addCallback(_read_short)
# check that failed downloads cause failed reads
d.addCallback(lambda ign: self.handler.openFile("uri/"+self.gross_uri, sftp.FXF_READ, {}))
def _read_broken(rf):
d2 = defer.succeed(None)
d2.addCallback(lambda ign: self.g.nuke_from_orbit())
d2.addCallback(lambda ign:
self.shouldFailWithSFTPError(sftp.FX_FAILURE, "read broken",
rf.readChunk, 0, 100))
# close shouldn't fail
d2.addCallback(lambda ign: rf.close())
d2.addCallback(lambda res: self.failUnlessReallyEqual(res, None))
return d2
d.addCallback(_read_broken)
d.addCallback(lambda ign: self.failUnlessEqual(sftpd.all_heisenfiles, {}))
d.addCallback(lambda ign: self.failUnlessEqual(self.handler._heisenfiles, {}))
return d
def test_openFile_read_error(self):
# The check at the end of openFile_read tested this for large files, but it trashed
# the grid in the process, so this needs to be a separate test.
small = upload.Data("0123456789"*10, None)
d = self._set_up("openFile_read_error")
d.addCallback(lambda ign: self.root.add_file(u"small", small))
d.addCallback(lambda n: self.handler.openFile("/uri/"+n.get_uri(), sftp.FXF_READ, {}))
def _read_broken(rf):
d2 = defer.succeed(None)
d2.addCallback(lambda ign: self.g.nuke_from_orbit())
d2.addCallback(lambda ign:
self.shouldFailWithSFTPError(sftp.FX_FAILURE, "read broken",
rf.readChunk, 0, 100))
# close shouldn't fail
d2.addCallback(lambda ign: rf.close())
d2.addCallback(lambda res: self.failUnlessReallyEqual(res, None))
return d2
d.addCallback(_read_broken)
d.addCallback(lambda ign: self.failUnlessEqual(sftpd.all_heisenfiles, {})) d.addCallback(lambda ign: self.failUnlessEqual(sftpd.all_heisenfiles, {}))
d.addCallback(lambda ign: self.failUnlessEqual(self.handler._heisenfiles, {})) d.addCallback(lambda ign: self.failUnlessEqual(self.handler._heisenfiles, {}))
return d return d
@ -982,6 +1023,24 @@ class Handler(GridTestMixin, ShouldFailMixin, ReallyEqualMixin, unittest.TestCas
self.shouldFail(NoSuchChildError, "rename new while open", "new", self.shouldFail(NoSuchChildError, "rename new while open", "new",
self.root.get, u"new")) self.root.get, u"new"))
# check that failed downloads cause failed reads and failed close, when open for writing
gross = u"gro\u00DF".encode("utf-8")
d.addCallback(lambda ign: self.handler.openFile(gross, sftp.FXF_READ | sftp.FXF_WRITE, {}))
def _read_write_broken(rwf):
d2 = rwf.writeChunk(0, "abcdefghij")
d2.addCallback(lambda ign: self.g.nuke_from_orbit())
# reading should fail (reliably if we read past the written chunk)
d2.addCallback(lambda ign:
self.shouldFailWithSFTPError(sftp.FX_FAILURE, "read/write broken",
rwf.readChunk, 0, 100))
# close should fail in this case
d2.addCallback(lambda ign:
self.shouldFailWithSFTPError(sftp.FX_FAILURE, "read/write broken close",
rwf.close))
return d2
d.addCallback(_read_write_broken)
d.addCallback(lambda ign: self.failUnlessEqual(sftpd.all_heisenfiles, {})) d.addCallback(lambda ign: self.failUnlessEqual(sftpd.all_heisenfiles, {}))
d.addCallback(lambda ign: self.failUnlessEqual(self.handler._heisenfiles, {})) d.addCallback(lambda ign: self.failUnlessEqual(self.handler._heisenfiles, {}))
return d return d