Add got_static_announcement and unit test

This commit is contained in:
David Stainton 2016-08-24 21:11:58 +00:00
parent 61eb839843
commit de61cd260c
2 changed files with 39 additions and 11 deletions

View File

@ -42,6 +42,12 @@ from allmydata.util.observer import ObserverList
from allmydata.util.rrefutil import add_version_to_remote_reference
from allmydata.util.hashutil import sha1
def get_serverid_from_furl(furl):
m = re.match(r'pb://(\w+)@', furl)
assert m, furl
id = m.group(1).lower()
return base32.a2b(id)
# who is responsible for de-duplication?
# both?
# IC remembers the unpacked announcements it receives, to provide for late
@ -80,6 +86,7 @@ class StorageFarmBroker(service.MultiService):
# own Reconnector, and will give us a RemoteReference when we ask
# them for it.
self.servers = {}
self.static_servers = []
self.introducer_client = None
self._threshold_listeners = [] # tuples of (threshold, Deferred)
self._connected_high_water_mark = 0
@ -97,7 +104,7 @@ class StorageFarmBroker(service.MultiService):
# these two are used in unit tests
def test_add_rref(self, serverid, rref, ann):
s = NativeStorageServer(serverid, ann.copy())
s = NativeStorageServer(serverid, ann.copy(), self._tub_options, self._tub_handlers)
s.rref = rref
s._is_connected = True
self.servers[serverid] = s
@ -127,20 +134,25 @@ class StorageFarmBroker(service.MultiService):
remaining.append( (threshold, d) )
self._threshold_listeners = remaining
def got_static_announcement(self, key_s, ann):
server_id = get_serverid_from_furl(ann["anonymous-storage-FURL"])
assert server_id not in self.static_servers # XXX
self.static_servers.append(server_id)
self._got_announcement(key_s, ann)
def _got_announcement(self, key_s, ann):
if key_s is not None:
precondition(isinstance(key_s, str), key_s)
precondition(key_s.startswith("v0-"), key_s)
assert ann["service-name"] == "storage"
precondition(isinstance(key_s, str), key_s)
precondition(key_s.startswith("v0-"), key_s)
precondition(ann["service-name"] == "storage", ann["service-name"])
s = NativeStorageServer(key_s, ann, self._tub_options, self._tub_handlers)
s.on_status_changed(lambda _: self._got_connection())
serverid = s.get_serverid()
old = self.servers.get(serverid)
if old:
server_id = s.get_serverid()
old = self.servers.get(server_id)
if old and server_id not in self.static_servers:
if old.get_announcement() == ann:
return # duplicate
# replacement
del self.servers[serverid]
del self.servers[server_id]
old.stop_connecting()
old.disownServiceParent()
# NOTE: this disownServiceParent() returns a Deferred that
@ -156,8 +168,8 @@ class StorageFarmBroker(service.MultiService):
# cycles around when they fire earlier than that, which will
# almost always be the case for normal runtime).
# now we forget about them and start using the new one
self.servers[serverid] = s
s.setServiceParent(self)
self.servers[server_id] = s
s.start_connecting(self._trigger_connections)
# the descriptor will manage their own Reconnector, and each time we
# need servers, we'll ask them if they're connected or not.
@ -346,11 +358,13 @@ class NativeStorageServer(service.MultiService):
available_space = protocol_v1_version.get('maximum-immutable-share-size', None)
return available_space
def start_connecting(self, trigger_cb):
self._tub = Tub()
for (name, value) in self._tub_options.items():
self._tub.setOption(name, value)
# XXX todo: do stuff with the handlers
# XXX todo: set tub handlers
self._tub.setServiceParent(self)
furl = str(self.announcement["anonymous-storage-FURL"])
self._trigger_cb = trigger_cb

View File

@ -38,6 +38,20 @@ class TestNativeStorageServer(unittest.TestCase):
class TestStorageFarmBroker(unittest.TestCase):
def test_static_announcement(self):
broker = StorageFarmBroker(True)
key_s = 'v0-1234-{}'.format(1)
ann = {
"service-name": "storage",
"anonymous-storage-FURL": "pb://{}@nowhere/fake".format(base32.b2a(str(1))),
"permutation-seed-base32": "aaaaaaaaaaaaaaaaaaaaaaaa",
}
broker.got_static_announcement(key_s, ann)
self.failUnlessEqual(len(broker.static_servers), 1)
self.failUnlessEqual(broker.servers['1'].announcement, ann)
self.failUnlessEqual(broker.servers['1'].key_s, key_s)
@inlineCallbacks
def test_threshold_reached(self):
introducer = Mock()