Change API to a listener-style, with helper

This commit is contained in:
meejah 2016-04-26 11:44:58 -06:00
parent 55898941da
commit b834b71dac
5 changed files with 71 additions and 47 deletions

@ -373,12 +373,12 @@ class Client(node.Node, pollmixin.PollMixin):
ps = self.get_config("client", "peers.preferred", "").split(",") ps = self.get_config("client", "peers.preferred", "").split(",")
preferred_peers = tuple([p.strip() for p in ps if p != ""]) preferred_peers = tuple([p.strip() for p in ps if p != ""])
sb = storage_client.StorageFarmBroker(self.tub, permute_peers=True, preferred_peers=preferred_peers) sb = storage_client.StorageFarmBroker(self.tub, permute_peers=True, preferred_peers=preferred_peers)
self.storage_broker = sb
connection_threshold = min(self.encoding_params["k"], connection_threshold = min(self.encoding_params["k"],
self.encoding_params["happy"] + 1) self.encoding_params["happy"] + 1)
helper = storage_client.ConnectedEnough(sb, connection_threshold)
self.storage_broker = sb self.upload_ready_d = helper.when_connected_enough()
self.upload_ready_d = sb.when_connected_to(connection_threshold)
# load static server specifications from tahoe.cfg, if any. # load static server specifications from tahoe.cfg, if any.
# Not quite ready yet. # Not quite ready yet.

@ -31,11 +31,12 @@ the foolscap-based server implemented in src/allmydata/storage/*.py .
import re, time import re, time
from zope.interface import implements from zope.interface import implements
from twisted.internet import defer
from foolscap.api import eventually from foolscap.api import eventually
from allmydata.interfaces import IStorageBroker, IDisplayableServer, IServer from allmydata.interfaces import IStorageBroker, IDisplayableServer, IServer
from allmydata.util import log, base32 from allmydata.util import log, base32
from allmydata.util.assertutil import precondition from allmydata.util.assertutil import precondition
from allmydata.util.observer import OneShotObserverList from allmydata.util.observer import OneShotObserverList, ObserverList
from allmydata.util.rrefutil import add_version_to_remote_reference from allmydata.util.rrefutil import add_version_to_remote_reference
from allmydata.util.hashutil import sha1 from allmydata.util.hashutil import sha1
@ -56,6 +57,41 @@ from allmydata.util.hashutil import sha1
# don't pass signatures: only pass validated blessed-objects # don't pass signatures: only pass validated blessed-objects
class ConnectedEnough(object):
def __init__(self, storage_farm_broker, threshold):
self._broker = storage_farm_broker
self._threshold = int(threshold)
if self._threshold <= 0:
raise ValueError("threshold must be positive")
self._threshold_passed = False
self._observers = OneShotObserverList()
self._broker.on_servers_changed(self._check_enough_connected)
def when_connected_enough(self):
"""
:returns: a Deferred that fires if/when our high water mark for
number of connected servers becomes (or ever was) above
"threshold".
"""
if self._threshold_passed:
return defer.succeed(None)
return self._observers.when_fired()
def _check_enough_connected(self):
"""
internal helper
"""
if self._threshold_passed:
return
num_servers = len(self._broker.get_connected_servers())
if num_servers >= self._threshold:
self._threshold_passed = True
self._observers.fire(None)
class StorageFarmBroker: class StorageFarmBroker:
implements(IStorageBroker) implements(IStorageBroker)
"""I live on the client, and know about storage servers. For each server """I live on the client, and know about storage servers. For each server
@ -75,59 +111,37 @@ class StorageFarmBroker:
# them for it. # them for it.
self.servers = {} self.servers = {}
self.introducer_client = None self.introducer_client = None
# the most servers we've connected to at once self._server_listeners = ObserverList()
self._highest_connections = 0
# maps int -> OneShotObserverList, where the int is the threshold
self._connected_observers = dict()
def when_connected_to(self, threshold): def on_servers_changed(self, callback):
""" self._server_listeners.subscribe(callback)
:returns: a Deferred that fires if/when our high water mark for
number of connected servers becomes (or ever was) above
"threshold".
"""
threshold = int(threshold)
if threshold <= 0:
raise ValueError("threshold must be positive")
if threshold <= self._highest_connections:
return defer.succeed(None)
try:
return self._connected_observers[threshold].when_fired()
except KeyError:
self._connected_observers[threshold] = OneShotObserverList()
return self._connected_observers[threshold].when_fired()
def check_enough_connected(self):
"""
internal helper
"""
num_servers = len(self.get_connected_servers())
self._highest_connections = max(num_servers, self._highest_connections)
try:
self._connected_observers[num_servers].fire_if_not_fired(None)
except KeyError:
pass
# these two are used in unit tests # these two are used in unit tests
def test_add_rref(self, serverid, rref, ann): def test_add_rref(self, serverid, rref, ann):
s = NativeStorageServer(serverid, ann.copy(), self) s = NativeStorageServer(serverid, ann.copy())
s.rref = rref s.rref = rref
s._is_connected = True s._is_connected = True
self.servers[serverid] = s self.servers[serverid] = s
def test_add_server(self, serverid, s): def test_add_server(self, serverid, s):
s.on_status_changed(lambda _: self._got_connection())
self.servers[serverid] = s self.servers[serverid] = s
def use_introducer(self, introducer_client): def use_introducer(self, introducer_client):
self.introducer_client = ic = introducer_client self.introducer_client = ic = introducer_client
ic.subscribe_to("storage", self._got_announcement) ic.subscribe_to("storage", self._got_announcement)
def _got_connection(self):
# this is called by NativeStorageClient when it is connected
self._server_listeners.notify()
def _got_announcement(self, key_s, ann): def _got_announcement(self, key_s, ann):
if key_s is not None: if key_s is not None:
precondition(isinstance(key_s, str), key_s) precondition(isinstance(key_s, str), key_s)
precondition(key_s.startswith("v0-"), key_s) precondition(key_s.startswith("v0-"), key_s)
assert ann["service-name"] == "storage" assert ann["service-name"] == "storage"
s = NativeStorageServer(key_s, ann, self) s = NativeStorageServer(key_s, ann)
s.on_status_changed(lambda _: self._got_connection())
serverid = s.get_serverid() serverid = s.get_serverid()
old = self.servers.get(serverid) old = self.servers.get(serverid)
if old: if old:
@ -136,7 +150,7 @@ class StorageFarmBroker:
# replacement # replacement
del self.servers[serverid] del self.servers[serverid]
old.stop_connecting() old.stop_connecting()
# now we forget about them and start using the new one # now we forget about them and start using the new one
self.servers[serverid] = s self.servers[serverid] = s
s.start_connecting(self.tub, self._trigger_connections) s.start_connecting(self.tub, self._trigger_connections)
# the descriptor will manage their own Reconnector, and each time we # the descriptor will manage their own Reconnector, and each time we
@ -224,10 +238,9 @@ class NativeStorageServer:
"application-version": "unknown: no get_version()", "application-version": "unknown: no get_version()",
} }
def __init__(self, key_s, ann, broker): def __init__(self, key_s, ann):
self.key_s = key_s self.key_s = key_s
self.announcement = ann self.announcement = ann
self.broker = broker
assert "anonymous-storage-FURL" in ann, ann assert "anonymous-storage-FURL" in ann, ann
furl = str(ann["anonymous-storage-FURL"]) furl = str(ann["anonymous-storage-FURL"])
@ -257,6 +270,14 @@ class NativeStorageServer:
self._is_connected = False self._is_connected = False
self._reconnector = None self._reconnector = None
self._trigger_cb = None self._trigger_cb = None
self._on_status_changed = ObserverList()
def on_status_changed(self, status_changed):
"""
:param status_changed: a callable taking a single arg (the
NativeStorageServer) that is notified when we become connected
"""
return self._on_status_changed.subscribe(status_changed)
# Special methods used by copy.copy() and copy.deepcopy(). When those are # Special methods used by copy.copy() and copy.deepcopy(). When those are
# used in allmydata.immutable.filenode to copy CheckResults during # used in allmydata.immutable.filenode to copy CheckResults during
@ -330,7 +351,7 @@ class NativeStorageServer:
default = self.VERSION_DEFAULTS default = self.VERSION_DEFAULTS
d = add_version_to_remote_reference(rref, default) d = add_version_to_remote_reference(rref, default)
d.addCallback(self._got_versioned_service, lp) d.addCallback(self._got_versioned_service, lp)
d.addCallback(lambda ign: self.broker.check_enough_connected()) d.addCallback(lambda ign: self._on_status_changed.notify(self))
d.addErrback(log.err, format="storageclient._got_connection", d.addErrback(log.err, format="storageclient._got_connection",
name=self.get_name(), umid="Sdq3pg") name=self.get_name(), umid="Sdq3pg")

@ -41,7 +41,7 @@ class WebResultsRendering(unittest.TestCase, WebRenderingMixin):
"my-version": "ver", "my-version": "ver",
"oldest-supported": "oldest", "oldest-supported": "oldest",
} }
s = NativeStorageServer(key_s, ann, sb) s = NativeStorageServer(key_s, ann)
sb.test_add_server(peerid, s) # XXX: maybe use key_s? sb.test_add_server(peerid, s) # XXX: maybe use key_s?
c = FakeClient() c = FakeClient()
c.storage_broker = sb c.storage_broker = sb

@ -1,12 +1,11 @@
import os from mock import Mock
from mock import Mock, patch
from allmydata.util import base32 from allmydata.util import base32
from twisted.trial import unittest from twisted.trial import unittest
from twisted.internet.defer import Deferred, succeed, inlineCallbacks from twisted.internet.defer import succeed, inlineCallbacks
from allmydata.storage_client import NativeStorageServer from allmydata.storage_client import NativeStorageServer
from allmydata.storage_client import StorageFarmBroker from allmydata.storage_client import StorageFarmBroker, ConnectedEnough
class NativeStorageServerWithVersion(NativeStorageServer): class NativeStorageServerWithVersion(NativeStorageServer):
@ -42,7 +41,7 @@ class TestStorageFarmBroker(unittest.TestCase):
tub = Mock() tub = Mock()
introducer = Mock() introducer = Mock()
broker = StorageFarmBroker(tub, True) broker = StorageFarmBroker(tub, True)
done = broker.when_connected_to(5) done = ConnectedEnough(broker, 5).when_connected_enough()
broker.use_introducer(introducer) broker.use_introducer(introducer)
# subscribes to "storage" to learn of new storage nodes # subscribes to "storage" to learn of new storage nodes
subscribe = introducer.mock_calls[0] subscribe = introducer.mock_calls[0]

@ -184,6 +184,8 @@ class FakeDisplayableServer(StubServer):
self.last_loss_time = last_loss_time self.last_loss_time = last_loss_time
self.last_rx_time = last_rx_time self.last_rx_time = last_rx_time
self.last_connect_time = last_connect_time self.last_connect_time = last_connect_time
def on_status_changed(self, cb):
cb(self)
def is_connected(self): def is_connected(self):
return self.connected return self.connected
def get_permutation_seed(self): def get_permutation_seed(self):
@ -234,6 +236,8 @@ class FakeStorageServer(service.MultiService):
self.lease_checker = FakeLeaseChecker() self.lease_checker = FakeLeaseChecker()
def get_stats(self): def get_stats(self):
return {"storage_server.accepting_immutable_shares": False} return {"storage_server.accepting_immutable_shares": False}
def on_status_changed(self, cb):
cb(self)
class FakeClient(Client): class FakeClient(Client):
def __init__(self): def __init__(self):