Factor some SSK "signature" key handling code into a more reusable shape

This gives the test suite access to the derivation function so it can
re-derive certain values to use as expected results to compare against actual
results.
This commit is contained in:
Jean-Paul Calderone 2023-01-03 10:38:18 -05:00
parent 5bad92cfc5
commit c7bb190290
4 changed files with 53 additions and 27 deletions

View File

@ -10,6 +10,9 @@ MODE_WRITE = "MODE_WRITE" # replace all shares, probably.. not for initial
MODE_READ = "MODE_READ"
MODE_REPAIR = "MODE_REPAIR" # query all peers, get the privkey
from allmydata.crypto import aes, rsa
from allmydata.util import hashutil
class NotWriteableError(Exception):
pass
@ -61,3 +64,33 @@ class CorruptShareError(BadShareError):
class UnknownVersionError(BadShareError):
"""The share we received was of a version we don't recognize."""
def encrypt_privkey(writekey: bytes, privkey: rsa.PrivateKey) -> bytes:
"""
For SSK, encrypt a private ("signature") key using the writekey.
"""
encryptor = aes.create_encryptor(writekey)
crypttext = aes.encrypt_data(encryptor, privkey)
return crypttext
def decrypt_privkey(writekey: bytes, enc_privkey: bytes) -> rsa.PrivateKey:
"""
The inverse of ``encrypt_privkey``.
"""
decryptor = aes.create_decryptor(writekey)
privkey = aes.decrypt_data(decryptor, enc_privkey)
return privkey
def derive_mutable_keys(keypair: tuple[rsa.PublicKey, rsa.PrivateKey]) -> tuple[bytes, bytes, bytes]:
"""
Derive the SSK writekey, encrypted writekey, and fingerprint from the
public/private ("verification" / "signature") keypair.
"""
pubkey, privkey = keypair
pubkey_s = rsa.der_string_from_verifying_key(pubkey)
privkey_s = rsa.der_string_from_signing_key(privkey)
writekey = hashutil.ssk_writekey_hash(privkey_s)
encprivkey = encrypt_privkey(writekey, privkey_s)
fingerprint = hashutil.ssk_pubkey_fingerprint_hash(pubkey_s)
return writekey, encprivkey, fingerprint

View File

@ -9,8 +9,6 @@ from zope.interface import implementer
from twisted.internet import defer, reactor
from foolscap.api import eventually
from allmydata.crypto import aes
from allmydata.crypto import rsa
from allmydata.interfaces import IMutableFileNode, ICheckable, ICheckResults, \
NotEnoughSharesError, MDMF_VERSION, SDMF_VERSION, IMutableUploadable, \
IMutableFileVersion, IWriteable
@ -21,8 +19,14 @@ from allmydata.uri import WriteableSSKFileURI, ReadonlySSKFileURI, \
from allmydata.monitor import Monitor
from allmydata.mutable.publish import Publish, MutableData,\
TransformingUploadable
from allmydata.mutable.common import MODE_READ, MODE_WRITE, MODE_CHECK, UnrecoverableFileError, \
UncoordinatedWriteError
from allmydata.mutable.common import (
MODE_READ,
MODE_WRITE,
MODE_CHECK,
UnrecoverableFileError,
UncoordinatedWriteError,
derive_mutable_keys,
)
from allmydata.mutable.servermap import ServerMap, ServermapUpdater
from allmydata.mutable.retrieve import Retrieve
from allmydata.mutable.checker import MutableChecker, MutableCheckAndRepairer
@ -132,13 +136,10 @@ class MutableFileNode(object):
Deferred that fires (with the MutableFileNode instance you should
use) when it completes.
"""
(pubkey, privkey) = keypair
self._pubkey, self._privkey = pubkey, privkey
pubkey_s = rsa.der_string_from_verifying_key(self._pubkey)
privkey_s = rsa.der_string_from_signing_key(self._privkey)
self._writekey = hashutil.ssk_writekey_hash(privkey_s)
self._encprivkey = self._encrypt_privkey(self._writekey, privkey_s)
self._fingerprint = hashutil.ssk_pubkey_fingerprint_hash(pubkey_s)
self._pubkey, self._privkey = keypair
self._writekey, self._encprivkey, self._fingerprint = derive_mutable_keys(
keypair,
)
if version == MDMF_VERSION:
self._uri = WriteableMDMFFileURI(self._writekey, self._fingerprint)
self._protocol_version = version
@ -164,16 +165,6 @@ class MutableFileNode(object):
(contents, type(contents))
return contents(self)
def _encrypt_privkey(self, writekey, privkey):
encryptor = aes.create_encryptor(writekey)
crypttext = aes.encrypt_data(encryptor, privkey)
return crypttext
def _decrypt_privkey(self, enc_privkey):
decryptor = aes.create_decryptor(self._writekey)
privkey = aes.decrypt_data(decryptor, enc_privkey)
return privkey
def _populate_pubkey(self, pubkey):
self._pubkey = pubkey
def _populate_required_shares(self, required_shares):

View File

@ -24,7 +24,7 @@ from allmydata import hashtree, codec
from allmydata.storage.server import si_b2a
from allmydata.mutable.common import CorruptShareError, BadShareError, \
UncoordinatedWriteError
UncoordinatedWriteError, decrypt_privkey
from allmydata.mutable.layout import MDMFSlotReadProxy
@implementer(IRetrieveStatus)
@ -923,9 +923,10 @@ class Retrieve(object):
def _try_to_validate_privkey(self, enc_privkey, reader, server):
alleged_privkey_s = self._node._decrypt_privkey(enc_privkey)
node_writekey = self._node.get_writekey()
alleged_privkey_s = decrypt_privkey(node_writekey, enc_privkey)
alleged_writekey = hashutil.ssk_writekey_hash(alleged_privkey_s)
if alleged_writekey != self._node.get_writekey():
if alleged_writekey != node_writekey:
self.log("invalid privkey from %s shnum %d" %
(reader, reader.shnum),
level=log.WEIRD, umid="YIw4tA")

View File

@ -21,7 +21,7 @@ from allmydata.storage.server import si_b2a
from allmydata.interfaces import IServermapUpdaterStatus
from allmydata.mutable.common import MODE_CHECK, MODE_ANYTHING, MODE_WRITE, \
MODE_READ, MODE_REPAIR, CorruptShareError
MODE_READ, MODE_REPAIR, CorruptShareError, decrypt_privkey
from allmydata.mutable.layout import SIGNED_PREFIX_LENGTH, MDMFSlotReadProxy
@implementer(IServermapUpdaterStatus)
@ -943,9 +943,10 @@ class ServermapUpdater(object):
writekey stored in my node. If it is valid, then I set the
privkey and encprivkey properties of the node.
"""
alleged_privkey_s = self._node._decrypt_privkey(enc_privkey)
node_writekey = self._node.get_writekey()
alleged_privkey_s = decrypt_privkey(node_writekey, enc_privkey)
alleged_writekey = hashutil.ssk_writekey_hash(alleged_privkey_s)
if alleged_writekey != self._node.get_writekey():
if alleged_writekey != node_writekey:
self.log("invalid privkey from %r shnum %d" %
(server.get_name(), shnum),
parent=lp, level=log.WEIRD, umid="aJVccw")