From fadde6713506d9e6c124f5680ab8c7abebe31837 Mon Sep 17 00:00:00 2001 From: AidanBeltonS Date: Wed, 3 Jul 2024 02:55:34 +0100 Subject: [PATCH] Dequant improvements rebase (#8255) * Single load for half2 * Store scales in local mem * Vec load quantized values --- ggml/src/ggml-sycl/common.hpp | 6 ++++++ ggml/src/ggml-sycl/convert.cpp | 7 +++++-- ggml/src/ggml-sycl/dequantize.hpp | 30 +++++++++++++++++++----------- 3 files changed, 30 insertions(+), 13 deletions(-) diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index dfd4a7c2c..476d847ca 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -351,4 +351,10 @@ static __dpct_inline__ float warp_reduce_max(float x, return x; } +// Helper for vec loading aligned data +template +inline sycl::vec vec_aligned_load(const Tp* aligned_ptr) { + return *reinterpret_cast*>(aligned_ptr); +} + #endif // GGML_SYCL_COMMON_HPP diff --git a/ggml/src/ggml-sycl/convert.cpp b/ggml/src/ggml-sycl/convert.cpp index ce9de2b42..a15271b51 100644 --- a/ggml/src/ggml-sycl/convert.cpp +++ b/ggml/src/ggml-sycl/convert.cpp @@ -152,12 +152,15 @@ static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int k, dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); - stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * + stream->submit([&](sycl::handler &cgh) { + sycl::local_accessor scale_local_acc(sycl::range<1>(12), cgh); + cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)), [=](sycl::nd_item<3> item_ct1) { - dequantize_block_q4_K(vx, y, item_ct1); + dequantize_block_q4_K(vx, y, scale_local_acc.get_pointer(), item_ct1); }); + }); } } diff --git a/ggml/src/ggml-sycl/dequantize.hpp b/ggml/src/ggml-sycl/dequantize.hpp index b6080d83a..ed8ad098b 100644 --- a/ggml/src/ggml-sycl/dequantize.hpp +++ b/ggml/src/ggml-sycl/dequantize.hpp @@ -293,7 +293,8 @@ static void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restri #if QK_K == 256 static inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) { if (j < 4) { - d = q[j] & 63; m = q[j + 4] & 63; + d = q[j] & 63; + m = q[j + 4] & 63; } else { d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4); m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4); @@ -303,7 +304,7 @@ static inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8 template static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy, - const sycl::nd_item<3> &item_ct1) { + uint8_t* scales_local, const sycl::nd_item<3> &item_ct1) { const block_q4_K * x = (const block_q4_K *) vx; const int i = item_ct1.get_group(2); @@ -318,19 +319,26 @@ static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restri dst_t * y = yy + i*QK_K + 64*il + n*ir; - const float dall = x[i].dm[0]; - const float dmin = x[i].dm[1]; + const sycl::half2 dm = x[i].dm; + const float dall = dm[0]; + const float dmin = dm[1]; - const uint8_t * q = x[i].qs + 32*il + n*ir; + if (tid < 12) + scales_local[tid] = x[i].scales[tid]; + item_ct1.barrier(sycl::access::fence_space::local_space); uint8_t sc, m; - get_scale_min_k4(is + 0, x[i].scales, sc, m); - const float d1 = dall * sc; const float m1 = dmin * m; - get_scale_min_k4(is + 1, x[i].scales, sc, m); - const float d2 = dall * sc; const float m2 = dmin * m; + get_scale_min_k4(is + 0, scales_local, sc, m); + const float d1 = dall * sc; + const float m1 = dmin * m; + get_scale_min_k4(is + 1, scales_local, sc, m); + const float d2 = dall * sc; + const float m2 = dmin * m; + + sycl::vec q_vec = vec_aligned_load(x[i].qs + 32*il + n*ir); for (int l = 0; l < n; ++l) { - y[l + 0] = d1 * (q[l] & 0xF) - m1; - y[l +32] = d2 * (q[l] >> 4) - m2; + y[l + 0] = d1 * (q_vec[l] & 0xF) - m1; + y[l +32] = d2 * (q_vec[l] >> 4) - m2; } #else const int tid = item_ct1.get_local_id(2);