mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 19:04:35 +00:00
Fix im2col with 32fp (#5286)
This commit is contained in:
parent
191221178f
commit
a305dba8ff
@ -8247,7 +8247,8 @@ static void clamp_f32(const float * x, float * dst, const float min, const float
|
|||||||
dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
|
dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void im2col_f32_f16(const float *x, sycl::half *dst, int offset_delta,
|
template <typename T>
|
||||||
|
static void im2col_kernel(const float *x, T *dst, int offset_delta,
|
||||||
int IW, int IH, int OW, int KW, int KH,
|
int IW, int IH, int OW, int KW, int KH,
|
||||||
int pelements, int CHW, int s0, int s1, int p0,
|
int pelements, int CHW, int s0, int s1, int p0,
|
||||||
int p1, int d0, int d1,
|
int p1, int d0, int d1,
|
||||||
@ -11019,7 +11020,8 @@ static void soft_max_f32_sycl(const float *x, const float *y, float *dst,
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
static void im2col_f32_f16_sycl(const float *x, sycl::half *dst, int IW, int IH,
|
template <typename T>
|
||||||
|
static void im2col_sycl(const float *x, T *dst, int IW, int IH,
|
||||||
int OW, int OH, int KW, int KH, int IC,
|
int OW, int OH, int KW, int KH, int IC,
|
||||||
int offset_delta, int s0, int s1, int p0,
|
int offset_delta, int s0, int s1, int p0,
|
||||||
int p1, int d0, int d1,
|
int p1, int d0, int d1,
|
||||||
@ -11036,7 +11038,7 @@ static void im2col_f32_f16_sycl(const float *x, sycl::half *dst, int IW, int IH,
|
|||||||
sycl::range<3>(1, 1, SYCL_IM2COL_BLOCK_SIZE),
|
sycl::range<3>(1, 1, SYCL_IM2COL_BLOCK_SIZE),
|
||||||
sycl::range<3>(1, 1, SYCL_IM2COL_BLOCK_SIZE)),
|
sycl::range<3>(1, 1, SYCL_IM2COL_BLOCK_SIZE)),
|
||||||
[=](sycl::nd_item<3> item_ct1) {
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
im2col_f32_f16(x, dst, offset_delta, IW, IH, OW, KW, KH,
|
im2col_kernel(x, dst, offset_delta, IW, IH, OW, KW, KH,
|
||||||
parallel_elements, (IC * KH * KW), s0, s1, p0,
|
parallel_elements, (IC * KH * KW), s0, s1, p0,
|
||||||
p1, d0, d1, item_ct1);
|
p1, d0, d1, item_ct1);
|
||||||
});
|
});
|
||||||
@ -12424,7 +12426,7 @@ inline void ggml_sycl_op_im2col(const ggml_tensor *src0,
|
|||||||
|
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
||||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT( dst->type == GGML_TYPE_F16);
|
GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
|
const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
|
||||||
const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
|
const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
|
||||||
@ -12447,8 +12449,11 @@ inline void ggml_sycl_op_im2col(const ggml_tensor *src0,
|
|||||||
|
|
||||||
const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
|
const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
|
||||||
|
|
||||||
im2col_f32_f16_sycl(src1_dd, (sycl::half *)dst_dd, IW, IH, OW, OH, KW, KH,
|
if (dst->type == GGML_TYPE_F16) {
|
||||||
IC, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
|
im2col_sycl(src1_dd, (sycl::half *)dst_dd, IW, IH, OW, OH, KW, KH, IC, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
|
||||||
|
} else {
|
||||||
|
im2col_sycl(src1_dd, (float *)dst_dd, IW, IH, OW, OH, KW, KH, IC, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
|
||||||
|
}
|
||||||
|
|
||||||
(void) src0;
|
(void) src0;
|
||||||
(void) src0_dd;
|
(void) src0_dd;
|
||||||
|
Loading…
Reference in New Issue
Block a user