faster without const variable second-guessing of the compiler

This commit is contained in:
Adam Ierymenko 2019-09-05 17:31:12 -07:00
parent 274b2682d6
commit c0e92d06a5
No known key found for this signature in database
GPG Key ID: C8877CF2D7A5D7F3

View File

@ -167,6 +167,9 @@ public:
* to use makes the IV itself a secret. This is not strictly necessary
* but comes at little cost.
*
* This code is ZeroTier-specific in a few ways, like the way the IV
* is specified, but would not be hard to generalize.
*
* @param k1 GMAC key
* @param k2 GMAC auth tag keyed hash key
* @param k3 CTR IV keyed hash key
@ -199,7 +202,7 @@ public:
miv[10] = (uint8_t)(len >> 8);
miv[11] = (uint8_t)len;
// Compute auth TAG: AES-ECB[k2](GMAC[k1](miv,plaintext))[0:8]
// Compute auth tag: AES-ECB[k2](GMAC[k1](miv,plaintext))[0:8]
k1.gmac(miv,in,len,ctrIv);
k2.encrypt(ctrIv,ctrIv); // ECB mode encrypt step is because GMAC is not a PRF
#ifdef ZT_NO_TYPE_PUNNING
@ -525,22 +528,6 @@ private:
const __m64 iv0 = (__m64)(*((const uint64_t *)iv));
uint64_t ctr = Utils::ntoh(*((const uint64_t *)(iv+8)));
const __m128i k0 = _k.ni.k[0];
const __m128i k1 = _k.ni.k[1];
const __m128i k2 = _k.ni.k[2];
const __m128i k3 = _k.ni.k[3];
const __m128i k4 = _k.ni.k[4];
const __m128i k5 = _k.ni.k[5];
const __m128i k6 = _k.ni.k[6];
const __m128i k7 = _k.ni.k[7];
const __m128i k8 = _k.ni.k[8];
const __m128i k9 = _k.ni.k[9];
const __m128i k10 = _k.ni.k[10];
const __m128i k11 = _k.ni.k[11];
const __m128i k12 = _k.ni.k[12];
const __m128i k13 = _k.ni.k[13];
const __m128i k14 = _k.ni.k[14];
#define ZT_AES_CTR_AESNI_ROUND(k) \
c0 = _mm_aesenc_si128(c0,k); \
c1 = _mm_aesenc_si128(c1,k); \
@ -552,36 +539,41 @@ private:
c7 = _mm_aesenc_si128(c7,k)
while (len >= 128) {
__m128i c0 = _mm_xor_si128(_mm_set_epi64((__m64)Utils::hton(ctr),iv0),k0);
__m128i c1 = _mm_xor_si128(_mm_set_epi64((__m64)Utils::hton((uint64_t)(ctr+1ULL)),iv0),k0);
__m128i c2 = _mm_xor_si128(_mm_set_epi64((__m64)Utils::hton((uint64_t)(ctr+2ULL)),iv0),k0);
__m128i c3 = _mm_xor_si128(_mm_set_epi64((__m64)Utils::hton((uint64_t)(ctr+3ULL)),iv0),k0);
__m128i c4 = _mm_xor_si128(_mm_set_epi64((__m64)Utils::hton((uint64_t)(ctr+4ULL)),iv0),k0);
__m128i c5 = _mm_xor_si128(_mm_set_epi64((__m64)Utils::hton((uint64_t)(ctr+5ULL)),iv0),k0);
__m128i c6 = _mm_xor_si128(_mm_set_epi64((__m64)Utils::hton((uint64_t)(ctr+6ULL)),iv0),k0);
__m128i c7 = _mm_xor_si128(_mm_set_epi64((__m64)Utils::hton((uint64_t)(ctr+7ULL)),iv0),k0);
_mm_prefetch(in,_MM_HINT_T0);
_mm_prefetch(in + 32,_MM_HINT_T0);
_mm_prefetch(in + 64,_MM_HINT_T0);
_mm_prefetch(in + 96,_MM_HINT_T0);
_mm_prefetch(in + 128,_MM_HINT_T0);
__m128i c0 = _mm_xor_si128(_mm_set_epi64((__m64)Utils::hton(ctr),iv0),_k.ni.k[0]);
__m128i c1 = _mm_xor_si128(_mm_set_epi64((__m64)Utils::hton((uint64_t)(ctr+1ULL)),iv0),_k.ni.k[0]);
__m128i c2 = _mm_xor_si128(_mm_set_epi64((__m64)Utils::hton((uint64_t)(ctr+2ULL)),iv0),_k.ni.k[0]);
__m128i c3 = _mm_xor_si128(_mm_set_epi64((__m64)Utils::hton((uint64_t)(ctr+3ULL)),iv0),_k.ni.k[0]);
__m128i c4 = _mm_xor_si128(_mm_set_epi64((__m64)Utils::hton((uint64_t)(ctr+4ULL)),iv0),_k.ni.k[0]);
__m128i c5 = _mm_xor_si128(_mm_set_epi64((__m64)Utils::hton((uint64_t)(ctr+5ULL)),iv0),_k.ni.k[0]);
__m128i c6 = _mm_xor_si128(_mm_set_epi64((__m64)Utils::hton((uint64_t)(ctr+6ULL)),iv0),_k.ni.k[0]);
__m128i c7 = _mm_xor_si128(_mm_set_epi64((__m64)Utils::hton((uint64_t)(ctr+7ULL)),iv0),_k.ni.k[0]);
ctr += 8;
ZT_AES_CTR_AESNI_ROUND(k1);
ZT_AES_CTR_AESNI_ROUND(k2);
ZT_AES_CTR_AESNI_ROUND(k3);
ZT_AES_CTR_AESNI_ROUND(k4);
ZT_AES_CTR_AESNI_ROUND(k5);
ZT_AES_CTR_AESNI_ROUND(k6);
ZT_AES_CTR_AESNI_ROUND(k7);
ZT_AES_CTR_AESNI_ROUND(k8);
ZT_AES_CTR_AESNI_ROUND(k9);
ZT_AES_CTR_AESNI_ROUND(k10);
ZT_AES_CTR_AESNI_ROUND(k11);
ZT_AES_CTR_AESNI_ROUND(k12);
ZT_AES_CTR_AESNI_ROUND(k13);
_mm_storeu_si128((__m128i *)out,_mm_xor_si128(_mm_loadu_si128((const __m128i *)in),_mm_aesenclast_si128(c0,k14)));
_mm_storeu_si128((__m128i *)(out + 16),_mm_xor_si128(_mm_loadu_si128((const __m128i *)(in + 16)),_mm_aesenclast_si128(c1,k14)));
_mm_storeu_si128((__m128i *)(out + 32),_mm_xor_si128(_mm_loadu_si128((const __m128i *)(in + 32)),_mm_aesenclast_si128(c2,k14)));
_mm_storeu_si128((__m128i *)(out + 48),_mm_xor_si128(_mm_loadu_si128((const __m128i *)(in + 48)),_mm_aesenclast_si128(c3,k14)));
_mm_storeu_si128((__m128i *)(out + 64),_mm_xor_si128(_mm_loadu_si128((const __m128i *)(in + 64)),_mm_aesenclast_si128(c4,k14)));
_mm_storeu_si128((__m128i *)(out + 80),_mm_xor_si128(_mm_loadu_si128((const __m128i *)(in + 80)),_mm_aesenclast_si128(c5,k14)));
_mm_storeu_si128((__m128i *)(out + 96),_mm_xor_si128(_mm_loadu_si128((const __m128i *)(in + 96)),_mm_aesenclast_si128(c6,k14)));
_mm_storeu_si128((__m128i *)(out + 112),_mm_xor_si128(_mm_loadu_si128((const __m128i *)(in + 112)),_mm_aesenclast_si128(c7,k14)));
ZT_AES_CTR_AESNI_ROUND(_k.ni.k[1]);
ZT_AES_CTR_AESNI_ROUND(_k.ni.k[2]);
ZT_AES_CTR_AESNI_ROUND(_k.ni.k[3]);
ZT_AES_CTR_AESNI_ROUND(_k.ni.k[4]);
ZT_AES_CTR_AESNI_ROUND(_k.ni.k[5]);
ZT_AES_CTR_AESNI_ROUND(_k.ni.k[6]);
ZT_AES_CTR_AESNI_ROUND(_k.ni.k[7]);
ZT_AES_CTR_AESNI_ROUND(_k.ni.k[8]);
ZT_AES_CTR_AESNI_ROUND(_k.ni.k[9]);
ZT_AES_CTR_AESNI_ROUND(_k.ni.k[10]);
ZT_AES_CTR_AESNI_ROUND(_k.ni.k[11]);
ZT_AES_CTR_AESNI_ROUND(_k.ni.k[12]);
ZT_AES_CTR_AESNI_ROUND(_k.ni.k[13]);
_mm_storeu_si128((__m128i *)out,_mm_xor_si128(_mm_loadu_si128((const __m128i *)in),_mm_aesenclast_si128(c0,_k.ni.k[14])));
_mm_storeu_si128((__m128i *)(out + 16),_mm_xor_si128(_mm_loadu_si128((const __m128i *)(in + 16)),_mm_aesenclast_si128(c1,_k.ni.k[14])));
_mm_storeu_si128((__m128i *)(out + 32),_mm_xor_si128(_mm_loadu_si128((const __m128i *)(in + 32)),_mm_aesenclast_si128(c2,_k.ni.k[14])));
_mm_storeu_si128((__m128i *)(out + 48),_mm_xor_si128(_mm_loadu_si128((const __m128i *)(in + 48)),_mm_aesenclast_si128(c3,_k.ni.k[14])));
_mm_storeu_si128((__m128i *)(out + 64),_mm_xor_si128(_mm_loadu_si128((const __m128i *)(in + 64)),_mm_aesenclast_si128(c4,_k.ni.k[14])));
_mm_storeu_si128((__m128i *)(out + 80),_mm_xor_si128(_mm_loadu_si128((const __m128i *)(in + 80)),_mm_aesenclast_si128(c5,_k.ni.k[14])));
_mm_storeu_si128((__m128i *)(out + 96),_mm_xor_si128(_mm_loadu_si128((const __m128i *)(in + 96)),_mm_aesenclast_si128(c6,_k.ni.k[14])));
_mm_storeu_si128((__m128i *)(out + 112),_mm_xor_si128(_mm_loadu_si128((const __m128i *)(in + 112)),_mm_aesenclast_si128(c7,_k.ni.k[14])));
in += 128;
out += 128;
len -= 128;
@ -590,42 +582,42 @@ private:
#undef ZT_AES_CTR_AESNI_ROUND
while (len >= 16) {
__m128i c0 = _mm_xor_si128(_mm_set_epi64((__m64)Utils::hton(ctr++),(__m64)iv0),k0);
c0 = _mm_aesenc_si128(c0,k1);
c0 = _mm_aesenc_si128(c0,k2);
c0 = _mm_aesenc_si128(c0,k3);
c0 = _mm_aesenc_si128(c0,k4);
c0 = _mm_aesenc_si128(c0,k5);
c0 = _mm_aesenc_si128(c0,k6);
c0 = _mm_aesenc_si128(c0,k7);
c0 = _mm_aesenc_si128(c0,k8);
c0 = _mm_aesenc_si128(c0,k9);
c0 = _mm_aesenc_si128(c0,k10);
c0 = _mm_aesenc_si128(c0,k11);
c0 = _mm_aesenc_si128(c0,k12);
c0 = _mm_aesenc_si128(c0,k13);
_mm_storeu_si128((__m128i *)out,_mm_xor_si128(_mm_loadu_si128((const __m128i *)in),_mm_aesenclast_si128(c0,k14)));
__m128i c0 = _mm_xor_si128(_mm_set_epi64((__m64)Utils::hton(ctr++),(__m64)iv0),_k.ni.k[0]);
c0 = _mm_aesenc_si128(c0,_k.ni.k[1]);
c0 = _mm_aesenc_si128(c0,_k.ni.k[2]);
c0 = _mm_aesenc_si128(c0,_k.ni.k[3]);
c0 = _mm_aesenc_si128(c0,_k.ni.k[4]);
c0 = _mm_aesenc_si128(c0,_k.ni.k[5]);
c0 = _mm_aesenc_si128(c0,_k.ni.k[6]);
c0 = _mm_aesenc_si128(c0,_k.ni.k[7]);
c0 = _mm_aesenc_si128(c0,_k.ni.k[8]);
c0 = _mm_aesenc_si128(c0,_k.ni.k[9]);
c0 = _mm_aesenc_si128(c0,_k.ni.k[10]);
c0 = _mm_aesenc_si128(c0,_k.ni.k[11]);
c0 = _mm_aesenc_si128(c0,_k.ni.k[12]);
c0 = _mm_aesenc_si128(c0,_k.ni.k[13]);
_mm_storeu_si128((__m128i *)out,_mm_xor_si128(_mm_loadu_si128((const __m128i *)in),_mm_aesenclast_si128(c0,_k.ni.k[14])));
in += 16;
out += 16;
len -= 16;
}
if (len) {
__m128i c0 = _mm_xor_si128(_mm_set_epi64((__m64)Utils::hton(ctr++),(__m64)iv0),k0);
c0 = _mm_aesenc_si128(c0,k1);
c0 = _mm_aesenc_si128(c0,k2);
c0 = _mm_aesenc_si128(c0,k3);
c0 = _mm_aesenc_si128(c0,k4);
c0 = _mm_aesenc_si128(c0,k5);
c0 = _mm_aesenc_si128(c0,k6);
c0 = _mm_aesenc_si128(c0,k7);
c0 = _mm_aesenc_si128(c0,k8);
c0 = _mm_aesenc_si128(c0,k9);
c0 = _mm_aesenc_si128(c0,k10);
c0 = _mm_aesenc_si128(c0,k11);
c0 = _mm_aesenc_si128(c0,k12);
c0 = _mm_aesenc_si128(c0,k13);
c0 = _mm_aesenclast_si128(c0,k14);
__m128i c0 = _mm_xor_si128(_mm_set_epi64((__m64)Utils::hton(ctr++),(__m64)iv0),_k.ni.k[0]);
c0 = _mm_aesenc_si128(c0,_k.ni.k[1]);
c0 = _mm_aesenc_si128(c0,_k.ni.k[2]);
c0 = _mm_aesenc_si128(c0,_k.ni.k[3]);
c0 = _mm_aesenc_si128(c0,_k.ni.k[4]);
c0 = _mm_aesenc_si128(c0,_k.ni.k[5]);
c0 = _mm_aesenc_si128(c0,_k.ni.k[6]);
c0 = _mm_aesenc_si128(c0,_k.ni.k[7]);
c0 = _mm_aesenc_si128(c0,_k.ni.k[8]);
c0 = _mm_aesenc_si128(c0,_k.ni.k[9]);
c0 = _mm_aesenc_si128(c0,_k.ni.k[10]);
c0 = _mm_aesenc_si128(c0,_k.ni.k[11]);
c0 = _mm_aesenc_si128(c0,_k.ni.k[12]);
c0 = _mm_aesenc_si128(c0,_k.ni.k[13]);
c0 = _mm_aesenclast_si128(c0,_k.ni.k[14]);
for(unsigned int i=0;i<len;++i)
out[i] = in[i] ^ ((const uint8_t *)&c0)[i];
}
@ -680,10 +672,6 @@ private:
unsigned int pblocks = blocks - (blocks % 4);
unsigned int rem = len % 16;
const __m128i h1 = _k.ni.hhhh;
const __m128i h2 = _k.ni.hhh;
const __m128i h3 = _k.ni.hh;
const __m128i h4 = _k.ni.h;
const __m128i shuf = _mm_set_epi8(0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15);
__m128i y = _mm_setzero_si128();
unsigned int i = 0;
@ -692,35 +680,37 @@ private:
__m128i d2 = _mm_shuffle_epi8(_mm_loadu_si128(ab + i + 1),shuf);
__m128i d3 = _mm_shuffle_epi8(_mm_loadu_si128(ab + i + 2),shuf);
__m128i d4 = _mm_shuffle_epi8(_mm_loadu_si128(ab + i + 3),shuf);
__m128i t0 = _mm_clmulepi64_si128(h1,d1,0x00);
__m128i t1 = _mm_clmulepi64_si128(h2,d2,0x00);
__m128i t2 = _mm_clmulepi64_si128(h3,d3,0x00);
__m128i t3 = _mm_clmulepi64_si128(h4,d4,0x00);
_mm_prefetch(ab + i + 4,_MM_HINT_T0);
_mm_prefetch(ab + i + 6,_MM_HINT_T0);
__m128i t0 = _mm_clmulepi64_si128(_k.ni.hhhh,d1,0x00);
__m128i t1 = _mm_clmulepi64_si128(_k.ni.hhh,d2,0x00);
__m128i t2 = _mm_clmulepi64_si128(_k.ni.hh,d3,0x00);
__m128i t3 = _mm_clmulepi64_si128(_k.ni.h,d4,0x00);
__m128i t8 = _mm_xor_si128(t0,t1);
t8 = _mm_xor_si128(t8,t2);
t8 = _mm_xor_si128(t8,t3);
__m128i t4 = _mm_clmulepi64_si128(h1,d1,0x11);
__m128i t5 = _mm_clmulepi64_si128(h2,d2,0x11);
__m128i t6 = _mm_clmulepi64_si128(h3,d3,0x11);
__m128i t7 = _mm_clmulepi64_si128(h4,d4,0x11);
__m128i t4 = _mm_clmulepi64_si128(_k.ni.hhhh,d1,0x11);
__m128i t5 = _mm_clmulepi64_si128(_k.ni.hhh,d2,0x11);
__m128i t6 = _mm_clmulepi64_si128(_k.ni.hh,d3,0x11);
__m128i t7 = _mm_clmulepi64_si128(_k.ni.h,d4,0x11);
__m128i t9 = _mm_xor_si128(t4,t5);
t9 = _mm_xor_si128(t9,t6);
t9 = _mm_xor_si128(t9,t7);
t0 = _mm_shuffle_epi32(h1,78);
t0 = _mm_shuffle_epi32(_k.ni.hhhh,78);
t4 = _mm_shuffle_epi32(d1,78);
t0 = _mm_xor_si128(t0,h1);
t0 = _mm_xor_si128(t0,_k.ni.hhhh);
t4 = _mm_xor_si128(t4,d1);
t1 = _mm_shuffle_epi32(h2,78);
t1 = _mm_shuffle_epi32(_k.ni.hhh,78);
t5 = _mm_shuffle_epi32(d2,78);
t1 = _mm_xor_si128(t1,h2);
t1 = _mm_xor_si128(t1,_k.ni.hhh);
t5 = _mm_xor_si128(t5,d2);
t2 = _mm_shuffle_epi32(h3,78);
t2 = _mm_shuffle_epi32(_k.ni.hh,78);
t6 = _mm_shuffle_epi32(d3,78);
t2 = _mm_xor_si128(t2,h3);
t2 = _mm_xor_si128(t2,_k.ni.hh);
t6 = _mm_xor_si128(t6,d3);
t3 = _mm_shuffle_epi32(h4,78);
t3 = _mm_shuffle_epi32(_k.ni.h,78);
t7 = _mm_shuffle_epi32(d4,78);
t3 = _mm_xor_si128(t3,h4);
t3 = _mm_xor_si128(t3,_k.ni.h);
t7 = _mm_xor_si128(t7,d4);
t0 = _mm_clmulepi64_si128(t0,t4,0x00);
t1 = _mm_clmulepi64_si128(t1,t5,0x00);
@ -763,17 +753,17 @@ private:
t6 = _mm_xor_si128(t6,t3);
y = _mm_shuffle_epi8(t6,shuf);
}
#undef h1
for (;i<blocks;++i)
y = _ghash_aesni(shuf,h4,y,_mm_loadu_si128(ab + i));
y = _ghash_aesni(shuf,_k.ni.h,y,_mm_loadu_si128(ab + i));
if (rem) {
__m128i last = _mm_setzero_si128();
memcpy(&last,ab + blocks,rem);
y = _ghash_aesni(shuf,h4,y,last);
y = _ghash_aesni(shuf,_k.ni.h,y,last);
}
y = _ghash_aesni(shuf,h4,y,_mm_set_epi64((__m64)0LL,(__m64)Utils::hton((uint64_t)len * (uint64_t)8)));
y = _ghash_aesni(shuf,_k.ni.h,y,_mm_set_epi64((__m64)0LL,(__m64)Utils::hton((uint64_t)len * (uint64_t)8)));
__m128i t = _mm_xor_si128(_mm_set_epi32(0x01000000,(int)*((const uint32_t *)(iv+8)),(int)*((const uint32_t *)(iv+4)),(int)*((const uint32_t *)(iv))),_k.ni.k[0]);
t = _mm_aesenc_si128(t,_k.ni.k[1]);