mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-26 11:24:35 +00:00
parent
eea986f215
commit
76b27d29c2
@ -1813,11 +1813,13 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
|
|||||||
sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
|
sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
|
||||||
l1, r1)), l2, r2)), l3, r3))), scale);
|
l1, r1)), l2, r2)), l3, r3))), scale);
|
||||||
}
|
}
|
||||||
|
|
||||||
float32x4_t sumv1 = vextq_f32 (sumv0, sumv0, 2);
|
float32x4_t sumv1 = vextq_f32 (sumv0, sumv0, 2);
|
||||||
float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
|
float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
|
||||||
|
|
||||||
vst1_f32(s, vget_low_f32 (sumv2));
|
vst1_f32(s, vget_low_f32 (sumv2));
|
||||||
vst1_f32(s + bs, vget_high_f32(sumv2));
|
vst1_f32(s + bs, vget_high_f32(sumv2));
|
||||||
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
@ -7576,14 +7576,6 @@ UseGgmlGemm2:;
|
|||||||
// This is the size of the rest of the dimensions of the result
|
// This is the size of the rest of the dimensions of the result
|
||||||
const int64_t nr1 = ne1 * ne2 * ne3;
|
const int64_t nr1 = ne1 * ne2 * ne3;
|
||||||
|
|
||||||
// dot kernels can handle 1 row and col at a time, but mmla kernels can process 2 rows and cols
|
|
||||||
int64_t num_rows_per_vec_dot = vec_dot_num_rows;
|
|
||||||
// TODO: currently the mmla kernels support only even numbered rows/cols.
|
|
||||||
// this check can be removed once they are extended to support odd numbered rows/cols too
|
|
||||||
if ((nr0 % 2 != 0) || (ne11 % 2 != 0)) {
|
|
||||||
num_rows_per_vec_dot = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Now select a reasonable chunk size.
|
// Now select a reasonable chunk size.
|
||||||
int chunk_size = 16;
|
int chunk_size = 16;
|
||||||
|
|
||||||
@ -7646,6 +7638,15 @@ UseGgmlGemm2:;
|
|||||||
const int64_t ir1_start = dr1 * ith1;
|
const int64_t ir1_start = dr1 * ith1;
|
||||||
const int64_t ir1_end = MIN(ir1_start + dr1, nr1);
|
const int64_t ir1_end = MIN(ir1_start + dr1, nr1);
|
||||||
|
|
||||||
|
// dot kernels can handle 1 row and col at a time, but mmla kernels can process 2 rows and cols
|
||||||
|
int64_t num_rows_per_vec_dot = vec_dot_num_rows;
|
||||||
|
|
||||||
|
// TODO: currently the mmla kernels support only even numbered rows/cols.
|
||||||
|
// this check can be removed once they are extended to support odd numbered rows/cols too
|
||||||
|
if ((nr0 % 2 != 0) || (ne11 % 2 != 0) || ((ir0_end - ir0_start) % 2 != 0) || ((ir1_end - ir1_start) % 2 != 0)) {
|
||||||
|
num_rows_per_vec_dot = 1;
|
||||||
|
}
|
||||||
|
|
||||||
ggml_compute_forward_mul_mat_one_chunk(params, dst, type, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end);
|
ggml_compute_forward_mul_mat_one_chunk(params, dst, type, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end);
|
||||||
|
|
||||||
if (nth >= nchunk0 * nchunk1) {
|
if (nth >= nchunk0 * nchunk1) {
|
||||||
|
Loading…
Reference in New Issue
Block a user