mirror of
https://github.com/tahoe-lafs/tahoe-lafs.git
synced 2025-02-21 02:01:31 +00:00
move AES to a helper-function style
This commit is contained in:
parent
47ccdb0177
commit
310fb60247
@ -1,32 +1,88 @@
|
||||
import six
|
||||
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||
from cryptography.hazmat.primitives.ciphers import (
|
||||
Cipher,
|
||||
algorithms,
|
||||
modes,
|
||||
CipherContext,
|
||||
)
|
||||
|
||||
class AES:
|
||||
|
||||
__DEFAULT_IV = '\x00' * 16
|
||||
DEFAULT_IV = '\x00' * 16
|
||||
|
||||
def __init__(self, key, iv=None):
|
||||
# validate the key
|
||||
if not isinstance(key, six.binary_type):
|
||||
raise TypeError('Key was not bytes')
|
||||
if len(key) not in (16, 32):
|
||||
raise ValueError('Key was not 16 or 32 bytes long')
|
||||
|
||||
# validate the IV
|
||||
if iv is None:
|
||||
iv = self.__DEFAULT_IV
|
||||
if not isinstance(iv, six.binary_type):
|
||||
raise TypeError('IV was not bytes')
|
||||
if len(iv) != 16:
|
||||
raise ValueError('IV was not 16 bytes long')
|
||||
def create_encryptor(key, iv=None):
|
||||
"""
|
||||
Create and return a new object which can do AES encryptions with
|
||||
the given key and initialization vector (IV). The default IV is 16
|
||||
zero-bytes.
|
||||
|
||||
self._cipher = Cipher(algorithms.AES(key), modes.CTR(iv), backend=default_backend())
|
||||
self._encryptor = self._cipher.encryptor()
|
||||
The returned object is suitable for use with `encrypt_data`
|
||||
"""
|
||||
key = _validate_key(key)
|
||||
iv = _validate_iv(iv)
|
||||
cipher = Cipher(
|
||||
algorithms.AES(key),
|
||||
modes.CTR(iv),
|
||||
backend=default_backend()
|
||||
)
|
||||
return cipher.encryptor()
|
||||
|
||||
def process(self, plaintext):
|
||||
if not isinstance(plaintext, six.binary_type):
|
||||
raise TypeError('Plaintext was not bytes')
|
||||
|
||||
return self._encryptor.update(plaintext)
|
||||
def encrypt_data(encryptor, plaintext):
|
||||
"""
|
||||
AES-encrypt `plaintext` with the given `encryptor`.
|
||||
|
||||
:param encryptor: an instance previously returned from `create_encryptor`
|
||||
|
||||
:param bytes plaintext: the data to encrypt
|
||||
|
||||
:returns: ciphertext
|
||||
"""
|
||||
|
||||
_validate_encryptor(encryptor)
|
||||
if not isinstance(plaintext, six.binary_type):
|
||||
raise ValueError('Plaintext was not bytes')
|
||||
|
||||
return encryptor.update(plaintext)
|
||||
|
||||
|
||||
create_decryptor = create_encryptor
|
||||
|
||||
|
||||
decrypt_data = encrypt_data
|
||||
|
||||
|
||||
def _validate_encryptor(encryptor):
|
||||
"""
|
||||
raise ValueError for `encryptor` is not a valid object
|
||||
"""
|
||||
if not isinstance(encryptor, CipherContext):
|
||||
raise ValueError(
|
||||
"'encryptor' must be a CipherContext"
|
||||
)
|
||||
|
||||
|
||||
def _validate_key(key):
|
||||
"""
|
||||
confirm `key` is suitable for AES encryption, or raise ValueError
|
||||
"""
|
||||
if not isinstance(key, six.binary_type):
|
||||
raise TypeError('Key was not bytes')
|
||||
if len(key) not in (16, 32):
|
||||
raise ValueError('Key was not 16 or 32 bytes long')
|
||||
return key
|
||||
|
||||
|
||||
def _validate_iv(iv):
|
||||
"""
|
||||
confirm `iv` is a suitable initialization vector
|
||||
"""
|
||||
if iv is None:
|
||||
return DEFAULT_IV
|
||||
if not isinstance(iv, six.binary_type):
|
||||
raise TypeError('IV was not bytes')
|
||||
if len(iv) != 16:
|
||||
raise ValueError('IV was not 16 bytes long')
|
||||
return iv
|
||||
|
@ -6,7 +6,7 @@ from twisted.internet import defer
|
||||
from foolscap.api import fireEventually
|
||||
import json
|
||||
|
||||
from allmydata.crypto.aes import AES
|
||||
from allmydata.crypto import aes
|
||||
from allmydata.deep_stats import DeepStats
|
||||
from allmydata.mutable.common import NotWriteableError
|
||||
from allmydata.mutable.filenode import MutableFileNode
|
||||
@ -214,8 +214,8 @@ def _encrypt_rw_uri(writekey, rw_uri):
|
||||
|
||||
salt = hashutil.mutable_rwcap_salt_hash(rw_uri)
|
||||
key = hashutil.mutable_rwcap_key_hash(salt, writekey)
|
||||
cryptor = AES(key)
|
||||
crypttext = cryptor.process(rw_uri)
|
||||
encryptor = aes.create_encryptor(key)
|
||||
crypttext = aes.encrypt_data(encryptor, rw_uri)
|
||||
mac = hashutil.hmac(key, salt + crypttext)
|
||||
assert len(mac) == 32
|
||||
return salt + crypttext + mac
|
||||
@ -331,8 +331,11 @@ class DirectoryNode(object):
|
||||
salt = encwrcap[:16]
|
||||
crypttext = encwrcap[16:-32]
|
||||
key = hashutil.mutable_rwcap_key_hash(salt, self._node.get_writekey())
|
||||
cryptor = AES(key)
|
||||
plaintext = cryptor.process(crypttext)
|
||||
# XXX uhm, so maybe this is confusing even if it's what the
|
||||
# original code was doing; worth making a aes.decrypt_data
|
||||
# just to be more clear?
|
||||
encryptor = aes.create_encryptor(key)
|
||||
plaintext = aes.encrypt_data(encryptor, crypttext)
|
||||
return plaintext
|
||||
|
||||
def _create_and_validate_node(self, rw_uri, ro_uri, name):
|
||||
|
@ -7,7 +7,7 @@ from twisted.internet import defer
|
||||
|
||||
from allmydata import uri
|
||||
from twisted.internet.interfaces import IConsumer
|
||||
from allmydata.crypto.aes import AES
|
||||
from allmydata.crypto import aes
|
||||
from allmydata.interfaces import IImmutableFileNode, IUploadResults
|
||||
from allmydata.util import consumer
|
||||
from allmydata.check_results import CheckResults, CheckAndRepairResults
|
||||
@ -201,8 +201,9 @@ class DecryptingConsumer(object):
|
||||
offset_big = offset // 16
|
||||
offset_small = offset % 16
|
||||
iv = binascii.unhexlify("%032x" % offset_big)
|
||||
self._decryptor = AES(readkey, iv=iv)
|
||||
self._decryptor.process("\x00"*offset_small)
|
||||
self._decryptor = aes.create_decryptor(readkey, iv)
|
||||
# this is just to advance the counter
|
||||
aes.decrypt_data(self._decryptor, "\x00" * offset_small)
|
||||
|
||||
def set_download_status_read_event(self, read_ev):
|
||||
self._read_ev = read_ev
|
||||
@ -219,7 +220,7 @@ class DecryptingConsumer(object):
|
||||
self._consumer.unregisterProducer()
|
||||
def write(self, ciphertext):
|
||||
started = now()
|
||||
plaintext = self._decryptor.process(ciphertext)
|
||||
plaintext = aes.decrypt_data(self._decryptor, ciphertext)
|
||||
if self._read_ev:
|
||||
elapsed = now() - started
|
||||
self._read_ev.update(0, elapsed, 0)
|
||||
|
@ -5,7 +5,7 @@ from twisted.internet import defer
|
||||
from twisted.application import service
|
||||
from foolscap.api import Referenceable, Copyable, RemoteCopy, fireEventually
|
||||
|
||||
from allmydata.crypto.aes import AES
|
||||
from allmydata.crypto import aes
|
||||
from allmydata.util.hashutil import file_renewal_secret_hash, \
|
||||
file_cancel_secret_hash, bucket_renewal_secret_hash, \
|
||||
bucket_cancel_secret_hash, plaintext_hasher, \
|
||||
@ -946,8 +946,7 @@ class EncryptAnUploadable(object):
|
||||
|
||||
d = self.original.get_encryption_key()
|
||||
def _got(key):
|
||||
e = AES(key)
|
||||
self._encryptor = e
|
||||
self._encryptor = aes.create_encryptor(key)
|
||||
|
||||
storage_index = storage_index_hash(key)
|
||||
assert isinstance(storage_index, str)
|
||||
@ -957,7 +956,7 @@ class EncryptAnUploadable(object):
|
||||
self._storage_index = storage_index
|
||||
if self._status:
|
||||
self._status.set_storage_index(storage_index)
|
||||
return e
|
||||
return self._encryptor
|
||||
d.addCallback(_got)
|
||||
return d
|
||||
|
||||
@ -1067,8 +1066,8 @@ class EncryptAnUploadable(object):
|
||||
# because the AES-CTR implementation doesn't offer a
|
||||
# way to change the counter value. Once it acquires
|
||||
# this ability, change this to simply update the counter
|
||||
# before each call to (hash_only==False) _encryptor.process()
|
||||
ciphertext = self._encryptor.process(chunk)
|
||||
# before each call to (hash_only==False) encrypt_data
|
||||
ciphertext = aes.encrypt_data(self._encryptor, chunk)
|
||||
if hash_only:
|
||||
self.log(" skipping encryption", level=log.NOISY)
|
||||
else:
|
||||
|
@ -4,7 +4,7 @@ from zope.interface import implementer
|
||||
from twisted.internet import defer, reactor
|
||||
from foolscap.api import eventually
|
||||
|
||||
from allmydata.crypto.aes import AES
|
||||
from allmydata.crypto import aes
|
||||
from allmydata.crypto import rsa
|
||||
from allmydata.interfaces import IMutableFileNode, ICheckable, ICheckResults, \
|
||||
NotEnoughSharesError, MDMF_VERSION, SDMF_VERSION, IMutableUploadable, \
|
||||
@ -160,13 +160,13 @@ class MutableFileNode(object):
|
||||
return contents(self)
|
||||
|
||||
def _encrypt_privkey(self, writekey, privkey):
|
||||
enc = AES(writekey)
|
||||
crypttext = enc.process(privkey)
|
||||
encryptor = aes.create_encryptor(writekey)
|
||||
crypttext = aes.encrypt_data(encryptor, privkey)
|
||||
return crypttext
|
||||
|
||||
def _decrypt_privkey(self, enc_privkey):
|
||||
enc = AES(self._writekey)
|
||||
privkey = enc.process(enc_privkey)
|
||||
encryptor = aes.create_encryptor(self._writekey)
|
||||
privkey = aes.encrypt_data(encryptor, enc_privkey)
|
||||
return privkey
|
||||
|
||||
def _populate_pubkey(self, pubkey):
|
||||
|
@ -5,7 +5,7 @@ from zope.interface import implementer
|
||||
from twisted.internet import defer
|
||||
from twisted.python import failure
|
||||
|
||||
from allmydata.crypto.aes import AES
|
||||
from allmydata.crypto import aes
|
||||
from allmydata.crypto import rsa
|
||||
from allmydata.interfaces import IPublishStatus, SDMF_VERSION, MDMF_VERSION, \
|
||||
IMutableUploadable
|
||||
@ -712,8 +712,8 @@ class Publish(object):
|
||||
|
||||
key = hashutil.ssk_readkey_data_hash(salt, self.readkey)
|
||||
self._status.set_status("Encrypting")
|
||||
enc = AES(key)
|
||||
crypttext = enc.process(data)
|
||||
encryptor = aes.create_encryptor(key)
|
||||
crypttext = aes.encrypt_data(encryptor, data)
|
||||
assert len(crypttext) == len(data)
|
||||
|
||||
now = time.time()
|
||||
|
@ -8,7 +8,7 @@ from twisted.internet.interfaces import IPushProducer, IConsumer
|
||||
from foolscap.api import eventually, fireEventually, DeadReferenceError, \
|
||||
RemoteException
|
||||
|
||||
from allmydata.crypto.aes import AES
|
||||
from allmydata.crypto import aes
|
||||
from allmydata.crypto import rsa
|
||||
from allmydata.interfaces import IRetrieveStatus, NotEnoughSharesError, \
|
||||
DownloadStopped, MDMF_VERSION, SDMF_VERSION
|
||||
@ -899,8 +899,9 @@ class Retrieve(object):
|
||||
self.log("decrypting segment %d" % self._current_segment)
|
||||
started = time.time()
|
||||
key = hashutil.ssk_readkey_data_hash(salt, self._node.get_readkey())
|
||||
decryptor = AES(key)
|
||||
plaintext = decryptor.process(segment)
|
||||
# XXX make aes.* functions for decryption too .. even though its the same
|
||||
decryptor = aes.create_encryptor(key)
|
||||
plaintext = aes.encrypt_data(decryptor, segment)
|
||||
self._status.accumulate_decrypt_time(time.time() - started)
|
||||
return plaintext
|
||||
|
||||
|
@ -5,7 +5,7 @@ from base64 import b64decode
|
||||
from binascii import a2b_hex, b2a_hex
|
||||
from os import path
|
||||
|
||||
from allmydata.crypto.aes import AES
|
||||
from allmydata.crypto import aes
|
||||
from allmydata.crypto import ed25519, rsa
|
||||
|
||||
RESOURCE_DIR = path.join(path.dirname(__file__), 'data')
|
||||
@ -52,33 +52,34 @@ class TestRegression(unittest.TestCase):
|
||||
This was the old startup test run at import time in `pycryptopp.cipher.aes`.
|
||||
"""
|
||||
enc0 = b"dc95c078a2408989ad48a21492842087530f8afbc74536b9a963b4f1c4cb738b"
|
||||
cryptor = AES(key=b"\x00" * 32)
|
||||
ct = cryptor.process(b"\x00" * 32)
|
||||
cryptor = aes.create_encryptor(key=b"\x00" * 32)
|
||||
ct = aes.decrypt_data(cryptor, b"\x00" * 32)
|
||||
self.failUnlessEqual(enc0, b2a_hex(ct))
|
||||
|
||||
cryptor = AES(key=b"\x00" * 32)
|
||||
ct1 = cryptor.process(b"\x00" * 15)
|
||||
ct2 = cryptor.process(b"\x00" * 17)
|
||||
cryptor = aes.create_encryptor(key=b"\x00" * 32)
|
||||
ct1 = aes.decrypt_data(cryptor, b"\x00" * 15)
|
||||
ct2 = aes.decrypt_data(cryptor, b"\x00" * 17)
|
||||
self.failUnlessEqual(enc0, b2a_hex(ct1+ct2))
|
||||
|
||||
enc0 = b"66e94bd4ef8a2c3b884cfa59ca342b2e"
|
||||
cryptor = AES(key=b"\x00" * 16)
|
||||
ct = cryptor.process(b"\x00" * 16)
|
||||
cryptor = aes.create_encryptor(key=b"\x00" * 16)
|
||||
ct = aes.decrypt_data(cryptor, b"\x00" * 16)
|
||||
self.failUnlessEqual(enc0, b2a_hex(ct))
|
||||
|
||||
cryptor = AES(key=b"\x00" * 16)
|
||||
ct1 = cryptor.process(b"\x00" * 8)
|
||||
ct2 = cryptor.process(b"\x00" * 8)
|
||||
cryptor = aes.create_encryptor(key=b"\x00" * 16)
|
||||
ct1 = aes.decrypt_data(cryptor, b"\x00" * 8)
|
||||
ct2 = aes.decrypt_data(cryptor, b"\x00" * 8)
|
||||
self.failUnlessEqual(enc0, b2a_hex(ct1+ct2))
|
||||
|
||||
def _test_from_Niels_AES(keysize, result):
|
||||
def fake_ecb_using_ctr(k, p):
|
||||
return AES(key=k, iv=p).process(b'\x00' * 16)
|
||||
encryptor = aes.create_encryptor(key=k, iv=p)
|
||||
return aes.encrypt_data(encryptor, b'\x00' * 16)
|
||||
|
||||
E = fake_ecb_using_ctr
|
||||
b = 16
|
||||
k = keysize
|
||||
S = '\x00' * (k+b)
|
||||
S = '\x00' * (k + b)
|
||||
|
||||
for i in range(1000):
|
||||
K = S[-k:]
|
||||
@ -104,8 +105,8 @@ class TestRegression(unittest.TestCase):
|
||||
plaintext = b'test'
|
||||
expected_ciphertext = b'\x7fEK\\'
|
||||
|
||||
aes = AES(self.AES_KEY)
|
||||
ciphertext = aes.process(plaintext)
|
||||
k = aes.create_encryptor(self.AES_KEY)
|
||||
ciphertext = aes.decrypt_data(k, plaintext)
|
||||
|
||||
self.failUnlessEqual(ciphertext, expected_ciphertext)
|
||||
|
||||
@ -126,8 +127,8 @@ class TestRegression(unittest.TestCase):
|
||||
b'\x1f\xa1|\xd2$E\xb5\xe7\x9d\xae\xd1\x1f)\xe4\xc7\x83\xb8\xd5|dHhU\xc8\x9a\xb1\x10\xed'
|
||||
b'\xd1\xe7|\xd1')
|
||||
|
||||
aes = AES(self.AES_KEY)
|
||||
ciphertext = aes.process(plaintext)
|
||||
k = aes.create_encryptor(self.AES_KEY)
|
||||
ciphertext = aes.decrypt_data(k, plaintext)
|
||||
|
||||
self.failUnlessEqual(ciphertext, expected_ciphertext)
|
||||
|
||||
@ -145,8 +146,8 @@ class TestRegression(unittest.TestCase):
|
||||
plaintext = b'test'
|
||||
expected_ciphertext = b'\x82\x0e\rt'
|
||||
|
||||
aes = AES(self.AES_KEY, iv=self.IV)
|
||||
ciphertext = aes.process(plaintext)
|
||||
k = aes.create_encryptor(self.AES_KEY, iv=self.IV)
|
||||
ciphertext = aes.decrypt_data(k, plaintext)
|
||||
|
||||
self.failUnlessEqual(ciphertext, expected_ciphertext)
|
||||
|
||||
@ -167,8 +168,8 @@ class TestRegression(unittest.TestCase):
|
||||
b'\x97a\xdc\x100?\xf5L\x9f\xd9\xeeO\x98\xda\xf5g\x93\xa7q\xe1\xb1~\xf8\x1b\xe8[\\s'
|
||||
b'\x144$\x86\xeaC^f')
|
||||
|
||||
aes = AES(self.AES_KEY, iv=self.IV)
|
||||
ciphertext = aes.process(plaintext)
|
||||
k = aes.create_encryptor(self.AES_KEY, iv=self.IV)
|
||||
ciphertext = aes.decrypt_data(k, plaintext)
|
||||
|
||||
self.failUnlessEqual(ciphertext, expected_ciphertext)
|
||||
|
||||
|
@ -5,7 +5,7 @@ from twisted.application import service
|
||||
|
||||
from foolscap.api import Tub, fireEventually, flushEventualQueue
|
||||
|
||||
from allmydata.crypto.aes import AES
|
||||
from allmydata.crypto import aes
|
||||
from allmydata.storage.server import si_b2a
|
||||
from allmydata.storage_client import StorageFarmBroker
|
||||
from allmydata.immutable import offloaded, upload
|
||||
@ -189,12 +189,12 @@ class AssistedUpload(unittest.TestCase):
|
||||
|
||||
key = hashutil.convergence_hash(k, n, segsize, DATA, "test convergence string")
|
||||
assert len(key) == 16
|
||||
encryptor = AES(key)
|
||||
encryptor = aes.create_encryptor(key)
|
||||
SI = hashutil.storage_index_hash(key)
|
||||
SI_s = si_b2a(SI)
|
||||
encfile = os.path.join(self.basedir, "CHK_encoding", SI_s)
|
||||
f = open(encfile, "wb")
|
||||
f.write(encryptor.process(DATA))
|
||||
f.write(aes.decrypt_data(encryptor, DATA))
|
||||
f.close()
|
||||
|
||||
u = upload.Uploader(self.helper_furl)
|
||||
|
@ -16,7 +16,7 @@ if sys.platform == "win32":
|
||||
|
||||
from twisted.python import log
|
||||
|
||||
from allmydata.crypto.aes import AES
|
||||
from allmydata.crypto import aes
|
||||
from allmydata.util.assertutil import _assert
|
||||
|
||||
|
||||
@ -109,9 +109,10 @@ class EncryptedTemporaryFile(object):
|
||||
offset_big = offset // 16
|
||||
offset_small = offset % 16
|
||||
iv = binascii.unhexlify("%032x" % offset_big)
|
||||
cipher = AES(self.key, iv=iv)
|
||||
cipher.process("\x00"*offset_small)
|
||||
return cipher.process(data)
|
||||
cipher = aes.create_encryptor(self.key, iv)
|
||||
# this is just to advance the counter
|
||||
aes.encrypt_data(cipher, "\x00" * offset_small)
|
||||
return aes.encrypt_data(ciper, data)
|
||||
|
||||
def close(self):
|
||||
self.file.close()
|
||||
|
Loading…
x
Reference in New Issue
Block a user