move AES to a helper-function style

This commit is contained in:
meejah
2019-06-17 15:54:46 -06:00
parent 47ccdb0177
commit 310fb60247
10 changed files with 138 additions and 76 deletions

View File

@ -1,32 +1,88 @@
import six import six
from cryptography.hazmat.backends import default_backend from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from cryptography.hazmat.primitives.ciphers import (
Cipher,
algorithms,
modes,
CipherContext,
)
class AES:
__DEFAULT_IV = '\x00' * 16 DEFAULT_IV = '\x00' * 16
def __init__(self, key, iv=None):
# validate the key
if not isinstance(key, six.binary_type):
raise TypeError('Key was not bytes')
if len(key) not in (16, 32):
raise ValueError('Key was not 16 or 32 bytes long')
# validate the IV def create_encryptor(key, iv=None):
if iv is None: """
iv = self.__DEFAULT_IV Create and return a new object which can do AES encryptions with
if not isinstance(iv, six.binary_type): the given key and initialization vector (IV). The default IV is 16
raise TypeError('IV was not bytes') zero-bytes.
if len(iv) != 16:
raise ValueError('IV was not 16 bytes long')
self._cipher = Cipher(algorithms.AES(key), modes.CTR(iv), backend=default_backend()) The returned object is suitable for use with `encrypt_data`
self._encryptor = self._cipher.encryptor() """
key = _validate_key(key)
iv = _validate_iv(iv)
cipher = Cipher(
algorithms.AES(key),
modes.CTR(iv),
backend=default_backend()
)
return cipher.encryptor()
def process(self, plaintext):
if not isinstance(plaintext, six.binary_type):
raise TypeError('Plaintext was not bytes')
return self._encryptor.update(plaintext) def encrypt_data(encryptor, plaintext):
"""
AES-encrypt `plaintext` with the given `encryptor`.
:param encryptor: an instance previously returned from `create_encryptor`
:param bytes plaintext: the data to encrypt
:returns: ciphertext
"""
_validate_encryptor(encryptor)
if not isinstance(plaintext, six.binary_type):
raise ValueError('Plaintext was not bytes')
return encryptor.update(plaintext)
create_decryptor = create_encryptor
decrypt_data = encrypt_data
def _validate_encryptor(encryptor):
"""
raise ValueError for `encryptor` is not a valid object
"""
if not isinstance(encryptor, CipherContext):
raise ValueError(
"'encryptor' must be a CipherContext"
)
def _validate_key(key):
"""
confirm `key` is suitable for AES encryption, or raise ValueError
"""
if not isinstance(key, six.binary_type):
raise TypeError('Key was not bytes')
if len(key) not in (16, 32):
raise ValueError('Key was not 16 or 32 bytes long')
return key
def _validate_iv(iv):
"""
confirm `iv` is a suitable initialization vector
"""
if iv is None:
return DEFAULT_IV
if not isinstance(iv, six.binary_type):
raise TypeError('IV was not bytes')
if len(iv) != 16:
raise ValueError('IV was not 16 bytes long')
return iv

View File

@ -6,7 +6,7 @@ from twisted.internet import defer
from foolscap.api import fireEventually from foolscap.api import fireEventually
import json import json
from allmydata.crypto.aes import AES from allmydata.crypto import aes
from allmydata.deep_stats import DeepStats from allmydata.deep_stats import DeepStats
from allmydata.mutable.common import NotWriteableError from allmydata.mutable.common import NotWriteableError
from allmydata.mutable.filenode import MutableFileNode from allmydata.mutable.filenode import MutableFileNode
@ -214,8 +214,8 @@ def _encrypt_rw_uri(writekey, rw_uri):
salt = hashutil.mutable_rwcap_salt_hash(rw_uri) salt = hashutil.mutable_rwcap_salt_hash(rw_uri)
key = hashutil.mutable_rwcap_key_hash(salt, writekey) key = hashutil.mutable_rwcap_key_hash(salt, writekey)
cryptor = AES(key) encryptor = aes.create_encryptor(key)
crypttext = cryptor.process(rw_uri) crypttext = aes.encrypt_data(encryptor, rw_uri)
mac = hashutil.hmac(key, salt + crypttext) mac = hashutil.hmac(key, salt + crypttext)
assert len(mac) == 32 assert len(mac) == 32
return salt + crypttext + mac return salt + crypttext + mac
@ -331,8 +331,11 @@ class DirectoryNode(object):
salt = encwrcap[:16] salt = encwrcap[:16]
crypttext = encwrcap[16:-32] crypttext = encwrcap[16:-32]
key = hashutil.mutable_rwcap_key_hash(salt, self._node.get_writekey()) key = hashutil.mutable_rwcap_key_hash(salt, self._node.get_writekey())
cryptor = AES(key) # XXX uhm, so maybe this is confusing even if it's what the
plaintext = cryptor.process(crypttext) # original code was doing; worth making a aes.decrypt_data
# just to be more clear?
encryptor = aes.create_encryptor(key)
plaintext = aes.encrypt_data(encryptor, crypttext)
return plaintext return plaintext
def _create_and_validate_node(self, rw_uri, ro_uri, name): def _create_and_validate_node(self, rw_uri, ro_uri, name):

View File

@ -7,7 +7,7 @@ from twisted.internet import defer
from allmydata import uri from allmydata import uri
from twisted.internet.interfaces import IConsumer from twisted.internet.interfaces import IConsumer
from allmydata.crypto.aes import AES from allmydata.crypto import aes
from allmydata.interfaces import IImmutableFileNode, IUploadResults from allmydata.interfaces import IImmutableFileNode, IUploadResults
from allmydata.util import consumer from allmydata.util import consumer
from allmydata.check_results import CheckResults, CheckAndRepairResults from allmydata.check_results import CheckResults, CheckAndRepairResults
@ -201,8 +201,9 @@ class DecryptingConsumer(object):
offset_big = offset // 16 offset_big = offset // 16
offset_small = offset % 16 offset_small = offset % 16
iv = binascii.unhexlify("%032x" % offset_big) iv = binascii.unhexlify("%032x" % offset_big)
self._decryptor = AES(readkey, iv=iv) self._decryptor = aes.create_decryptor(readkey, iv)
self._decryptor.process("\x00"*offset_small) # this is just to advance the counter
aes.decrypt_data(self._decryptor, "\x00" * offset_small)
def set_download_status_read_event(self, read_ev): def set_download_status_read_event(self, read_ev):
self._read_ev = read_ev self._read_ev = read_ev
@ -219,7 +220,7 @@ class DecryptingConsumer(object):
self._consumer.unregisterProducer() self._consumer.unregisterProducer()
def write(self, ciphertext): def write(self, ciphertext):
started = now() started = now()
plaintext = self._decryptor.process(ciphertext) plaintext = aes.decrypt_data(self._decryptor, ciphertext)
if self._read_ev: if self._read_ev:
elapsed = now() - started elapsed = now() - started
self._read_ev.update(0, elapsed, 0) self._read_ev.update(0, elapsed, 0)

View File

@ -5,7 +5,7 @@ from twisted.internet import defer
from twisted.application import service from twisted.application import service
from foolscap.api import Referenceable, Copyable, RemoteCopy, fireEventually from foolscap.api import Referenceable, Copyable, RemoteCopy, fireEventually
from allmydata.crypto.aes import AES from allmydata.crypto import aes
from allmydata.util.hashutil import file_renewal_secret_hash, \ from allmydata.util.hashutil import file_renewal_secret_hash, \
file_cancel_secret_hash, bucket_renewal_secret_hash, \ file_cancel_secret_hash, bucket_renewal_secret_hash, \
bucket_cancel_secret_hash, plaintext_hasher, \ bucket_cancel_secret_hash, plaintext_hasher, \
@ -946,8 +946,7 @@ class EncryptAnUploadable(object):
d = self.original.get_encryption_key() d = self.original.get_encryption_key()
def _got(key): def _got(key):
e = AES(key) self._encryptor = aes.create_encryptor(key)
self._encryptor = e
storage_index = storage_index_hash(key) storage_index = storage_index_hash(key)
assert isinstance(storage_index, str) assert isinstance(storage_index, str)
@ -957,7 +956,7 @@ class EncryptAnUploadable(object):
self._storage_index = storage_index self._storage_index = storage_index
if self._status: if self._status:
self._status.set_storage_index(storage_index) self._status.set_storage_index(storage_index)
return e return self._encryptor
d.addCallback(_got) d.addCallback(_got)
return d return d
@ -1067,8 +1066,8 @@ class EncryptAnUploadable(object):
# because the AES-CTR implementation doesn't offer a # because the AES-CTR implementation doesn't offer a
# way to change the counter value. Once it acquires # way to change the counter value. Once it acquires
# this ability, change this to simply update the counter # this ability, change this to simply update the counter
# before each call to (hash_only==False) _encryptor.process() # before each call to (hash_only==False) encrypt_data
ciphertext = self._encryptor.process(chunk) ciphertext = aes.encrypt_data(self._encryptor, chunk)
if hash_only: if hash_only:
self.log(" skipping encryption", level=log.NOISY) self.log(" skipping encryption", level=log.NOISY)
else: else:

View File

@ -4,7 +4,7 @@ from zope.interface import implementer
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from foolscap.api import eventually from foolscap.api import eventually
from allmydata.crypto.aes import AES from allmydata.crypto import aes
from allmydata.crypto import rsa from allmydata.crypto import rsa
from allmydata.interfaces import IMutableFileNode, ICheckable, ICheckResults, \ from allmydata.interfaces import IMutableFileNode, ICheckable, ICheckResults, \
NotEnoughSharesError, MDMF_VERSION, SDMF_VERSION, IMutableUploadable, \ NotEnoughSharesError, MDMF_VERSION, SDMF_VERSION, IMutableUploadable, \
@ -160,13 +160,13 @@ class MutableFileNode(object):
return contents(self) return contents(self)
def _encrypt_privkey(self, writekey, privkey): def _encrypt_privkey(self, writekey, privkey):
enc = AES(writekey) encryptor = aes.create_encryptor(writekey)
crypttext = enc.process(privkey) crypttext = aes.encrypt_data(encryptor, privkey)
return crypttext return crypttext
def _decrypt_privkey(self, enc_privkey): def _decrypt_privkey(self, enc_privkey):
enc = AES(self._writekey) encryptor = aes.create_encryptor(self._writekey)
privkey = enc.process(enc_privkey) privkey = aes.encrypt_data(encryptor, enc_privkey)
return privkey return privkey
def _populate_pubkey(self, pubkey): def _populate_pubkey(self, pubkey):

View File

@ -5,7 +5,7 @@ from zope.interface import implementer
from twisted.internet import defer from twisted.internet import defer
from twisted.python import failure from twisted.python import failure
from allmydata.crypto.aes import AES from allmydata.crypto import aes
from allmydata.crypto import rsa from allmydata.crypto import rsa
from allmydata.interfaces import IPublishStatus, SDMF_VERSION, MDMF_VERSION, \ from allmydata.interfaces import IPublishStatus, SDMF_VERSION, MDMF_VERSION, \
IMutableUploadable IMutableUploadable
@ -712,8 +712,8 @@ class Publish(object):
key = hashutil.ssk_readkey_data_hash(salt, self.readkey) key = hashutil.ssk_readkey_data_hash(salt, self.readkey)
self._status.set_status("Encrypting") self._status.set_status("Encrypting")
enc = AES(key) encryptor = aes.create_encryptor(key)
crypttext = enc.process(data) crypttext = aes.encrypt_data(encryptor, data)
assert len(crypttext) == len(data) assert len(crypttext) == len(data)
now = time.time() now = time.time()

View File

@ -8,7 +8,7 @@ from twisted.internet.interfaces import IPushProducer, IConsumer
from foolscap.api import eventually, fireEventually, DeadReferenceError, \ from foolscap.api import eventually, fireEventually, DeadReferenceError, \
RemoteException RemoteException
from allmydata.crypto.aes import AES from allmydata.crypto import aes
from allmydata.crypto import rsa from allmydata.crypto import rsa
from allmydata.interfaces import IRetrieveStatus, NotEnoughSharesError, \ from allmydata.interfaces import IRetrieveStatus, NotEnoughSharesError, \
DownloadStopped, MDMF_VERSION, SDMF_VERSION DownloadStopped, MDMF_VERSION, SDMF_VERSION
@ -899,8 +899,9 @@ class Retrieve(object):
self.log("decrypting segment %d" % self._current_segment) self.log("decrypting segment %d" % self._current_segment)
started = time.time() started = time.time()
key = hashutil.ssk_readkey_data_hash(salt, self._node.get_readkey()) key = hashutil.ssk_readkey_data_hash(salt, self._node.get_readkey())
decryptor = AES(key) # XXX make aes.* functions for decryption too .. even though its the same
plaintext = decryptor.process(segment) decryptor = aes.create_encryptor(key)
plaintext = aes.encrypt_data(decryptor, segment)
self._status.accumulate_decrypt_time(time.time() - started) self._status.accumulate_decrypt_time(time.time() - started)
return plaintext return plaintext

View File

@ -5,7 +5,7 @@ from base64 import b64decode
from binascii import a2b_hex, b2a_hex from binascii import a2b_hex, b2a_hex
from os import path from os import path
from allmydata.crypto.aes import AES from allmydata.crypto import aes
from allmydata.crypto import ed25519, rsa from allmydata.crypto import ed25519, rsa
RESOURCE_DIR = path.join(path.dirname(__file__), 'data') RESOURCE_DIR = path.join(path.dirname(__file__), 'data')
@ -52,33 +52,34 @@ class TestRegression(unittest.TestCase):
This was the old startup test run at import time in `pycryptopp.cipher.aes`. This was the old startup test run at import time in `pycryptopp.cipher.aes`.
""" """
enc0 = b"dc95c078a2408989ad48a21492842087530f8afbc74536b9a963b4f1c4cb738b" enc0 = b"dc95c078a2408989ad48a21492842087530f8afbc74536b9a963b4f1c4cb738b"
cryptor = AES(key=b"\x00" * 32) cryptor = aes.create_encryptor(key=b"\x00" * 32)
ct = cryptor.process(b"\x00" * 32) ct = aes.decrypt_data(cryptor, b"\x00" * 32)
self.failUnlessEqual(enc0, b2a_hex(ct)) self.failUnlessEqual(enc0, b2a_hex(ct))
cryptor = AES(key=b"\x00" * 32) cryptor = aes.create_encryptor(key=b"\x00" * 32)
ct1 = cryptor.process(b"\x00" * 15) ct1 = aes.decrypt_data(cryptor, b"\x00" * 15)
ct2 = cryptor.process(b"\x00" * 17) ct2 = aes.decrypt_data(cryptor, b"\x00" * 17)
self.failUnlessEqual(enc0, b2a_hex(ct1+ct2)) self.failUnlessEqual(enc0, b2a_hex(ct1+ct2))
enc0 = b"66e94bd4ef8a2c3b884cfa59ca342b2e" enc0 = b"66e94bd4ef8a2c3b884cfa59ca342b2e"
cryptor = AES(key=b"\x00" * 16) cryptor = aes.create_encryptor(key=b"\x00" * 16)
ct = cryptor.process(b"\x00" * 16) ct = aes.decrypt_data(cryptor, b"\x00" * 16)
self.failUnlessEqual(enc0, b2a_hex(ct)) self.failUnlessEqual(enc0, b2a_hex(ct))
cryptor = AES(key=b"\x00" * 16) cryptor = aes.create_encryptor(key=b"\x00" * 16)
ct1 = cryptor.process(b"\x00" * 8) ct1 = aes.decrypt_data(cryptor, b"\x00" * 8)
ct2 = cryptor.process(b"\x00" * 8) ct2 = aes.decrypt_data(cryptor, b"\x00" * 8)
self.failUnlessEqual(enc0, b2a_hex(ct1+ct2)) self.failUnlessEqual(enc0, b2a_hex(ct1+ct2))
def _test_from_Niels_AES(keysize, result): def _test_from_Niels_AES(keysize, result):
def fake_ecb_using_ctr(k, p): def fake_ecb_using_ctr(k, p):
return AES(key=k, iv=p).process(b'\x00' * 16) encryptor = aes.create_encryptor(key=k, iv=p)
return aes.encrypt_data(encryptor, b'\x00' * 16)
E = fake_ecb_using_ctr E = fake_ecb_using_ctr
b = 16 b = 16
k = keysize k = keysize
S = '\x00' * (k+b) S = '\x00' * (k + b)
for i in range(1000): for i in range(1000):
K = S[-k:] K = S[-k:]
@ -104,8 +105,8 @@ class TestRegression(unittest.TestCase):
plaintext = b'test' plaintext = b'test'
expected_ciphertext = b'\x7fEK\\' expected_ciphertext = b'\x7fEK\\'
aes = AES(self.AES_KEY) k = aes.create_encryptor(self.AES_KEY)
ciphertext = aes.process(plaintext) ciphertext = aes.decrypt_data(k, plaintext)
self.failUnlessEqual(ciphertext, expected_ciphertext) self.failUnlessEqual(ciphertext, expected_ciphertext)
@ -126,8 +127,8 @@ class TestRegression(unittest.TestCase):
b'\x1f\xa1|\xd2$E\xb5\xe7\x9d\xae\xd1\x1f)\xe4\xc7\x83\xb8\xd5|dHhU\xc8\x9a\xb1\x10\xed' b'\x1f\xa1|\xd2$E\xb5\xe7\x9d\xae\xd1\x1f)\xe4\xc7\x83\xb8\xd5|dHhU\xc8\x9a\xb1\x10\xed'
b'\xd1\xe7|\xd1') b'\xd1\xe7|\xd1')
aes = AES(self.AES_KEY) k = aes.create_encryptor(self.AES_KEY)
ciphertext = aes.process(plaintext) ciphertext = aes.decrypt_data(k, plaintext)
self.failUnlessEqual(ciphertext, expected_ciphertext) self.failUnlessEqual(ciphertext, expected_ciphertext)
@ -145,8 +146,8 @@ class TestRegression(unittest.TestCase):
plaintext = b'test' plaintext = b'test'
expected_ciphertext = b'\x82\x0e\rt' expected_ciphertext = b'\x82\x0e\rt'
aes = AES(self.AES_KEY, iv=self.IV) k = aes.create_encryptor(self.AES_KEY, iv=self.IV)
ciphertext = aes.process(plaintext) ciphertext = aes.decrypt_data(k, plaintext)
self.failUnlessEqual(ciphertext, expected_ciphertext) self.failUnlessEqual(ciphertext, expected_ciphertext)
@ -167,8 +168,8 @@ class TestRegression(unittest.TestCase):
b'\x97a\xdc\x100?\xf5L\x9f\xd9\xeeO\x98\xda\xf5g\x93\xa7q\xe1\xb1~\xf8\x1b\xe8[\\s' b'\x97a\xdc\x100?\xf5L\x9f\xd9\xeeO\x98\xda\xf5g\x93\xa7q\xe1\xb1~\xf8\x1b\xe8[\\s'
b'\x144$\x86\xeaC^f') b'\x144$\x86\xeaC^f')
aes = AES(self.AES_KEY, iv=self.IV) k = aes.create_encryptor(self.AES_KEY, iv=self.IV)
ciphertext = aes.process(plaintext) ciphertext = aes.decrypt_data(k, plaintext)
self.failUnlessEqual(ciphertext, expected_ciphertext) self.failUnlessEqual(ciphertext, expected_ciphertext)

View File

@ -5,7 +5,7 @@ from twisted.application import service
from foolscap.api import Tub, fireEventually, flushEventualQueue from foolscap.api import Tub, fireEventually, flushEventualQueue
from allmydata.crypto.aes import AES from allmydata.crypto import aes
from allmydata.storage.server import si_b2a from allmydata.storage.server import si_b2a
from allmydata.storage_client import StorageFarmBroker from allmydata.storage_client import StorageFarmBroker
from allmydata.immutable import offloaded, upload from allmydata.immutable import offloaded, upload
@ -189,12 +189,12 @@ class AssistedUpload(unittest.TestCase):
key = hashutil.convergence_hash(k, n, segsize, DATA, "test convergence string") key = hashutil.convergence_hash(k, n, segsize, DATA, "test convergence string")
assert len(key) == 16 assert len(key) == 16
encryptor = AES(key) encryptor = aes.create_encryptor(key)
SI = hashutil.storage_index_hash(key) SI = hashutil.storage_index_hash(key)
SI_s = si_b2a(SI) SI_s = si_b2a(SI)
encfile = os.path.join(self.basedir, "CHK_encoding", SI_s) encfile = os.path.join(self.basedir, "CHK_encoding", SI_s)
f = open(encfile, "wb") f = open(encfile, "wb")
f.write(encryptor.process(DATA)) f.write(aes.decrypt_data(encryptor, DATA))
f.close() f.close()
u = upload.Uploader(self.helper_furl) u = upload.Uploader(self.helper_furl)

View File

@ -16,7 +16,7 @@ if sys.platform == "win32":
from twisted.python import log from twisted.python import log
from allmydata.crypto.aes import AES from allmydata.crypto import aes
from allmydata.util.assertutil import _assert from allmydata.util.assertutil import _assert
@ -109,9 +109,10 @@ class EncryptedTemporaryFile(object):
offset_big = offset // 16 offset_big = offset // 16
offset_small = offset % 16 offset_small = offset % 16
iv = binascii.unhexlify("%032x" % offset_big) iv = binascii.unhexlify("%032x" % offset_big)
cipher = AES(self.key, iv=iv) cipher = aes.create_encryptor(self.key, iv)
cipher.process("\x00"*offset_small) # this is just to advance the counter
return cipher.process(data) aes.encrypt_data(cipher, "\x00" * offset_small)
return aes.encrypt_data(ciper, data)
def close(self): def close(self):
self.file.close() self.file.close()