mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-11-11 21:39:52 +00:00
Update ggml_sycl_op_mul_mat_vec_q (#5502)
* Update ggml_sycl_op_mul_mat_vec_q * Apply suggestions from code review Co-authored-by: Abhilash Majumder <30946547+abhilash1910@users.noreply.github.com> * revert suggestion on macro * fix bug * Add quant type GGML_TYPE_IQ1_S to unsupported * fix format --------- Co-authored-by: Abhilash Majumder <30946547+abhilash1910@users.noreply.github.com>
This commit is contained in:
parent
633782b8d9
commit
b9111bd209
220
ggml-sycl.cpp
220
ggml-sycl.cpp
@ -9188,7 +9188,9 @@ static void convert_mul_mat_vec_f16_sycl(const void *vx, const dfloat *y,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void mul_mat_vec_q4_0_q8_1_sycl(const void *vx, const void *vy,
|
template <int qk, int qi, typename block_q_t, int vdr,
|
||||||
|
vec_dot_q_sycl_t vec_dot_q_sycl>
|
||||||
|
static void mul_mat_vec_q_sycl_submitter(const void *vx, const void *vy,
|
||||||
float *dst, const int ncols,
|
float *dst, const int ncols,
|
||||||
const int nrows,
|
const int nrows,
|
||||||
dpct::queue_ptr stream) {
|
dpct::queue_ptr stream) {
|
||||||
@ -9197,164 +9199,10 @@ static void mul_mat_vec_q4_0_q8_1_sycl(const void *vx, const void *vy,
|
|||||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
||||||
stream->parallel_for(
|
stream->parallel_for(
|
||||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
sycl::nd_range<3>(block_nums * block_dims, block_dims), [=
|
||||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
|
](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
|
||||||
mul_mat_vec_q<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ,
|
mul_mat_vec_q<qk, qi, block_q_t, vdr, vec_dot_q_sycl>(
|
||||||
vec_dot_q4_0_q8_1>(vx, vy, dst, ncols, nrows,
|
vx, vy, dst, ncols, nrows, item_ct1);
|
||||||
item_ct1);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
static void mul_mat_vec_q4_1_q8_1_sycl(const void *vx, const void *vy,
|
|
||||||
float *dst, const int ncols,
|
|
||||||
const int nrows,
|
|
||||||
dpct::queue_ptr stream) {
|
|
||||||
GGML_ASSERT(ncols % QK4_1 == 0);
|
|
||||||
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
|
||||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
|
||||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
|
||||||
stream->parallel_for(
|
|
||||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
|
||||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
|
|
||||||
mul_mat_vec_q<QK4_0, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ,
|
|
||||||
vec_dot_q4_1_q8_1>(vx, vy, dst, ncols, nrows,
|
|
||||||
item_ct1);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy,
|
|
||||||
float *dst, const int ncols,
|
|
||||||
const int nrows,
|
|
||||||
dpct::queue_ptr stream) {
|
|
||||||
GGML_ASSERT(ncols % QK5_0 == 0);
|
|
||||||
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
|
||||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
|
||||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
|
||||||
stream->parallel_for(
|
|
||||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
|
||||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
|
|
||||||
mul_mat_vec_q<QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ,
|
|
||||||
vec_dot_q5_0_q8_1>(vx, vy, dst, ncols, nrows,
|
|
||||||
item_ct1);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
static void mul_mat_vec_q5_1_q8_1_sycl(const void *vx, const void *vy,
|
|
||||||
float *dst, const int ncols,
|
|
||||||
const int nrows,
|
|
||||||
dpct::queue_ptr stream) {
|
|
||||||
GGML_ASSERT(ncols % QK5_1 == 0);
|
|
||||||
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
|
||||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
|
||||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
|
||||||
stream->parallel_for(
|
|
||||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
|
||||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
|
|
||||||
mul_mat_vec_q<QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ,
|
|
||||||
vec_dot_q5_1_q8_1>(vx, vy, dst, ncols, nrows,
|
|
||||||
item_ct1);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
static void mul_mat_vec_q8_0_q8_1_sycl(const void *vx, const void *vy,
|
|
||||||
float *dst, const int ncols,
|
|
||||||
const int nrows,
|
|
||||||
dpct::queue_ptr stream) {
|
|
||||||
GGML_ASSERT(ncols % QK8_0 == 0);
|
|
||||||
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
|
||||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
|
||||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
|
||||||
stream->parallel_for(
|
|
||||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
|
||||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
|
|
||||||
mul_mat_vec_q<QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ,
|
|
||||||
vec_dot_q8_0_q8_1>(vx, vy, dst, ncols, nrows,
|
|
||||||
item_ct1);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
static void mul_mat_vec_q2_K_q8_1_sycl(const void *vx, const void *vy,
|
|
||||||
float *dst, const int ncols,
|
|
||||||
const int nrows,
|
|
||||||
dpct::queue_ptr stream) {
|
|
||||||
GGML_ASSERT(ncols % QK_K == 0);
|
|
||||||
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
|
||||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
|
||||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
|
||||||
stream->parallel_for(
|
|
||||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
|
||||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
|
|
||||||
mul_mat_vec_q<QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ,
|
|
||||||
vec_dot_q2_K_q8_1>(vx, vy, dst, ncols, nrows,
|
|
||||||
item_ct1);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
static void mul_mat_vec_q3_K_q8_1_sycl(const void *vx, const void *vy,
|
|
||||||
float *dst, const int ncols,
|
|
||||||
const int nrows,
|
|
||||||
dpct::queue_ptr stream) {
|
|
||||||
GGML_ASSERT(ncols % QK_K == 0);
|
|
||||||
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
|
||||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
|
||||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
|
||||||
stream->parallel_for(
|
|
||||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
|
||||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
|
|
||||||
mul_mat_vec_q<QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ,
|
|
||||||
vec_dot_q3_K_q8_1>(vx, vy, dst, ncols, nrows,
|
|
||||||
item_ct1);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy,
|
|
||||||
float *dst, const int ncols,
|
|
||||||
const int nrows,
|
|
||||||
dpct::queue_ptr stream) {
|
|
||||||
GGML_ASSERT(ncols % QK_K == 0);
|
|
||||||
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
|
||||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
|
||||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
|
||||||
stream->parallel_for(
|
|
||||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
|
||||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
|
|
||||||
mul_mat_vec_q<QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ,
|
|
||||||
vec_dot_q4_K_q8_1>(vx, vy, dst, ncols, nrows,
|
|
||||||
item_ct1);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy,
|
|
||||||
float *dst, const int ncols,
|
|
||||||
const int nrows,
|
|
||||||
dpct::queue_ptr stream) {
|
|
||||||
GGML_ASSERT(ncols % QK_K == 0);
|
|
||||||
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
|
||||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
|
||||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
|
||||||
stream->parallel_for(
|
|
||||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
|
||||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
|
|
||||||
mul_mat_vec_q<QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ,
|
|
||||||
vec_dot_q5_K_q8_1>(vx, vy, dst, ncols, nrows,
|
|
||||||
item_ct1);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy,
|
|
||||||
float *dst, const int ncols,
|
|
||||||
const int nrows,
|
|
||||||
dpct::queue_ptr stream) {
|
|
||||||
GGML_ASSERT(ncols % QK_K == 0);
|
|
||||||
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
|
|
||||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
|
||||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
|
||||||
stream->parallel_for(
|
|
||||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
|
||||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
|
|
||||||
mul_mat_vec_q<QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ,
|
|
||||||
vec_dot_q6_K_q8_1>(vx, vy, dst, ncols, nrows,
|
|
||||||
item_ct1);
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -12095,36 +11943,62 @@ inline void ggml_sycl_op_mul_mat_vec_q(
|
|||||||
const int64_t ne00 = src0->ne[0];
|
const int64_t ne00 = src0->ne[0];
|
||||||
const int64_t row_diff = row_high - row_low;
|
const int64_t row_diff = row_high - row_low;
|
||||||
|
|
||||||
|
// TODO: support these quantization types
|
||||||
|
GGML_ASSERT(!(src0->type == GGML_TYPE_IQ2_XXS ||
|
||||||
|
src0->type == GGML_TYPE_IQ2_XS ||
|
||||||
|
src0->type == GGML_TYPE_IQ3_XXS ||
|
||||||
|
src0->type == GGML_TYPE_IQ1_S));
|
||||||
|
|
||||||
switch (src0->type) {
|
switch (src0->type) {
|
||||||
case GGML_TYPE_Q4_0:
|
case GGML_TYPE_Q4_0:
|
||||||
mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
|
mul_mat_vec_q_sycl_submitter<QK4_0, QI4_0, block_q4_0,
|
||||||
|
VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>(
|
||||||
|
src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_Q4_1:
|
case GGML_TYPE_Q4_1:
|
||||||
mul_mat_vec_q4_1_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
|
mul_mat_vec_q_sycl_submitter<QK4_1, QI4_1, block_q4_1,
|
||||||
|
VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>(
|
||||||
|
src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_Q5_0:
|
case GGML_TYPE_Q5_0:
|
||||||
mul_mat_vec_q5_0_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
|
mul_mat_vec_q_sycl_submitter<QK5_0, QI5_0, block_q5_0,
|
||||||
|
VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>(
|
||||||
|
src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_Q5_1:
|
case GGML_TYPE_Q5_1:
|
||||||
mul_mat_vec_q5_1_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
|
mul_mat_vec_q_sycl_submitter<QK5_1, QI5_1, block_q5_1,
|
||||||
|
VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>(
|
||||||
|
src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_Q8_0:
|
case GGML_TYPE_Q8_0:
|
||||||
mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
|
mul_mat_vec_q_sycl_submitter<QK8_0, QI8_0, block_q8_0,
|
||||||
|
VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>(
|
||||||
|
src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_Q2_K:
|
case GGML_TYPE_Q2_K:
|
||||||
mul_mat_vec_q2_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
|
mul_mat_vec_q_sycl_submitter<QK_K, QI2_K, block_q2_K,
|
||||||
|
VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>(
|
||||||
|
src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_Q3_K:
|
case GGML_TYPE_Q3_K:
|
||||||
mul_mat_vec_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
|
mul_mat_vec_q_sycl_submitter<QK_K, QI3_K, block_q3_K,
|
||||||
|
VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>(
|
||||||
|
src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_Q4_K:
|
case GGML_TYPE_Q4_K:
|
||||||
mul_mat_vec_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
|
mul_mat_vec_q_sycl_submitter<QK_K, QI4_K, block_q4_K,
|
||||||
|
VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>(
|
||||||
|
src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_Q5_K:
|
case GGML_TYPE_Q5_K:
|
||||||
mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
|
mul_mat_vec_q_sycl_submitter<QK_K, QI5_K, block_q5_K,
|
||||||
|
VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>(
|
||||||
|
src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_Q6_K:
|
case GGML_TYPE_Q6_K:
|
||||||
mul_mat_vec_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
|
mul_mat_vec_q_sycl_submitter<QK_K, QI6_K, block_q6_K,
|
||||||
|
VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>(
|
||||||
|
src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
GGML_ASSERT(false);
|
GGML_ASSERT(false);
|
||||||
@ -12145,7 +12019,7 @@ inline void ggml_sycl_op_dequantize_mul_mat_vec(
|
|||||||
const int64_t src1_ncols, const int64_t src1_padded_row_size,
|
const int64_t src1_ncols, const int64_t src1_padded_row_size,
|
||||||
const dpct::queue_ptr &stream) {
|
const dpct::queue_ptr &stream) {
|
||||||
|
|
||||||
GGML_TENSOR_BINARY_OP_LOCALS
|
GGML_TENSOR_BINARY_OP_LOCALS;
|
||||||
|
|
||||||
const int64_t row_diff = row_high - row_low;
|
const int64_t row_diff = row_high - row_low;
|
||||||
|
|
||||||
@ -15093,6 +14967,12 @@ static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, const ggml_ten
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (a->type == GGML_TYPE_IQ1_S) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (a->type == GGML_TYPE_IQ3_XXS) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
if (a->type == GGML_TYPE_IQ2_XXS) {
|
if (a->type == GGML_TYPE_IQ2_XXS) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user