remove wrong assert in norm

WA for permute(0,1,3,2) mul_mat
ggml-ci
This commit is contained in:
Meng, Hengyu 2024-10-25 07:41:48 +00:00
parent 958367bf53
commit c263ca767b
2 changed files with 7 additions and 3 deletions

View File

@ -5173,6 +5173,10 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
if (op->op == GGML_OP_MUL_MAT) { if (op->op == GGML_OP_MUL_MAT) {
a = op->src[0]; a = op->src[0];
b = op->src[1]; b = op->src[1];
if (ggml_is_permuted(a) || ggml_is_permuted(b)) {
// TODO: fix like https://github.com/ggerganov/llama.cpp/pull/10021
return false;
}
} else { } else {
a = op->src[2]; a = op->src[2];
b = op->src[1]; b = op->src[1];

View File

@ -8,7 +8,6 @@ static void norm_f32(const float* x, float* dst, const int ncols, const float ep
const int nthreads = item_ct1.get_local_range(2); const int nthreads = item_ct1.get_local_range(2);
const int nwarps = nthreads / WARP_SIZE; const int nwarps = nthreads / WARP_SIZE;
assert(nwarps % WARP_SIZE == 0);
sycl::float2 mean_var = sycl::float2(0.f, 0.f); sycl::float2 mean_var = sycl::float2(0.f, 0.f);
for (int col = tid; col < ncols; col += block_size) { for (int col = tid; col < ncols; col += block_size) {
@ -55,7 +54,6 @@ static void group_norm_f32(const float* x, float* dst, const int group_size, con
int end = start + group_size; int end = start + group_size;
const int nthreads = item_ct1.get_local_range(2); const int nthreads = item_ct1.get_local_range(2);
const int nwarps = nthreads / WARP_SIZE; const int nwarps = nthreads / WARP_SIZE;
assert(nwarps % WARP_SIZE == 0);
start += item_ct1.get_local_id(2); start += item_ct1.get_local_id(2);
int nreduce = nwarps / WARP_SIZE; int nreduce = nwarps / WARP_SIZE;
@ -144,7 +142,6 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const floa
const int tid = item_ct1.get_local_id(2); const int tid = item_ct1.get_local_id(2);
const int nthreads = item_ct1.get_local_range(2); const int nthreads = item_ct1.get_local_range(2);
const int nwarps = nthreads / WARP_SIZE; const int nwarps = nthreads / WARP_SIZE;
assert(nwarps % WARP_SIZE == 0);
float tmp = 0.0f; // partial sum for thread in warp float tmp = 0.0f; // partial sum for thread in warp
for (int col = tid; col < ncols; col += block_size) { for (int col = tid; col < ncols; col += block_size) {
@ -202,6 +199,7 @@ static void norm_f32_sycl(const float* x, float* dst, const int ncols,
} }
else { else {
const int work_group_size = ggml_sycl_info().max_work_group_sizes[device]; const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
const sycl::range<3> block_dims(1, 1, work_group_size); const sycl::range<3> block_dims(1, 1, work_group_size);
/* /*
DPCT1049:17: The work-group size passed to the SYCL kernel may exceed DPCT1049:17: The work-group size passed to the SYCL kernel may exceed
@ -244,6 +242,7 @@ static void group_norm_f32_sycl(const float* x, float* dst,
} }
else { else {
const int work_group_size = ggml_sycl_info().max_work_group_sizes[device]; const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
const sycl::range<3> block_dims(1, 1, work_group_size); const sycl::range<3> block_dims(1, 1, work_group_size);
/* /*
DPCT1049:18: The work-group size passed to the SYCL kernel may exceed DPCT1049:18: The work-group size passed to the SYCL kernel may exceed
@ -290,6 +289,7 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols,
} }
else { else {
const int work_group_size = ggml_sycl_info().max_work_group_sizes[device]; const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
const sycl::range<3> block_dims(1, 1, work_group_size); const sycl::range<3> block_dims(1, 1, work_group_size);
/* /*
DPCT1049:19: The work-group size passed to the SYCL kernel may exceed DPCT1049:19: The work-group size passed to the SYCL kernel may exceed