mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-24 18:34:36 +00:00
Retire the ggml_mul_mat() branch for transposed src0 (#500)
* Retire the ggml_mul_mat() for transposed src0 - It can always be made contiguous with ggml_cpy() - The code is now simplified - The results are deterministic in respect to num threads * SIMD-ify dequantize_row_q4_0() for ARM_NEON (#502) * Attempt to SIMD-ify dequantize_row_q4_0() for ARM_NEON * Fix dequantization - forgot to interleave the quants
This commit is contained in:
parent
502a400192
commit
ecbe466a36
609
ggml.c
609
ggml.c
@ -496,7 +496,7 @@ static void quantize_row_q4_0_reference(const float * restrict x, void * restric
|
||||
void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
|
||||
assert(k % QK == 0);
|
||||
|
||||
#if __ARM_NEON || defined(__AVX2__) || defined(__wasm_simd128__) || defined(__POWER9_VECTOR__)
|
||||
#if defined(__ARM_NEON) || defined(__AVX2__) || defined(__wasm_simd128__) || defined(__POWER9_VECTOR__)
|
||||
const int nb = k / QK;
|
||||
const size_t bs = sizeof(float) + QK/2;
|
||||
|
||||
@ -507,7 +507,6 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
|
||||
#endif
|
||||
|
||||
#if defined(__POWER9_VECTOR__)
|
||||
#if QK == 32
|
||||
const vector float v85 = vec_splats(8.5f);
|
||||
for (int i = 0; i < nb; i++) {
|
||||
float amax = 0.0f; // absolute max
|
||||
@ -548,11 +547,7 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
|
||||
//memcpy(pb, pp, sizeof(pp));
|
||||
pb += bs;
|
||||
}
|
||||
#else
|
||||
#error "not implemented for QK"
|
||||
#endif
|
||||
#elif __ARM_NEON
|
||||
#if QK == 32
|
||||
for (int i = 0; i < nb; i++) {
|
||||
float amax = 0.0f; // absolute max
|
||||
|
||||
@ -589,11 +584,7 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
|
||||
memcpy(pb, pp, sizeof(pp));
|
||||
pb += bs;
|
||||
}
|
||||
#else
|
||||
#error "not implemented for QK"
|
||||
#endif
|
||||
#elif defined(__AVX2__)
|
||||
#if QK == 32
|
||||
for (int i = 0; i < nb; i++) {
|
||||
// Load elements into 4 AVX vectors
|
||||
__m256 v0 = _mm256_loadu_ps( x );
|
||||
@ -660,11 +651,7 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
|
||||
_mm_storeu_si128( ( __m128i* )pb, res );
|
||||
pb += bs;
|
||||
}
|
||||
#else
|
||||
#error "not implemented for QK"
|
||||
#endif
|
||||
#elif defined(__wasm_simd128__)
|
||||
#if QK == 32
|
||||
for (int i = 0; i < nb; i++) {
|
||||
float amax = 0.0f; // absolute max
|
||||
|
||||
@ -701,9 +688,6 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
|
||||
memcpy(pb, pp, sizeof(pp));
|
||||
pb += bs;
|
||||
}
|
||||
#else
|
||||
#error "not implemented for QK"
|
||||
#endif
|
||||
#else
|
||||
// scalar
|
||||
quantize_row_q4_0_reference(x, y, k);
|
||||
@ -771,7 +755,7 @@ void dequantize_row_q4_0(const void * restrict x, float * restrict y, int k) {
|
||||
const uint8_t * restrict pd = ((const uint8_t *)x + 0*bs);
|
||||
const uint8_t * restrict pb = ((const uint8_t *)x + 0*bs + sizeof(float));
|
||||
|
||||
#if defined(__AVX2__) && QK % 32 == 0
|
||||
#if defined(__AVX2__)
|
||||
for (int i = 0; i < nb; i++) {
|
||||
// scale factor
|
||||
const __m256 d_v = _mm256_broadcast_ss((const float *) (pd + i*bs));
|
||||
@ -804,6 +788,59 @@ void dequantize_row_q4_0(const void * restrict x, float * restrict y, int k) {
|
||||
}
|
||||
}
|
||||
}
|
||||
#elif defined(__ARM_NEON)
|
||||
for (int i = 0; i < nb; i++) {
|
||||
const float d = *(const float *) (pd + i*bs);
|
||||
|
||||
const uint8_t * restrict pp = pb + i*bs;
|
||||
|
||||
const float32x4_t vd = vdupq_n_f32(d);
|
||||
|
||||
for (int l = 0; l < QK; l += 16) {
|
||||
// Load 16x4-bit integers into 8x8-bit integers
|
||||
const uint8x8_t v8 = vld1_u8(pp + l/2);
|
||||
|
||||
// Expand 4-bit nibbles to 8-bit bytes
|
||||
const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0f));
|
||||
const uint8x8_t v1 = vshr_n_u8(v8, 4);
|
||||
|
||||
// Convert to signed 8-bit integers
|
||||
const int8x8_t vs_0 = vreinterpret_s8_u8(v0);
|
||||
const int8x8_t vs_1 = vreinterpret_s8_u8(v1);
|
||||
|
||||
// Subtract 8 from each byte
|
||||
const int8x8_t vb_0 = vsub_s8(vs_0, vdup_n_s8(8));
|
||||
const int8x8_t vb_1 = vsub_s8(vs_1, vdup_n_s8(8));
|
||||
|
||||
// Interleave and combine
|
||||
const int8x8_t vx_0 = vzip1_s8(vb_0, vb_1);
|
||||
const int8x8_t vx_1 = vzip2_s8(vb_0, vb_1);
|
||||
|
||||
const int8x16_t vq = vcombine_s8(vx_0, vx_1);
|
||||
|
||||
// convert to 2x int16x8_t
|
||||
const int16x8_t vi_0 = vmovl_s8(vget_low_s8 (vq));
|
||||
const int16x8_t vi_1 = vmovl_s8(vget_high_s8(vq));
|
||||
|
||||
// convert to 4x float32x4_t
|
||||
const float32x4_t vf_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vi_0)));
|
||||
const float32x4_t vf_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vi_0)));
|
||||
const float32x4_t vf_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vi_1)));
|
||||
const float32x4_t vf_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vi_1)));
|
||||
|
||||
// Multiply by d
|
||||
const float32x4_t r0 = vmulq_f32(vf_0, vd);
|
||||
const float32x4_t r1 = vmulq_f32(vf_1, vd);
|
||||
const float32x4_t r2 = vmulq_f32(vf_2, vd);
|
||||
const float32x4_t r3 = vmulq_f32(vf_3, vd);
|
||||
|
||||
// Store
|
||||
vst1q_f32(y + i*QK + l + 0, r0);
|
||||
vst1q_f32(y + i*QK + l + 4, r1);
|
||||
vst1q_f32(y + i*QK + l + 8, r2);
|
||||
vst1q_f32(y + i*QK + l + 12, r3);
|
||||
}
|
||||
}
|
||||
#else
|
||||
// scalar
|
||||
for (int i = 0; i < nb; i++) {
|
||||
@ -1500,8 +1537,7 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
|
||||
|
||||
float sumf = 0.0;
|
||||
|
||||
#ifdef __ARM_NEON
|
||||
#if QK == 32
|
||||
#if defined(__ARM_NEON)
|
||||
float sum0 = 0.0f;
|
||||
float sum1 = 0.0f;
|
||||
|
||||
@ -1600,12 +1636,7 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
|
||||
}
|
||||
|
||||
sumf = sum0 + sum1;
|
||||
#else
|
||||
#error "not implemented for QK"
|
||||
#endif
|
||||
#elif defined(__AVX512F__)
|
||||
|
||||
#if QK == 32
|
||||
// Initialize accumulator with zeros
|
||||
__m512 acc0 = _mm512_setzero_ps();
|
||||
__m512 acc1 = _mm512_setzero_ps();
|
||||
@ -1634,11 +1665,7 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
|
||||
|
||||
// Horizontal sum of all lanes of the accumulator
|
||||
sumf = _mm512_reduce_add_ps( acc0 ) + _mm512_reduce_add_ps( acc1 );
|
||||
#else
|
||||
#error "not implemented for QK"
|
||||
#endif
|
||||
#elif defined(__AVX2__)
|
||||
#if QK == 32
|
||||
const size_t countBlocks = nb;
|
||||
|
||||
// Initialize accumulator with zeros
|
||||
@ -1689,11 +1716,7 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
|
||||
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
|
||||
|
||||
sumf = _mm_cvtss_f32( res );
|
||||
#else
|
||||
#error "not implemented for QK"
|
||||
#endif
|
||||
#elif defined(__wasm_simd128__)
|
||||
#if QK == 32
|
||||
// wasm simd
|
||||
float sum0 = 0.0f;
|
||||
float sum1 = 0.0f;
|
||||
@ -1776,9 +1799,6 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
|
||||
}
|
||||
|
||||
sumf = sum0 + sum1;
|
||||
#else
|
||||
#error "not implemented for QK"
|
||||
#endif
|
||||
#else
|
||||
// scalar
|
||||
for (int i = 0; i < nb; i++) {
|
||||
@ -1823,7 +1843,6 @@ inline static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void
|
||||
float sumf = 0.0;
|
||||
|
||||
#if defined(__AVX2__)
|
||||
#if QK == 32
|
||||
// Initialize accumulator with zeros
|
||||
__m256 acc = _mm256_setzero_ps();
|
||||
// Accumulator for constant offsets
|
||||
@ -1898,9 +1917,6 @@ inline static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void
|
||||
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
|
||||
|
||||
sumf = _mm_cvtss_f32( res ) + acc_offset * QK;
|
||||
#else
|
||||
#error "not implemented for QK"
|
||||
#endif
|
||||
#else
|
||||
// scalar
|
||||
for (int i = 0; i < nb; i++) {
|
||||
@ -2017,167 +2033,6 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float
|
||||
#endif
|
||||
}
|
||||
|
||||
inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_fp16_t * restrict x, const float v) {
|
||||
#if defined(GGML_SIMD)
|
||||
const int np = (n & ~(GGML_F16_STEP - 1));
|
||||
|
||||
GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
|
||||
|
||||
GGML_F16_VEC ax[GGML_F16_ARR];
|
||||
GGML_F16_VEC ay[GGML_F16_ARR];
|
||||
|
||||
for (int i = 0; i < np; i += GGML_F16_STEP) {
|
||||
for (int j = 0; j < GGML_F16_ARR; j++) {
|
||||
ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
|
||||
ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
|
||||
ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx);
|
||||
|
||||
GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
|
||||
}
|
||||
}
|
||||
|
||||
// leftovers
|
||||
for (int i = np; i < n; ++i) {
|
||||
GGML_ASSERT(false);
|
||||
y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v);
|
||||
}
|
||||
#else
|
||||
for (int i = 0; i < n; ++i) {
|
||||
y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
inline static void ggml_vec_mad_q4_0(const int n, float * restrict y, void * restrict x, const float v) {
|
||||
assert(n % QK == 0);
|
||||
|
||||
const int nb = n / QK;
|
||||
const size_t bs = sizeof(float) + QK/2;
|
||||
|
||||
const uint8_t * restrict pd = ((const uint8_t *)x + 0*bs);
|
||||
const uint8_t * restrict pb = ((const uint8_t *)x + 0*bs + sizeof(float));
|
||||
|
||||
#if __ARM_NEON
|
||||
#if QK == 32
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
const float d0 = v*(*(const float *) (pd + i*bs));
|
||||
|
||||
const uint8_t * restrict pp = pb + i*bs;
|
||||
|
||||
const uint8x8_t m4b = vdup_n_u8(0xf);
|
||||
const int8x8_t s8b = vdup_n_s8(0x8);
|
||||
|
||||
const float32x4_t vd = vdupq_n_f32(d0);
|
||||
|
||||
for (int j = 0; j < 2; j++) {
|
||||
const uint8x8_t vx = vld1_u8(pp + j*8);
|
||||
|
||||
const int8x8_t vxl = vreinterpret_s8_u8(vand_u8(vx, m4b));
|
||||
const int8x8_t vxh = vreinterpret_s8_u8(vshr_n_u8(vx, 4));
|
||||
|
||||
// sub 8
|
||||
const int8x8_t vxls = vsub_s8(vxl, s8b);
|
||||
const int8x8_t vxhs = vsub_s8(vxh, s8b);
|
||||
|
||||
//const int8x8_t vxlt = vzip_s8(vxls, vxhs)[0];
|
||||
//const int8x8_t vxht = vzip_s8(vxls, vxhs)[1];
|
||||
const int8x8_t vxlt = vzip1_s8(vxls, vxhs);
|
||||
const int8x8_t vxht = vzip2_s8(vxls, vxhs);
|
||||
|
||||
const int8x16_t vxq = vcombine_s8(vxlt, vxht);
|
||||
|
||||
// convert to 2x int16x8_t
|
||||
const int16x8_t vxq0 = vmovl_s8(vget_low_s8 (vxq));
|
||||
const int16x8_t vxq1 = vmovl_s8(vget_high_s8(vxq));
|
||||
|
||||
// convert to 4x float32x4_t
|
||||
const float32x4_t vx0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vxq0)));
|
||||
const float32x4_t vx1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vxq0)));
|
||||
const float32x4_t vx2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vxq1)));
|
||||
const float32x4_t vx3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vxq1)));
|
||||
|
||||
const float32x4_t vy0 = vld1q_f32(y + i*32 + j*16 + 0);
|
||||
const float32x4_t vy1 = vld1q_f32(y + i*32 + j*16 + 4);
|
||||
const float32x4_t vy2 = vld1q_f32(y + i*32 + j*16 + 8);
|
||||
const float32x4_t vy3 = vld1q_f32(y + i*32 + j*16 + 12);
|
||||
|
||||
const float32x4_t vr0 = vfmaq_f32(vy0, vx0, vd);
|
||||
const float32x4_t vr1 = vfmaq_f32(vy1, vx1, vd);
|
||||
const float32x4_t vr2 = vfmaq_f32(vy2, vx2, vd);
|
||||
const float32x4_t vr3 = vfmaq_f32(vy3, vx3, vd);
|
||||
|
||||
vst1q_f32(y + i*32 + j*16 + 0, vr0);
|
||||
vst1q_f32(y + i*32 + j*16 + 4, vr1);
|
||||
vst1q_f32(y + i*32 + j*16 + 8, vr2);
|
||||
vst1q_f32(y + i*32 + j*16 + 12, vr3);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#else
|
||||
// scalar
|
||||
for (int i = 0; i < nb; i++) {
|
||||
const float d = *(const float *) (pd + i*bs);
|
||||
|
||||
const uint8_t * restrict pp = pb + i*bs;
|
||||
|
||||
for (int l = 0; l < QK; l += 2) {
|
||||
const uint8_t vi = pp[l/2];
|
||||
|
||||
const int8_t vi0 = vi & 0xf;
|
||||
const int8_t vi1 = vi >> 4;
|
||||
|
||||
const float v0 = (vi0 - 8)*d;
|
||||
const float v1 = (vi1 - 8)*d;
|
||||
|
||||
y[i*QK + l + 0] += v0*v;
|
||||
y[i*QK + l + 1] += v1*v;
|
||||
|
||||
assert(!isnan(y[i*QK + l + 0]));
|
||||
assert(!isnan(y[i*QK + l + 1]));
|
||||
assert(!isinf(y[i*QK + l + 0]));
|
||||
assert(!isinf(y[i*QK + l + 1]));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
inline static void ggml_vec_mad_q4_1(const int n, float * restrict y, void * restrict x, const float v) {
|
||||
assert(n % QK == 0);
|
||||
|
||||
const int nb = n / QK;
|
||||
const size_t bs = 2*sizeof(float) + QK/2;
|
||||
|
||||
const uint8_t * restrict pd = ((const uint8_t *)x + 0*bs);
|
||||
const uint8_t * restrict pm = ((const uint8_t *)x + 0*bs + sizeof(float));
|
||||
const uint8_t * restrict pb = ((const uint8_t *)x + 0*bs + 2*sizeof(float));
|
||||
|
||||
for (int i = 0; i < nb; i++) {
|
||||
const float d = *(const float *) (pd + i*bs);
|
||||
const float m = *(const float *) (pm + i*bs);
|
||||
|
||||
const uint8_t * restrict pp = pb + i*bs;
|
||||
|
||||
for (int l = 0; l < QK; l += 2) {
|
||||
const uint8_t vi = pp[l/2];
|
||||
|
||||
const uint8_t vi0 = vi & 0xf;
|
||||
const uint8_t vi1 = vi >> 4;
|
||||
|
||||
const float v0 = d*vi0 + m;
|
||||
const float v1 = d*vi1 + m;
|
||||
|
||||
y[i*QK + l + 0] += v0*v;
|
||||
y[i*QK + l + 1] += v1*v;
|
||||
|
||||
assert(!isnan(y[i*QK + l + 0]));
|
||||
assert(!isnan(y[i*QK + l + 1]));
|
||||
assert(!isinf(y[i*QK + l + 0]));
|
||||
assert(!isinf(y[i*QK + l + 1]));
|
||||
//printf("mad: v0 %f v1 %f, i = %d, l = %d, d = %f, vi = %d, vi0 = %d, vi1 = %d\n", v0, v1, i, l, d, vi, vi0, vi1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] *= v; }
|
||||
inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
|
||||
#if defined(GGML_SIMD)
|
||||
@ -2617,6 +2472,10 @@ static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct
|
||||
(t0->ne[3] == t1->ne[3]);
|
||||
}
|
||||
|
||||
static inline bool ggml_is_transposed(const struct ggml_tensor * tensor) {
|
||||
return tensor->nb[0] > tensor->nb[1];
|
||||
}
|
||||
|
||||
static inline bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
|
||||
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
||||
|
||||
@ -4010,6 +3869,7 @@ struct ggml_tensor * ggml_mul_mat(
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b) {
|
||||
GGML_ASSERT(ggml_can_mul_mat(a, b));
|
||||
GGML_ASSERT(!ggml_is_transposed(a));
|
||||
|
||||
bool is_node = false;
|
||||
|
||||
@ -5949,7 +5809,7 @@ static void ggml_compute_forward_mul_mat_f32(
|
||||
assert(ne3 == ne13);
|
||||
|
||||
// TODO: we don't support permuted src0
|
||||
assert(nb00 == sizeof(float) || nb01 == sizeof(float));
|
||||
assert(nb00 == sizeof(float));
|
||||
|
||||
// dst cannot be transposed or permuted
|
||||
assert(nb0 == sizeof(float));
|
||||
@ -5964,9 +5824,6 @@ static void ggml_compute_forward_mul_mat_f32(
|
||||
|
||||
// nb01 >= nb00 - src0 is not transposed
|
||||
// compute by src0 rows
|
||||
//
|
||||
// nb00 < nb01 - src0 is transposed
|
||||
// compute by src0 columns
|
||||
|
||||
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
||||
if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
|
||||
@ -6007,42 +5864,13 @@ static void ggml_compute_forward_mul_mat_f32(
|
||||
#endif
|
||||
|
||||
if (params->type == GGML_TASK_INIT) {
|
||||
if (nb01 >= nb00) {
|
||||
return;
|
||||
}
|
||||
|
||||
// TODO: fix this memset (wsize is overestimated)
|
||||
memset(params->wdata, 0, params->wsize);
|
||||
return;
|
||||
}
|
||||
|
||||
if (params->type == GGML_TASK_FINALIZE) {
|
||||
if (nb01 >= nb00) {
|
||||
return;
|
||||
}
|
||||
|
||||
// TODO: fix this memset (wsize is overestimated)
|
||||
//assert(params->wsize == (ggml_nbytes(dst) + CACHE_LINE_SIZE)*nth);
|
||||
|
||||
float * const wdata = params->wdata;
|
||||
|
||||
// cols per thread
|
||||
const int dc = (ne + nth - 1)/nth;
|
||||
|
||||
// col range for this thread
|
||||
const int ic0 = dc*ith;
|
||||
const int ic1 = MIN(ic0 + dc, ne);
|
||||
|
||||
ggml_vec_cpy_f32(ic1 - ic0, (float *) dst->data + ic0, wdata + ic0);
|
||||
|
||||
for (int k = 1; k < nth; k++) {
|
||||
ggml_vec_acc_f32(ic1 - ic0, (float *) dst->data + ic0, wdata + (ne + CACHE_LINE_SIZE_F32)*k + ic0);
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
if (nb01 >= nb00) {
|
||||
// TODO: do not support transposed src1
|
||||
assert(nb10 == sizeof(float));
|
||||
|
||||
@ -6082,53 +5910,6 @@ static void ggml_compute_forward_mul_mat_f32(
|
||||
(float *) ((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13)));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// parallelize by src1 columns using ggml_vec_mad_f32
|
||||
// each thread has its own work data
|
||||
// during FINALIZE we accumulate all work data into dst
|
||||
|
||||
// total columns in src1
|
||||
const int nc = ne10;
|
||||
|
||||
// columns per thread
|
||||
const int dc = (nc + nth - 1)/nth;
|
||||
|
||||
// column range for this thread
|
||||
const int ic0 = dc*ith;
|
||||
const int ic1 = MIN(ic0 + dc, nc);
|
||||
|
||||
// work data for thread
|
||||
const int wo = (ne + CACHE_LINE_SIZE_F32)*ith;
|
||||
float * const wdata = params->wdata;
|
||||
|
||||
for (int i13 = 0; i13 < ne13; ++i13) {
|
||||
for (int i12 = 0; i12 < ne12; ++i12) {
|
||||
for (int i11 = 0; i11 < ne11; ++i11) {
|
||||
for (int ic = ic0; ic < ic1; ++ic) {
|
||||
// src1 indices
|
||||
const int i10 = ic;
|
||||
|
||||
// src0 indices
|
||||
const int i03 = i13;
|
||||
const int i02 = i12;
|
||||
const int i00 = ic;
|
||||
|
||||
// dst indices
|
||||
const int i1 = i11;
|
||||
const int i2 = i12;
|
||||
const int i3 = i13;
|
||||
|
||||
assert(sizeof(float)*(wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0 + ne01) <= params->wsize);
|
||||
|
||||
ggml_vec_mad_f32(ne01,
|
||||
(float *) (wdata + wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0),
|
||||
(float *) ((char *) src0->data + (i00*nb00 + i02*nb02 + i03*nb03)),
|
||||
*(float *) ((char *) src1->data + (i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13)));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//int64_t t1 = ggml_perf_time_us();
|
||||
//static int64_t acc = 0;
|
||||
@ -6192,7 +5973,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
|
||||
GGML_ASSERT(ne3 == ne13);
|
||||
|
||||
// TODO: we don't support permuted src0
|
||||
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t) || nb01 == sizeof(ggml_fp16_t));
|
||||
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
|
||||
|
||||
// dst cannot be transposed or permuted
|
||||
GGML_ASSERT(nb0 == sizeof(float));
|
||||
@ -6207,9 +5988,6 @@ static void ggml_compute_forward_mul_mat_f16_f32(
|
||||
|
||||
// nb01 >= nb00 - src0 is not transposed
|
||||
// compute by src0 rows
|
||||
//
|
||||
// nb00 < nb01 - src0 is transposed
|
||||
// compute by src0 columns
|
||||
|
||||
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
||||
if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
|
||||
@ -6261,7 +6039,6 @@ static void ggml_compute_forward_mul_mat_f16_f32(
|
||||
#endif
|
||||
|
||||
if (params->type == GGML_TASK_INIT) {
|
||||
if (nb01 >= nb00) {
|
||||
ggml_fp16_t * const wdata = params->wdata;
|
||||
|
||||
size_t id = 0;
|
||||
@ -6280,42 +6057,10 @@ static void ggml_compute_forward_mul_mat_f16_f32(
|
||||
return;
|
||||
}
|
||||
|
||||
// TODO: fix this memset (wsize is overestimated)
|
||||
memset(params->wdata, 0, params->wsize);
|
||||
return;
|
||||
}
|
||||
|
||||
if (params->type == GGML_TASK_FINALIZE) {
|
||||
if (nb01 >= nb00) {
|
||||
return;
|
||||
}
|
||||
|
||||
// TODO: fix this memset (wsize is overestimated)
|
||||
//assert(params->wsize == (ggml_nbytes(dst) + CACHE_LINE_SIZE)*nth);
|
||||
|
||||
ggml_fp16_t * const wdata = params->wdata;
|
||||
|
||||
// cols per thread
|
||||
const int dc = (ne + nth - 1)/nth;
|
||||
|
||||
// col range for this thread
|
||||
const int ic0 = dc*ith;
|
||||
const int ic1 = MIN(ic0 + dc, ne);
|
||||
|
||||
for (int i = ic0; i < ic1; ++i) {
|
||||
((float *) dst->data)[i] = GGML_FP16_TO_FP32(wdata[i]);
|
||||
}
|
||||
|
||||
for (int k = 1; k < nth; k++) {
|
||||
for (int i = ic0; i < ic1; ++i) {
|
||||
((float *) dst->data)[i] += GGML_FP16_TO_FP32(wdata[(ne + CACHE_LINE_SIZE_F32)*k + i]);
|
||||
}
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
if (nb01 >= nb00) {
|
||||
// fp16 -> half the size, so divide by 2
|
||||
// TODO: do not support transposed src1
|
||||
assert(nb10/2 == sizeof(ggml_fp16_t));
|
||||
@ -6356,55 +6101,6 @@ static void ggml_compute_forward_mul_mat_f16_f32(
|
||||
ggml_vec_dot_f16(ne00, &dst_col[ic*ne0], src0_row, src1_col + ic*ne00);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// parallelize by src1 columns using ggml_vec_mad_f16
|
||||
// each thread has its own work data
|
||||
// during FINALIZE we accumulate all work data into dst
|
||||
|
||||
// total columns in src1
|
||||
const int nc = ne10;
|
||||
|
||||
// columns per thread
|
||||
const int dc = (nc + nth - 1)/nth;
|
||||
|
||||
// column range for this thread
|
||||
const int ic0 = dc*ith;
|
||||
const int ic1 = MIN(ic0 + dc, nc);
|
||||
|
||||
// work data for thread
|
||||
const int wo = (ne + CACHE_LINE_SIZE_F32)*ith;
|
||||
ggml_fp16_t * const wdata = params->wdata;
|
||||
|
||||
for (int i13 = 0; i13 < ne13; ++i13) {
|
||||
for (int i12 = 0; i12 < ne12; ++i12) {
|
||||
for (int i11 = 0; i11 < ne11; ++i11) {
|
||||
// dst indices
|
||||
const int i1 = i11;
|
||||
const int i2 = i12;
|
||||
const int i3 = i13;
|
||||
|
||||
ggml_fp16_t * dst_row = wdata + wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0;
|
||||
|
||||
for (int ic = ic0; ic < ic1; ++ic) {
|
||||
// src1 indices
|
||||
const int i10 = ic;
|
||||
|
||||
// src0 indices
|
||||
const int i03 = i13;
|
||||
const int i02 = i12;
|
||||
const int i00 = ic;
|
||||
|
||||
assert(sizeof(ggml_fp16_t)*(wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0 + ne01) <= params->wsize);
|
||||
|
||||
ggml_fp16_t * src0_col = (ggml_fp16_t *) ((char *) src0->data + (i00*nb00 + i02*nb02 + i03*nb03));
|
||||
float src1_val = * (float *) ((char *) src1->data + (i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
|
||||
|
||||
ggml_vec_mad_f16(ne01, dst_row, src0_col, src1_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//int64_t t1 = ggml_time_us();
|
||||
//static int64_t acc = 0;
|
||||
@ -6467,7 +6163,7 @@ static void ggml_compute_forward_mul_mat_q4_0_f32(
|
||||
GGML_ASSERT(ne3 == ne13);
|
||||
|
||||
// TODO: we don't support permuted src0
|
||||
GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[GGML_TYPE_Q4_0] || nb01 == (int) GGML_TYPE_SIZE[GGML_TYPE_Q4_0]);
|
||||
GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[GGML_TYPE_Q4_0]);
|
||||
|
||||
// dst cannot be transposed or permuted
|
||||
GGML_ASSERT(nb0 == sizeof(float));
|
||||
@ -6482,9 +6178,6 @@ static void ggml_compute_forward_mul_mat_q4_0_f32(
|
||||
|
||||
// nb01 >= nb00 - src0 is not transposed
|
||||
// compute by src0 rows
|
||||
//
|
||||
// nb00 < nb01 - src0 is transposed
|
||||
// compute by src0 columns
|
||||
|
||||
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
||||
if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
|
||||
@ -6509,9 +6202,6 @@ static void ggml_compute_forward_mul_mat_q4_0_f32(
|
||||
{
|
||||
size_t id = 0;
|
||||
for (int i01 = 0; i01 < ne01; ++i01) {
|
||||
//for (int i00 = 0; i00 < ne00; ++i00) {
|
||||
// wdata[id++] = GGML_FP16_TO_FP32(*(ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00));
|
||||
//}
|
||||
dequantize_row_q4_0((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01, wdata + id, ne00);
|
||||
id += ne00;
|
||||
}
|
||||
@ -6538,16 +6228,11 @@ static void ggml_compute_forward_mul_mat_q4_0_f32(
|
||||
#endif
|
||||
|
||||
if (params->type == GGML_TASK_INIT) {
|
||||
//printf("HHHHHHHHH ith = %d, nth = %d\n", ith, nth);
|
||||
if (nb01 >= nb00) {
|
||||
char * wdata = params->wdata;
|
||||
|
||||
for (int i13 = 0; i13 < ne13; ++i13) {
|
||||
for (int i12 = 0; i12 < ne12; ++i12) {
|
||||
for (int i11 = 0; i11 < ne11; ++i11) {
|
||||
//for (int i10 = 0; i10 < ne10; ++i10) {
|
||||
// wdata[id++] = GGML_FP32_TO_FP16(*(float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10));
|
||||
//}
|
||||
quantize_row_q4_0((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
|
||||
wdata += (ne10*GGML_TYPE_SIZE[GGML_TYPE_Q4_0])/GGML_BLCK_SIZE[GGML_TYPE_Q4_0];
|
||||
}
|
||||
@ -6557,35 +6242,10 @@ static void ggml_compute_forward_mul_mat_q4_0_f32(
|
||||
return;
|
||||
}
|
||||
|
||||
// TODO: fix this memset (wsize is overestimated)
|
||||
memset(params->wdata, 0, params->wsize);
|
||||
return;
|
||||
}
|
||||
|
||||
if (params->type == GGML_TASK_FINALIZE) {
|
||||
if (nb01 >= nb00) {
|
||||
return;
|
||||
}
|
||||
|
||||
float * const wdata = params->wdata;
|
||||
|
||||
// cols per thread
|
||||
const int dc = (ne + nth - 1)/nth;
|
||||
|
||||
// col range for this thread
|
||||
const int ic0 = dc*ith;
|
||||
const int ic1 = MIN(ic0 + dc, ne);
|
||||
|
||||
ggml_vec_cpy_f32(ic1 - ic0, (float *) dst->data + ic0, wdata + ic0);
|
||||
|
||||
for (int k = 1; k < nth; k++) {
|
||||
ggml_vec_acc_f32(ic1 - ic0, (float *) dst->data + ic0, wdata + (ne + CACHE_LINE_SIZE_F32)*k + ic0);
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
if (nb01 >= nb00) {
|
||||
// TODO: do not support transposed src1
|
||||
|
||||
// parallelize by src0 rows using ggml_vec_dot_q4_0
|
||||
@ -6626,56 +6286,6 @@ static void ggml_compute_forward_mul_mat_q4_0_f32(
|
||||
ggml_vec_dot_q4_0(ne00, &dst_col[ic*ne0], src0_row, ((void *) (src1_col + (ic*ne00*GGML_TYPE_SIZE[GGML_TYPE_Q4_0])/GGML_BLCK_SIZE[GGML_TYPE_Q4_0])));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
//printf("AAAAA ith = %d, nth = %d\n", ith, nth);
|
||||
// parallelize by src1 columns using ggml_vec_mad_q4_0
|
||||
// each thread has its own work data
|
||||
// during FINALIZE we accumulate all work data into dst
|
||||
|
||||
// total columns in src1
|
||||
const int nc = ne10;
|
||||
|
||||
// columns per thread
|
||||
const int dc = (nc + nth - 1)/nth;
|
||||
|
||||
// column range for this thread
|
||||
const int ic0 = dc*ith;
|
||||
const int ic1 = MIN(ic0 + dc, nc);
|
||||
|
||||
// work data for thread
|
||||
const int wo = (ne + CACHE_LINE_SIZE_F32)*ith;
|
||||
float * const wdata = params->wdata;
|
||||
|
||||
for (int i13 = 0; i13 < ne13; ++i13) {
|
||||
for (int i12 = 0; i12 < ne12; ++i12) {
|
||||
for (int i11 = 0; i11 < ne11; ++i11) {
|
||||
// dst indices
|
||||
const int i1 = i11;
|
||||
const int i2 = i12;
|
||||
const int i3 = i13;
|
||||
|
||||
float * dst_row = wdata + wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0;
|
||||
|
||||
for (int ic = ic0; ic < ic1; ++ic) {
|
||||
// src1 indices
|
||||
const int i10 = ic;
|
||||
|
||||
// src0 indices
|
||||
const int i03 = i13;
|
||||
const int i02 = i12;
|
||||
const int i00 = ic;
|
||||
|
||||
assert(sizeof(float)*(wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0 + ne01) <= params->wsize);
|
||||
|
||||
void * src0_col = (void *) ((char *) src0->data + (i00*nb00 + i02*nb02 + i03*nb03));
|
||||
float src1_val = *(float *) ((char *) src1->data + (i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
|
||||
|
||||
ggml_vec_mad_q4_0(ne01, dst_row, src0_col, src1_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//int64_t t1 = ggml_time_us();
|
||||
//static int64_t acc = 0;
|
||||
@ -6738,7 +6348,7 @@ static void ggml_compute_forward_mul_mat_q4_1_f32(
|
||||
GGML_ASSERT(ne3 == ne13);
|
||||
|
||||
// TODO: we don't support permuted src0
|
||||
GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[GGML_TYPE_Q4_1] || nb01 == (int) GGML_TYPE_SIZE[GGML_TYPE_Q4_1]);
|
||||
GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[GGML_TYPE_Q4_1]);
|
||||
|
||||
// dst cannot be transposed or permuted
|
||||
GGML_ASSERT(nb0 == sizeof(float));
|
||||
@ -6753,9 +6363,6 @@ static void ggml_compute_forward_mul_mat_q4_1_f32(
|
||||
|
||||
// nb01 >= nb00 - src0 is not transposed
|
||||
// compute by src0 rows
|
||||
//
|
||||
// nb00 < nb01 - src0 is transposed
|
||||
// compute by src0 columns
|
||||
|
||||
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
||||
if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
|
||||
@ -6780,9 +6387,6 @@ static void ggml_compute_forward_mul_mat_q4_1_f32(
|
||||
{
|
||||
size_t id = 0;
|
||||
for (int i01 = 0; i01 < ne01; ++i01) {
|
||||
//for (int i00 = 0; i00 < ne00; ++i00) {
|
||||
// wdata[id++] = GGML_FP16_TO_FP32(*(ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00));
|
||||
//}
|
||||
dequantize_row_q4_1((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01, wdata + id, ne00);
|
||||
id += ne00;
|
||||
}
|
||||
@ -6809,8 +6413,6 @@ static void ggml_compute_forward_mul_mat_q4_1_f32(
|
||||
#endif
|
||||
|
||||
if (params->type == GGML_TASK_INIT) {
|
||||
//printf("HHHHHHHHH ith = %d, nth = %d\n", ith, nth);
|
||||
if (nb01 >= nb00) {
|
||||
char * wdata = params->wdata;
|
||||
|
||||
for (int i13 = 0; i13 < ne13; ++i13) {
|
||||
@ -6828,35 +6430,10 @@ static void ggml_compute_forward_mul_mat_q4_1_f32(
|
||||
return;
|
||||
}
|
||||
|
||||
// TODO: fix this memset (wsize is overestimated)
|
||||
memset(params->wdata, 0, params->wsize);
|
||||
return;
|
||||
}
|
||||
|
||||
if (params->type == GGML_TASK_FINALIZE) {
|
||||
if (nb01 >= nb00) {
|
||||
return;
|
||||
}
|
||||
|
||||
float * const wdata = params->wdata;
|
||||
|
||||
// cols per thread
|
||||
const int dc = (ne + nth - 1)/nth;
|
||||
|
||||
// col range for this thread
|
||||
const int ic0 = dc*ith;
|
||||
const int ic1 = MIN(ic0 + dc, ne);
|
||||
|
||||
ggml_vec_cpy_f32(ic1 - ic0, (float *) dst->data + ic0, wdata + ic0);
|
||||
|
||||
for (int k = 1; k < nth; k++) {
|
||||
ggml_vec_acc_f32(ic1 - ic0, (float *) dst->data + ic0, wdata + (ne + CACHE_LINE_SIZE_F32)*k + ic0);
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
if (nb01 >= nb00) {
|
||||
// TODO: do not support transposed src1
|
||||
|
||||
// parallelize by src0 rows using ggml_vec_dot_q4_1
|
||||
@ -6897,56 +6474,6 @@ static void ggml_compute_forward_mul_mat_q4_1_f32(
|
||||
ggml_vec_dot_q4_1(ne00, &dst_col[ic*ne0], src0_row, ((void *) (src1_col + (ic*ne00*GGML_TYPE_SIZE[GGML_TYPE_Q4_1])/GGML_BLCK_SIZE[GGML_TYPE_Q4_1])));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
//printf("AAAAA ith = %d, nth = %d\n", ith, nth);
|
||||
// parallelize by src1 columns using ggml_vec_mad_q4_1
|
||||
// each thread has its own work data
|
||||
// during FINALIZE we accumulate all work data into dst
|
||||
|
||||
// total columns in src1
|
||||
const int nc = ne10;
|
||||
|
||||
// columns per thread
|
||||
const int dc = (nc + nth - 1)/nth;
|
||||
|
||||
// column range for this thread
|
||||
const int ic0 = dc*ith;
|
||||
const int ic1 = MIN(ic0 + dc, nc);
|
||||
|
||||
// work data for thread
|
||||
const int wo = (ne + CACHE_LINE_SIZE_F32)*ith;
|
||||
float * const wdata = params->wdata;
|
||||
|
||||
for (int i13 = 0; i13 < ne13; ++i13) {
|
||||
for (int i12 = 0; i12 < ne12; ++i12) {
|
||||
for (int i11 = 0; i11 < ne11; ++i11) {
|
||||
// dst indices
|
||||
const int i1 = i11;
|
||||
const int i2 = i12;
|
||||
const int i3 = i13;
|
||||
|
||||
float * dst_row = wdata + wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0;
|
||||
|
||||
for (int ic = ic0; ic < ic1; ++ic) {
|
||||
// src1 indices
|
||||
const int i10 = ic;
|
||||
|
||||
// src0 indices
|
||||
const int i03 = i13;
|
||||
const int i02 = i12;
|
||||
const int i00 = ic;
|
||||
|
||||
assert(sizeof(float)*(wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0 + ne01) <= params->wsize);
|
||||
|
||||
void * src0_col = (void *) ((char *) src0->data + (i00*nb00 + i02*nb02 + i03*nb03));
|
||||
float src1_val = *(float *) ((char *) src1->data + (i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
|
||||
|
||||
ggml_vec_mad_q4_1(ne01, dst_row, src0_col, src1_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//int64_t t1 = ggml_time_us();
|
||||
//static int64_t acc = 0;
|
||||
@ -9588,11 +9115,6 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
||||
|
||||
size_t cur = 0;
|
||||
|
||||
// TODO: better way to determine if the matrix is transposed
|
||||
if (node->src0->nb[1] < node->src0->nb[0]) {
|
||||
cur = ggml_nbytes(node)*node->n_tasks; // TODO: this can become (n_tasks-1)
|
||||
// TODO: overestimated by factor of x2 for FP16
|
||||
} else {
|
||||
if (node->src0->type == GGML_TYPE_F16 &&
|
||||
node->src1->type == GGML_TYPE_F32) {
|
||||
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
||||
@ -9639,7 +9161,6 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
||||
} else {
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
}
|
||||
|
||||
work_size = MAX(work_size, cur);
|
||||
} break;
|
||||
|
Loading…
Reference in New Issue
Block a user