mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 02:44:36 +00:00
sync : ggml (fix im2col) (#4591)
* cuda : fix im2col_f32_f16 (ggml/#658) ggml-ci * ggml-alloc : fix ggml_tallocr_is_own --------- Co-authored-by: leejet <leejet714@gmail.com>
This commit is contained in:
parent
a55876955b
commit
ba66175132
@ -72,7 +72,7 @@ static void remove_allocated_tensor(ggml_tallocr_t alloc, struct ggml_tensor * t
|
|||||||
|
|
||||||
// check if a tensor is allocated by this buffer
|
// check if a tensor is allocated by this buffer
|
||||||
static bool ggml_tallocr_is_own(ggml_tallocr_t alloc, const struct ggml_tensor * tensor) {
|
static bool ggml_tallocr_is_own(ggml_tallocr_t alloc, const struct ggml_tensor * tensor) {
|
||||||
return tensor->buffer == alloc->buffer;
|
return tensor->buffer == alloc->buffer && (!tensor->view_src || tensor->view_src->buffer == alloc->buffer);
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool ggml_is_view(struct ggml_tensor * t) {
|
static bool ggml_is_view(struct ggml_tensor * t) {
|
||||||
|
@ -5273,17 +5273,17 @@ static __global__ void im2col_f32_f16(
|
|||||||
const int ky = (i - kd) / OW;
|
const int ky = (i - kd) / OW;
|
||||||
const int ix = i % OW;
|
const int ix = i % OW;
|
||||||
|
|
||||||
const int iiw = ix * s0 + kx * d0 - p0;
|
const int64_t iiw = ix * s0 + kx * d0 - p0;
|
||||||
const int iih = blockIdx.y * s1 + ky * d1 - p1;
|
const int64_t iih = blockIdx.y * s1 + ky * d1 - p1;
|
||||||
|
|
||||||
const int offset_dst =
|
const int64_t offset_dst =
|
||||||
(blockIdx.y * OW + ix) * CHW +
|
(blockIdx.y * OW + ix) * CHW +
|
||||||
(blockIdx.z * (KW * KH) + ky * KW + kx);
|
(blockIdx.z * (KW * KH) + ky * KW + kx);
|
||||||
|
|
||||||
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
|
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
|
||||||
dst[offset_dst] = __float2half(0.0f);
|
dst[offset_dst] = __float2half(0.0f);
|
||||||
} else {
|
} else {
|
||||||
const int offset_src = blockIdx.z * offset_delta;
|
const int64_t offset_src = blockIdx.z * offset_delta;
|
||||||
dst[offset_dst] = __float2half(x[offset_src + iih * IW + iiw]);
|
dst[offset_dst] = __float2half(x[offset_src + iih * IW + iiw]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user