mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-28 04:14:35 +00:00
ggml : remove Q4_0 bit shufling (ARM NEON)
This commit is contained in:
parent
fe60904eef
commit
a546dc6d60
466
ggml.c
466
ggml.c
@ -837,348 +837,52 @@ static_assert(sizeof(block_q8_1) == 3*sizeof(float) + QK8_1, "wrong q8_1 block s
|
||||
|
||||
// reference implementation for deterministic creation of model files
|
||||
static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) {
|
||||
assert(QK4_0 / 16 == 0);
|
||||
assert(k % QK4_0 == 0);
|
||||
const int nb = k / QK4_0;
|
||||
|
||||
uint8_t pp[QK4_0/2];
|
||||
const int nb = k / QK4_0;
|
||||
|
||||
for (int i = 0; i < nb; i++) {
|
||||
float amax = 0.0f; // absolute max
|
||||
float max = 0.0f;
|
||||
float max = 0.0f;
|
||||
|
||||
for (int l = 0; l < QK4_0; l++) {
|
||||
const float v = x[i*QK4_0 + l];
|
||||
if (amax < fabsf(v)) {
|
||||
amax = fabsf(v);
|
||||
max = v;
|
||||
max = v;
|
||||
}
|
||||
}
|
||||
|
||||
const float d = max / -8;
|
||||
const float d = max / -8;
|
||||
const float id = d ? 1.0f/d : 0.0f;
|
||||
|
||||
y[i].d = d;
|
||||
|
||||
for (int l = 0; l < QK4_0; l += 2) {
|
||||
const float v0 = x[i*QK4_0 + l + 0]*id;
|
||||
const float v1 = x[i*QK4_0 + l + 1]*id;
|
||||
uint64_t qs[QK4_0 / 16] = {0};
|
||||
|
||||
const uint8_t vi0 = MIN(15, (int8_t)roundf(v0) + 8);
|
||||
const uint8_t vi1 = MIN(15, (int8_t)roundf(v1) + 8);
|
||||
// pack first half of weights into low nibbles and second half into high nibbles
|
||||
for (int l = 0; l < QK4_0/2; ++l) {
|
||||
const float v0 = x[i*QK4_0 + 0 + l]*id;
|
||||
const float v1 = x[i*QK4_0 + QK4_0/2 + l]*id;
|
||||
|
||||
assert(vi0 < 16);
|
||||
assert(vi1 < 16);
|
||||
const uint64_t vi0 = MIN(15, (int8_t)(v0 + 8.5f));
|
||||
const uint64_t vi1 = MIN(15, (int8_t)(v1 + 8.5f));
|
||||
|
||||
pp[l/2] = vi0 | (vi1 << 4);
|
||||
qs[l/8] |= vi0 << (8*(l & 7));
|
||||
qs[l/8] |= vi1 << (8*(l & 7) + 4);
|
||||
}
|
||||
|
||||
memcpy(y[i].qs, pp, sizeof(pp));
|
||||
memcpy(y[i].qs, qs, sizeof(qs));
|
||||
}
|
||||
}
|
||||
|
||||
static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int k) {
|
||||
assert(k % QK4_0 == 0);
|
||||
const int nb = k / QK4_0;
|
||||
|
||||
block_q4_0 * restrict y = vy;
|
||||
|
||||
#if defined(__POWER9_VECTOR__)
|
||||
const vector float v85 = vec_splats(8.5f);
|
||||
const vector signed int v15 = vec_splats(15);
|
||||
for (int i = 0; i < nb; i++) {
|
||||
float max = 0.0f;
|
||||
float min = 0.0f;
|
||||
|
||||
vector float asrcv [8];
|
||||
vector float srcv [8];
|
||||
vector float maxv[8];
|
||||
vector float minv[8];
|
||||
|
||||
for (int l = 0; l < 8; l++) srcv[l] = *(vector float *)(x + i*32 + 4*l);
|
||||
//for (int l = 0; l < 8; l++) asrcv[l] = vec_abs(srcv[l]);
|
||||
|
||||
for (int l = 0; l < 4; l++) maxv[2*l] = vec_max(asrcv[2*l], asrcv[2*l+1]);
|
||||
//for (int l = 0; l < 2; l++) maxv[4*l] = vec_max(maxv[4*l], maxv[4*l+2]);
|
||||
maxv[0] = vec_max(maxv[0], maxv[2]);
|
||||
maxv[4] = vec_max(maxv[4], maxv[6]);
|
||||
//for (int l = 0; l < 1; l++) maxv[8*l] = vec_max(maxv[8*l], maxv[8*l+4]);
|
||||
maxv[0] = vec_max(maxv[0], maxv[4]);
|
||||
|
||||
for (int l = 0; l < 4; l++) minv[2*l] = vec_min(asrcv[2*l], asrcv[2*l+1]);
|
||||
//for (int l = 0; l < 2; l++) minv[4*l] = vec_min(minv[4*l], minv[4*l+2]);
|
||||
minv[0] = vec_min(minv[0], minv[2]);
|
||||
minv[4] = vec_min(minv[4], minv[6]);
|
||||
//for (int l = 0; l < 1; l++) minv[8*l] = vec_min(minv[8*l], minv[8*l+4]);
|
||||
minv[0] = vec_min(minv[0], minv[4]);
|
||||
|
||||
|
||||
max = MAX(
|
||||
MAX(vec_extract(maxv[0], 0), vec_extract(maxv[0], 1)),
|
||||
MAX(vec_extract(maxv[0], 2), vec_extract(maxv[0], 3)));
|
||||
min = MIN(
|
||||
MIN(vec_extract(minv[0], 0), vec_extract(minv[0], 1)),
|
||||
MIN(vec_extract(minv[0], 2), vec_extract(minv[0], 3)));
|
||||
|
||||
const float magnitude = max >= fabsf(min) ? max : min;
|
||||
const float d = magnitude / -8;
|
||||
const float id = d ? 1.0/d : 0.0;
|
||||
|
||||
y[i].d = d;
|
||||
|
||||
const vector float vid = vec_splats(id);
|
||||
uint8_t * restrict pb = y[i].qs;
|
||||
for (int l = 0; l < 8; l++) {
|
||||
const vector float vf = vec_madd(srcv[l], vid, v85);
|
||||
const vector signed int vi = vec_signed(vf);
|
||||
const vector signed int vc = vec_min(vi, v15);
|
||||
|
||||
pb[2*l + 0] = vec_extract(vc, 0) | (vec_extract(vc, 1) << 4);
|
||||
pb[2*l + 1] = vec_extract(vc, 2) | (vec_extract(vc, 3) << 4);
|
||||
}
|
||||
}
|
||||
#elif __ARM_NEON
|
||||
for (int i = 0; i < nb; i++) {
|
||||
float32x4_t srcv [8];
|
||||
float32x4_t maxv[8];
|
||||
float32x4_t minv[8];
|
||||
|
||||
for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*32 + 4*l);
|
||||
|
||||
for (int l = 0; l < 4; l++) maxv[2*l] = vmaxq_f32(srcv[2*l], srcv[2*l+1]);
|
||||
for (int l = 0; l < 2; l++) maxv[4*l] = vmaxq_f32(maxv[4*l], maxv[4*l+2]);
|
||||
for (int l = 0; l < 1; l++) maxv[8*l] = vmaxq_f32(maxv[8*l], maxv[8*l+4]);
|
||||
|
||||
for (int l = 0; l < 4; l++) minv[2*l] = vminq_f32(srcv[2*l], srcv[2*l+1]);
|
||||
for (int l = 0; l < 2; l++) minv[4*l] = vminq_f32(minv[4*l], minv[4*l+2]);
|
||||
for (int l = 0; l < 1; l++) minv[8*l] = vminq_f32(minv[8*l], minv[8*l+4]);
|
||||
|
||||
const float max = vmaxvq_f32(maxv[0]);
|
||||
const float min = vminvq_f32(minv[0]);
|
||||
|
||||
const float magnitude = max >= fabsf(min) ? max : min;
|
||||
const float d = magnitude / -8;
|
||||
const float id = d ? 1.0f/d : 0.0f;
|
||||
|
||||
y[i].d = d;
|
||||
|
||||
for (int l = 0; l < 8; l++) {
|
||||
const float32x4_t v = vmulq_n_f32(srcv[l], id);
|
||||
const float32x4_t vf = vaddq_f32(v, vdupq_n_f32(8.5f));
|
||||
const int32x4_t vi = vcvtq_s32_f32(vf);
|
||||
const int32x4_t vc = vminq_s32(vi, vdupq_n_s32(15));
|
||||
|
||||
y[i].qs[2*l + 0] = vgetq_lane_s32(vc, 0) | (vgetq_lane_s32(vc, 1) << 4);
|
||||
y[i].qs[2*l + 1] = vgetq_lane_s32(vc, 2) | (vgetq_lane_s32(vc, 3) << 4);
|
||||
}
|
||||
}
|
||||
#elif defined(__AVX2__)
|
||||
for (int i = 0; i < nb; i++) {
|
||||
// Load elements into 4 AVX vectors
|
||||
__m256 v0 = _mm256_loadu_ps( x );
|
||||
__m256 v1 = _mm256_loadu_ps( x + 8 );
|
||||
__m256 v2 = _mm256_loadu_ps( x + 16 );
|
||||
__m256 v3 = _mm256_loadu_ps( x + 24 );
|
||||
x += 32;
|
||||
|
||||
// Compute max for the block
|
||||
__m256 max = _mm256_max_ps( v0, v1 );
|
||||
__m256 maxTmp = _mm256_max_ps( v2, v3 );
|
||||
max = _mm256_max_ps( max, maxTmp );
|
||||
|
||||
__m128 max4 = _mm_max_ps( _mm256_extractf128_ps( max, 1 ), _mm256_castps256_ps128( max ) );
|
||||
max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
|
||||
max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
|
||||
const float maxScalar = _mm_cvtss_f32( max4 );
|
||||
|
||||
// Compute min for the block
|
||||
__m256 min = _mm256_min_ps( v0, v1 );
|
||||
__m256 minTmp = _mm256_min_ps( v2, v3 );
|
||||
min = _mm256_min_ps( min, minTmp );
|
||||
|
||||
__m128 min4 = _mm_min_ps( _mm256_extractf128_ps( min, 1 ), _mm256_castps256_ps128( min ) );
|
||||
min4 = _mm_min_ps( min4, _mm_movehl_ps( min4, min4 ) );
|
||||
min4 = _mm_min_ss( min4, _mm_movehdup_ps( min4 ) );
|
||||
const float minScalar = _mm_cvtss_f32( min4 );
|
||||
|
||||
// Quantize these floats
|
||||
const float magnitude = maxScalar >= fabsf(minScalar) ? maxScalar : minScalar;
|
||||
const float d = magnitude / -8.0f;
|
||||
y[i].d = d;
|
||||
const float id = ( magnitude != 0.0f ) ? -8.0f / magnitude : 0.0f;
|
||||
const __m256 mul = _mm256_set1_ps( id );
|
||||
|
||||
// Apply the multiplier
|
||||
v0 = _mm256_mul_ps( v0, mul );
|
||||
v1 = _mm256_mul_ps( v1, mul );
|
||||
v2 = _mm256_mul_ps( v2, mul );
|
||||
v3 = _mm256_mul_ps( v3, mul );
|
||||
|
||||
// Round to nearest integer
|
||||
v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
|
||||
v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
|
||||
v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
|
||||
v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
|
||||
|
||||
// Convert floats to integers
|
||||
__m256i i0 = _mm256_cvtps_epi32( v0 );
|
||||
__m256i i1 = _mm256_cvtps_epi32( v1 );
|
||||
__m256i i2 = _mm256_cvtps_epi32( v2 );
|
||||
__m256i i3 = _mm256_cvtps_epi32( v3 );
|
||||
|
||||
// Convert int32 to int16
|
||||
i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
|
||||
i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
|
||||
// Convert int16 to int8
|
||||
i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
|
||||
|
||||
// We got our precious signed bytes, but the order is now wrong
|
||||
// These AVX2 pack instructions process 16-byte pieces independently
|
||||
// The following instruction is fixing the order
|
||||
const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
|
||||
i0 = _mm256_permutevar8x32_epi32( i0, perm );
|
||||
|
||||
// Apply offset and clamp to translate the range from [ -8 .. +8 ] into [ +0 .. +15 ]
|
||||
const __m256i off = _mm256_set1_epi8( 8 );
|
||||
i0 = _mm256_add_epi8( i0, off );
|
||||
const __m256i maxNibble = _mm256_set1_epi8( 15 );
|
||||
i0 = _mm256_min_epi8( i0, maxNibble );
|
||||
|
||||
// Compress the vector into 4 bit/value, and store
|
||||
__m128i res = packNibbles( i0 );
|
||||
_mm_storeu_si128( ( __m128i* )y[i].qs, res );
|
||||
}
|
||||
#elif defined(__AVX__)
|
||||
for (int i = 0; i < nb; i++) {
|
||||
// Load elements into 4 AVX vectors
|
||||
__m256 v0 = _mm256_loadu_ps( x );
|
||||
__m256 v1 = _mm256_loadu_ps( x + 8 );
|
||||
__m256 v2 = _mm256_loadu_ps( x + 16 );
|
||||
__m256 v3 = _mm256_loadu_ps( x + 24 );
|
||||
x += 32;
|
||||
|
||||
// Compute max for the block
|
||||
__m256 max = _mm256_max_ps( v0, v1 );
|
||||
__m256 maxTmp = _mm256_max_ps( v2, v3 );
|
||||
max = _mm256_max_ps( max, maxTmp );
|
||||
|
||||
__m128 max4 = _mm_max_ps( _mm256_extractf128_ps( max, 1 ), _mm256_castps256_ps128( max ) );
|
||||
max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
|
||||
max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
|
||||
const float maxScalar = _mm_cvtss_f32( max4 );
|
||||
|
||||
// Compute min for the block
|
||||
__m256 min = _mm256_min_ps( v0, v1 );
|
||||
__m256 minTmp = _mm256_min_ps( v2, v3 );
|
||||
min = _mm256_min_ps( min, minTmp );
|
||||
|
||||
__m128 min4 = _mm_min_ps( _mm256_extractf128_ps( min, 1 ), _mm256_castps256_ps128( min ) );
|
||||
min4 = _mm_min_ps( min4, _mm_movehl_ps( min4, min4 ) );
|
||||
min4 = _mm_min_ss( min4, _mm_movehdup_ps( min4 ) );
|
||||
const float minScalar = _mm_cvtss_f32( min4 );
|
||||
|
||||
// Quantize these floats
|
||||
const float magnitude = maxScalar >= fabsf(minScalar) ? maxScalar : minScalar;
|
||||
const float d = magnitude / -8.0f;
|
||||
y[i].d = d;
|
||||
const float id = ( magnitude != 0.0f ) ? -8.0f / magnitude : 0.0f;
|
||||
const __m256 mul = _mm256_set1_ps( id );
|
||||
|
||||
// Apply the multiplier
|
||||
v0 = _mm256_mul_ps( v0, mul );
|
||||
v1 = _mm256_mul_ps( v1, mul );
|
||||
v2 = _mm256_mul_ps( v2, mul );
|
||||
v3 = _mm256_mul_ps( v3, mul );
|
||||
|
||||
// Round to nearest integer
|
||||
v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
|
||||
v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
|
||||
v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
|
||||
v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
|
||||
|
||||
// Convert floats to integers
|
||||
__m256i i0 = _mm256_cvtps_epi32( v0 );
|
||||
__m256i i1 = _mm256_cvtps_epi32( v1 );
|
||||
__m256i i2 = _mm256_cvtps_epi32( v2 );
|
||||
__m256i i3 = _mm256_cvtps_epi32( v3 );
|
||||
|
||||
// Since we don't have in AVX some necessary functions,
|
||||
// we split the registers in half and call AVX2 analogs from SSE
|
||||
__m128i ni0 = _mm256_castsi256_si128( i0 );
|
||||
__m128i ni1 = _mm256_extractf128_si256( i0, 1);
|
||||
__m128i ni2 = _mm256_castsi256_si128( i1 );
|
||||
__m128i ni3 = _mm256_extractf128_si256( i1, 1);
|
||||
__m128i ni4 = _mm256_castsi256_si128( i2 );
|
||||
__m128i ni5 = _mm256_extractf128_si256( i2, 1);
|
||||
__m128i ni6 = _mm256_castsi256_si128( i3 );
|
||||
__m128i ni7 = _mm256_extractf128_si256( i3, 1);
|
||||
|
||||
// Convert int32 to int16
|
||||
ni0 = _mm_packs_epi32( ni0, ni1 );
|
||||
ni2 = _mm_packs_epi32( ni2, ni3 );
|
||||
ni4 = _mm_packs_epi32( ni4, ni5 );
|
||||
ni6 = _mm_packs_epi32( ni6, ni7 );
|
||||
// Convert int16 to int8
|
||||
ni0 = _mm_packs_epi16( ni0, ni2 );
|
||||
ni4 = _mm_packs_epi16( ni4, ni6 );
|
||||
|
||||
// Apply offset and clamp to translate the range from [ -8 .. +8 ] into [ +0 .. +15 ]
|
||||
const __m128i off = _mm_set1_epi8( 8 );
|
||||
ni0 = _mm_add_epi8( ni0, off );
|
||||
ni4 = _mm_add_epi8( ni4, off );
|
||||
const __m128i maxNibble = _mm_set1_epi8( 15 );
|
||||
ni0 = _mm_min_epi8( ni0, maxNibble );
|
||||
ni4 = _mm_min_epi8( ni4, maxNibble );
|
||||
|
||||
// Compress the vector into 4 bit/value, and store
|
||||
__m128i res = packNibbles( ni0, ni4 );
|
||||
_mm_storeu_si128( ( __m128i* )y[i].qs, res );
|
||||
}
|
||||
#elif defined(__wasm_simd128__)
|
||||
for (int i = 0; i < nb; i++) {
|
||||
float max = 0.0f;
|
||||
float min = 0.0f;
|
||||
|
||||
v128_t srcv [8];
|
||||
v128_t maxv[8];
|
||||
v128_t minv[8];
|
||||
|
||||
for (int l = 0; l < 8; l++) srcv[l] = wasm_v128_load(x + i*32 + 4*l);
|
||||
|
||||
for (int l = 0; l < 4; l++) maxv[2*l] = wasm_f32x4_max(srcv[2*l], srcv[2*l+1]);
|
||||
for (int l = 0; l < 2; l++) maxv[4*l] = wasm_f32x4_max(maxv[4*l], maxv[4*l+2]);
|
||||
for (int l = 0; l < 1; l++) maxv[8*l] = wasm_f32x4_max(maxv[8*l], maxv[8*l+4]);
|
||||
|
||||
for (int l = 0; l < 4; l++) minv[2*l] = wasm_f32x4_min(srcv[2*l], srcv[2*l+1]);
|
||||
for (int l = 0; l < 2; l++) minv[4*l] = wasm_f32x4_min(minv[4*l], minv[4*l+2]);
|
||||
for (int l = 0; l < 1; l++) minv[8*l] = wasm_f32x4_min(minv[8*l], minv[8*l+4]);
|
||||
|
||||
max = MAX(
|
||||
MAX(wasm_f32x4_extract_lane(maxv[0], 0), wasm_f32x4_extract_lane(maxv[0], 1)),
|
||||
MAX(wasm_f32x4_extract_lane(maxv[0], 2), wasm_f32x4_extract_lane(maxv[0], 3)));
|
||||
min = MIN(
|
||||
MIN(wasm_f32x4_extract_lane(minv[0], 0), wasm_f32x4_extract_lane(minv[0], 1)),
|
||||
MIN(wasm_f32x4_extract_lane(minv[0], 2), wasm_f32x4_extract_lane(minv[0], 3)));
|
||||
|
||||
const float magnitude = max >= fabsf(min) ? max : min;
|
||||
const float d = magnitude / -8;
|
||||
const float id = d ? 1.0/d : 0.0;
|
||||
|
||||
y[i].d = d;
|
||||
|
||||
for (int l = 0; l < 8; l++) {
|
||||
const v128_t v = wasm_f32x4_mul(srcv[l], wasm_f32x4_splat(id));
|
||||
const v128_t vf = wasm_f32x4_add(v, wasm_f32x4_splat(8.5f));
|
||||
const v128_t vi = wasm_i32x4_trunc_sat_f32x4(vf);
|
||||
const v128_t vc = wasm_i32x4_min(vi, wasm_i32x4_splat(15));
|
||||
|
||||
y[i].qs[2*l + 0] = wasm_i32x4_extract_lane(vc, 0) | (wasm_i32x4_extract_lane(vc, 1) << 4);
|
||||
y[i].qs[2*l + 1] = wasm_i32x4_extract_lane(vc, 2) | (wasm_i32x4_extract_lane(vc, 3) << 4);
|
||||
}
|
||||
}
|
||||
#else
|
||||
// scalar
|
||||
quantize_row_q4_0_reference(x, y, k);
|
||||
#endif
|
||||
}
|
||||
|
||||
static void quantize_row_q4_1_reference(const float * restrict x, void * restrict vy, int k) {
|
||||
@ -1843,121 +1547,33 @@ static void quantize_row_q8_1(const float * restrict x, void * restrict vy, int
|
||||
}
|
||||
|
||||
static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, int k) {
|
||||
assert(QK4_0 / 16 == 0);
|
||||
assert(k % QK4_0 == 0);
|
||||
|
||||
const int nb = k / QK4_0;
|
||||
|
||||
const block_q4_0 * restrict x = vx;
|
||||
|
||||
#if defined(__AVX2__)
|
||||
for (int i = 0; i < nb; i++) {
|
||||
// scale factor
|
||||
const __m256 d_v = _mm256_broadcast_ss(&x[i].d);
|
||||
|
||||
const uint8_t * restrict pp = x[i].qs;
|
||||
|
||||
for (int l = 0; l < QK4_0; l += 32) {
|
||||
// Load 32x4-bit integers into 32x8-bit integers
|
||||
__m256i vx8 = bytes_from_nibbles_32(pp+l/2);
|
||||
|
||||
// Subtract 8 from the integers
|
||||
vx8 = _mm256_sub_epi8(vx8, _mm256_set1_epi8(8));
|
||||
|
||||
// Convert to 16-bit int
|
||||
const __m256i vx16_lo = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 0));
|
||||
const __m256i vx16_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 1));
|
||||
|
||||
// Convert to 32-bit int -> float 32
|
||||
const __m256 vf[4] = {
|
||||
_mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_lo, 0))),
|
||||
_mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_lo, 1))),
|
||||
_mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 0))),
|
||||
_mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 1)))
|
||||
};
|
||||
|
||||
// Scale and store
|
||||
for (int j = 0; j < 4; j++) {
|
||||
const __m256 result = _mm256_mul_ps(vf[j], d_v);
|
||||
_mm256_storeu_ps(y + i * QK4_0 + l + j*8, result);
|
||||
}
|
||||
}
|
||||
}
|
||||
#elif defined(__ARM_NEON)
|
||||
for (int i = 0; i < nb; i++) {
|
||||
const float32x4_t vd = vdupq_n_f32(x[i].d);
|
||||
|
||||
const uint8_t * restrict pp = x[i].qs;
|
||||
|
||||
for (int l = 0; l < QK4_0; l += 16) {
|
||||
// Load 16x4-bit integers into 8x8-bit integers
|
||||
const uint8x8_t v8 = vld1_u8(pp + l/2);
|
||||
|
||||
// Expand 4-bit qs to 8-bit bytes
|
||||
const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0F));
|
||||
const uint8x8_t v1 = vshr_n_u8(v8, 4);
|
||||
|
||||
// Convert to signed 8-bit integers
|
||||
const int8x8_t vs_0 = vreinterpret_s8_u8(v0);
|
||||
const int8x8_t vs_1 = vreinterpret_s8_u8(v1);
|
||||
|
||||
// Subtract 8 from each byte
|
||||
const int8x8_t vb_0 = vsub_s8(vs_0, vdup_n_s8(8));
|
||||
const int8x8_t vb_1 = vsub_s8(vs_1, vdup_n_s8(8));
|
||||
|
||||
// Interleave and combine
|
||||
const int8x8_t vx_0 = vzip1_s8(vb_0, vb_1);
|
||||
const int8x8_t vx_1 = vzip2_s8(vb_0, vb_1);
|
||||
|
||||
const int8x16_t vq = vcombine_s8(vx_0, vx_1);
|
||||
|
||||
// convert to 2x int16x8_t
|
||||
const int16x8_t vi_0 = vmovl_s8(vget_low_s8 (vq));
|
||||
const int16x8_t vi_1 = vmovl_s8(vget_high_s8(vq));
|
||||
|
||||
// convert to 4x float32x4_t
|
||||
const float32x4_t vf_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vi_0)));
|
||||
const float32x4_t vf_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vi_0)));
|
||||
const float32x4_t vf_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vi_1)));
|
||||
const float32x4_t vf_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vi_1)));
|
||||
|
||||
// Multiply by d
|
||||
const float32x4_t r0 = vmulq_f32(vf_0, vd);
|
||||
const float32x4_t r1 = vmulq_f32(vf_1, vd);
|
||||
const float32x4_t r2 = vmulq_f32(vf_2, vd);
|
||||
const float32x4_t r3 = vmulq_f32(vf_3, vd);
|
||||
|
||||
// Store
|
||||
vst1q_f32(y + i*QK4_0 + l + 0, r0);
|
||||
vst1q_f32(y + i*QK4_0 + l + 4, r1);
|
||||
vst1q_f32(y + i*QK4_0 + l + 8, r2);
|
||||
vst1q_f32(y + i*QK4_0 + l + 12, r3);
|
||||
}
|
||||
}
|
||||
#else
|
||||
// scalar
|
||||
for (int i = 0; i < nb; i++) {
|
||||
const float d = x[i].d;
|
||||
|
||||
const uint8_t * restrict pp = x[i].qs;
|
||||
// unpack nibbles into bytes
|
||||
uint64_t qs[QK4_0 / 8] = {0};
|
||||
|
||||
for (int l = 0; l < QK4_0; l += 2) {
|
||||
const uint8_t vi = pp[l/2];
|
||||
memcpy(qs + 0, x[i].qs, sizeof(x[i].qs));
|
||||
memcpy(qs + QK4_0 / 16, x[i].qs, sizeof(x[i].qs));
|
||||
|
||||
const int8_t vi0 = vi & 0x0F;
|
||||
const int8_t vi1 = vi >> 4;
|
||||
for (int l = 0; l < QK4_0 / 16; ++l) {
|
||||
qs[l ] = (qs[l ] & 0x0F0F0F0F0F0F0F0FULL) >> 0;
|
||||
qs[l + QK4_0/16] = (qs[l + QK4_0/16] & 0xF0F0F0F0F0F0F0F0ULL) >> 4;
|
||||
}
|
||||
|
||||
const float v0 = (vi0 - 8)*d;
|
||||
const float v1 = (vi1 - 8)*d;
|
||||
const uint8_t * restrict qsp = (const uint8_t * restrict) qs;
|
||||
|
||||
//printf("d = %f, vi = %d, vi0 = %d, vi1 = %d, v0 = %f, v1 = %f\n", d, vi, vi0, vi1, v0, v1);
|
||||
|
||||
y[i*QK4_0 + l + 0] = v0;
|
||||
y[i*QK4_0 + l + 1] = v1;
|
||||
|
||||
assert(!isnan(y[i*QK4_0 + l + 0]));
|
||||
assert(!isnan(y[i*QK4_0 + l + 1]));
|
||||
for (int l = 0; l < QK4_0; ++l) {
|
||||
y[i*QK4_0 + l] = (qsp[l] - 8)*d;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, int k) {
|
||||
@ -2887,12 +2503,6 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
|
||||
const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
|
||||
const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
|
||||
|
||||
// interleave
|
||||
const int8x16_t v0_0lz = vzip1q_s8(v0_0ls, v0_0hs);
|
||||
const int8x16_t v0_0hz = vzip2q_s8(v0_0ls, v0_0hs);
|
||||
const int8x16_t v0_1lz = vzip1q_s8(v0_1ls, v0_1hs);
|
||||
const int8x16_t v0_1hz = vzip2q_s8(v0_1ls, v0_1hs);
|
||||
|
||||
// load y
|
||||
const int8x16_t v1_0l = vld1q_s8(y0->qs);
|
||||
const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
|
||||
@ -2901,21 +2511,21 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
|
||||
|
||||
#if defined(__ARM_FEATURE_DOTPROD)
|
||||
// dot product into int32x4_t
|
||||
const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l), v0_0hz, v1_0h);
|
||||
const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l), v0_1hz, v1_1h);
|
||||
const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l), v0_0hs, v1_0h);
|
||||
const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l), v0_1hs, v1_1h);
|
||||
|
||||
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d);
|
||||
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d);
|
||||
#else
|
||||
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
|
||||
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
|
||||
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h));
|
||||
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h));
|
||||
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0l));
|
||||
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0l));
|
||||
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0h));
|
||||
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0h));
|
||||
|
||||
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lz), vget_low_s8 (v1_1l));
|
||||
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lz), vget_high_s8(v1_1l));
|
||||
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hz), vget_low_s8 (v1_1h));
|
||||
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hz), vget_high_s8(v1_1h));
|
||||
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1l));
|
||||
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1l));
|
||||
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1h));
|
||||
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1h));
|
||||
|
||||
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
|
||||
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
|
||||
|
Loading…
Reference in New Issue
Block a user