Retrieve: implement/test stopProducing

This commit is contained in:
Brian Warner 2011-09-05 12:02:42 -07:00
parent 748e419a9b
commit a15ce96846
2 changed files with 64 additions and 45 deletions

View File

@ -7,7 +7,7 @@ from twisted.python import failure
from twisted.internet.interfaces import IPushProducer, IConsumer from twisted.internet.interfaces import IPushProducer, IConsumer
from foolscap.api import eventually, fireEventually from foolscap.api import eventually, fireEventually
from allmydata.interfaces import IRetrieveStatus, NotEnoughSharesError, \ from allmydata.interfaces import IRetrieveStatus, NotEnoughSharesError, \
MDMF_VERSION, SDMF_VERSION DownloadStopped, MDMF_VERSION, SDMF_VERSION
from allmydata.util import hashutil, log, mathutil from allmydata.util import hashutil, log, mathutil
from allmydata.util.dictutil import DictOfSets from allmydata.util.dictutil import DictOfSets
from allmydata import hashtree, codec from allmydata import hashtree, codec
@ -143,6 +143,7 @@ class Retrieve:
self._status.set_size(datalength) self._status.set_size(datalength)
self._status.set_encoding(k, N) self._status.set_encoding(k, N)
self.readers = {} self.readers = {}
self._stopped = False
self._pause_deferred = None self._pause_deferred = None
self._offset = None self._offset = None
self._read_length = None self._read_length = None
@ -196,6 +197,10 @@ class Retrieve:
eventually(p.callback, None) eventually(p.callback, None)
def stopProducing(self):
self._stopped = True
self.resumeProducing()
def _check_for_paused(self, res): def _check_for_paused(self, res):
""" """
@ -205,6 +210,8 @@ class Retrieve:
the Deferred fires immediately. Otherwise, the Deferred fires the Deferred fires immediately. Otherwise, the Deferred fires
when the downloader is unpaused. when the downloader is unpaused.
""" """
if self._stopped:
raise DownloadStopped("our Consumer called stopProducing()")
if self._pause_deferred is not None: if self._pause_deferred is not None:
d = defer.Deferred() d = defer.Deferred()
self._pause_deferred.addCallback(lambda ignored: d.callback(res)) self._pause_deferred.addCallback(lambda ignored: d.callback(res))

View File

@ -3,16 +3,15 @@ import os, re, base64
from cStringIO import StringIO from cStringIO import StringIO
from twisted.trial import unittest from twisted.trial import unittest
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from twisted.internet.interfaces import IConsumer
from zope.interface import implements
from allmydata import uri, client from allmydata import uri, client
from allmydata.nodemaker import NodeMaker from allmydata.nodemaker import NodeMaker
from allmydata.util import base32, consumer, fileutil, mathutil from allmydata.util import base32, consumer, fileutil, mathutil
from allmydata.util.hashutil import tagged_hash, ssk_writekey_hash, \ from allmydata.util.hashutil import tagged_hash, ssk_writekey_hash, \
ssk_pubkey_fingerprint_hash ssk_pubkey_fingerprint_hash
from allmydata.util.consumer import MemoryConsumer
from allmydata.util.deferredutil import gatherResults from allmydata.util.deferredutil import gatherResults
from allmydata.interfaces import IRepairResults, ICheckAndRepairResults, \ from allmydata.interfaces import IRepairResults, ICheckAndRepairResults, \
NotEnoughSharesError, SDMF_VERSION, MDMF_VERSION NotEnoughSharesError, SDMF_VERSION, MDMF_VERSION, DownloadStopped
from allmydata.monitor import Monitor from allmydata.monitor import Monitor
from allmydata.test.common import ShouldFailMixin from allmydata.test.common import ShouldFailMixin
from allmydata.test.no_network import GridTestMixin from allmydata.test.no_network import GridTestMixin
@ -37,6 +36,9 @@ from allmydata.mutable.repairer import MustForceRepairError
import allmydata.test.common_util as testutil import allmydata.test.common_util as testutil
from allmydata.test.common import TEST_RSA_KEY_SIZE from allmydata.test.common import TEST_RSA_KEY_SIZE
from allmydata.test.test_download import PausingConsumer, \
PausingAndStoppingConsumer, StoppingConsumer, \
ImmediatelyStoppingConsumer
# this "FakeStorage" exists to put the share data in RAM and avoid using real # this "FakeStorage" exists to put the share data in RAM and avoid using real
@ -544,26 +546,60 @@ class Filenode(unittest.TestCase, testutil.ShouldFailMixin):
return d return d
def test_retrieve_pause(self): def test_retrieve_producer_mdmf(self):
# We should make sure that the retriever is able to pause # We should make sure that the retriever is able to pause and stop
# correctly. # correctly.
d = self.nodemaker.create_mutable_file(version=MDMF_VERSION) data = "contents1" * 100000
def _created(node): d = self.nodemaker.create_mutable_file(MutableData(data),
self.node = node version=MDMF_VERSION)
d.addCallback(lambda node: node.get_best_mutable_version())
return node.overwrite(MutableData("contents1" * 100000)) d.addCallback(self._test_retrieve_producer, "MDMF", data)
d.addCallback(_created)
# Now we'll retrieve it into a pausing consumer.
d.addCallback(lambda ignored:
self.node.get_best_mutable_version())
def _got_version(version):
self.c = PausingConsumer()
return version.read(self.c)
d.addCallback(_got_version)
d.addCallback(lambda ignored:
self.failUnlessEqual(self.c.data, "contents1" * 100000))
return d return d
# note: SDMF has only one big segment, so we can't use the usual
# after-the-first-write() trick to pause or stop the download.
# Disabled until we find a better approach.
def OFF_test_retrieve_producer_sdmf(self):
data = "contents1" * 100000
d = self.nodemaker.create_mutable_file(MutableData(data),
version=SDMF_VERSION)
d.addCallback(lambda node: node.get_best_mutable_version())
d.addCallback(self._test_retrieve_producer, "SDMF", data)
return d
def _test_retrieve_producer(self, version, kind, data):
# Now we'll retrieve it into a pausing consumer.
c = PausingConsumer()
d = version.read(c)
d.addCallback(lambda ign: self.failUnlessEqual(c.size, len(data)))
c2 = PausingAndStoppingConsumer()
d.addCallback(lambda ign:
self.shouldFail(DownloadStopped, kind+"_pause_stop",
"our Consumer called stopProducing()",
version.read, c2))
c3 = StoppingConsumer()
d.addCallback(lambda ign:
self.shouldFail(DownloadStopped, kind+"_stop",
"our Consumer called stopProducing()",
version.read, c3))
c4 = ImmediatelyStoppingConsumer()
d.addCallback(lambda ign:
self.shouldFail(DownloadStopped, kind+"_stop_imm",
"our Consumer called stopProducing()",
version.read, c4))
def _then(ign):
c5 = MemoryConsumer()
d1 = version.read(c5)
c5.producer.stopProducing()
return self.shouldFail(DownloadStopped, kind+"_stop_imm2",
"our Consumer called stopProducing()",
lambda: d1)
d.addCallback(_then)
return d
def test_download_from_mdmf_cap(self): def test_download_from_mdmf_cap(self):
# We should be able to download an MDMF file given its cap # We should be able to download an MDMF file given its cap
@ -1048,30 +1084,6 @@ class PublishMixin:
index = versionmap[shnum] index = versionmap[shnum]
shares[peerid][shnum] = oldshares[index][peerid][shnum] shares[peerid][shnum] = oldshares[index][peerid][shnum]
class PausingConsumer:
implements(IConsumer)
def __init__(self):
self.data = ""
self.already_paused = False
def registerProducer(self, producer, streaming):
self.producer = producer
self.producer.resumeProducing()
def unregisterProducer(self):
self.producer = None
def _unpause(self, ignored):
self.producer.resumeProducing()
def write(self, data):
self.data += data
if not self.already_paused:
self.producer.pauseProducing()
self.already_paused = True
reactor.callLater(15, self._unpause, None)
class Servermap(unittest.TestCase, PublishMixin): class Servermap(unittest.TestCase, PublishMixin):
def setUp(self): def setUp(self):
return self.publish_one() return self.publish_one()