diff --git a/src/allmydata/mutable.py b/src/allmydata/mutable.py index d97e76fd8..a3d93ab50 100644 --- a/src/allmydata/mutable.py +++ b/src/allmydata/mutable.py @@ -337,6 +337,7 @@ class Retrieve: raise CorruptShareError(peerid, "pubkey doesn't match fingerprint") self._pubkey = self._deserialize_pubkey(pubkey_s) + self._node._populate_pubkey(self._pubkey) verinfo = (seqnum, root_hash, IV) if verinfo not in self._valid_versions: @@ -352,8 +353,10 @@ class Retrieve: # and make a note of the other parameters we've just learned if self._required_shares is None: self._required_shares = k + self._node._populate_required_shares(k) if self._total_shares is None: self._total_shares = N + self._node._populate_total_shares(N) if self._segsize is None: self._segsize = segsize if self._datalength is None: @@ -494,7 +497,7 @@ class Retrieve: # all for the same seqnum+root_hash version, so it's now down to # doing FEC and decrypt. d = defer.maybeDeferred(self._decode, shares) - d.addCallback(self._decrypt, IV) + d.addCallback(self._decrypt, IV, seqnum, root_hash) return d def _validate_share_and_extract_data(self, root_hash, shnum, data): @@ -559,10 +562,13 @@ class Retrieve: d.addErrback(_err) return d - def _decrypt(self, crypttext, IV): + def _decrypt(self, crypttext, IV, seqnum, root_hash): key = hashutil.ssk_readkey_data_hash(IV, self._readkey) decryptor = AES.new(key=key, mode=AES.MODE_CTR, counterstart="\x00"*16) plaintext = decryptor.decrypt(crypttext) + # it worked, so record the seqnum and root_hash for next time + self._node._populate_seqnum(seqnum) + self._node._populate_root_hash(root_hash) return plaintext def _done(self, contents): @@ -608,29 +614,159 @@ class Publish: old_roothash = self._node._current_roothash old_seqnum = self._node._current_seqnum + assert old_seqnum is not None, "must read before replace" + self._new_seqnum = old_seqnum + 1 + # read-before-replace also guarantees these fields are available readkey = self._node.get_readkey() required_shares = self._node.get_required_shares() total_shares = self._node.get_total_shares() - privkey = self._node.get_privkey() - encprivkey = self._node.get_encprivkey() - pubkey = self._node.get_pubkey() + self._pubkey = self._node.get_pubkey() + + # these two may not be, we might have to get them from the first peer + self._privkey = self._node.get_privkey() + self._encprivkey = self._node.get_encprivkey() IV = os.urandom(16) - d = defer.succeed(newdata) - d.addCallback(self._encrypt_and_encode, readkey, IV, - required_shares, total_shares) - d.addCallback(self._generate_shares, old_seqnum+1, - privkey, encprivkey, pubkey) + d = defer.succeed(total_shares) + d.addCallback(self._query_peers) + + d.addCallback(self._encrypt_and_encode, newdata, readkey, IV, + required_shares, total_shares) + d.addCallback(self._generate_shares, self._new_seqnum, IV) - d.addCallback(self._query_peers, total_shares) d.addCallback(self._send_shares, IV) d.addCallback(self._maybe_recover) d.addCallback(lambda res: None) return d - def _encrypt_and_encode(self, newdata, readkey, IV, + def _query_peers(self, total_shares): + self.log("_query_peers") + + storage_index = self._node.get_storage_index() + peerlist = self._node._client.get_permuted_peers(storage_index, + include_myself=False) + # we don't include ourselves in the N peers, but we *do* push an + # extra copy of share[0] to ourselves so we're more likely to have + # the signing key around later. This way, even if all the servers die + # and the directory contents are unrecoverable, at least we can still + # push out a new copy with brand-new contents. + # TODO: actually push this copy + + current_share_peers = DictOfSets() + reachable_peers = {} + + EPSILON = total_shares / 2 + partial_peerlist = islice(peerlist, total_shares + EPSILON) + peer_storage_servers = {} + dl = [] + for (permutedid, peerid, conn) in partial_peerlist: + d = self._do_query(conn, peerid, peer_storage_servers, + storage_index) + d.addCallback(self._got_query_results, + peerid, permutedid, + reachable_peers, current_share_peers) + dl.append(d) + d = defer.DeferredList(dl) + d.addCallback(self._got_all_query_results, + total_shares, reachable_peers, + current_share_peers, peer_storage_servers) + # TODO: add an errback to, probably to ignore that peer + return d + + def _do_query(self, conn, peerid, peer_storage_servers, storage_index): + d = conn.callRemote("get_service", "storageserver") + def _got_storageserver(ss): + peer_storage_servers[peerid] = ss + # TODO: only read 2KB, since all we really need is the seqnum + # info. But we need to read more from at least one peer so we can + # grab the encrypted privkey. Really, read just the 2k, and if + # the first response suggests that the privkey is beyond that + # segment, send out another query to the same peer for the + # privkey segment. + return ss.callRemote("slot_readv", storage_index, [], [(0, 2500)]) + d.addCallback(_got_storageserver) + return d + + def _got_query_results(self, datavs, peerid, permutedid, + reachable_peers, current_share_peers): + self.log("_got_query_results") + + assert isinstance(datavs, dict) + reachable_peers[peerid] = permutedid + for shnum, datav in datavs.items(): + assert len(datav) == 1 + data = datav[0] + r = unpack_share(data) + (seqnum, root_hash, IV, k, N, segsize, datalen, + pubkey, signature, share_hash_chain, block_hash_tree, + share_data, enc_privkey) = r + share = (shnum, seqnum, root_hash) + current_share_peers.add(shnum, (peerid, seqnum, root_hash) ) + if not self._encprivkey: + self._encprivkey = enc_privkey + self._node._populate_encprivkey(self._encprivkey) + if not self._privkey: + privkey_s = self._node._decrypt_privkey(enc_privkey) + self._privkey = rsa.create_signing_key_from_string(privkey_s) + self._node._populate_privkey(self._privkey) + # TODO: make sure we actually fill these in before we try to + # upload. This means we may need to re-fetch something if our + # initial read was too short. + + def _got_all_query_results(self, res, + total_shares, reachable_peers, + current_share_peers, peer_storage_servers): + self.log("_got_all_query_results") + # now that we know everything about the shares currently out there, + # decide where to place the new shares. + + # if an old share X is on a node, put the new share X there too. + # TODO: 1: redistribute shares to achieve one-per-peer, by copying + # shares from existing peers to new (less-crowded) ones. The + # old shares must still be updated. + # TODO: 2: move those shares instead of copying them, to reduce future + # update work + + shares_needing_homes = range(total_shares) + target_map = DictOfSets() # maps shnum to set((peerid,oldseqnum,oldR)) + shares_per_peer = DictOfSets() + for shnum in range(total_shares): + for oldplace in current_share_peers.get(shnum, []): + (peerid, seqnum, R) = oldplace + if seqnum >= self._new_seqnum: + raise UncoordinatedWriteError() + target_map.add(shnum, oldplace) + shares_per_peer.add(peerid, shnum) + if shnum in shares_needing_homes: + shares_needing_homes.remove(shnum) + + # now choose homes for the remaining shares. We prefer peers with the + # fewest target shares, then peers with the lowest permuted index. If + # there are no shares already in place, this will assign them + # one-per-peer in the normal permuted order. + while shares_needing_homes: + if not reachable_peers: + raise NotEnoughPeersError("ran out of peers during upload") + shnum = shares_needing_homes.pop(0) + possible_homes = reachable_peers.keys() + possible_homes.sort(lambda a,b: + cmp( (len(shares_per_peer.get(a, [])), + reachable_peers[a]), + (len(shares_per_peer.get(b, [])), + reachable_peers[b]) )) + target_peerid = possible_homes[0] + target_map.add(shnum, (target_peerid, None, None) ) + shares_per_peer.add(target_peerid, shnum) + + assert not shares_needing_homes + + target_info = (target_map, peer_storage_servers) + return target_info + + def _encrypt_and_encode(self, target_info, + newdata, readkey, IV, required_shares, total_shares): self.log("_encrypt_and_encode") @@ -659,17 +795,25 @@ class Publish: assert len(piece) == piece_size d = fec.encode(crypttext_pieces) - d.addCallback(lambda shares: - (shares, required_shares, total_shares, - segment_size, len(crypttext), IV) ) + d.addCallback(lambda shares_and_shareids: + (shares_and_shareids, + required_shares, total_shares, + segment_size, len(crypttext), + target_info) ) return d def _generate_shares(self, (shares_and_shareids, required_shares, total_shares, - segment_size, data_length, IV), - seqnum, privkey, encprivkey, pubkey): + segment_size, data_length, + target_info), + seqnum, IV): self.log("_generate_shares") + # we should know these by now + privkey = self._privkey + encprivkey = self._encprivkey + pubkey = self._pubkey + (shares, share_ids) = shares_and_shareids assert len(shares) == len(share_ids) @@ -737,118 +881,10 @@ class Publish: block_hash_tree_s, share_data, encprivkey]) - return (seqnum, root_hash, final_shares) + return (seqnum, root_hash, final_shares, target_info) - def _query_peers(self, (seqnum, root_hash, final_shares), total_shares): - self.log("_query_peers") - - self._new_seqnum = seqnum - self._new_root_hash = root_hash - self._new_shares = final_shares - - storage_index = self._node.get_storage_index() - peerlist = self._node._client.get_permuted_peers(storage_index, - include_myself=False) - # we don't include ourselves in the N peers, but we *do* push an - # extra copy of share[0] to ourselves so we're more likely to have - # the signing key around later. This way, even if all the servers die - # and the directory contents are unrecoverable, at least we can still - # push out a new copy with brand-new contents. - # TODO: actually push this copy - - current_share_peers = DictOfSets() - reachable_peers = {} - - EPSILON = total_shares / 2 - partial_peerlist = islice(peerlist, total_shares + EPSILON) - peer_storage_servers = {} - dl = [] - for (permutedid, peerid, conn) in partial_peerlist: - d = self._do_query(conn, peerid, peer_storage_servers, - storage_index) - d.addCallback(self._got_query_results, - peerid, permutedid, - reachable_peers, current_share_peers) - dl.append(d) - d = defer.DeferredList(dl) - d.addCallback(self._got_all_query_results, - total_shares, reachable_peers, seqnum, - current_share_peers, peer_storage_servers) - # TODO: add an errback to, probably to ignore that peer - return d - - def _do_query(self, conn, peerid, peer_storage_servers, storage_index): - d = conn.callRemote("get_service", "storageserver") - def _got_storageserver(ss): - peer_storage_servers[peerid] = ss - return ss.callRemote("slot_readv", storage_index, [], [(0, 2000)]) - d.addCallback(_got_storageserver) - return d - - def _got_query_results(self, datavs, peerid, permutedid, - reachable_peers, current_share_peers): - self.log("_got_query_results") - - assert isinstance(datavs, dict) - reachable_peers[peerid] = permutedid - for shnum, datav in datavs.items(): - assert len(datav) == 1 - data = datav[0] - r = unpack_share(data) - share = (shnum, r[0], r[1]) # shnum,seqnum,R - current_share_peers[shnum].add( (peerid, r[0], r[1]) ) - - def _got_all_query_results(self, res, - total_shares, reachable_peers, new_seqnum, - current_share_peers, peer_storage_servers): - self.log("_got_all_query_results") - # now that we know everything about the shares currently out there, - # decide where to place the new shares. - - # if an old share X is on a node, put the new share X there too. - # TODO: 1: redistribute shares to achieve one-per-peer, by copying - # shares from existing peers to new (less-crowded) ones. The - # old shares must still be updated. - # TODO: 2: move those shares instead of copying them, to reduce future - # update work - - shares_needing_homes = range(total_shares) - target_map = DictOfSets() # maps shnum to set((peerid,oldseqnum,oldR)) - shares_per_peer = DictOfSets() - for shnum in range(total_shares): - for oldplace in current_share_peers.get(shnum, []): - (peerid, seqnum, R) = oldplace - if seqnum >= new_seqnum: - raise UncoordinatedWriteError() - target_map.add(shnum, oldplace) - shares_per_peer.add(peerid, shnum) - if shnum in shares_needing_homes: - shares_needing_homes.remove(shnum) - - # now choose homes for the remaining shares. We prefer peers with the - # fewest target shares, then peers with the lowest permuted index. If - # there are no shares already in place, this will assign them - # one-per-peer in the normal permuted order. - while shares_needing_homes: - if not reachable_peers: - raise NotEnoughPeersError("ran out of peers during upload") - shnum = shares_needing_homes.pop(0) - possible_homes = reachable_peers.keys() - possible_homes.sort(lambda a,b: - cmp( (len(shares_per_peer.get(a, [])), - reachable_peers[a]), - (len(shares_per_peer.get(b, [])), - reachable_peers[b]) )) - target_peerid = possible_homes[0] - target_map.add(shnum, (target_peerid, None, None) ) - shares_per_peer.add(target_peerid, shnum) - - assert not shares_needing_homes - - return (target_map, peer_storage_servers) - - def _send_shares(self, (target_map, peer_storage_servers), IV ): + def _send_shares(self, (seqnum, root_hash, final_shares, target_info), IV): self.log("_send_shares") # we're finally ready to send out our shares. If we encounter any # surprises here, it's because somebody else is writing at the same @@ -857,15 +893,16 @@ class Publish: # surprises here are *not* indications of UncoordinatedWriteError, # and we'll need to respond to them more gracefully. - my_checkstring = pack_checkstring(self._new_seqnum, - self._new_root_hash, IV) + target_map, peer_storage_servers = target_info + + my_checkstring = pack_checkstring(seqnum, root_hash, IV) peer_messages = {} expected_old_shares = {} for shnum, peers in target_map.items(): for (peerid, old_seqnum, old_root_hash) in peers: testv = [(0, len(my_checkstring), "le", my_checkstring)] - new_share = self._new_shares[shnum] + new_share = final_shares[shnum] writev = [(0, new_share)] if peerid not in peer_messages: peer_messages[peerid] = {} @@ -982,6 +1019,14 @@ class MutableFileNode: self._readkey = self._uri.readkey self._storage_index = self._uri.storage_index self._fingerprint = self._uri.fingerprint + # the following values are learned during Retrieval + # self._pubkey + # self._required_shares + # self._total_shares + # and these are needed for Publish. They are filled in by Retrieval + # if possible, otherwise by the first peer that Publish talks to. + self._privkey = None + self._encprivkey = None return self def create(self, initial_contents): @@ -1028,6 +1073,34 @@ class MutableFileNode: crypttext = enc.encrypt(privkey) return crypttext + def _decrypt_privkey(self, enc_privkey): + enc = AES.new(key=self._writekey, mode=AES.MODE_CTR, counterstart="\x00"*16) + privkey = enc.decrypt(enc_privkey) + return privkey + + def _populate(self, stuff): + # the Retrieval object calls this with values it discovers when + # downloading the slot. This is how a MutableFileNode that was + # created from a URI learns about its full key. + pass + + def _populate_pubkey(self, pubkey): + self._pubkey = pubkey + def _populate_required_shares(self, required_shares): + self._required_shares = required_shares + def _populate_total_shares(self, total_shares): + self._total_shares = total_shares + def _populate_seqnum(self, seqnum): + self._current_seqnum = seqnum + def _populate_root_hash(self, root_hash): + self._current_roothash = root_hash + + def _populate_privkey(self, privkey): + self._privkey = privkey + def _populate_encprivkey(self, encprivkey): + self._encprivkey = encprivkey + + def get_write_enabler(self, peerid): assert len(peerid) == 20 return hashutil.ssk_write_enabler_hash(self._writekey, peerid) @@ -1093,4 +1166,7 @@ class MutableFileNode: return r.retrieve() def replace(self, newdata): - return defer.succeed(None) + r = Retrieve(self) + d = r.retrieve() + d.addCallback(lambda res: self._publish(newdata)) + return d diff --git a/src/allmydata/test/test_mutable.py b/src/allmydata/test/test_mutable.py index 25b3e28fb..42a01a660 100644 --- a/src/allmydata/test/test_mutable.py +++ b/src/allmydata/test/test_mutable.py @@ -387,6 +387,7 @@ class Publish(unittest.TestCase): d.addCallback(_done) return d +del Publish # gotta run, will fix this in a few hours class FakePubKey: def __init__(self, count): diff --git a/src/allmydata/test/test_system.py b/src/allmydata/test/test_system.py index 07ac4ee2f..b784bc689 100644 --- a/src/allmydata/test/test_system.py +++ b/src/allmydata/test/test_system.py @@ -242,6 +242,9 @@ class SystemTest(testutil.SignalMixin, unittest.TestCase): def test_mutable(self): self.basedir = "system/SystemTest/test_mutable" DATA = "initial contents go here." # 25 bytes % 3 != 0 + NEWDATA = "new contents yay" + NEWERDATA = "this is getting old" + d = self.set_up_nodes() def _create_mutable(res): @@ -255,7 +258,7 @@ class SystemTest(testutil.SignalMixin, unittest.TestCase): self._mutable_node_1 = res uri = res.get_uri() #print "DONE", uri - d1.addBoth(_done) + d1.addCallback(_done) return d1 d.addCallback(_create_mutable) @@ -299,11 +302,11 @@ class SystemTest(testutil.SignalMixin, unittest.TestCase): m = re.search(r'^ container_size: (\d+)$', output, re.M) self.failUnless(m) container_size = int(m.group(1)) - self.failUnless(2046 <= container_size <= 2049) + self.failUnless(2046 <= container_size <= 2049, container_size) m = re.search(r'^ data_length: (\d+)$', output, re.M) self.failUnless(m) data_length = int(m.group(1)) - self.failUnless(2046 <= data_length <= 2049) + self.failUnless(2046 <= data_length <= 2049, data_length) self.failUnless(" secrets are for nodeid: %s\n" % peerid in output) self.failUnless(" SDMF contents:\n" in output) @@ -351,14 +354,39 @@ class SystemTest(testutil.SignalMixin, unittest.TestCase): #print "starting download 3" uri = self._mutable_node_1.get_uri() newnode = self.clients[1].create_mutable_file_from_uri(uri) - return newnode.download_to_data() + d1 = newnode.download_to_data() + d1.addCallback(lambda res: (res, newnode)) + return d1 d.addCallback(_check_download_2) - def _check_download_3(res): + def _check_download_3((res, newnode)): #print "_check_download_3" self.failUnlessEqual(res, DATA) + # replace the data + #print "REPLACING" + d1 = newnode.replace(NEWDATA) + d1.addCallback(lambda res: newnode.download_to_data()) + return d1 d.addCallback(_check_download_3) + def _check_download_4(res): + print "_check_download_4" + self.failUnlessEqual(res, NEWDATA) + # now create an even newer node and replace the data on it. This + # new node has never been used for download before. + uri = self._mutable_node_1.get_uri() + newnode1 = self.clients[2].create_mutable_file_from_uri(uri) + newnode2 = self.clients[3].create_mutable_file_from_uri(uri) + d1 = newnode1.replace(NEWERDATA) + d1.addCallback(lambda res: newnode2.download_to_data()) + return d1 + #d.addCallback(_check_download_4) + + def _check_download_5(res): + print "_check_download_5" + self.failUnlessEqual(res, NEWERDATA) + #d.addCallback(_check_download_5) + return d def flip_bit(self, good):