diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index 5ebdf96d1..81c814f9b 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -3823,13 +3823,13 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r svfloat32_t sumv1 = svdup_n_f32(0.0f); const int vector_length = ggml_sve_cnt_b*8; - // VLA Implementation using switch case + // VLA Implementation using switch case switch(vector_length) - { + { case 128: // predicate for activating higher lanes for 4 float32 elements const svbool_t pg =svptrue_pat_b32(SV_VL4); - + for (; ib + 1 < nb; ib += 2) { const block_q4_0 * restrict x0 = &x[ib + 0]; const block_q4_0 * restrict x1 = &x[ib + 1]; @@ -3866,15 +3866,15 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1)); - break; + break; - case 256: + case 256: // predicate for activating higher lanes for 16 int8 elements const svbool_t ptrueh_256 = svptrue_pat_b8(SV_VL16); // predicate for activating lower lanes for 16 int8 elements const svbool_t ptruel_256 = svnot_b_z(svptrue_b8(), ptrueh_256); - + for (; ib + 1 < nb; ib += 2) { const block_q4_0 * restrict x0 = &x[ib + 0]; const block_q4_0 * restrict x1 = &x[ib + 1]; @@ -3904,14 +3904,14 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1)); - break; + break; case 512: // predicate for activating higher lanes for 32 int8 elements const svbool_t ptrue = svptrue_pat_b8(SV_VL32); - // predicate for activating higher lanes for 16 int8 elements + // predicate for activating higher lanes for 16 int8 elements const svbool_t ptrueh = svptrue_pat_b8(SV_VL16); - // predicate for activating lower lanes for 16 int8 elements from first 32 int8 activated lanes + // predicate for activating lower lanes for 16 int8 elements from first 32 int8 activated lanes const svbool_t ptruel = svnot_b_z(ptrue, ptrueh); for (; ib < nb; ib += 2) { @@ -3942,9 +3942,9 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r } sumf = svaddv_f32(ptrue, svadd_f32_x(ptrue, sumv0, sumv1)); break; - - default: - break; + + default: + break; } @@ -5403,20 +5403,20 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r const int vector_length = ggml_sve_cnt_b*8; - //VLA Implemenation for SVE + //VLA Implemenation for SVE switch(vector_length) - { + { case 128: // predicate for activating lanes for 16 Int8 elements svbool_t pg1 =svptrue_pat_b8(SV_VL16); svbool_t pg =svptrue_pat_b32(SV_VL4); - for (; ib + 1 < nb; ib += 2) { + for (; ib + 1 < nb; ib += 2) { const block_q8_0 * restrict x0 = &x[ib + 0]; const block_q8_0 * restrict x1 = &x[ib + 1]; const block_q8_0 * restrict y0 = &y[ib + 0]; const block_q8_0 * restrict y1 = &y[ib + 1]; - + // load x const svint8_t qx0_0 = svld1_s8(pg1, x0->qs); const svint8_t qx0_1 = svld1_s8(pg1, x0->qs+16); @@ -5434,11 +5434,11 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r } - sumf = svaddv_f32(pg, svadd_f32_x(pg, sumv0, sumv1)); - break; + sumf = svaddv_f32(pg, svadd_f32_x(pg, sumv0, sumv1)); + break; case 256: - //printf("sve256"); + //printf("sve256"); for (; ib + 1 < nb; ib += 2) { const block_q8_0 * restrict x0 = &x[ib + 0]; const block_q8_0 * restrict x1 = &x[ib + 1]; @@ -5452,24 +5452,24 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r // load y const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs); const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs); - + sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); - - } - sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1)); - break; + + } + sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1)); + break; case 512: - // predicate for activating high 256 bit + // predicate for activating high 256 bit const svbool_t ptrueh = svptrue_pat_b8(SV_VL32); - // predicate for activating low 256 bit + // predicate for activating low 256 bit const svbool_t ptruel = svnot_b_z(svptrue_b8(), ptrueh); - + // predicate for activating high lanes for 8 float32 elements svbool_t asd = svptrue_pat_b32(SV_VL8); // predicate for activating low lanes for 8 float32 elements - svbool_t dsa = svnot_b_z(svptrue_b32(), asd); + svbool_t dsa = svnot_b_z(svptrue_b32(), asd); svfloat32_t sumv00 = svdup_n_f32(0.0f); @@ -5480,9 +5480,9 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r const block_q8_0 * restrict y0 = &y[ib + 0]; const block_q8_0 * restrict y1 = &y[ib + 1]; - //load 32 int8_t in first half of vector and put another 32 int8_t in second vector lower bits + //load 32 int8_t in first half of vector and put another 32 int8_t in second vector lower bits // and add them to make one 64 element vector - // load x + // load x const svint8_t qx_32 = svld1_s8(ptrueh,x0->qs); svint8_t qx_64 = svld1_s8(ptruel,x0->qs+2); qx_64 = svadd_s8_x(svptrue_b8(), qx_32, qx_64); @@ -5491,11 +5491,11 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r const svint8_t qy_32 = svld1_s8(ptrueh,y0->qs); svint8_t qy_64 = svld1_s8(ptruel,y0->qs+2); qy_64 = svadd_s8_x(svptrue_b8(), qy_32, qy_64); - - // scale creation + + // scale creation float32_t deq1= GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d); float32_t deq2 = GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d); - + // duplicate deq1 in first half of vector and deq2 in second half of vector svfloat32_t temp = svdup_f32_m(svdup_f32_z(asd, deq1), dsa,deq2); @@ -5503,16 +5503,16 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r svfloat32_t sumvt = svdup_n_f32(0.0f); sumvt = svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx_64, qy_64)); - + sumv00 = svmla_f32_m(svptrue_b32(),sumv00,sumvt,temp); } - sumf = svaddv_f32(svptrue_b32(), sumv00); + sumf = svaddv_f32(svptrue_b32(), sumv00); break; - default: - break; + default: + break; } #elif defined(__ARM_NEON)