Change API to a listener-style, with helper

This commit is contained in:
meejah 2016-04-26 11:44:58 -06:00
parent 55898941da
commit b834b71dac
5 changed files with 71 additions and 47 deletions

View File

@ -373,12 +373,12 @@ class Client(node.Node, pollmixin.PollMixin):
ps = self.get_config("client", "peers.preferred", "").split(",")
preferred_peers = tuple([p.strip() for p in ps if p != ""])
sb = storage_client.StorageFarmBroker(self.tub, permute_peers=True, preferred_peers=preferred_peers)
self.storage_broker = sb
connection_threshold = min(self.encoding_params["k"],
self.encoding_params["happy"] + 1)
self.storage_broker = sb
self.upload_ready_d = sb.when_connected_to(connection_threshold)
helper = storage_client.ConnectedEnough(sb, connection_threshold)
self.upload_ready_d = helper.when_connected_enough()
# load static server specifications from tahoe.cfg, if any.
# Not quite ready yet.

View File

@ -31,11 +31,12 @@ the foolscap-based server implemented in src/allmydata/storage/*.py .
import re, time
from zope.interface import implements
from twisted.internet import defer
from foolscap.api import eventually
from allmydata.interfaces import IStorageBroker, IDisplayableServer, IServer
from allmydata.util import log, base32
from allmydata.util.assertutil import precondition
from allmydata.util.observer import OneShotObserverList
from allmydata.util.observer import OneShotObserverList, ObserverList
from allmydata.util.rrefutil import add_version_to_remote_reference
from allmydata.util.hashutil import sha1
@ -56,6 +57,41 @@ from allmydata.util.hashutil import sha1
# don't pass signatures: only pass validated blessed-objects
class ConnectedEnough(object):
def __init__(self, storage_farm_broker, threshold):
self._broker = storage_farm_broker
self._threshold = int(threshold)
if self._threshold <= 0:
raise ValueError("threshold must be positive")
self._threshold_passed = False
self._observers = OneShotObserverList()
self._broker.on_servers_changed(self._check_enough_connected)
def when_connected_enough(self):
"""
:returns: a Deferred that fires if/when our high water mark for
number of connected servers becomes (or ever was) above
"threshold".
"""
if self._threshold_passed:
return defer.succeed(None)
return self._observers.when_fired()
def _check_enough_connected(self):
"""
internal helper
"""
if self._threshold_passed:
return
num_servers = len(self._broker.get_connected_servers())
if num_servers >= self._threshold:
self._threshold_passed = True
self._observers.fire(None)
class StorageFarmBroker:
implements(IStorageBroker)
"""I live on the client, and know about storage servers. For each server
@ -75,59 +111,37 @@ class StorageFarmBroker:
# them for it.
self.servers = {}
self.introducer_client = None
# the most servers we've connected to at once
self._highest_connections = 0
# maps int -> OneShotObserverList, where the int is the threshold
self._connected_observers = dict()
self._server_listeners = ObserverList()
def when_connected_to(self, threshold):
"""
:returns: a Deferred that fires if/when our high water mark for
number of connected servers becomes (or ever was) above
"threshold".
"""
threshold = int(threshold)
if threshold <= 0:
raise ValueError("threshold must be positive")
if threshold <= self._highest_connections:
return defer.succeed(None)
try:
return self._connected_observers[threshold].when_fired()
except KeyError:
self._connected_observers[threshold] = OneShotObserverList()
return self._connected_observers[threshold].when_fired()
def check_enough_connected(self):
"""
internal helper
"""
num_servers = len(self.get_connected_servers())
self._highest_connections = max(num_servers, self._highest_connections)
try:
self._connected_observers[num_servers].fire_if_not_fired(None)
except KeyError:
pass
def on_servers_changed(self, callback):
self._server_listeners.subscribe(callback)
# these two are used in unit tests
def test_add_rref(self, serverid, rref, ann):
s = NativeStorageServer(serverid, ann.copy(), self)
s = NativeStorageServer(serverid, ann.copy())
s.rref = rref
s._is_connected = True
self.servers[serverid] = s
def test_add_server(self, serverid, s):
s.on_status_changed(lambda _: self._got_connection())
self.servers[serverid] = s
def use_introducer(self, introducer_client):
self.introducer_client = ic = introducer_client
ic.subscribe_to("storage", self._got_announcement)
def _got_connection(self):
# this is called by NativeStorageClient when it is connected
self._server_listeners.notify()
def _got_announcement(self, key_s, ann):
if key_s is not None:
precondition(isinstance(key_s, str), key_s)
precondition(key_s.startswith("v0-"), key_s)
assert ann["service-name"] == "storage"
s = NativeStorageServer(key_s, ann, self)
s = NativeStorageServer(key_s, ann)
s.on_status_changed(lambda _: self._got_connection())
serverid = s.get_serverid()
old = self.servers.get(serverid)
if old:
@ -136,7 +150,7 @@ class StorageFarmBroker:
# replacement
del self.servers[serverid]
old.stop_connecting()
# now we forget about them and start using the new one
# now we forget about them and start using the new one
self.servers[serverid] = s
s.start_connecting(self.tub, self._trigger_connections)
# the descriptor will manage their own Reconnector, and each time we
@ -224,10 +238,9 @@ class NativeStorageServer:
"application-version": "unknown: no get_version()",
}
def __init__(self, key_s, ann, broker):
def __init__(self, key_s, ann):
self.key_s = key_s
self.announcement = ann
self.broker = broker
assert "anonymous-storage-FURL" in ann, ann
furl = str(ann["anonymous-storage-FURL"])
@ -257,6 +270,14 @@ class NativeStorageServer:
self._is_connected = False
self._reconnector = None
self._trigger_cb = None
self._on_status_changed = ObserverList()
def on_status_changed(self, status_changed):
"""
:param status_changed: a callable taking a single arg (the
NativeStorageServer) that is notified when we become connected
"""
return self._on_status_changed.subscribe(status_changed)
# Special methods used by copy.copy() and copy.deepcopy(). When those are
# used in allmydata.immutable.filenode to copy CheckResults during
@ -330,7 +351,7 @@ class NativeStorageServer:
default = self.VERSION_DEFAULTS
d = add_version_to_remote_reference(rref, default)
d.addCallback(self._got_versioned_service, lp)
d.addCallback(lambda ign: self.broker.check_enough_connected())
d.addCallback(lambda ign: self._on_status_changed.notify(self))
d.addErrback(log.err, format="storageclient._got_connection",
name=self.get_name(), umid="Sdq3pg")

View File

@ -41,7 +41,7 @@ class WebResultsRendering(unittest.TestCase, WebRenderingMixin):
"my-version": "ver",
"oldest-supported": "oldest",
}
s = NativeStorageServer(key_s, ann, sb)
s = NativeStorageServer(key_s, ann)
sb.test_add_server(peerid, s) # XXX: maybe use key_s?
c = FakeClient()
c.storage_broker = sb

View File

@ -1,12 +1,11 @@
import os
from mock import Mock, patch
from mock import Mock
from allmydata.util import base32
from twisted.trial import unittest
from twisted.internet.defer import Deferred, succeed, inlineCallbacks
from twisted.internet.defer import succeed, inlineCallbacks
from allmydata.storage_client import NativeStorageServer
from allmydata.storage_client import StorageFarmBroker
from allmydata.storage_client import StorageFarmBroker, ConnectedEnough
class NativeStorageServerWithVersion(NativeStorageServer):
@ -42,7 +41,7 @@ class TestStorageFarmBroker(unittest.TestCase):
tub = Mock()
introducer = Mock()
broker = StorageFarmBroker(tub, True)
done = broker.when_connected_to(5)
done = ConnectedEnough(broker, 5).when_connected_enough()
broker.use_introducer(introducer)
# subscribes to "storage" to learn of new storage nodes
subscribe = introducer.mock_calls[0]

View File

@ -184,6 +184,8 @@ class FakeDisplayableServer(StubServer):
self.last_loss_time = last_loss_time
self.last_rx_time = last_rx_time
self.last_connect_time = last_connect_time
def on_status_changed(self, cb):
cb(self)
def is_connected(self):
return self.connected
def get_permutation_seed(self):
@ -234,6 +236,8 @@ class FakeStorageServer(service.MultiService):
self.lease_checker = FakeLeaseChecker()
def get_stats(self):
return {"storage_server.accepting_immutable_shares": False}
def on_status_changed(self, cb):
cb(self)
class FakeClient(Client):
def __init__(self):