ggml : avoid directly using vmlal_high_s8, for 32-bit ARM compat

The compiler seems smart enough to use the same instruction
even when using vget_high_s8 instead.
This commit is contained in:
Francis Couture-Harpin 2024-08-01 01:11:30 -04:00
parent 5417089aeb
commit 45719a2472

View File

@ -5975,25 +5975,25 @@ void ggml_vec_dot_tq1_0_q8_K(int n, float * restrict s, size_t bs, const void *
const int8x16_t qy9 = vld1q_s8(y[i].qs + 144);
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0));
sumi1 = vmlal_high_s8(sumi1, sqx0, qy0);
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0));
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1));
sumi1 = vmlal_high_s8(sumi1, sqx1, qy1);
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx1), vget_high_s8(qy1));
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx2), vget_low_s8(qy2));
sumi1 = vmlal_high_s8(sumi1, sqx2, qy2);
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx2), vget_high_s8(qy2));
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx3), vget_low_s8(qy3));
sumi1 = vmlal_high_s8(sumi1, sqx3, qy3);
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx3), vget_high_s8(qy3));
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx4), vget_low_s8(qy4));
sumi1 = vmlal_high_s8(sumi1, sqx4, qy4);
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx4), vget_high_s8(qy4));
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx5), vget_low_s8(qy5));
sumi1 = vmlal_high_s8(sumi1, sqx5, qy5);
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx5), vget_high_s8(qy5));
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx6), vget_low_s8(qy6));
sumi1 = vmlal_high_s8(sumi1, sqx6, qy6);
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx6), vget_high_s8(qy6));
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx7), vget_low_s8(qy7));
sumi1 = vmlal_high_s8(sumi1, sqx7, qy7);
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx7), vget_high_s8(qy7));
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx8), vget_low_s8(qy8));
sumi1 = vmlal_high_s8(sumi1, sqx8, qy8);
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx8), vget_high_s8(qy8));
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx9), vget_low_s8(qy9));
sumi1 = vmlal_high_s8(sumi1, sqx9, qy9);
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx9), vget_high_s8(qy9));
}
// last 16 bytes of 5-element, along with the 4 bytes of 4 elements
@ -6024,17 +6024,17 @@ void ggml_vec_dot_tq1_0_q8_K(int n, float * restrict s, size_t bs, const void *
const int8x16_t qy5 = vld1q_s8(y[i].qs + 240);
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0));
sumi1 = vmlal_high_s8(sumi1, sqx0, qy0);
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0));
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1));
sumi1 = vmlal_high_s8(sumi1, sqx1, qy1);
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx1), vget_high_s8(qy1));
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx2), vget_low_s8(qy2));
sumi1 = vmlal_high_s8(sumi1, sqx2, qy2);
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx2), vget_high_s8(qy2));
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx3), vget_low_s8(qy3));
sumi1 = vmlal_high_s8(sumi1, sqx3, qy3);
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx3), vget_high_s8(qy3));
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx4), vget_low_s8(qy4));
sumi1 = vmlal_high_s8(sumi1, sqx4, qy4);
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx4), vget_high_s8(qy4));
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx5), vget_low_s8(qy5));
sumi1 = vmlal_high_s8(sumi1, sqx5, qy5);
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx5), vget_high_s8(qy5));
}
const int16x8_t ysum0 = vld1q_s16(y[i].bsums);
@ -6254,21 +6254,21 @@ void ggml_vec_dot_tq2_0_q8_K(int n, float * restrict s, size_t bs, const void *
const int8x16_t qy7 = vld1q_s8(y[i].qs + j*4 + 112);
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0));
sumi1 = vmlal_high_s8(sumi1, sqx0, qy0);
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0));
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1));
sumi1 = vmlal_high_s8(sumi1, sqx1, qy1);
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx1), vget_high_s8(qy1));
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx2), vget_low_s8(qy2));
sumi1 = vmlal_high_s8(sumi1, sqx2, qy2);
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx2), vget_high_s8(qy2));
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx3), vget_low_s8(qy3));
sumi1 = vmlal_high_s8(sumi1, sqx3, qy3);
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx3), vget_high_s8(qy3));
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx4), vget_low_s8(qy4));
sumi1 = vmlal_high_s8(sumi1, sqx4, qy4);
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx4), vget_high_s8(qy4));
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx5), vget_low_s8(qy5));
sumi1 = vmlal_high_s8(sumi1, sqx5, qy5);
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx5), vget_high_s8(qy5));
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx6), vget_low_s8(qy6));
sumi1 = vmlal_high_s8(sumi1, sqx6, qy6);
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx6), vget_high_s8(qy6));
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx7), vget_low_s8(qy7));
sumi1 = vmlal_high_s8(sumi1, sqx7, qy7);
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx7), vget_high_s8(qy7));
}
const int16x8_t ysum0 = vld1q_s16(y[i].bsums);