cuda : poc for norm quants (only -b 1 works)

This commit is contained in:
Georgi Gerganov 2023-08-30 21:39:49 +03:00
parent df54d2f1d4
commit 8c2b881281

View File

@ -163,21 +163,31 @@ typedef float2 dfloat2;
#endif //GGML_CUDA_F16
static __device__ __forceinline__ int get_int_from_int8(const int8_t * x8, const int & i32) {
const uint16_t * x16 = (uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment
//const uint16_t * x16 = (uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment
x8 += sizeof(int) * i32;
int x32 = 0;
x32 |= x16[0] << 0;
x32 |= x16[1] << 16;
//x32 |= x16[0] << 0;
//x32 |= x16[1] << 16;
x32 |= ((uint32_t)(x8[0])) << 0;
x32 |= ((uint32_t)(x8[1])) << 8;
x32 |= ((uint32_t)(x8[2])) << 16;
x32 |= ((uint32_t)(x8[3])) << 24;
return x32;
}
static __device__ __forceinline__ int get_int_from_uint8(const uint8_t * x8, const int & i32) {
const uint16_t * x16 = (uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment
//const uint16_t * x16 = (uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment
x8 += sizeof(int) * i32;
int x32 = 0;
x32 |= x16[0] << 0;
x32 |= x16[1] << 16;
//x32 |= x16[0] << 0;
//x32 |= x16[1] << 16;
x32 |= ((uint32_t)(x8[0])) << 0;
x32 |= ((uint32_t)(x8[1])) << 8;
x32 |= ((uint32_t)(x8[2])) << 16;
x32 |= ((uint32_t)(x8[3])) << 24;
return x32;
}
@ -2093,7 +2103,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbx;
x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->qs, kqsx);
// x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbx] = bxi->d;
//x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbx] = Q4_0D(bxi->d);
}
const int blocks_per_tile_x_row = WARP_SIZE / QI4_0;
@ -2109,7 +2119,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbxd;
x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbxd] = bxi->d;
x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbxd] = Q4_0D(bxi->d);
}
}
@ -2143,7 +2153,7 @@ static __device__ __forceinline__ float vec_dot_q4_1_q8_1(
#pragma unroll
for (int i = 0; i < VDR_Q4_1_Q8_1_MMVQ; ++i) {
v[i] = get_int_from_uint8_aligned(bq4_1->qs, iqs + i);
v[i] = get_int_from_uint8(bq4_1->qs, iqs + i);
u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI4_1);
}
@ -2151,7 +2161,7 @@ static __device__ __forceinline__ float vec_dot_q4_1_q8_1(
const float d = Q4_1D(bq4_1->dm);
const float m = Q4_1M(bq4_1->dm);
const float2 dm = {d, m};
const half2 dm = {d, m};
return vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMVQ>(v, u, dm, bq8_1->ds);
}
@ -2189,7 +2199,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
const block_q4_1 * bxi = bx0 + i*blocks_per_row + kbx;
x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx);
x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->qs, kqsx);
}
const int blocks_per_tile_x_row = WARP_SIZE / QI4_1;
@ -2205,7 +2215,8 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
const block_q4_1 * bxi = bx0 + i*blocks_per_row + kbxd;
x_dm[i * (WARP_SIZE/QI4_1) + i / QI4_1 + kbxd] = bxi->dm;
x_dm[i * (WARP_SIZE/QI4_1) + i / QI4_1 + kbxd].x = Q4_1D(bxi->dm);
x_dm[i * (WARP_SIZE/QI4_1) + i / QI4_1 + kbxd].y = Q4_1M(bxi->dm);
}
}
@ -2353,16 +2364,16 @@ static __device__ __forceinline__ float vec_dot_q5_1_q8_1(
#pragma unroll
for (int i = 0; i < VDR_Q5_1_Q8_1_MMVQ; ++i) {
vl[i] = get_int_from_uint8_aligned(bq5_1->qs, iqs + i);
vh[i] = get_int_from_uint8_aligned(bq5_1->qh, 0) >> (4 * (iqs + i));
vl[i] = get_int_from_uint8(bq5_1->qs, iqs + i);
vh[i] = get_int_from_uint8(bq5_1->qh, 0) >> (4 * (iqs + i));
u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI5_1);
}
const float d = Q5_1D(bq4_1->dm);
const float m = Q5_1M(bq4_1->dm);
const half d = Q5_1D(bq5_1->dm);
const half m = Q5_1M(bq5_1->dm);
const float2 dm = {d, m};
const half2 dm = {d, m};
return vec_dot_q5_1_q8_1_impl<VDR_Q5_1_Q8_1_MMVQ>(vl, vh, u, dm, bq8_1->ds);
}
@ -2400,8 +2411,8 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
const block_q5_1 * bxi = bx0 + i*blocks_per_row + kbx;
const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx);
const int qh = get_int_from_uint8_aligned(bxi->qh, 0) >> (4 * (k % QI5_1));
const int ql = get_int_from_uint8(bxi->qs, kqsx);
const int qh = get_int_from_uint8(bxi->qh, 0) >> (4 * (k % QI5_1));
int qs0 = (ql >> 0) & 0x0F0F0F0F;
qs0 |= (qh << 4) & 0x00000010; // 0 -> 4
@ -2433,7 +2444,8 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
const block_q5_1 * bxi = bx0 + i*blocks_per_row + kbxd;
x_dm[i * (WARP_SIZE/QI5_1) + i / QI5_1 + kbxd] = bxi->dm;
x_dm[i * (WARP_SIZE/QI5_1) + i / QI5_1 + kbxd].x = Q5_1D(bxi->dm);
x_dm[i * (WARP_SIZE/QI5_1) + i / QI5_1 + kbxd].y = Q5_1M(bxi->dm);
}
}