diff --git a/src/allmydata/mutable/common.py b/src/allmydata/mutable/common.py index a2e482d3c..33a1c2731 100644 --- a/src/allmydata/mutable/common.py +++ b/src/allmydata/mutable/common.py @@ -10,6 +10,9 @@ MODE_WRITE = "MODE_WRITE" # replace all shares, probably.. not for initial MODE_READ = "MODE_READ" MODE_REPAIR = "MODE_REPAIR" # query all peers, get the privkey +from allmydata.crypto import aes, rsa +from allmydata.util import hashutil + class NotWriteableError(Exception): pass @@ -61,3 +64,33 @@ class CorruptShareError(BadShareError): class UnknownVersionError(BadShareError): """The share we received was of a version we don't recognize.""" + + +def encrypt_privkey(writekey: bytes, privkey: rsa.PrivateKey) -> bytes: + """ + For SSK, encrypt a private ("signature") key using the writekey. + """ + encryptor = aes.create_encryptor(writekey) + crypttext = aes.encrypt_data(encryptor, privkey) + return crypttext + +def decrypt_privkey(writekey: bytes, enc_privkey: bytes) -> rsa.PrivateKey: + """ + The inverse of ``encrypt_privkey``. + """ + decryptor = aes.create_decryptor(writekey) + privkey = aes.decrypt_data(decryptor, enc_privkey) + return privkey + +def derive_mutable_keys(keypair: tuple[rsa.PublicKey, rsa.PrivateKey]) -> tuple[bytes, bytes, bytes]: + """ + Derive the SSK writekey, encrypted writekey, and fingerprint from the + public/private ("verification" / "signature") keypair. + """ + pubkey, privkey = keypair + pubkey_s = rsa.der_string_from_verifying_key(pubkey) + privkey_s = rsa.der_string_from_signing_key(privkey) + writekey = hashutil.ssk_writekey_hash(privkey_s) + encprivkey = encrypt_privkey(writekey, privkey_s) + fingerprint = hashutil.ssk_pubkey_fingerprint_hash(pubkey_s) + return writekey, encprivkey, fingerprint diff --git a/src/allmydata/mutable/filenode.py b/src/allmydata/mutable/filenode.py index 99fdcc085..00b31c52b 100644 --- a/src/allmydata/mutable/filenode.py +++ b/src/allmydata/mutable/filenode.py @@ -9,8 +9,6 @@ from zope.interface import implementer from twisted.internet import defer, reactor from foolscap.api import eventually -from allmydata.crypto import aes -from allmydata.crypto import rsa from allmydata.interfaces import IMutableFileNode, ICheckable, ICheckResults, \ NotEnoughSharesError, MDMF_VERSION, SDMF_VERSION, IMutableUploadable, \ IMutableFileVersion, IWriteable @@ -21,8 +19,14 @@ from allmydata.uri import WriteableSSKFileURI, ReadonlySSKFileURI, \ from allmydata.monitor import Monitor from allmydata.mutable.publish import Publish, MutableData,\ TransformingUploadable -from allmydata.mutable.common import MODE_READ, MODE_WRITE, MODE_CHECK, UnrecoverableFileError, \ - UncoordinatedWriteError +from allmydata.mutable.common import ( + MODE_READ, + MODE_WRITE, + MODE_CHECK, + UnrecoverableFileError, + UncoordinatedWriteError, + derive_mutable_keys, +) from allmydata.mutable.servermap import ServerMap, ServermapUpdater from allmydata.mutable.retrieve import Retrieve from allmydata.mutable.checker import MutableChecker, MutableCheckAndRepairer @@ -132,13 +136,10 @@ class MutableFileNode(object): Deferred that fires (with the MutableFileNode instance you should use) when it completes. """ - (pubkey, privkey) = keypair - self._pubkey, self._privkey = pubkey, privkey - pubkey_s = rsa.der_string_from_verifying_key(self._pubkey) - privkey_s = rsa.der_string_from_signing_key(self._privkey) - self._writekey = hashutil.ssk_writekey_hash(privkey_s) - self._encprivkey = self._encrypt_privkey(self._writekey, privkey_s) - self._fingerprint = hashutil.ssk_pubkey_fingerprint_hash(pubkey_s) + self._pubkey, self._privkey = keypair + self._writekey, self._encprivkey, self._fingerprint = derive_mutable_keys( + keypair, + ) if version == MDMF_VERSION: self._uri = WriteableMDMFFileURI(self._writekey, self._fingerprint) self._protocol_version = version @@ -164,16 +165,6 @@ class MutableFileNode(object): (contents, type(contents)) return contents(self) - def _encrypt_privkey(self, writekey, privkey): - encryptor = aes.create_encryptor(writekey) - crypttext = aes.encrypt_data(encryptor, privkey) - return crypttext - - def _decrypt_privkey(self, enc_privkey): - decryptor = aes.create_decryptor(self._writekey) - privkey = aes.decrypt_data(decryptor, enc_privkey) - return privkey - def _populate_pubkey(self, pubkey): self._pubkey = pubkey def _populate_required_shares(self, required_shares): diff --git a/src/allmydata/mutable/retrieve.py b/src/allmydata/mutable/retrieve.py index efb2c0f85..64573a49a 100644 --- a/src/allmydata/mutable/retrieve.py +++ b/src/allmydata/mutable/retrieve.py @@ -24,7 +24,7 @@ from allmydata import hashtree, codec from allmydata.storage.server import si_b2a from allmydata.mutable.common import CorruptShareError, BadShareError, \ - UncoordinatedWriteError + UncoordinatedWriteError, decrypt_privkey from allmydata.mutable.layout import MDMFSlotReadProxy @implementer(IRetrieveStatus) @@ -923,9 +923,10 @@ class Retrieve(object): def _try_to_validate_privkey(self, enc_privkey, reader, server): - alleged_privkey_s = self._node._decrypt_privkey(enc_privkey) + node_writekey = self._node.get_writekey() + alleged_privkey_s = decrypt_privkey(node_writekey, enc_privkey) alleged_writekey = hashutil.ssk_writekey_hash(alleged_privkey_s) - if alleged_writekey != self._node.get_writekey(): + if alleged_writekey != node_writekey: self.log("invalid privkey from %s shnum %d" % (reader, reader.shnum), level=log.WEIRD, umid="YIw4tA") diff --git a/src/allmydata/mutable/servermap.py b/src/allmydata/mutable/servermap.py index cd220ce0f..99aa85d24 100644 --- a/src/allmydata/mutable/servermap.py +++ b/src/allmydata/mutable/servermap.py @@ -21,7 +21,7 @@ from allmydata.storage.server import si_b2a from allmydata.interfaces import IServermapUpdaterStatus from allmydata.mutable.common import MODE_CHECK, MODE_ANYTHING, MODE_WRITE, \ - MODE_READ, MODE_REPAIR, CorruptShareError + MODE_READ, MODE_REPAIR, CorruptShareError, decrypt_privkey from allmydata.mutable.layout import SIGNED_PREFIX_LENGTH, MDMFSlotReadProxy @implementer(IServermapUpdaterStatus) @@ -943,9 +943,10 @@ class ServermapUpdater(object): writekey stored in my node. If it is valid, then I set the privkey and encprivkey properties of the node. """ - alleged_privkey_s = self._node._decrypt_privkey(enc_privkey) + node_writekey = self._node.get_writekey() + alleged_privkey_s = decrypt_privkey(node_writekey, enc_privkey) alleged_writekey = hashutil.ssk_writekey_hash(alleged_privkey_s) - if alleged_writekey != self._node.get_writekey(): + if alleged_writekey != node_writekey: self.log("invalid privkey from %r shnum %d" % (server.get_name(), shnum), parent=lp, level=log.WEIRD, umid="aJVccw")