mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 19:04:35 +00:00
[SYCL] Fix DMMV dequantization (#9279)
Fixed dmmv dequant for ncols== GGML_SYCL_DMMV_X
This commit is contained in:
parent
c8671ae282
commit
5910ea9427
@ -76,8 +76,8 @@ static void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat *
|
|||||||
}
|
}
|
||||||
|
|
||||||
// sum up partial sums and write back result
|
// sum up partial sums and write back result
|
||||||
#pragma unroll
|
const int mask_start = ncols > GGML_SYCL_DMMV_X ? WARP_SIZE >> 1 : WARP_SIZE >> 2;
|
||||||
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
|
for (int mask = mask_start; mask > 0; mask >>= 1) {
|
||||||
tmp +=
|
tmp +=
|
||||||
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
|
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user