ggml : SIMD ggml_ssm_scan for Mamba-2

* ggml : improve ggml_mul speed when masking recurrent states
This commit is contained in:
Francis Couture-Harpin 2024-08-18 21:49:39 -04:00
parent 1f0fea70fb
commit dceff23fae

View File

@ -10207,7 +10207,37 @@ static void ggml_compute_forward_mul_f32(
GGML_ASSERT( nb0 == sizeof(float));
GGML_ASSERT(nb00 == sizeof(float));
if (nb10 == sizeof(float)) {
if (ne00 > 1 && ne10 == 1) {
// fast broadcast path
for (int64_t ir = ith; ir < nr; ir += nth) {
// src0 and dst are same shape => same indices
const int64_t i03 = ir/(ne02*ne01);
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
const int64_t i13 = i03 % ne13;
const int64_t i12 = i02 % ne12;
const int64_t i11 = i01 % ne11;
float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
const float scale = src1_ptr[0];
if (scale == 0.0f) {
// NOTE: this also sets NANs to zero, which is not compliant with IEEE754,
// but it is useful when resetting the state of recurrent models.
memset((char *)dst->data + ir*nb1, 0, nb1);
} else {
if (dst->data != src0->data) {
// src0 is same shape as dst => same indices
memcpy((char *)dst->data + ir*nb1, (char *)src0->data + ir*nb01, ne0 * sizeof(float));
}
if (scale != 1.0f) {
ggml_vec_scale_f32(ne0, (float *) ((char *) dst->data + ir*nb1), scale);
}
}
}
} else if (nb10 == sizeof(float)) {
for (int64_t ir = ith; ir < nr; ir += nth) {
// src0 and dst are same shape => same indices
const int64_t i03 = ir/(ne02*ne01);
@ -15919,23 +15949,56 @@ static void ggml_compute_forward_ssm_scan_f32(
const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
const float dA = expf(dt_soft_plus * A[h]);
// TODO: SIMD implementation
// dim
for (int i1 = 0; i1 < nr; ++i1) {
const int i = i1 + h*nr;
const float x_dt = x[i] * dt_soft_plus;
const int ii = i1 + h*nr;
const float x_dt = x[ii] * dt_soft_plus;
float sumf = 0.0f;
#if defined(GGML_SIMD)
const int np = (nc & ~(GGML_F32_STEP - 1));
GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);
GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);
GGML_F32_VEC ax[GGML_F32_ARR];
GGML_F32_VEC ay[GGML_F32_ARR];
GGML_F32_VEC az[GGML_F32_ARR];
for (int i = 0; i < np; i += GGML_F32_STEP) {
for (int j = 0; j < GGML_F32_ARR; j++) {
ax[j] = GGML_F32_VEC_LOAD(s0 + i + j*GGML_F32_EPR + ii*nc);
ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + (h & (ng - 1))*nc);
az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + (h & (ng - 1))*nc);
ax[j] = GGML_F32_VEC_MUL(ax[j], adA);
ay[j] = GGML_F32_VEC_MUL(ay[j], axdt);
ax[j] = GGML_F32_VEC_ADD(ax[j], ay[j]);
sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], az[j]);
GGML_F32_VEC_STORE(s + i + j*GGML_F32_EPR + ii*nc, ax[j]);
}
}
// reduce sum0..sum3 to sum0
GGML_F32_VEC_REDUCE(sumf, sum);
#else
const int np = 0;
#endif
// d_state
for (int i0 = 0; i0 < nc; ++i0) {
const int ii = i0 + i*nc;
for (int i0 = np; i0 < nc; ++i0) {
const int i = i0 + ii*nc;
const int ig = i0 + (h & (ng - 1))*nc;
// state = prev_state * dA + dB * x
const float state = (s0[ii] * dA) + (B[ig] * x_dt);
const float state = (s0[i] * dA) + (B[ig] * x_dt);
// y = rowwise_dotprod(state, C)
sumf += state * C[ig];
s[ii] = state;
s[i] = state;
}
y[i] = sumf + x[i] * D[h];
y[ii] = sumf + x[ii] * D[h];
}
}
} else {
@ -15948,20 +16011,22 @@ static void ggml_compute_forward_ssm_scan_f32(
// dim
for (int i1 = 0; i1 < nr; ++i1) {
const int i = i1 + h*nr;
const float x_dt = x[i] * dt_soft_plus;
const int ii = i1 + h*nr;
const float x_dt = x[ii] * dt_soft_plus;
float sumf = 0.0f;
// NOTE: can't really use GGML_SIMD here because d_state is usually 16
// and also because expf is used within the loop.
// d_state
for (int i0 = 0; i0 < nc; ++i0) {
const int ii = i0 + i*nc;
const int i = i0 + ii*nc;
const int ig = i0 + (h & (ng - 1))*nc;
// state = prev_state * dA + dB * x
const float state = (s0[ii] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt);
const float state = (s0[i] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt);
// y = rowwise_dotprod(state, C)
sumf += state * C[ig];
s[ii] = state;
s[i] = state;
}
y[i] = sumf + x[i] * D[h];
y[ii] = sumf + x[ii] * D[h];
}
}
}