mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-11-15 15:29:53 +00:00
ggml : SIMD ggml_ssm_scan for Mamba-2
* ggml : improve ggml_mul speed when masking recurrent states
This commit is contained in:
parent
1f0fea70fb
commit
dceff23fae
@ -10207,7 +10207,37 @@ static void ggml_compute_forward_mul_f32(
|
|||||||
GGML_ASSERT( nb0 == sizeof(float));
|
GGML_ASSERT( nb0 == sizeof(float));
|
||||||
GGML_ASSERT(nb00 == sizeof(float));
|
GGML_ASSERT(nb00 == sizeof(float));
|
||||||
|
|
||||||
if (nb10 == sizeof(float)) {
|
if (ne00 > 1 && ne10 == 1) {
|
||||||
|
// fast broadcast path
|
||||||
|
for (int64_t ir = ith; ir < nr; ir += nth) {
|
||||||
|
// src0 and dst are same shape => same indices
|
||||||
|
const int64_t i03 = ir/(ne02*ne01);
|
||||||
|
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
|
||||||
|
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
|
||||||
|
|
||||||
|
const int64_t i13 = i03 % ne13;
|
||||||
|
const int64_t i12 = i02 % ne12;
|
||||||
|
const int64_t i11 = i01 % ne11;
|
||||||
|
|
||||||
|
float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
|
||||||
|
|
||||||
|
const float scale = src1_ptr[0];
|
||||||
|
|
||||||
|
if (scale == 0.0f) {
|
||||||
|
// NOTE: this also sets NANs to zero, which is not compliant with IEEE754,
|
||||||
|
// but it is useful when resetting the state of recurrent models.
|
||||||
|
memset((char *)dst->data + ir*nb1, 0, nb1);
|
||||||
|
} else {
|
||||||
|
if (dst->data != src0->data) {
|
||||||
|
// src0 is same shape as dst => same indices
|
||||||
|
memcpy((char *)dst->data + ir*nb1, (char *)src0->data + ir*nb01, ne0 * sizeof(float));
|
||||||
|
}
|
||||||
|
if (scale != 1.0f) {
|
||||||
|
ggml_vec_scale_f32(ne0, (float *) ((char *) dst->data + ir*nb1), scale);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if (nb10 == sizeof(float)) {
|
||||||
for (int64_t ir = ith; ir < nr; ir += nth) {
|
for (int64_t ir = ith; ir < nr; ir += nth) {
|
||||||
// src0 and dst are same shape => same indices
|
// src0 and dst are same shape => same indices
|
||||||
const int64_t i03 = ir/(ne02*ne01);
|
const int64_t i03 = ir/(ne02*ne01);
|
||||||
@ -15919,23 +15949,56 @@ static void ggml_compute_forward_ssm_scan_f32(
|
|||||||
const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
|
const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
|
||||||
const float dA = expf(dt_soft_plus * A[h]);
|
const float dA = expf(dt_soft_plus * A[h]);
|
||||||
|
|
||||||
// TODO: SIMD implementation
|
|
||||||
// dim
|
// dim
|
||||||
for (int i1 = 0; i1 < nr; ++i1) {
|
for (int i1 = 0; i1 < nr; ++i1) {
|
||||||
const int i = i1 + h*nr;
|
const int ii = i1 + h*nr;
|
||||||
const float x_dt = x[i] * dt_soft_plus;
|
const float x_dt = x[ii] * dt_soft_plus;
|
||||||
float sumf = 0.0f;
|
float sumf = 0.0f;
|
||||||
|
#if defined(GGML_SIMD)
|
||||||
|
const int np = (nc & ~(GGML_F32_STEP - 1));
|
||||||
|
|
||||||
|
GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
|
||||||
|
|
||||||
|
GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);
|
||||||
|
GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);
|
||||||
|
|
||||||
|
GGML_F32_VEC ax[GGML_F32_ARR];
|
||||||
|
GGML_F32_VEC ay[GGML_F32_ARR];
|
||||||
|
GGML_F32_VEC az[GGML_F32_ARR];
|
||||||
|
|
||||||
|
for (int i = 0; i < np; i += GGML_F32_STEP) {
|
||||||
|
for (int j = 0; j < GGML_F32_ARR; j++) {
|
||||||
|
ax[j] = GGML_F32_VEC_LOAD(s0 + i + j*GGML_F32_EPR + ii*nc);
|
||||||
|
ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + (h & (ng - 1))*nc);
|
||||||
|
az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + (h & (ng - 1))*nc);
|
||||||
|
|
||||||
|
ax[j] = GGML_F32_VEC_MUL(ax[j], adA);
|
||||||
|
ay[j] = GGML_F32_VEC_MUL(ay[j], axdt);
|
||||||
|
|
||||||
|
ax[j] = GGML_F32_VEC_ADD(ax[j], ay[j]);
|
||||||
|
|
||||||
|
sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], az[j]);
|
||||||
|
|
||||||
|
GGML_F32_VEC_STORE(s + i + j*GGML_F32_EPR + ii*nc, ax[j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// reduce sum0..sum3 to sum0
|
||||||
|
GGML_F32_VEC_REDUCE(sumf, sum);
|
||||||
|
#else
|
||||||
|
const int np = 0;
|
||||||
|
#endif
|
||||||
// d_state
|
// d_state
|
||||||
for (int i0 = 0; i0 < nc; ++i0) {
|
for (int i0 = np; i0 < nc; ++i0) {
|
||||||
const int ii = i0 + i*nc;
|
const int i = i0 + ii*nc;
|
||||||
const int ig = i0 + (h & (ng - 1))*nc;
|
const int ig = i0 + (h & (ng - 1))*nc;
|
||||||
// state = prev_state * dA + dB * x
|
// state = prev_state * dA + dB * x
|
||||||
const float state = (s0[ii] * dA) + (B[ig] * x_dt);
|
const float state = (s0[i] * dA) + (B[ig] * x_dt);
|
||||||
// y = rowwise_dotprod(state, C)
|
// y = rowwise_dotprod(state, C)
|
||||||
sumf += state * C[ig];
|
sumf += state * C[ig];
|
||||||
s[ii] = state;
|
s[i] = state;
|
||||||
}
|
}
|
||||||
y[i] = sumf + x[i] * D[h];
|
y[ii] = sumf + x[ii] * D[h];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -15948,20 +16011,22 @@ static void ggml_compute_forward_ssm_scan_f32(
|
|||||||
|
|
||||||
// dim
|
// dim
|
||||||
for (int i1 = 0; i1 < nr; ++i1) {
|
for (int i1 = 0; i1 < nr; ++i1) {
|
||||||
const int i = i1 + h*nr;
|
const int ii = i1 + h*nr;
|
||||||
const float x_dt = x[i] * dt_soft_plus;
|
const float x_dt = x[ii] * dt_soft_plus;
|
||||||
float sumf = 0.0f;
|
float sumf = 0.0f;
|
||||||
|
// NOTE: can't really use GGML_SIMD here because d_state is usually 16
|
||||||
|
// and also because expf is used within the loop.
|
||||||
// d_state
|
// d_state
|
||||||
for (int i0 = 0; i0 < nc; ++i0) {
|
for (int i0 = 0; i0 < nc; ++i0) {
|
||||||
const int ii = i0 + i*nc;
|
const int i = i0 + ii*nc;
|
||||||
const int ig = i0 + (h & (ng - 1))*nc;
|
const int ig = i0 + (h & (ng - 1))*nc;
|
||||||
// state = prev_state * dA + dB * x
|
// state = prev_state * dA + dB * x
|
||||||
const float state = (s0[ii] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt);
|
const float state = (s0[i] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt);
|
||||||
// y = rowwise_dotprod(state, C)
|
// y = rowwise_dotprod(state, C)
|
||||||
sumf += state * C[ig];
|
sumf += state * C[ig];
|
||||||
s[ii] = state;
|
s[i] = state;
|
||||||
}
|
}
|
||||||
y[i] = sumf + x[i] * D[h];
|
y[ii] = sumf + x[ii] * D[h];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user