diff --git a/node/AES.hpp b/node/AES.hpp index 31cd1e21b..b84a9a167 100644 --- a/node/AES.hpp +++ b/node/AES.hpp @@ -114,51 +114,6 @@ public: _decryptSW(in,out); } - inline void ecbScramble(const void *in,unsigned int inlen,void *out) - { - if (inlen < 16) - return; - -#ifdef ZT_AES_AESNI - if (likely(HW_ACCEL)) { - const uint8_t *i = (const uint8_t *)in; - uint8_t *o = (uint8_t *)out; - while (inlen >= 128) { - _encrypt_8xecb_aesni(i,o); - i += 128; - o += 128; - inlen -= 128; - } - while (inlen >= 16) { - _encrypt_aesni(i,o); - i += 16; - o += 16; - inlen -= 16; - } - if (inlen) { - i -= (16 - inlen); - o -= (16 - inlen); - _encrypt_aesni(i,o); - } - return; - } -#endif - - const uint8_t *i = (const uint8_t *)in; - uint8_t *o = (uint8_t *)out; - while (inlen >= 16) { - _encryptSW(i,o); - i += 16; - o += 16; - inlen -= 16; - } - if (inlen) { - i -= (16 - inlen); - o -= (16 - inlen); - _encryptSW(i,o); - } - } - inline void gcmEncrypt(const uint8_t iv[12],const void *in,unsigned int inlen,const void *assoc,unsigned int assoclen,void *out,uint8_t *tag,unsigned int taglen) { #ifdef ZT_AES_AESNI @@ -183,6 +138,32 @@ public: return false; } + static inline void scramble(const uint8_t key[16],const void *in,unsigned int inlen,void *out) + { + if (inlen < 16) + return; + +#ifdef ZT_AES_AESNI + if (likely(HW_ACCEL)) { + _scramble_aesni(key,(const uint8_t *)in,(uint8_t *)out,inlen); + return; + } +#endif + } + + static inline void unscramble(const uint8_t key[16],const void *in,unsigned int inlen,void *out) + { + if (inlen < 16) + return; + +#ifdef ZT_AES_AESNI + if (likely(HW_ACCEL)) { + _unscramble_aesni(key,(const uint8_t *)in,(uint8_t *)out,inlen); + return; + } +#endif + } + private: void _initSW(const uint8_t key[32]); void _encryptSW(const uint8_t in[16],uint8_t out[16]) const; @@ -376,6 +357,169 @@ private: _k.ni.hhhh = _swap128_aesni(hhhh); } + static inline __m128i _assist128_aesni(__m128i a,__m128i b) + { + __m128i c; + b = _mm_shuffle_epi32(b ,0xff); + c = _mm_slli_si128(a, 0x04); + a = _mm_xor_si128(a, c); + c = _mm_slli_si128(c, 0x04); + a = _mm_xor_si128(a, c); + c = _mm_slli_si128(c, 0x04); + a = _mm_xor_si128(a, c); + a = _mm_xor_si128(a, b); + return a; + } + static inline void _expand128_aesni(__m128i schedule[10],const void *const key) + { + __m128i t; + schedule[0] = t = _mm_loadu_si128((const __m128i *)key); + schedule[1] = t = _assist128_aesni(t, _mm_aeskeygenassist_si128(t, 0x01)); + schedule[2] = t = _assist128_aesni(t, _mm_aeskeygenassist_si128(t, 0x02)); + schedule[3] = t = _assist128_aesni(t, _mm_aeskeygenassist_si128(t, 0x04)); + schedule[4] = t = _assist128_aesni(t, _mm_aeskeygenassist_si128(t, 0x08)); + schedule[5] = t = _assist128_aesni(t, _mm_aeskeygenassist_si128(t, 0x10)); + schedule[6] = t = _assist128_aesni(t, _mm_aeskeygenassist_si128(t, 0x20)); + schedule[7] = t = _assist128_aesni(t, _mm_aeskeygenassist_si128(t, 0x40)); + schedule[8] = t = _assist128_aesni(t, _mm_aeskeygenassist_si128(t, 0x80)); + schedule[9] = t = _assist128_aesni(t, _mm_aeskeygenassist_si128(t, 0x1b)); + schedule[10] = _assist128_aesni(t, _mm_aeskeygenassist_si128(t, 0x36)); + } + static inline void _scramble_aesni(const uint8_t key[16],const uint8_t *in,uint8_t *out,unsigned int len) + { + __m128i t = _mm_loadu_si128((const __m128i *)key); + __m128i k0 = t; + __m128i k1 = t = _assist128_aesni(t, _mm_aeskeygenassist_si128(t, 0x01)); + __m128i k2 = t = _assist128_aesni(t, _mm_aeskeygenassist_si128(t, 0x02)); + __m128i k3 = t = _assist128_aesni(t, _mm_aeskeygenassist_si128(t, 0x04)); + __m128i k4 = t = _assist128_aesni(t, _mm_aeskeygenassist_si128(t, 0x08)); + __m128i k5 = t = _assist128_aesni(t, _mm_aeskeygenassist_si128(t, 0x10)); + + while (len >= 32) { + len -= 32; + + __m128i d0 = _mm_loadu_si128((const __m128i *)in); + in += 16; + __m128i d1 = _mm_loadu_si128((const __m128i *)in); + in += 16; + + d0 = _mm_xor_si128(d0,k0); + d1 = _mm_xor_si128(d1,k0); + d0 = _mm_aesenc_si128(d0,k1); + d1 = _mm_aesenc_si128(d1,k1); + d0 = _mm_aesenc_si128(d0,k2); + d1 = _mm_aesenc_si128(d1,k2); + d0 = _mm_aesenc_si128(d0,k3); + d1 = _mm_aesenc_si128(d1,k3); + d0 = _mm_aesenc_si128(d0,k4); + d1 = _mm_aesenc_si128(d1,k4); + + _mm_storeu_si128((__m128i *)out,_mm_aesenclast_si128(d0,k5)); + out += 16; + _mm_storeu_si128((__m128i *)out,_mm_aesenclast_si128(d1,k5)); + out += 16; + } + + while (len >= 16) { + len -= 16; + + __m128i d0 = _mm_loadu_si128((const __m128i *)in); + in += 16; + + d0 = _mm_xor_si128(d0,k0); + d0 = _mm_aesenc_si128(d0,k1); + d0 = _mm_aesenc_si128(d0,k2); + d0 = _mm_aesenc_si128(d0,k3); + d0 = _mm_aesenc_si128(d0,k4); + + _mm_storeu_si128((__m128i *)out,_mm_aesenclast_si128(d0,k5)); + out += 16; + } + + if (len) { + __m128i last = _mm_setzero_si128(); + last = _mm_xor_si128(last,k0); + last = _mm_aesenc_si128(last,k1); + last = _mm_aesenc_si128(last,k2); + last = _mm_aesenc_si128(last,k3); + last = _mm_aesenc_si128(last,k4); + uint8_t lpad[16]; + _mm_storeu_si128((__m128i *)lpad,_mm_aesenclast_si128(last,k5)); + for(unsigned int i=0;i= 32) { + len -= 32; + + __m128i d0 = _mm_loadu_si128((const __m128i *)in); + in += 16; + __m128i d1 = _mm_loadu_si128((const __m128i *)in); + in += 16; + + d0 = _mm_xor_si128(d0,dk0); + d1 = _mm_xor_si128(d1,dk0); + d0 = _mm_aesdec_si128(d0,dk1); + d1 = _mm_aesdec_si128(d1,dk1); + d0 = _mm_aesdec_si128(d0,dk2); + d1 = _mm_aesdec_si128(d1,dk2); + d0 = _mm_aesdec_si128(d0,dk3); + d1 = _mm_aesdec_si128(d1,dk3); + d0 = _mm_aesdec_si128(d0,dk4); + d1 = _mm_aesdec_si128(d1,dk4); + + _mm_storeu_si128((__m128i *)out,_mm_aesdeclast_si128(d0,dk5)); + out += 16; + _mm_storeu_si128((__m128i *)out,_mm_aesdeclast_si128(d1,dk5)); + out += 16; + } + + while (len >= 16) { + len -= 16; + + __m128i d0 = _mm_loadu_si128((const __m128i *)in); + in += 16; + + d0 = _mm_xor_si128(d0,dk0); + d0 = _mm_aesdec_si128(d0,dk1); + d0 = _mm_aesdec_si128(d0,dk2); + d0 = _mm_aesdec_si128(d0,dk3); + d0 = _mm_aesdec_si128(d0,dk4); + + _mm_storeu_si128((__m128i *)out,_mm_aesdeclast_si128(d0,dk5)); + out += 16; + } + + if (len) { + __m128i last = _mm_setzero_si128(); + last = _mm_xor_si128(last,dk5); // k0 + last = _mm_aesenc_si128(last,k1); + last = _mm_aesenc_si128(last,k2); + last = _mm_aesenc_si128(last,k3); + last = _mm_aesenc_si128(last,k4); + uint8_t lpad[16]; + _mm_storeu_si128((__m128i *)lpad,_mm_aesenclast_si128(last,dk0)); // k5 + for(unsigned int i=0;i