From dceff23faec99945d3161d24ea209a0c433546db Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sun, 18 Aug 2024 21:49:39 -0400 Subject: [PATCH] ggml : SIMD ggml_ssm_scan for Mamba-2 * ggml : improve ggml_mul speed when masking recurrent states --- ggml/src/ggml.c | 95 +++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 80 insertions(+), 15 deletions(-) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 666820908..f8e708088 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -10207,7 +10207,37 @@ static void ggml_compute_forward_mul_f32( GGML_ASSERT( nb0 == sizeof(float)); GGML_ASSERT(nb00 == sizeof(float)); - if (nb10 == sizeof(float)) { + if (ne00 > 1 && ne10 == 1) { + // fast broadcast path + for (int64_t ir = ith; ir < nr; ir += nth) { + // src0 and dst are same shape => same indices + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + + float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11); + + const float scale = src1_ptr[0]; + + if (scale == 0.0f) { + // NOTE: this also sets NANs to zero, which is not compliant with IEEE754, + // but it is useful when resetting the state of recurrent models. + memset((char *)dst->data + ir*nb1, 0, nb1); + } else { + if (dst->data != src0->data) { + // src0 is same shape as dst => same indices + memcpy((char *)dst->data + ir*nb1, (char *)src0->data + ir*nb01, ne0 * sizeof(float)); + } + if (scale != 1.0f) { + ggml_vec_scale_f32(ne0, (float *) ((char *) dst->data + ir*nb1), scale); + } + } + } + } else if (nb10 == sizeof(float)) { for (int64_t ir = ith; ir < nr; ir += nth) { // src0 and dst are same shape => same indices const int64_t i03 = ir/(ne02*ne01); @@ -15919,23 +15949,56 @@ static void ggml_compute_forward_ssm_scan_f32( const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h]; const float dA = expf(dt_soft_plus * A[h]); - // TODO: SIMD implementation // dim for (int i1 = 0; i1 < nr; ++i1) { - const int i = i1 + h*nr; - const float x_dt = x[i] * dt_soft_plus; + const int ii = i1 + h*nr; + const float x_dt = x[ii] * dt_soft_plus; float sumf = 0.0f; +#if defined(GGML_SIMD) + const int np = (nc & ~(GGML_F32_STEP - 1)); + + GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO }; + + GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA); + GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt); + + GGML_F32_VEC ax[GGML_F32_ARR]; + GGML_F32_VEC ay[GGML_F32_ARR]; + GGML_F32_VEC az[GGML_F32_ARR]; + + for (int i = 0; i < np; i += GGML_F32_STEP) { + for (int j = 0; j < GGML_F32_ARR; j++) { + ax[j] = GGML_F32_VEC_LOAD(s0 + i + j*GGML_F32_EPR + ii*nc); + ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + (h & (ng - 1))*nc); + az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + (h & (ng - 1))*nc); + + ax[j] = GGML_F32_VEC_MUL(ax[j], adA); + ay[j] = GGML_F32_VEC_MUL(ay[j], axdt); + + ax[j] = GGML_F32_VEC_ADD(ax[j], ay[j]); + + sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], az[j]); + + GGML_F32_VEC_STORE(s + i + j*GGML_F32_EPR + ii*nc, ax[j]); + } + } + + // reduce sum0..sum3 to sum0 + GGML_F32_VEC_REDUCE(sumf, sum); +#else + const int np = 0; +#endif // d_state - for (int i0 = 0; i0 < nc; ++i0) { - const int ii = i0 + i*nc; + for (int i0 = np; i0 < nc; ++i0) { + const int i = i0 + ii*nc; const int ig = i0 + (h & (ng - 1))*nc; // state = prev_state * dA + dB * x - const float state = (s0[ii] * dA) + (B[ig] * x_dt); + const float state = (s0[i] * dA) + (B[ig] * x_dt); // y = rowwise_dotprod(state, C) sumf += state * C[ig]; - s[ii] = state; + s[i] = state; } - y[i] = sumf + x[i] * D[h]; + y[ii] = sumf + x[ii] * D[h]; } } } else { @@ -15948,20 +16011,22 @@ static void ggml_compute_forward_ssm_scan_f32( // dim for (int i1 = 0; i1 < nr; ++i1) { - const int i = i1 + h*nr; - const float x_dt = x[i] * dt_soft_plus; + const int ii = i1 + h*nr; + const float x_dt = x[ii] * dt_soft_plus; float sumf = 0.0f; + // NOTE: can't really use GGML_SIMD here because d_state is usually 16 + // and also because expf is used within the loop. // d_state for (int i0 = 0; i0 < nc; ++i0) { - const int ii = i0 + i*nc; + const int i = i0 + ii*nc; const int ig = i0 + (h & (ng - 1))*nc; // state = prev_state * dA + dB * x - const float state = (s0[ii] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt); + const float state = (s0[i] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt); // y = rowwise_dotprod(state, C) sumf += state * C[ig]; - s[ii] = state; + s[i] = state; } - y[i] = sumf + x[i] * D[h]; + y[ii] = sumf + x[ii] * D[h]; } } }