diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index f4692f3c..d8de7531 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -7633,39 +7633,83 @@ static void ggml_compute_forward_ssm_scan_f32( const int ir1 = MIN(ir0 + dr, nr); const int ir = ir1 - ir0; - for (int i3 = 0; i3 < n_s; ++i3) { - for (int i2 = 0; i2 < n_t; ++i2) { - const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s} - const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} - const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s} - const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner} - const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s} - const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s} - float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} - float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s} + #ifdef __ARM_FEATURE_SVE + for (int i3 = 0; i3 < n_s; ++i3) { + for (int i2 = 0; i2 < n_t; ++i2) { + const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s} + const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} + const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s} + const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner} + const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s} + const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s} + float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} + float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s} - // use the output as the source for the next token-wise iterations - if (i2 > 0) { s0 = s; } + // use the output as the source for the next token-wise iterations + if (i2 > 0) { s0 = s; } - // d_inner - for (int i1 = 0; i1 < ir; ++i1) { - // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78 - float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1]; - float x_dt = x[i1] * dt_soft_plus; - float sumf = 0.0f; - // d_state - for (int i0 = 0; i0 < nc; ++i0) { - int i = i0 + i1*nc; - // state = prev_state * dA + dB * x - float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt); - // y = rowwise_dotprod(state, C) - sumf += state * C[i0]; - s[i] = state; + // d_inner + for (int i1 = 0; i1 < ir; ++i1) { + float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1]; + float x_dt = x[i1] * dt_soft_plus; + svfloat32_t vx_dt = GGML_F32_VEC_SET1(x_dt); + svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1(dt_soft_plus); + svfloat32_t r1_vector = GGML_F32_VEC_ZERO; + + for (int64_t k = 0; k < nc; k += svcntw()) { + svfloat32_t vA = GGML_F32_VEC_LOAD(&A[i1*nc + k]); + svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k]); + svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k]); + svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[i1*nc + k]); + + svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA); + t1 = exp_ps_sve(svptrue_b32(), t1); + svfloat32_t t2 = GGML_F32_VEC_MUL(vx_dt, vB); + + vs0 = GGML_F32_VEC_FMA(vs0, t1, t2); + r1_vector = GGML_F32_VEC_ADD(GGML_F32_VEC_MUL(vs0, vC), r1_vector); + + GGML_F32_VEC_STORE(&s[i1*nc + k], vs0); + } + y[i1] = GGML_F32xt_REDUCE_ONE(r1_vector); } - y[i1] = sumf; } } - } + #else + for (int i3 = 0; i3 < n_s; ++i3) { + for (int i2 = 0; i2 < n_t; ++i2) { + const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s} + const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} + const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s} + const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner} + const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s} + const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s} + float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} + float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s} + + // use the output as the source for the next token-wise iterations + if (i2 > 0) { s0 = s; } + + // d_inner + for (int i1 = 0; i1 < ir; ++i1) { + // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78 + float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1]; + float x_dt = x[i1] * dt_soft_plus; + float sumf = 0.0f; + // d_state + for (int i0 = 0; i0 < nc; ++i0) { + int i = i0 + i1*nc; + // state = prev_state * dA + dB * x + float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt); + // y = rowwise_dotprod(state, C) + sumf += state * C[i0]; + s[i] = state; + } + y[i1] = sumf; + } + } + } + #endif } void ggml_compute_forward_ssm_scan( diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h index 163a2743..09dbade2 100644 --- a/ggml/src/ggml-cpu/vec.h +++ b/ggml/src/ggml-cpu/vec.h @@ -647,6 +647,42 @@ inline static ggml_fp16_t ggml_silu_f16(ggml_fp16_t x) { #error "ref: https://github.com/ggml-org/llama.cpp/pull/7154#issuecomment-2143844461" #endif +/* Below function was borrowed from the GitHub repository: +https://github.com/openvinotoolkit/openvino/blob/master/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/common.hpp */ +#if defined(__ARM_FEATURE_SVE) && defined(__aarch64__) + inline static svfloat32_t exp_ps_sve(svbool_t pg, svfloat32_t src) { + // Constants + const svfloat32_t log2_e = svdup_n_f32(1.4426950409f); + const svfloat32_t ln2 = svdup_n_f32(0.6931473921f); + const svfloat32_t half_ln2_sq = svdup_n_f32(0.2413862043f); + const svuint32_t not_mask17 = svdup_n_u32(~((1u << 17) - 1)); + const svfloat32_t one = svdup_n_f32(1.0f); + const svfloat32_t inactive1 = svdup_n_f32(0.0f); + const svint32_t inactive2 = svdup_n_s32(0); + + // Algorithm starts here + svfloat32_t t0 = svmul_f32_m(pg, src, log2_e); // y = x * log2(e) + svfloat32_t t1 = svrintm_f32_m(inactive1, pg, t0); // rount to int (float) + svint32_t t2 = svcvt_s32_f32_m(inactive2, pg, t1); // n + + t1 = svsub_f32_m(pg, t0, t1); // a = y - floor(y) + t1 = svadd_f32_m(pg, t1, one); // b = a + 1 + + svuint32_t t3 = svlsr_n_u32_m(pg, svreinterpret_u32_f32(t1), 17); // v = b >> 17 (u32) + svfloat32_t t4 = svexpa_f32(t3); // c = fexpa(v) + t4 = svscale_f32_m(pg, t4, t2); // fexpa(v) * 2^(n) + + // and_(t2.d, t1.d, not_mask17.d) + svfloat32_t t5 = svreinterpret_f32_u32(svand_u32_m(pg, svreinterpret_u32_f32(t1), not_mask17)); + t5 = svsub_f32_m(pg, t1, t5); // z + t0 = svmla_f32_m(pg, ln2, t5, half_ln2_sq); // ln2 + half_ln2_sq * z + t0 = svmla_f32_m(pg, one, t5, t0); // 1 + (ln2 * z) + (half_ln2_sq * z * z) + t0 = svmul_f32_m(pg, t0, t4); // Final result + + return t0; + } +#endif + #if defined(__ARM_NEON) && defined(__aarch64__) // adapted from arm limited optimized routine