cleanup, scrambler functions

This commit is contained in:
Adam Ierymenko 2019-08-19 12:49:33 -07:00
parent b34218c8c2
commit ca60d08621
No known key found for this signature in database
GPG Key ID: 1657198823E52A61
2 changed files with 209 additions and 206 deletions

View File

@ -114,51 +114,6 @@ public:
_decryptSW(in,out);
}
inline void ecbScramble(const void *in,unsigned int inlen,void *out)
{
if (inlen < 16)
return;
#ifdef ZT_AES_AESNI
if (likely(HW_ACCEL)) {
const uint8_t *i = (const uint8_t *)in;
uint8_t *o = (uint8_t *)out;
while (inlen >= 128) {
_encrypt_8xecb_aesni(i,o);
i += 128;
o += 128;
inlen -= 128;
}
while (inlen >= 16) {
_encrypt_aesni(i,o);
i += 16;
o += 16;
inlen -= 16;
}
if (inlen) {
i -= (16 - inlen);
o -= (16 - inlen);
_encrypt_aesni(i,o);
}
return;
}
#endif
const uint8_t *i = (const uint8_t *)in;
uint8_t *o = (uint8_t *)out;
while (inlen >= 16) {
_encryptSW(i,o);
i += 16;
o += 16;
inlen -= 16;
}
if (inlen) {
i -= (16 - inlen);
o -= (16 - inlen);
_encryptSW(i,o);
}
}
inline void gcmEncrypt(const uint8_t iv[12],const void *in,unsigned int inlen,const void *assoc,unsigned int assoclen,void *out,uint8_t *tag,unsigned int taglen)
{
#ifdef ZT_AES_AESNI
@ -183,6 +138,32 @@ public:
return false;
}
static inline void scramble(const uint8_t key[16],const void *in,unsigned int inlen,void *out)
{
if (inlen < 16)
return;
#ifdef ZT_AES_AESNI
if (likely(HW_ACCEL)) {
_scramble_aesni(key,(const uint8_t *)in,(uint8_t *)out,inlen);
return;
}
#endif
}
static inline void unscramble(const uint8_t key[16],const void *in,unsigned int inlen,void *out)
{
if (inlen < 16)
return;
#ifdef ZT_AES_AESNI
if (likely(HW_ACCEL)) {
_unscramble_aesni(key,(const uint8_t *)in,(uint8_t *)out,inlen);
return;
}
#endif
}
private:
void _initSW(const uint8_t key[32]);
void _encryptSW(const uint8_t in[16],uint8_t out[16]) const;
@ -376,6 +357,169 @@ private:
_k.ni.hhhh = _swap128_aesni(hhhh);
}
static inline __m128i _assist128_aesni(__m128i a,__m128i b)
{
__m128i c;
b = _mm_shuffle_epi32(b ,0xff);
c = _mm_slli_si128(a, 0x04);
a = _mm_xor_si128(a, c);
c = _mm_slli_si128(c, 0x04);
a = _mm_xor_si128(a, c);
c = _mm_slli_si128(c, 0x04);
a = _mm_xor_si128(a, c);
a = _mm_xor_si128(a, b);
return a;
}
static inline void _expand128_aesni(__m128i schedule[10],const void *const key)
{
__m128i t;
schedule[0] = t = _mm_loadu_si128((const __m128i *)key);
schedule[1] = t = _assist128_aesni(t, _mm_aeskeygenassist_si128(t, 0x01));
schedule[2] = t = _assist128_aesni(t, _mm_aeskeygenassist_si128(t, 0x02));
schedule[3] = t = _assist128_aesni(t, _mm_aeskeygenassist_si128(t, 0x04));
schedule[4] = t = _assist128_aesni(t, _mm_aeskeygenassist_si128(t, 0x08));
schedule[5] = t = _assist128_aesni(t, _mm_aeskeygenassist_si128(t, 0x10));
schedule[6] = t = _assist128_aesni(t, _mm_aeskeygenassist_si128(t, 0x20));
schedule[7] = t = _assist128_aesni(t, _mm_aeskeygenassist_si128(t, 0x40));
schedule[8] = t = _assist128_aesni(t, _mm_aeskeygenassist_si128(t, 0x80));
schedule[9] = t = _assist128_aesni(t, _mm_aeskeygenassist_si128(t, 0x1b));
schedule[10] = _assist128_aesni(t, _mm_aeskeygenassist_si128(t, 0x36));
}
static inline void _scramble_aesni(const uint8_t key[16],const uint8_t *in,uint8_t *out,unsigned int len)
{
__m128i t = _mm_loadu_si128((const __m128i *)key);
__m128i k0 = t;
__m128i k1 = t = _assist128_aesni(t, _mm_aeskeygenassist_si128(t, 0x01));
__m128i k2 = t = _assist128_aesni(t, _mm_aeskeygenassist_si128(t, 0x02));
__m128i k3 = t = _assist128_aesni(t, _mm_aeskeygenassist_si128(t, 0x04));
__m128i k4 = t = _assist128_aesni(t, _mm_aeskeygenassist_si128(t, 0x08));
__m128i k5 = t = _assist128_aesni(t, _mm_aeskeygenassist_si128(t, 0x10));
while (len >= 32) {
len -= 32;
__m128i d0 = _mm_loadu_si128((const __m128i *)in);
in += 16;
__m128i d1 = _mm_loadu_si128((const __m128i *)in);
in += 16;
d0 = _mm_xor_si128(d0,k0);
d1 = _mm_xor_si128(d1,k0);
d0 = _mm_aesenc_si128(d0,k1);
d1 = _mm_aesenc_si128(d1,k1);
d0 = _mm_aesenc_si128(d0,k2);
d1 = _mm_aesenc_si128(d1,k2);
d0 = _mm_aesenc_si128(d0,k3);
d1 = _mm_aesenc_si128(d1,k3);
d0 = _mm_aesenc_si128(d0,k4);
d1 = _mm_aesenc_si128(d1,k4);
_mm_storeu_si128((__m128i *)out,_mm_aesenclast_si128(d0,k5));
out += 16;
_mm_storeu_si128((__m128i *)out,_mm_aesenclast_si128(d1,k5));
out += 16;
}
while (len >= 16) {
len -= 16;
__m128i d0 = _mm_loadu_si128((const __m128i *)in);
in += 16;
d0 = _mm_xor_si128(d0,k0);
d0 = _mm_aesenc_si128(d0,k1);
d0 = _mm_aesenc_si128(d0,k2);
d0 = _mm_aesenc_si128(d0,k3);
d0 = _mm_aesenc_si128(d0,k4);
_mm_storeu_si128((__m128i *)out,_mm_aesenclast_si128(d0,k5));
out += 16;
}
if (len) {
__m128i last = _mm_setzero_si128();
last = _mm_xor_si128(last,k0);
last = _mm_aesenc_si128(last,k1);
last = _mm_aesenc_si128(last,k2);
last = _mm_aesenc_si128(last,k3);
last = _mm_aesenc_si128(last,k4);
uint8_t lpad[16];
_mm_storeu_si128((__m128i *)lpad,_mm_aesenclast_si128(last,k5));
for(unsigned int i=0;i<len;++i) {
out[i] = in[i] ^ lpad[i];
}
}
}
static inline void _unscramble_aesni(const uint8_t key[16],const uint8_t *in,uint8_t *out,unsigned int len)
{
__m128i t = _mm_loadu_si128((const __m128i *)key);
__m128i dk5 = t; // k0
__m128i k1 = t = _assist128_aesni(t, _mm_aeskeygenassist_si128(t, 0x01));
__m128i k2 = t = _assist128_aesni(t, _mm_aeskeygenassist_si128(t, 0x02));
__m128i k3 = t = _assist128_aesni(t, _mm_aeskeygenassist_si128(t, 0x04));
__m128i k4 = t = _assist128_aesni(t, _mm_aeskeygenassist_si128(t, 0x08));
__m128i dk0 = t = _assist128_aesni(t, _mm_aeskeygenassist_si128(t, 0x10)); // k5
__m128i dk1 = _mm_aesimc_si128(k4);
__m128i dk2 = _mm_aesimc_si128(k3);
__m128i dk3 = _mm_aesimc_si128(k2);
__m128i dk4 = _mm_aesimc_si128(k1);
while (len >= 32) {
len -= 32;
__m128i d0 = _mm_loadu_si128((const __m128i *)in);
in += 16;
__m128i d1 = _mm_loadu_si128((const __m128i *)in);
in += 16;
d0 = _mm_xor_si128(d0,dk0);
d1 = _mm_xor_si128(d1,dk0);
d0 = _mm_aesdec_si128(d0,dk1);
d1 = _mm_aesdec_si128(d1,dk1);
d0 = _mm_aesdec_si128(d0,dk2);
d1 = _mm_aesdec_si128(d1,dk2);
d0 = _mm_aesdec_si128(d0,dk3);
d1 = _mm_aesdec_si128(d1,dk3);
d0 = _mm_aesdec_si128(d0,dk4);
d1 = _mm_aesdec_si128(d1,dk4);
_mm_storeu_si128((__m128i *)out,_mm_aesdeclast_si128(d0,dk5));
out += 16;
_mm_storeu_si128((__m128i *)out,_mm_aesdeclast_si128(d1,dk5));
out += 16;
}
while (len >= 16) {
len -= 16;
__m128i d0 = _mm_loadu_si128((const __m128i *)in);
in += 16;
d0 = _mm_xor_si128(d0,dk0);
d0 = _mm_aesdec_si128(d0,dk1);
d0 = _mm_aesdec_si128(d0,dk2);
d0 = _mm_aesdec_si128(d0,dk3);
d0 = _mm_aesdec_si128(d0,dk4);
_mm_storeu_si128((__m128i *)out,_mm_aesdeclast_si128(d0,dk5));
out += 16;
}
if (len) {
__m128i last = _mm_setzero_si128();
last = _mm_xor_si128(last,dk5); // k0
last = _mm_aesenc_si128(last,k1);
last = _mm_aesenc_si128(last,k2);
last = _mm_aesenc_si128(last,k3);
last = _mm_aesenc_si128(last,k4);
uint8_t lpad[16];
_mm_storeu_si128((__m128i *)lpad,_mm_aesenclast_si128(last,dk0)); // k5
for(unsigned int i=0;i<len;++i) {
out[i] = in[i] ^ lpad[i];
}
}
}
inline void _encrypt_aesni(const void *in,void *out) const
{
__m128i tmp;
@ -396,160 +540,6 @@ private:
tmp = _mm_aesenc_si128(tmp,_k.ni.k[13]);
_mm_storeu_si128((__m128i *)out,_mm_aesenclast_si128(tmp,_k.ni.k[14]));
}
inline void _encrypt_8xecb_aesni(const void *in,void *out) const
{
__m128i tmp0 = _mm_loadu_si128((const __m128i *)in);
__m128i tmp1 = _mm_loadu_si128((const __m128i *)((const uint8_t *)in + 16));
__m128i tmp2 = _mm_loadu_si128((const __m128i *)((const uint8_t *)in + 32));
__m128i tmp3 = _mm_loadu_si128((const __m128i *)((const uint8_t *)in + 48));
__m128i tmp4 = _mm_loadu_si128((const __m128i *)((const uint8_t *)in + 64));
__m128i tmp5 = _mm_loadu_si128((const __m128i *)((const uint8_t *)in + 80));
__m128i tmp6 = _mm_loadu_si128((const __m128i *)((const uint8_t *)in + 96));
__m128i tmp7 = _mm_loadu_si128((const __m128i *)((const uint8_t *)in + 112));
{
__m128i k0 = _k.ni.k[0];
__m128i k1 = _k.ni.k[1];
__m128i k2 = _k.ni.k[2];
__m128i k3 = _k.ni.k[3];
tmp0 = _mm_xor_si128(tmp0,k0);
tmp1 = _mm_xor_si128(tmp1,k0);
tmp2 = _mm_xor_si128(tmp2,k0);
tmp3 = _mm_xor_si128(tmp3,k0);
tmp4 = _mm_xor_si128(tmp4,k0);
tmp5 = _mm_xor_si128(tmp5,k0);
tmp6 = _mm_xor_si128(tmp6,k0);
tmp7 = _mm_xor_si128(tmp7,k0);
tmp0 = _mm_aesenc_si128(tmp0,k1);
tmp1 = _mm_aesenc_si128(tmp1,k1);
tmp2 = _mm_aesenc_si128(tmp2,k1);
tmp3 = _mm_aesenc_si128(tmp3,k1);
tmp4 = _mm_aesenc_si128(tmp4,k1);
tmp5 = _mm_aesenc_si128(tmp5,k1);
tmp6 = _mm_aesenc_si128(tmp6,k1);
tmp7 = _mm_aesenc_si128(tmp7,k1);
tmp0 = _mm_aesenc_si128(tmp0,k2);
tmp1 = _mm_aesenc_si128(tmp1,k2);
tmp2 = _mm_aesenc_si128(tmp2,k2);
tmp3 = _mm_aesenc_si128(tmp3,k2);
tmp4 = _mm_aesenc_si128(tmp4,k2);
tmp5 = _mm_aesenc_si128(tmp5,k2);
tmp6 = _mm_aesenc_si128(tmp6,k2);
tmp7 = _mm_aesenc_si128(tmp7,k2);
tmp0 = _mm_aesenc_si128(tmp0,k3);
tmp1 = _mm_aesenc_si128(tmp1,k3);
tmp2 = _mm_aesenc_si128(tmp2,k3);
tmp3 = _mm_aesenc_si128(tmp3,k3);
tmp4 = _mm_aesenc_si128(tmp4,k3);
tmp5 = _mm_aesenc_si128(tmp5,k3);
tmp6 = _mm_aesenc_si128(tmp6,k3);
tmp7 = _mm_aesenc_si128(tmp7,k3);
}
{
__m128i k4 = _k.ni.k[4];
__m128i k5 = _k.ni.k[5];
__m128i k6 = _k.ni.k[6];
__m128i k7 = _k.ni.k[7];
tmp0 = _mm_aesenc_si128(tmp0,k4);
tmp1 = _mm_aesenc_si128(tmp1,k4);
tmp2 = _mm_aesenc_si128(tmp2,k4);
tmp3 = _mm_aesenc_si128(tmp3,k4);
tmp4 = _mm_aesenc_si128(tmp4,k4);
tmp5 = _mm_aesenc_si128(tmp5,k4);
tmp6 = _mm_aesenc_si128(tmp6,k4);
tmp7 = _mm_aesenc_si128(tmp7,k4);
tmp0 = _mm_aesenc_si128(tmp0,k5);
tmp1 = _mm_aesenc_si128(tmp1,k5);
tmp2 = _mm_aesenc_si128(tmp2,k5);
tmp3 = _mm_aesenc_si128(tmp3,k5);
tmp4 = _mm_aesenc_si128(tmp4,k5);
tmp5 = _mm_aesenc_si128(tmp5,k5);
tmp6 = _mm_aesenc_si128(tmp6,k5);
tmp7 = _mm_aesenc_si128(tmp7,k5);
tmp0 = _mm_aesenc_si128(tmp0,k6);
tmp1 = _mm_aesenc_si128(tmp1,k6);
tmp2 = _mm_aesenc_si128(tmp2,k6);
tmp3 = _mm_aesenc_si128(tmp3,k6);
tmp4 = _mm_aesenc_si128(tmp4,k6);
tmp5 = _mm_aesenc_si128(tmp5,k6);
tmp6 = _mm_aesenc_si128(tmp6,k6);
tmp7 = _mm_aesenc_si128(tmp7,k6);
tmp0 = _mm_aesenc_si128(tmp0,k7);
tmp1 = _mm_aesenc_si128(tmp1,k7);
tmp2 = _mm_aesenc_si128(tmp2,k7);
tmp3 = _mm_aesenc_si128(tmp3,k7);
tmp4 = _mm_aesenc_si128(tmp4,k7);
tmp5 = _mm_aesenc_si128(tmp5,k7);
tmp6 = _mm_aesenc_si128(tmp6,k7);
tmp7 = _mm_aesenc_si128(tmp7,k7);
}
{
__m128i k8 = _k.ni.k[8];
__m128i k9 = _k.ni.k[9];
__m128i k10 = _k.ni.k[10];
__m128i k11 = _k.ni.k[11];
tmp0 = _mm_aesenc_si128(tmp0,k8);
tmp1 = _mm_aesenc_si128(tmp1,k8);
tmp2 = _mm_aesenc_si128(tmp2,k8);
tmp3 = _mm_aesenc_si128(tmp3,k8);
tmp4 = _mm_aesenc_si128(tmp4,k8);
tmp5 = _mm_aesenc_si128(tmp5,k8);
tmp6 = _mm_aesenc_si128(tmp6,k8);
tmp7 = _mm_aesenc_si128(tmp7,k8);
tmp0 = _mm_aesenc_si128(tmp0,k9);
tmp1 = _mm_aesenc_si128(tmp1,k9);
tmp2 = _mm_aesenc_si128(tmp2,k9);
tmp3 = _mm_aesenc_si128(tmp3,k9);
tmp4 = _mm_aesenc_si128(tmp4,k9);
tmp5 = _mm_aesenc_si128(tmp5,k9);
tmp6 = _mm_aesenc_si128(tmp6,k9);
tmp7 = _mm_aesenc_si128(tmp7,k9);
tmp0 = _mm_aesenc_si128(tmp0,k10);
tmp1 = _mm_aesenc_si128(tmp1,k10);
tmp2 = _mm_aesenc_si128(tmp2,k10);
tmp3 = _mm_aesenc_si128(tmp3,k10);
tmp4 = _mm_aesenc_si128(tmp4,k10);
tmp5 = _mm_aesenc_si128(tmp5,k10);
tmp6 = _mm_aesenc_si128(tmp6,k10);
tmp7 = _mm_aesenc_si128(tmp7,k10);
tmp0 = _mm_aesenc_si128(tmp0,k11);
tmp1 = _mm_aesenc_si128(tmp1,k11);
tmp2 = _mm_aesenc_si128(tmp2,k11);
tmp3 = _mm_aesenc_si128(tmp3,k11);
tmp4 = _mm_aesenc_si128(tmp4,k11);
tmp5 = _mm_aesenc_si128(tmp5,k11);
tmp6 = _mm_aesenc_si128(tmp6,k11);
tmp7 = _mm_aesenc_si128(tmp7,k11);
}
{
__m128i k12 = _k.ni.k[12];
__m128i k13 = _k.ni.k[13];
__m128i k14 = _k.ni.k[14];
tmp0 = _mm_aesenc_si128(tmp0,k12);
tmp1 = _mm_aesenc_si128(tmp1,k12);
tmp2 = _mm_aesenc_si128(tmp2,k12);
tmp3 = _mm_aesenc_si128(tmp3,k12);
tmp4 = _mm_aesenc_si128(tmp4,k12);
tmp5 = _mm_aesenc_si128(tmp5,k12);
tmp6 = _mm_aesenc_si128(tmp6,k12);
tmp7 = _mm_aesenc_si128(tmp7,k12);
tmp0 = _mm_aesenc_si128(tmp0,k13);
tmp1 = _mm_aesenc_si128(tmp1,k13);
tmp2 = _mm_aesenc_si128(tmp2,k13);
tmp3 = _mm_aesenc_si128(tmp3,k13);
tmp4 = _mm_aesenc_si128(tmp4,k13);
tmp5 = _mm_aesenc_si128(tmp5,k13);
tmp6 = _mm_aesenc_si128(tmp6,k13);
tmp7 = _mm_aesenc_si128(tmp7,k13);
_mm_storeu_si128((__m128i *)out,_mm_aesenclast_si128(tmp0,k14));
_mm_storeu_si128((__m128i *)((uint8_t *)out + 16),_mm_aesenclast_si128(tmp1,k14));
_mm_storeu_si128((__m128i *)((uint8_t *)out + 32),_mm_aesenclast_si128(tmp2,k14));
_mm_storeu_si128((__m128i *)((uint8_t *)out + 48),_mm_aesenclast_si128(tmp3,k14));
_mm_storeu_si128((__m128i *)((uint8_t *)out + 64),_mm_aesenclast_si128(tmp4,k14));
_mm_storeu_si128((__m128i *)((uint8_t *)out + 80),_mm_aesenclast_si128(tmp5,k14));
_mm_storeu_si128((__m128i *)((uint8_t *)out + 96),_mm_aesenclast_si128(tmp6,k14));
_mm_storeu_si128((__m128i *)((uint8_t *)out + 112),_mm_aesenclast_si128(tmp7,k14));
}
}
inline void _decrypt_aesni(const void *in,void *out) const
{
__m128i tmp;

View File

@ -221,24 +221,30 @@ static int testCrypto()
}
int64_t end = OSUtils::now();
*dummy = buf1[0];
std::cout << ((gcmBytes / 1048576.0) / ((double)(end - start) / 1000.0)) << " MiB/second" << std::endl << " AES-256 ECB scramble (benchmark): "; std::cout.flush();
std::cout << ((gcmBytes / 1048576.0) / ((double)(end - start) / 1000.0)) << " MiB/second" << std::endl << " AES scramble (benchmark): "; std::cout.flush();
double ecbBytes = 0.0;
AES::scramble((const uint8_t *)hexbuf,buf1,sizeof(buf1),buf2);
AES::unscramble((const uint8_t *)hexbuf,buf2,sizeof(buf2),buf3);
if (memcmp(buf1,buf3,sizeof(buf1)) != 0) {
std::cout << "FAILED (scramble/unscramble did not generate identical data)" << std::endl;
return -1;
}
start = OSUtils::now();
for(unsigned long i=0;i<50000;++i) {
tv.ecbScramble(buf1,sizeof(buf1),buf2);
tv.ecbScramble(buf2,sizeof(buf1),buf1);
for(unsigned long i=0;i<200000;++i) {
AES::scramble((const uint8_t *)hexbuf,buf1,sizeof(buf1),buf2);
AES::scramble((const uint8_t *)hexbuf,buf2,sizeof(buf1),buf1);
ecbBytes += (double)(sizeof(buf1) * 2);
}
end = OSUtils::now();
*dummy = buf1[0];
std::cout << ((ecbBytes / 1048576.0) / ((double)(end - start) / 1000.0)) << " MiB/second" << std::endl << " AES-256 GCM + ECB scramble (benchmark): "; std::cout.flush();
std::cout << ((ecbBytes / 1048576.0) / ((double)(end - start) / 1000.0)) << " MiB/second" << std::endl << " AES-256 GCM + scramble (benchmark): "; std::cout.flush();
ecbBytes = 0.0;
start = OSUtils::now();
for(unsigned long i=0;i<50000;++i) {
tv.gcmEncrypt((const uint8_t *)hexbuf,buf1,sizeof(buf1),nullptr,0,buf2,(uint8_t *)(hexbuf + 32),16);
tv.ecbScramble(buf1,sizeof(buf1),buf2);
AES::scramble((const uint8_t *)hexbuf,buf1,sizeof(buf1),buf2);
tv.gcmEncrypt((const uint8_t *)hexbuf,buf2,sizeof(buf2),nullptr,0,buf1,(uint8_t *)(hexbuf + 32),16);
tv.ecbScramble(buf2,sizeof(buf1),buf1);
AES::scramble((const uint8_t *)hexbuf,buf2,sizeof(buf1),buf1);
ecbBytes += (double)(sizeof(buf1) * 2);
}
end = OSUtils::now();
@ -342,6 +348,13 @@ static int testCrypto()
return -1;
}
std::cout << "PASS" << std::endl;
std::cout << "[crypto] Benchmarking SHA-384 (48 byte input)... "; std::cout.flush();
start = OSUtils::now();
for(unsigned int i=0;i<2000000;++i) {
SHA384(buf1,buf1,48);
}
end = OSUtils::now();
std::cout << (uint64_t)(2000000.0 / ((double)(end - start) / 1000.0)) << " hashes/second" << std::endl;
std::cout << "[crypto] Testing Poly1305... "; std::cout.flush();
poly1305(buf1,poly1305TV0Input,sizeof(poly1305TV0Input),poly1305TV0Key);