mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-11-15 07:19:53 +00:00
ggml : remove unused fast broadcast path in GGML_MUL
This was initially added because states were masked with ggml_mul, but this is no longer done and so this "optimisation" is no longer necessary, or at least not worth the additional code complexity.
This commit is contained in:
parent
038d958333
commit
805512a73b
@ -10173,37 +10173,7 @@ static void ggml_compute_forward_mul_f32(
|
|||||||
GGML_ASSERT( nb0 == sizeof(float));
|
GGML_ASSERT( nb0 == sizeof(float));
|
||||||
GGML_ASSERT(nb00 == sizeof(float));
|
GGML_ASSERT(nb00 == sizeof(float));
|
||||||
|
|
||||||
if (ne00 > 1 && ne10 == 1) {
|
if (nb10 == sizeof(float)) {
|
||||||
// fast broadcast path
|
|
||||||
for (int64_t ir = ith; ir < nr; ir += nth) {
|
|
||||||
// src0 and dst are same shape => same indices
|
|
||||||
const int64_t i03 = ir/(ne02*ne01);
|
|
||||||
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
|
|
||||||
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
|
|
||||||
|
|
||||||
const int64_t i13 = i03 % ne13;
|
|
||||||
const int64_t i12 = i02 % ne12;
|
|
||||||
const int64_t i11 = i01 % ne11;
|
|
||||||
|
|
||||||
float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
|
|
||||||
|
|
||||||
const float scale = src1_ptr[0];
|
|
||||||
|
|
||||||
if (scale == 0.0f) {
|
|
||||||
// NOTE: this also sets NANs to zero, which is not compliant with IEEE754,
|
|
||||||
// but it is useful when resetting the state of recurrent models.
|
|
||||||
memset((char *) dst->data + ir*nb1, 0, ne0 * sizeof(float));
|
|
||||||
} else {
|
|
||||||
if (dst->data != src0->data) {
|
|
||||||
// src0 is same shape as dst => same indices
|
|
||||||
memcpy((char *) dst->data + ir*nb1, (char *) src0->data + ir*nb01, ne0 * sizeof(float));
|
|
||||||
}
|
|
||||||
if (scale != 1.0f) {
|
|
||||||
ggml_vec_scale_f32(ne0, (float *) ((char *) dst->data + ir*nb1), scale);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if (nb10 == sizeof(float)) {
|
|
||||||
for (int64_t ir = ith; ir < nr; ir += nth) {
|
for (int64_t ir = ith; ir < nr; ir += nth) {
|
||||||
// src0 and dst are same shape => same indices
|
// src0 and dst are same shape => same indices
|
||||||
const int64_t i03 = ir/(ne02*ne01);
|
const int64_t i03 = ir/(ne02*ne01);
|
||||||
|
Loading…
Reference in New Issue
Block a user