sycl : fix grid type

This commit is contained in:
Georgi Gerganov 2024-03-11 15:17:08 +02:00
parent cb5a702e9c
commit 76be02aebc
No known key found for this signature in database
GPG Key ID: BF970631944C16B7

View File

@ -4891,7 +4891,7 @@ static void dequantize_block_iq3_s(const void * __restrict__ vx, dst_t * __restr
template<typename dst_t> template<typename dst_t>
static void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restrict__ yy, static void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restrict__ yy,
const sycl::nd_item<3> &item_ct1, const sycl::nd_item<3> &item_ct1,
const uint64_t *iq1s_grid, const uint32_t *iq1s_grid,
const uint8_t *ksigns_iq2xs, const uint8_t *ksigns_iq2xs,
const uint8_t *kmask_iq2xs) { const uint8_t *kmask_iq2xs) {
const int i = item_ct1.get_group(2); const int i = item_ct1.get_group(2);
@ -7806,7 +7806,7 @@ vec_dot_iq3_s_q8_1(const void *__restrict__ vbq,
static __dpct_inline__ float static __dpct_inline__ float
vec_dot_iq1_s_q8_1(const void *__restrict__ vbq, vec_dot_iq1_s_q8_1(const void *__restrict__ vbq,
const block_q8_1 *__restrict__ bq8_1, const int &iqs, const block_q8_1 *__restrict__ bq8_1, const int &iqs,
const uint64_t *iq1s_grid, const uint64_t *ksigns64) { const uint32_t *iq1s_grid, const uint64_t *ksigns64) {
#if QK_K == 256 #if QK_K == 256
const block_iq1_s * bq1 = (const block_iq1_s *) vbq; const block_iq1_s * bq1 = (const block_iq1_s *) vbq;
@ -8646,7 +8646,7 @@ static void mul_mat_vec_q_iq3_s_q8_1(const void * __restrict__ vx, const void *
template <int qk, int qi, typename block_q_t, int vdr> template <int qk, int qi, typename block_q_t, int vdr>
static void mul_mat_vec_q_iq1_s_q8_1(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows, static void mul_mat_vec_q_iq1_s_q8_1(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows,
const sycl::nd_item<3> &item_ct1, const sycl::nd_item<3> &item_ct1,
const uint64_t *iq1s_grid_ptr, const uint64_t *ksigns64_ptr ) { const uint32_t *iq1s_grid_ptr, const uint64_t *ksigns64_ptr ) {
const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
item_ct1.get_local_id(1); item_ct1.get_local_id(1);