[SYCL] fix scratch size of softmax (#8642)

This commit is contained in:
luoyu-intel 2024-07-23 07:43:28 +00:00 committed by GitHub
parent 081fe431aa
commit 063d99ad11
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -152,7 +152,8 @@ static void soft_max_f32_sycl(const float * x, const float * mask,
const sycl::range<3> block_dims(1, 1, nth); const sycl::range<3> block_dims(1, 1, nth);
const sycl::range<3> block_nums(1, 1, nrows_x); const sycl::range<3> block_nums(1, 1, nrows_x);
const size_t n_local_scratch = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE); const size_t n_val_tmp = nth / WARP_SIZE;
const size_t n_local_scratch = (GGML_PAD(ncols_x, WARP_SIZE) + n_val_tmp);
const uint32_t n_head_kv = nrows_x/nrows_y; const uint32_t n_head_kv = nrows_x/nrows_y;
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv)); const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));