move AES to a helper-function style

This commit is contained in:
meejah 2019-06-17 15:54:46 -06:00
parent 47ccdb0177
commit 310fb60247
10 changed files with 138 additions and 76 deletions

View File

@ -1,32 +1,88 @@
import six
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.primitives.ciphers import (
Cipher,
algorithms,
modes,
CipherContext,
)
class AES:
__DEFAULT_IV = '\x00' * 16
DEFAULT_IV = '\x00' * 16
def __init__(self, key, iv=None):
# validate the key
if not isinstance(key, six.binary_type):
raise TypeError('Key was not bytes')
if len(key) not in (16, 32):
raise ValueError('Key was not 16 or 32 bytes long')
# validate the IV
if iv is None:
iv = self.__DEFAULT_IV
if not isinstance(iv, six.binary_type):
raise TypeError('IV was not bytes')
if len(iv) != 16:
raise ValueError('IV was not 16 bytes long')
def create_encryptor(key, iv=None):
"""
Create and return a new object which can do AES encryptions with
the given key and initialization vector (IV). The default IV is 16
zero-bytes.
self._cipher = Cipher(algorithms.AES(key), modes.CTR(iv), backend=default_backend())
self._encryptor = self._cipher.encryptor()
The returned object is suitable for use with `encrypt_data`
"""
key = _validate_key(key)
iv = _validate_iv(iv)
cipher = Cipher(
algorithms.AES(key),
modes.CTR(iv),
backend=default_backend()
)
return cipher.encryptor()
def process(self, plaintext):
if not isinstance(plaintext, six.binary_type):
raise TypeError('Plaintext was not bytes')
return self._encryptor.update(plaintext)
def encrypt_data(encryptor, plaintext):
"""
AES-encrypt `plaintext` with the given `encryptor`.
:param encryptor: an instance previously returned from `create_encryptor`
:param bytes plaintext: the data to encrypt
:returns: ciphertext
"""
_validate_encryptor(encryptor)
if not isinstance(plaintext, six.binary_type):
raise ValueError('Plaintext was not bytes')
return encryptor.update(plaintext)
create_decryptor = create_encryptor
decrypt_data = encrypt_data
def _validate_encryptor(encryptor):
"""
raise ValueError for `encryptor` is not a valid object
"""
if not isinstance(encryptor, CipherContext):
raise ValueError(
"'encryptor' must be a CipherContext"
)
def _validate_key(key):
"""
confirm `key` is suitable for AES encryption, or raise ValueError
"""
if not isinstance(key, six.binary_type):
raise TypeError('Key was not bytes')
if len(key) not in (16, 32):
raise ValueError('Key was not 16 or 32 bytes long')
return key
def _validate_iv(iv):
"""
confirm `iv` is a suitable initialization vector
"""
if iv is None:
return DEFAULT_IV
if not isinstance(iv, six.binary_type):
raise TypeError('IV was not bytes')
if len(iv) != 16:
raise ValueError('IV was not 16 bytes long')
return iv

View File

@ -6,7 +6,7 @@ from twisted.internet import defer
from foolscap.api import fireEventually
import json
from allmydata.crypto.aes import AES
from allmydata.crypto import aes
from allmydata.deep_stats import DeepStats
from allmydata.mutable.common import NotWriteableError
from allmydata.mutable.filenode import MutableFileNode
@ -214,8 +214,8 @@ def _encrypt_rw_uri(writekey, rw_uri):
salt = hashutil.mutable_rwcap_salt_hash(rw_uri)
key = hashutil.mutable_rwcap_key_hash(salt, writekey)
cryptor = AES(key)
crypttext = cryptor.process(rw_uri)
encryptor = aes.create_encryptor(key)
crypttext = aes.encrypt_data(encryptor, rw_uri)
mac = hashutil.hmac(key, salt + crypttext)
assert len(mac) == 32
return salt + crypttext + mac
@ -331,8 +331,11 @@ class DirectoryNode(object):
salt = encwrcap[:16]
crypttext = encwrcap[16:-32]
key = hashutil.mutable_rwcap_key_hash(salt, self._node.get_writekey())
cryptor = AES(key)
plaintext = cryptor.process(crypttext)
# XXX uhm, so maybe this is confusing even if it's what the
# original code was doing; worth making a aes.decrypt_data
# just to be more clear?
encryptor = aes.create_encryptor(key)
plaintext = aes.encrypt_data(encryptor, crypttext)
return plaintext
def _create_and_validate_node(self, rw_uri, ro_uri, name):

View File

@ -7,7 +7,7 @@ from twisted.internet import defer
from allmydata import uri
from twisted.internet.interfaces import IConsumer
from allmydata.crypto.aes import AES
from allmydata.crypto import aes
from allmydata.interfaces import IImmutableFileNode, IUploadResults
from allmydata.util import consumer
from allmydata.check_results import CheckResults, CheckAndRepairResults
@ -201,8 +201,9 @@ class DecryptingConsumer(object):
offset_big = offset // 16
offset_small = offset % 16
iv = binascii.unhexlify("%032x" % offset_big)
self._decryptor = AES(readkey, iv=iv)
self._decryptor.process("\x00"*offset_small)
self._decryptor = aes.create_decryptor(readkey, iv)
# this is just to advance the counter
aes.decrypt_data(self._decryptor, "\x00" * offset_small)
def set_download_status_read_event(self, read_ev):
self._read_ev = read_ev
@ -219,7 +220,7 @@ class DecryptingConsumer(object):
self._consumer.unregisterProducer()
def write(self, ciphertext):
started = now()
plaintext = self._decryptor.process(ciphertext)
plaintext = aes.decrypt_data(self._decryptor, ciphertext)
if self._read_ev:
elapsed = now() - started
self._read_ev.update(0, elapsed, 0)

View File

@ -5,7 +5,7 @@ from twisted.internet import defer
from twisted.application import service
from foolscap.api import Referenceable, Copyable, RemoteCopy, fireEventually
from allmydata.crypto.aes import AES
from allmydata.crypto import aes
from allmydata.util.hashutil import file_renewal_secret_hash, \
file_cancel_secret_hash, bucket_renewal_secret_hash, \
bucket_cancel_secret_hash, plaintext_hasher, \
@ -946,8 +946,7 @@ class EncryptAnUploadable(object):
d = self.original.get_encryption_key()
def _got(key):
e = AES(key)
self._encryptor = e
self._encryptor = aes.create_encryptor(key)
storage_index = storage_index_hash(key)
assert isinstance(storage_index, str)
@ -957,7 +956,7 @@ class EncryptAnUploadable(object):
self._storage_index = storage_index
if self._status:
self._status.set_storage_index(storage_index)
return e
return self._encryptor
d.addCallback(_got)
return d
@ -1067,8 +1066,8 @@ class EncryptAnUploadable(object):
# because the AES-CTR implementation doesn't offer a
# way to change the counter value. Once it acquires
# this ability, change this to simply update the counter
# before each call to (hash_only==False) _encryptor.process()
ciphertext = self._encryptor.process(chunk)
# before each call to (hash_only==False) encrypt_data
ciphertext = aes.encrypt_data(self._encryptor, chunk)
if hash_only:
self.log(" skipping encryption", level=log.NOISY)
else:

View File

@ -4,7 +4,7 @@ from zope.interface import implementer
from twisted.internet import defer, reactor
from foolscap.api import eventually
from allmydata.crypto.aes import AES
from allmydata.crypto import aes
from allmydata.crypto import rsa
from allmydata.interfaces import IMutableFileNode, ICheckable, ICheckResults, \
NotEnoughSharesError, MDMF_VERSION, SDMF_VERSION, IMutableUploadable, \
@ -160,13 +160,13 @@ class MutableFileNode(object):
return contents(self)
def _encrypt_privkey(self, writekey, privkey):
enc = AES(writekey)
crypttext = enc.process(privkey)
encryptor = aes.create_encryptor(writekey)
crypttext = aes.encrypt_data(encryptor, privkey)
return crypttext
def _decrypt_privkey(self, enc_privkey):
enc = AES(self._writekey)
privkey = enc.process(enc_privkey)
encryptor = aes.create_encryptor(self._writekey)
privkey = aes.encrypt_data(encryptor, enc_privkey)
return privkey
def _populate_pubkey(self, pubkey):

View File

@ -5,7 +5,7 @@ from zope.interface import implementer
from twisted.internet import defer
from twisted.python import failure
from allmydata.crypto.aes import AES
from allmydata.crypto import aes
from allmydata.crypto import rsa
from allmydata.interfaces import IPublishStatus, SDMF_VERSION, MDMF_VERSION, \
IMutableUploadable
@ -712,8 +712,8 @@ class Publish(object):
key = hashutil.ssk_readkey_data_hash(salt, self.readkey)
self._status.set_status("Encrypting")
enc = AES(key)
crypttext = enc.process(data)
encryptor = aes.create_encryptor(key)
crypttext = aes.encrypt_data(encryptor, data)
assert len(crypttext) == len(data)
now = time.time()

View File

@ -8,7 +8,7 @@ from twisted.internet.interfaces import IPushProducer, IConsumer
from foolscap.api import eventually, fireEventually, DeadReferenceError, \
RemoteException
from allmydata.crypto.aes import AES
from allmydata.crypto import aes
from allmydata.crypto import rsa
from allmydata.interfaces import IRetrieveStatus, NotEnoughSharesError, \
DownloadStopped, MDMF_VERSION, SDMF_VERSION
@ -899,8 +899,9 @@ class Retrieve(object):
self.log("decrypting segment %d" % self._current_segment)
started = time.time()
key = hashutil.ssk_readkey_data_hash(salt, self._node.get_readkey())
decryptor = AES(key)
plaintext = decryptor.process(segment)
# XXX make aes.* functions for decryption too .. even though its the same
decryptor = aes.create_encryptor(key)
plaintext = aes.encrypt_data(decryptor, segment)
self._status.accumulate_decrypt_time(time.time() - started)
return plaintext

View File

@ -5,7 +5,7 @@ from base64 import b64decode
from binascii import a2b_hex, b2a_hex
from os import path
from allmydata.crypto.aes import AES
from allmydata.crypto import aes
from allmydata.crypto import ed25519, rsa
RESOURCE_DIR = path.join(path.dirname(__file__), 'data')
@ -52,33 +52,34 @@ class TestRegression(unittest.TestCase):
This was the old startup test run at import time in `pycryptopp.cipher.aes`.
"""
enc0 = b"dc95c078a2408989ad48a21492842087530f8afbc74536b9a963b4f1c4cb738b"
cryptor = AES(key=b"\x00" * 32)
ct = cryptor.process(b"\x00" * 32)
cryptor = aes.create_encryptor(key=b"\x00" * 32)
ct = aes.decrypt_data(cryptor, b"\x00" * 32)
self.failUnlessEqual(enc0, b2a_hex(ct))
cryptor = AES(key=b"\x00" * 32)
ct1 = cryptor.process(b"\x00" * 15)
ct2 = cryptor.process(b"\x00" * 17)
cryptor = aes.create_encryptor(key=b"\x00" * 32)
ct1 = aes.decrypt_data(cryptor, b"\x00" * 15)
ct2 = aes.decrypt_data(cryptor, b"\x00" * 17)
self.failUnlessEqual(enc0, b2a_hex(ct1+ct2))
enc0 = b"66e94bd4ef8a2c3b884cfa59ca342b2e"
cryptor = AES(key=b"\x00" * 16)
ct = cryptor.process(b"\x00" * 16)
cryptor = aes.create_encryptor(key=b"\x00" * 16)
ct = aes.decrypt_data(cryptor, b"\x00" * 16)
self.failUnlessEqual(enc0, b2a_hex(ct))
cryptor = AES(key=b"\x00" * 16)
ct1 = cryptor.process(b"\x00" * 8)
ct2 = cryptor.process(b"\x00" * 8)
cryptor = aes.create_encryptor(key=b"\x00" * 16)
ct1 = aes.decrypt_data(cryptor, b"\x00" * 8)
ct2 = aes.decrypt_data(cryptor, b"\x00" * 8)
self.failUnlessEqual(enc0, b2a_hex(ct1+ct2))
def _test_from_Niels_AES(keysize, result):
def fake_ecb_using_ctr(k, p):
return AES(key=k, iv=p).process(b'\x00' * 16)
encryptor = aes.create_encryptor(key=k, iv=p)
return aes.encrypt_data(encryptor, b'\x00' * 16)
E = fake_ecb_using_ctr
b = 16
k = keysize
S = '\x00' * (k+b)
S = '\x00' * (k + b)
for i in range(1000):
K = S[-k:]
@ -104,8 +105,8 @@ class TestRegression(unittest.TestCase):
plaintext = b'test'
expected_ciphertext = b'\x7fEK\\'
aes = AES(self.AES_KEY)
ciphertext = aes.process(plaintext)
k = aes.create_encryptor(self.AES_KEY)
ciphertext = aes.decrypt_data(k, plaintext)
self.failUnlessEqual(ciphertext, expected_ciphertext)
@ -126,8 +127,8 @@ class TestRegression(unittest.TestCase):
b'\x1f\xa1|\xd2$E\xb5\xe7\x9d\xae\xd1\x1f)\xe4\xc7\x83\xb8\xd5|dHhU\xc8\x9a\xb1\x10\xed'
b'\xd1\xe7|\xd1')
aes = AES(self.AES_KEY)
ciphertext = aes.process(plaintext)
k = aes.create_encryptor(self.AES_KEY)
ciphertext = aes.decrypt_data(k, plaintext)
self.failUnlessEqual(ciphertext, expected_ciphertext)
@ -145,8 +146,8 @@ class TestRegression(unittest.TestCase):
plaintext = b'test'
expected_ciphertext = b'\x82\x0e\rt'
aes = AES(self.AES_KEY, iv=self.IV)
ciphertext = aes.process(plaintext)
k = aes.create_encryptor(self.AES_KEY, iv=self.IV)
ciphertext = aes.decrypt_data(k, plaintext)
self.failUnlessEqual(ciphertext, expected_ciphertext)
@ -167,8 +168,8 @@ class TestRegression(unittest.TestCase):
b'\x97a\xdc\x100?\xf5L\x9f\xd9\xeeO\x98\xda\xf5g\x93\xa7q\xe1\xb1~\xf8\x1b\xe8[\\s'
b'\x144$\x86\xeaC^f')
aes = AES(self.AES_KEY, iv=self.IV)
ciphertext = aes.process(plaintext)
k = aes.create_encryptor(self.AES_KEY, iv=self.IV)
ciphertext = aes.decrypt_data(k, plaintext)
self.failUnlessEqual(ciphertext, expected_ciphertext)

View File

@ -5,7 +5,7 @@ from twisted.application import service
from foolscap.api import Tub, fireEventually, flushEventualQueue
from allmydata.crypto.aes import AES
from allmydata.crypto import aes
from allmydata.storage.server import si_b2a
from allmydata.storage_client import StorageFarmBroker
from allmydata.immutable import offloaded, upload
@ -189,12 +189,12 @@ class AssistedUpload(unittest.TestCase):
key = hashutil.convergence_hash(k, n, segsize, DATA, "test convergence string")
assert len(key) == 16
encryptor = AES(key)
encryptor = aes.create_encryptor(key)
SI = hashutil.storage_index_hash(key)
SI_s = si_b2a(SI)
encfile = os.path.join(self.basedir, "CHK_encoding", SI_s)
f = open(encfile, "wb")
f.write(encryptor.process(DATA))
f.write(aes.decrypt_data(encryptor, DATA))
f.close()
u = upload.Uploader(self.helper_furl)

View File

@ -16,7 +16,7 @@ if sys.platform == "win32":
from twisted.python import log
from allmydata.crypto.aes import AES
from allmydata.crypto import aes
from allmydata.util.assertutil import _assert
@ -109,9 +109,10 @@ class EncryptedTemporaryFile(object):
offset_big = offset // 16
offset_small = offset % 16
iv = binascii.unhexlify("%032x" % offset_big)
cipher = AES(self.key, iv=iv)
cipher.process("\x00"*offset_small)
return cipher.process(data)
cipher = aes.create_encryptor(self.key, iv)
# this is just to advance the counter
aes.encrypt_data(cipher, "\x00" * offset_small)
return aes.encrypt_data(ciper, data)
def close(self):
self.file.close()