mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 10:54:36 +00:00
add some new ops, fix some operators and add batch operations to certain operators. (ggml/747)
* cuda: fix group_norm * cuda: add batch inference support for ggml_pad/ggml_upscale * add ggml_arrange * add ggml_timestep_embedding * update ggml_arange/ggml_timestep_embedding tests * cuda: fix im2col * add ggml_arange/ggml_timestep_embbeding support for metal backend * fix some bugs * fix some bugs * Update ggml.h Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Update ggml-cuda.cu Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Update ggml-metal.m Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Update ggml-metal.m Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Update ggml-metal.metal Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * modify according to the review comments * ggml : fix compile warnings + code style * ggml : normalize compute_forward calls + fix seg fault in debug * minor --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> Co-authored-by: slaren <slarengh@gmail.com>
This commit is contained in:
parent
82f3e668ad
commit
7d43c585dc
227
ggml-cuda.cu
227
ggml-cuda.cu
@ -616,6 +616,8 @@ static_assert(sizeof(block_iq4_xs) == sizeof(ggml_fp16_t) + sizeof(uint16_t) + Q
|
||||
#define CUDA_UPSCALE_BLOCK_SIZE 256
|
||||
#define CUDA_CONCAT_BLOCK_SIZE 256
|
||||
#define CUDA_PAD_BLOCK_SIZE 256
|
||||
#define CUDA_ARANGE_BLOCK_SIZE 256
|
||||
#define CUDA_TIMESTEP_EMBEDDING_BLOCK_SIZE 256
|
||||
#define CUDA_ACC_BLOCK_SIZE 256
|
||||
#define CUDA_IM2COL_BLOCK_SIZE 256
|
||||
#define CUDA_POOL2D_BLOCK_SIZE 256
|
||||
@ -990,17 +992,21 @@ static __global__ void concat_f32(const float * x,const float * y, float * dst,
|
||||
nidx +
|
||||
blockIdx.y * ne0 +
|
||||
blockIdx.z * ne0 * gridDim.y;
|
||||
dst[offset_dst] = x[offset_src];
|
||||
dst[offset_dst] = x[offset_src];
|
||||
} else {
|
||||
int offset_src =
|
||||
nidx +
|
||||
blockIdx.y * ne0 +
|
||||
(blockIdx.z - ne02) * ne0 * gridDim.y;
|
||||
dst[offset_dst] = y[offset_src];
|
||||
dst[offset_dst] = y[offset_src];
|
||||
}
|
||||
}
|
||||
|
||||
static __global__ void upscale_f32(const float * x, float * dst, const int ne00, const int nb02, const int scale_factor) {
|
||||
static __global__ void upscale_f32(const float * x, float * dst, const int ne00, const int ne00xne01, const int scale_factor) {
|
||||
// blockIdx.z: idx of ne02*ne03
|
||||
// blockIdx.y: idx of ne01*scale_factor, aka ne1
|
||||
// blockIDx.x: idx of ne00*scale_factor / BLOCK_SIZE
|
||||
// ne00xne01: ne00 * ne01
|
||||
int ne0 = ne00 * scale_factor;
|
||||
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (nidx >= ne0) {
|
||||
@ -1012,7 +1018,7 @@ static __global__ void upscale_f32(const float * x, float * dst, const int ne00,
|
||||
int offset_src =
|
||||
i00 +
|
||||
i01 * ne00 +
|
||||
blockIdx.z * nb02;
|
||||
blockIdx.z * ne00xne01;
|
||||
int offset_dst =
|
||||
nidx +
|
||||
blockIdx.y * ne0 +
|
||||
@ -1020,7 +1026,10 @@ static __global__ void upscale_f32(const float * x, float * dst, const int ne00,
|
||||
dst[offset_dst] = x[offset_src];
|
||||
}
|
||||
|
||||
static __global__ void pad_f32(const float * x, float * dst, const int ne0, const int ne00, const int ne01, const int ne02) {
|
||||
static __global__ void pad_f32(const float * x, float * dst, const int ne0, const int ne00, const int ne01, const int ne02, const int ne03) {
|
||||
// blockIdx.z: idx of ne2*ne3, aka ne02*ne03
|
||||
// blockIdx.y: idx of ne1
|
||||
// blockIDx.x: idx of ne0 / BLOCK_SIZE
|
||||
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (nidx >= ne0) {
|
||||
return;
|
||||
@ -1031,19 +1040,53 @@ static __global__ void pad_f32(const float * x, float * dst, const int ne0, cons
|
||||
nidx +
|
||||
blockIdx.y * ne0 +
|
||||
blockIdx.z * ne0 * gridDim.y;
|
||||
if (nidx < ne00 && blockIdx.y < ne01 && blockIdx.z < ne02) {
|
||||
if (nidx < ne00 && blockIdx.y < ne01 && blockIdx.z < ne02*ne03) {
|
||||
int offset_src =
|
||||
nidx +
|
||||
blockIdx.y * ne00 +
|
||||
blockIdx.z * ne00 * ne01;
|
||||
dst[offset_dst] = x[offset_src];
|
||||
dst[offset_dst] = x[offset_src];
|
||||
} else {
|
||||
dst[offset_dst] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
static __global__ void arange_f32(float * dst, const int ne0, const float start, const float step) {
|
||||
// blockIDx.x: idx of ne0 / BLOCK_SIZE
|
||||
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (nidx >= ne0) {
|
||||
return;
|
||||
}
|
||||
dst[nidx] = start + step * nidx;
|
||||
}
|
||||
|
||||
static __global__ void timestep_embedding_f32(const float * timesteps, float * dst, const int nb1, const int dim, const int max_period) {
|
||||
// blockIDx.y: idx of timesteps->ne[0]
|
||||
// blockIDx.x: idx of ((dim + 1) / 2) / BLOCK_SIZE
|
||||
int i = blockIdx.y;
|
||||
int j = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
float * embed_data = (float *)((char *)dst + i*nb1);
|
||||
|
||||
if (dim % 2 != 0 && j == ((dim + 1) / 2)) {
|
||||
embed_data[dim] = 0.f;
|
||||
}
|
||||
|
||||
int half = dim / 2;
|
||||
if (j >= half) {
|
||||
return;
|
||||
}
|
||||
|
||||
float timestep = timesteps[i];
|
||||
float freq = (float)expf(-logf(max_period) * j / half);
|
||||
float arg = timestep * freq;
|
||||
embed_data[j] = cosf(arg);
|
||||
embed_data[j + half] = sinf(arg);
|
||||
}
|
||||
|
||||
template <int block_size>
|
||||
static __global__ void group_norm_f32(const float * x, float * dst, const int group_size, const int ne_elements, const float eps) {
|
||||
// blockIdx.x: num_groups idx
|
||||
// threadIdx.x: block_size idx
|
||||
int start = blockIdx.x * group_size;
|
||||
int end = start + group_size;
|
||||
|
||||
@ -6448,7 +6491,7 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||
const int nb12, const int nb13) {
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i >= ne) {
|
||||
return;
|
||||
@ -6456,17 +6499,17 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
|
||||
|
||||
// determine indices i03/i13, i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor
|
||||
// then combine those indices with the corresponding byte offsets to get the total offsets
|
||||
const int i03 = i/(ne00 * ne01 * ne02);
|
||||
const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
|
||||
const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
|
||||
const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
|
||||
const int x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
|
||||
const int64_t i03 = i/(ne00 * ne01 * ne02);
|
||||
const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
|
||||
const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
|
||||
const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
|
||||
const int64_t x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
|
||||
|
||||
const int i13 = i/(ne10 * ne11 * ne12);
|
||||
const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
|
||||
const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
|
||||
const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
|
||||
const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13 * nb13;
|
||||
const int64_t i13 = i/(ne10 * ne11 * ne12);
|
||||
const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
|
||||
const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
|
||||
const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
|
||||
const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13 * nb13;
|
||||
|
||||
cpy_1(cx + x_offset, cdst + dst_offset);
|
||||
}
|
||||
@ -6956,23 +6999,23 @@ static __global__ void clamp_f32(const float * x, float * dst, const float min,
|
||||
|
||||
template <typename T>
|
||||
static __global__ void im2col_kernel(
|
||||
const float * x, T * dst, int batch_offset,
|
||||
int offset_delta, int IC, int IW, int IH, int OH, int OW, int KW, int KH, int pelements, int CHW,
|
||||
const float * x, T * dst, int64_t batch_offset,
|
||||
int64_t offset_delta, int64_t IC, int64_t IW, int64_t IH, int64_t OH, int64_t OW, int64_t KW, int64_t KH, int64_t pelements, int64_t CHW,
|
||||
int s0, int s1, int p0, int p1, int d0, int d1) {
|
||||
const int i = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
const int64_t i = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (i >= pelements) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int ksize = OW * (KH > 1 ? KW : 1);
|
||||
const int kx = i / ksize;
|
||||
const int kd = kx * ksize;
|
||||
const int ky = (i - kd) / OW;
|
||||
const int ix = i % OW;
|
||||
const int64_t ksize = OW * (KH > 1 ? KW : 1);
|
||||
const int64_t kx = i / ksize;
|
||||
const int64_t kd = kx * ksize;
|
||||
const int64_t ky = (i - kd) / OW;
|
||||
const int64_t ix = i % OW;
|
||||
|
||||
const int oh = blockIdx.y;
|
||||
const int batch = blockIdx.z / IC;
|
||||
const int ic = blockIdx.z % IC;
|
||||
const int64_t oh = blockIdx.y;
|
||||
const int64_t batch = blockIdx.z / IC;
|
||||
const int64_t ic = blockIdx.z % IC;
|
||||
|
||||
const int64_t iiw = ix * s0 + kx * d0 - p0;
|
||||
const int64_t iih = oh * s1 + ky * d1 - p1;
|
||||
@ -7298,19 +7341,33 @@ static void concat_f32_cuda(const float * x, const float * y, float * dst, const
|
||||
concat_f32<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne02);
|
||||
}
|
||||
|
||||
static void upscale_f32_cuda(const float * x, float * dst, const int ne00, const int ne01, const int ne02, const int scale_factor, cudaStream_t stream) {
|
||||
static void upscale_f32_cuda(const float * x, float * dst, const int ne00, const int ne01, const int ne02, const int ne03,
|
||||
const int scale_factor, cudaStream_t stream) {
|
||||
int ne0 = (ne00 * scale_factor);
|
||||
int num_blocks = (ne0 + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE;
|
||||
dim3 gridDim(num_blocks, (ne01 * scale_factor), ne02);
|
||||
dim3 gridDim(num_blocks, (ne01 * scale_factor), ne02*ne03);
|
||||
upscale_f32<<<gridDim, CUDA_UPSCALE_BLOCK_SIZE, 0, stream>>>(x, dst, ne00, ne00 * ne01, scale_factor);
|
||||
}
|
||||
|
||||
static void pad_f32_cuda(const float * x, float * dst,
|
||||
const int ne00, const int ne01, const int ne02,
|
||||
const int ne0, const int ne1, const int ne2, cudaStream_t stream) {
|
||||
const int ne00, const int ne01, const int ne02, const int ne03,
|
||||
const int ne0, const int ne1, const int ne2, const int ne3, cudaStream_t stream) {
|
||||
int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE;
|
||||
dim3 gridDim(num_blocks, ne1, ne2);
|
||||
pad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(x, dst, ne0, ne00, ne01, ne02);
|
||||
dim3 gridDim(num_blocks, ne1, ne2*ne3);
|
||||
pad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(x, dst, ne0, ne00, ne01, ne02, ne03);
|
||||
}
|
||||
|
||||
static void arange_f32_cuda(float * dst, const int ne0, const float start, const float step, cudaStream_t stream) {
|
||||
int num_blocks = (ne0 + CUDA_ARANGE_BLOCK_SIZE - 1) / CUDA_ARANGE_BLOCK_SIZE;
|
||||
arange_f32<<<num_blocks, CUDA_ARANGE_BLOCK_SIZE, 0, stream>>>(dst, ne0, start, step);
|
||||
}
|
||||
|
||||
static void timestep_embedding_f32_cuda(const float * x, float * dst, const int ne00, const int nb1,
|
||||
const int dim, const int max_period, cudaStream_t stream) {
|
||||
int half_ceil = (dim + 1) / 2;
|
||||
int num_blocks = (half_ceil + CUDA_TIMESTEP_EMBEDDING_BLOCK_SIZE - 1) / CUDA_TIMESTEP_EMBEDDING_BLOCK_SIZE;
|
||||
dim3 gridDim(num_blocks, ne00, 1);
|
||||
timestep_embedding_f32<<<gridDim, CUDA_TIMESTEP_EMBEDDING_BLOCK_SIZE, 0, stream>>>(x, dst, nb1, dim, max_period);
|
||||
}
|
||||
|
||||
static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
|
||||
@ -8443,8 +8500,8 @@ static void soft_max_f32_cuda(const float * x, const float * mask, const float *
|
||||
|
||||
template <typename T>
|
||||
static void im2col_cuda(const float* x, T* dst,
|
||||
int IW, int IH, int OW, int OH, int KW, int KH, int IC,
|
||||
int batch, int batch_offset, int offset_delta,
|
||||
int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, int64_t KH, int64_t IC,
|
||||
int64_t batch, int64_t batch_offset, int64_t offset_delta,
|
||||
int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
|
||||
const int parallel_elements = OW * KW * KH;
|
||||
const int num_blocks = (parallel_elements + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE;
|
||||
@ -9123,7 +9180,7 @@ static void ggml_cuda_op_group_norm(
|
||||
|
||||
int num_groups = dst->op_params[0];
|
||||
int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
|
||||
group_norm_f32_cuda(src0_dd, dst_dd, num_groups, group_size, src0->ne[0] * src0->ne[1] * src0->ne[2], main_stream);
|
||||
group_norm_f32_cuda(src0_dd, dst_dd, num_groups * src0->ne[3], group_size, ggml_nelements(src0), main_stream);
|
||||
|
||||
(void) src1;
|
||||
(void) dst;
|
||||
@ -9156,7 +9213,7 @@ static void ggml_cuda_op_upscale(
|
||||
|
||||
const int scale_factor = dst->op_params[0];
|
||||
|
||||
upscale_f32_cuda(src0_dd, dst_dd, src0->ne[0], src0->ne[1], src0->ne[2], scale_factor, main_stream);
|
||||
upscale_f32_cuda(src0_dd, dst_dd, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], scale_factor, main_stream);
|
||||
|
||||
(void) src1;
|
||||
(void) dst;
|
||||
@ -9172,8 +9229,49 @@ static void ggml_cuda_op_pad(
|
||||
GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
|
||||
|
||||
pad_f32_cuda(src0_dd, dst_dd,
|
||||
src0->ne[0], src0->ne[1], src0->ne[2],
|
||||
dst->ne[0], dst->ne[1], dst->ne[2], main_stream);
|
||||
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
|
||||
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], main_stream);
|
||||
|
||||
(void) src1;
|
||||
(void) dst;
|
||||
(void) src1_dd;
|
||||
}
|
||||
|
||||
static void ggml_cuda_op_arange(
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
|
||||
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
|
||||
float start;
|
||||
float stop;
|
||||
float step;
|
||||
memcpy(&start, (float *)dst->op_params + 0, sizeof(float));
|
||||
memcpy(&stop, (float *)dst->op_params + 1, sizeof(float));
|
||||
memcpy(&step, (float *)dst->op_params + 2, sizeof(float));
|
||||
|
||||
int64_t steps = (int64_t)ceil((stop - start) / step);
|
||||
GGML_ASSERT(ggml_nelements(dst) == steps);
|
||||
|
||||
arange_f32_cuda(dst_dd, dst->ne[0], start, step, main_stream);
|
||||
|
||||
(void) src0;
|
||||
(void) src1;
|
||||
(void) src0_dd;
|
||||
(void) src1_dd;
|
||||
}
|
||||
|
||||
static void ggml_cuda_op_timestep_embedding(
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
|
||||
const int dim = dst->op_params[0];
|
||||
const int max_period = dst->op_params[1];
|
||||
|
||||
timestep_embedding_f32_cuda(src0_dd, dst_dd, src0->ne[0], dst->nb[1], dim, max_period, main_stream);
|
||||
|
||||
(void) src1;
|
||||
(void) dst;
|
||||
@ -10458,6 +10556,45 @@ static void ggml_cuda_pad(const ggml_tensor * src0, const ggml_tensor * src1, gg
|
||||
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_pad);
|
||||
}
|
||||
|
||||
static void ggml_cuda_arange(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
|
||||
|
||||
const bool dst_on_device = dst->backend == GGML_BACKEND_TYPE_GPU;
|
||||
|
||||
// dd = data device
|
||||
float * src0_ddf = nullptr;
|
||||
float * src1_ddf = nullptr;
|
||||
float * dst_ddf = nullptr;
|
||||
|
||||
cuda_pool_alloc<float> dst_f;
|
||||
|
||||
ggml_cuda_set_device(g_main_device);
|
||||
cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
|
||||
|
||||
if (dst_on_device) {
|
||||
dst_ddf = (float *) dst_extra->data_device[g_main_device];
|
||||
} else {
|
||||
dst_ddf = dst_f.alloc(ggml_nelements(dst));
|
||||
}
|
||||
|
||||
// do the computation
|
||||
ggml_cuda_op_arange(src0, src1, dst, src0_ddf, src1_ddf, dst_ddf, main_stream);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
// copy dst to host if necessary
|
||||
if (!dst_on_device) {
|
||||
CUDA_CHECK(cudaMemcpyAsync(dst->data, dst_ddf, ggml_nbytes(dst), cudaMemcpyDeviceToHost, main_stream));
|
||||
}
|
||||
|
||||
if (dst->backend == GGML_BACKEND_TYPE_CPU) {
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_cuda_timestep_embedding(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_timestep_embedding);
|
||||
}
|
||||
|
||||
static void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_rms_norm);
|
||||
}
|
||||
@ -11358,6 +11495,12 @@ GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, st
|
||||
case GGML_OP_PAD:
|
||||
func = ggml_cuda_pad;
|
||||
break;
|
||||
case GGML_OP_ARANGE:
|
||||
func = ggml_cuda_arange;
|
||||
break;
|
||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
func = ggml_cuda_timestep_embedding;
|
||||
break;
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
func = ggml_cuda_leaky_relu;
|
||||
break;
|
||||
@ -12253,6 +12396,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
||||
case GGML_OP_GROUP_NORM:
|
||||
case GGML_OP_UPSCALE:
|
||||
case GGML_OP_PAD:
|
||||
case GGML_OP_ARANGE:
|
||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
return true;
|
||||
default:
|
||||
|
62
ggml-metal.m
62
ggml-metal.m
@ -163,6 +163,8 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_IM2COL_F32,
|
||||
GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
|
||||
GGML_METAL_KERNEL_TYPE_PAD_F32,
|
||||
GGML_METAL_KERNEL_TYPE_ARANGE_F32,
|
||||
GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32,
|
||||
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
|
||||
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC,
|
||||
GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,
|
||||
@ -569,6 +571,8 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32, arange_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
|
||||
@ -697,6 +701,8 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
||||
return false;
|
||||
case GGML_OP_UPSCALE:
|
||||
case GGML_OP_PAD:
|
||||
case GGML_OP_ARANGE:
|
||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
case GGML_OP_ARGSORT:
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
return true;
|
||||
@ -1091,7 +1097,8 @@ static bool ggml_metal_graph_compute(
|
||||
{
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
|
||||
const float scale = *(const float *) dst->op_params;
|
||||
float scale;
|
||||
memcpy(&scale, dst->op_params, sizeof(scale));
|
||||
|
||||
int64_t n = ggml_nelements(dst);
|
||||
|
||||
@ -1250,11 +1257,15 @@ static bool ggml_metal_graph_compute(
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX].pipeline;
|
||||
}
|
||||
|
||||
const float scale = ((float *) dst->op_params)[0];
|
||||
const float max_bias = ((float *) dst->op_params)[1];
|
||||
float scale;
|
||||
float max_bias;
|
||||
|
||||
memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale));
|
||||
memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias));
|
||||
|
||||
const int64_t nrows_x = ggml_nrows(src0);
|
||||
const int64_t nrows_y = src0->ne[1];
|
||||
|
||||
const uint32_t n_head_kv = nrows_x/nrows_y;
|
||||
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
|
||||
|
||||
@ -2086,6 +2097,7 @@ static bool ggml_metal_graph_compute(
|
||||
|
||||
//const int n_past = ((int32_t *) dst->op_params)[0];
|
||||
const int n_head = ((int32_t *) dst->op_params)[1];
|
||||
|
||||
float max_bias;
|
||||
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
|
||||
|
||||
@ -2300,6 +2312,50 @@ static bool ggml_metal_graph_compute(
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_ARANGE:
|
||||
{
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
|
||||
float start;
|
||||
float step;
|
||||
|
||||
memcpy(&start, ((int32_t *) dst->op_params) + 0, sizeof(float));
|
||||
memcpy(&step, ((int32_t *) dst->op_params) + 2, sizeof(float));
|
||||
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline;
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:0];
|
||||
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:1];
|
||||
[encoder setBytes:&start length:sizeof(start) atIndex:2];
|
||||
[encoder setBytes:&step length:sizeof(step) atIndex:3];
|
||||
|
||||
const int nth = MIN(1024, ne0);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
{
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
|
||||
const int dim = dst->op_params[0];
|
||||
const int max_period = dst->op_params[1];
|
||||
|
||||
const int half = dim / 2;
|
||||
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline;
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:2];
|
||||
[encoder setBytes:&dim length:sizeof(dim) atIndex:3];
|
||||
[encoder setBytes:&max_period length:sizeof(max_period) atIndex:4];
|
||||
|
||||
const int nth = MIN(1024, half);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne00, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_ARGSORT:
|
||||
{
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
|
@ -1959,6 +1959,49 @@ kernel void kernel_pad_f32(
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_arange_f32(
|
||||
device char * dst,
|
||||
constant int64_t & ne0,
|
||||
constant float & start,
|
||||
constant float & step,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
|
||||
device float * dst_ptr = (device float *) dst;
|
||||
|
||||
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
||||
dst_ptr[i0] = start + step * i0;
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_timestep_embedding_f32(
|
||||
device const char * src0,
|
||||
device char * dst,
|
||||
constant uint64_t & nb1,
|
||||
constant int & dim,
|
||||
constant int & max_period,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
|
||||
int i = tgpig.x;
|
||||
device float * embed_data = (device float *)(dst + i*nb1);
|
||||
|
||||
int half_ = dim / 2;
|
||||
for (int j = tpitg.x; j < half_; j += ntg.x) {
|
||||
float timestep = ((device float *)src0)[i];
|
||||
float freq = (float)exp(-log((float)max_period) * j / half_);
|
||||
float arg = timestep * freq;
|
||||
embed_data[j ] = cos(arg);
|
||||
embed_data[j + half_] = sin(arg);
|
||||
}
|
||||
|
||||
if (dim % 2 != 0 && tpitg.x == 0) {
|
||||
embed_data[dim] = 0.f;
|
||||
}
|
||||
}
|
||||
|
||||
// bitonic sort implementation following the CUDA kernels as reference
|
||||
typedef void (argsort_t)(
|
||||
device const float * x,
|
||||
|
207
ggml.c
207
ggml.c
@ -1822,6 +1822,8 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||
"POOL_2D",
|
||||
"UPSCALE",
|
||||
"PAD",
|
||||
"ARANGE",
|
||||
"TIMESTEP_EMBEDDING",
|
||||
"ARGSORT",
|
||||
"LEAKY_RELU",
|
||||
|
||||
@ -1850,7 +1852,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||
"CROSS_ENTROPY_LOSS_BACK",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 72, "GGML_OP_COUNT != 72");
|
||||
static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
|
||||
|
||||
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"none",
|
||||
@ -1908,6 +1910,8 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"pool_2d(x)",
|
||||
"upscale(x)",
|
||||
"pad(x)",
|
||||
"arange(start, stop, step)",
|
||||
"timestep_embedding(timesteps, dim, max_period)",
|
||||
"argsort(x)",
|
||||
"leaky_relu(x)",
|
||||
|
||||
@ -1936,7 +1940,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"cross_entropy_loss_back(x,y)",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 72, "GGML_OP_COUNT != 72");
|
||||
static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
|
||||
|
||||
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
||||
|
||||
@ -2895,11 +2899,21 @@ static int32_t ggml_get_op_params_i32(const struct ggml_tensor * tensor, uint32_
|
||||
return ((const int32_t *)(tensor->op_params))[i];
|
||||
}
|
||||
|
||||
static float ggml_get_op_params_f32(const struct ggml_tensor * tensor, uint32_t i) {
|
||||
assert(i < GGML_MAX_OP_PARAMS / sizeof(float));
|
||||
return ((const float *)(tensor->op_params))[i];
|
||||
}
|
||||
|
||||
static void ggml_set_op_params_i32(struct ggml_tensor * tensor, uint32_t i, int32_t value) {
|
||||
assert(i < GGML_MAX_OP_PARAMS / sizeof(int32_t));
|
||||
((int32_t *)(tensor->op_params))[i] = value;
|
||||
}
|
||||
|
||||
static void ggml_set_op_params_f32(struct ggml_tensor * tensor, uint32_t i, float value) {
|
||||
assert(i < GGML_MAX_OP_PARAMS / sizeof(float));
|
||||
((float *)(tensor->op_params))[i] = value;
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor) {
|
||||
memset(tensor->data, 0, ggml_nbytes(tensor));
|
||||
return tensor;
|
||||
@ -5898,6 +5912,55 @@ struct ggml_tensor * ggml_upscale(
|
||||
return ggml_upscale_impl(ctx, a, scale_factor);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_arange(
|
||||
struct ggml_context * ctx,
|
||||
float start,
|
||||
float stop,
|
||||
float step) {
|
||||
|
||||
GGML_ASSERT(stop > start);
|
||||
|
||||
const int64_t steps = (int64_t) ceilf((stop - start) / step);
|
||||
|
||||
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, steps);
|
||||
|
||||
result->op = GGML_OP_ARANGE;
|
||||
ggml_set_op_params_f32(result, 0, start);
|
||||
ggml_set_op_params_f32(result, 1, stop);
|
||||
ggml_set_op_params_f32(result, 2, step);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_timestep_embedding(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * timesteps,
|
||||
int dim,
|
||||
int max_period) {
|
||||
bool is_node = false;
|
||||
|
||||
if (timesteps->grad) {
|
||||
GGML_ASSERT(false); // TODO: implement backward
|
||||
is_node = true;
|
||||
}
|
||||
|
||||
int actual_dim = dim;
|
||||
if (dim % 2 != 0) {
|
||||
actual_dim = dim + 1;
|
||||
}
|
||||
|
||||
struct ggml_tensor * result = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, actual_dim, timesteps->ne[0]);
|
||||
|
||||
result->op = GGML_OP_TIMESTEP_EMBEDDING;
|
||||
ggml_set_op_params_i32(result, 0, dim);
|
||||
ggml_set_op_params_i32(result, 1, max_period);
|
||||
|
||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||
result->src[0] = timesteps;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_argsort
|
||||
|
||||
struct ggml_tensor * ggml_argsort(
|
||||
@ -10231,7 +10294,7 @@ static void ggml_compute_forward_group_norm_f32(
|
||||
int n_channels = src0->ne[2];
|
||||
int n_groups = dst->op_params[0];
|
||||
int n_channels_per_group = (n_channels + n_groups - 1) / n_groups;
|
||||
for (int i = ith; i < n_groups; i+=nth) {
|
||||
for (int i = ith; i < n_groups; i += nth) {
|
||||
int start = i * n_channels_per_group;
|
||||
int end = start + n_channels_per_group;
|
||||
if (end > n_channels) {
|
||||
@ -10245,28 +10308,32 @@ static void ggml_compute_forward_group_norm_f32(
|
||||
for (int64_t i01 = 0; i01 < ne01; i01++) {
|
||||
const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
|
||||
|
||||
ggml_float sumr = 0.0;
|
||||
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
||||
sum += (ggml_float)x[i00];
|
||||
sumr += (ggml_float)x[i00];
|
||||
}
|
||||
sum += sumr;
|
||||
}
|
||||
}
|
||||
float mean = sum / (ne00 * ne01 * step);
|
||||
ggml_float sum2 = 0.0;
|
||||
const float mean = sum / (ne00 * ne01 * step);
|
||||
|
||||
ggml_float sum2 = 0.0;
|
||||
for (int64_t i02 = start; i02 < end; i02++) {
|
||||
for (int64_t i01 = 0; i01 < ne01; i01++) {
|
||||
const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
|
||||
|
||||
float * y = (float *)((char *) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3);
|
||||
|
||||
ggml_float sumr = 0.0;
|
||||
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
||||
float v = x[i00] - mean;
|
||||
y[i00] = v;
|
||||
sum2 += (ggml_float)(v * v);
|
||||
sumr += (ggml_float)(v * v);
|
||||
}
|
||||
sum2 += sumr;
|
||||
}
|
||||
}
|
||||
float variance = sum2 / (ne00 * ne01 * step);
|
||||
const float variance = sum2 / (ne00 * ne01 * step);
|
||||
const float scale = 1.0f / sqrtf(variance + eps);
|
||||
|
||||
for (int64_t i02 = start; i02 < end; i02++) {
|
||||
@ -13547,6 +13614,106 @@ static void ggml_compute_forward_pad(
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// ggml_compute_forward_arange
|
||||
|
||||
static void ggml_compute_forward_arange_f32(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
|
||||
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
|
||||
return;
|
||||
}
|
||||
|
||||
GGML_ASSERT(dst->nb[0] == sizeof(float));
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
const float start = ggml_get_op_params_f32(dst, 0);
|
||||
const float stop = ggml_get_op_params_f32(dst, 1);
|
||||
const float step = ggml_get_op_params_f32(dst, 2);
|
||||
|
||||
const int64_t steps = (int64_t) ceilf((stop - start) / step);
|
||||
|
||||
GGML_ASSERT(ggml_nelements(dst) == steps);
|
||||
|
||||
for (int64_t i = ith; i < steps; i+= nth) {
|
||||
float value = start + step * i;
|
||||
((float *)dst->data)[i] = value;
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_arange(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
switch (dst->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_arange_f32(params, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_ASSERT(false);
|
||||
} break;
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_timestep_embedding_f32(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
|
||||
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
|
||||
return;
|
||||
}
|
||||
|
||||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
GGML_TENSOR_UNARY_OP_LOCALS
|
||||
|
||||
const int dim = ggml_get_op_params_i32(dst, 0);
|
||||
const int max_period = ggml_get_op_params_i32(dst, 1);
|
||||
|
||||
int half = dim / 2;
|
||||
|
||||
for (int64_t i = 0; i < ne00; i++) {
|
||||
float * embed_data = (float *)((char *) dst->data + i*nb1);
|
||||
for (int64_t j = ith; j < half; j += nth) {
|
||||
float timestep = ((float *)src0->data)[i];
|
||||
float freq = (float)expf(-logf(max_period) * j / half);
|
||||
float arg = timestep * freq;
|
||||
embed_data[j] = cosf(arg);
|
||||
embed_data[j + half] = sinf(arg);
|
||||
}
|
||||
if (dim % 2 != 0 && ith == 0) {
|
||||
embed_data[dim] = 0.f;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_timestep_embedding(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
|
||||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_timestep_embedding_f32(params, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_ASSERT(false);
|
||||
} break;
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_argsort
|
||||
|
||||
static void ggml_compute_forward_argsort_f32(
|
||||
@ -15615,6 +15782,14 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||
{
|
||||
ggml_compute_forward_pad(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_ARANGE:
|
||||
{
|
||||
ggml_compute_forward_arange(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
{
|
||||
ggml_compute_forward_timestep_embedding(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_ARGSORT:
|
||||
{
|
||||
ggml_compute_forward_argsort(params, tensor);
|
||||
@ -16617,6 +16792,14 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
||||
{
|
||||
GGML_ASSERT(false); // TODO: not implemented
|
||||
} break;
|
||||
case GGML_OP_ARANGE:
|
||||
{
|
||||
GGML_ASSERT(false); // TODO: not implemented
|
||||
} break;
|
||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
{
|
||||
GGML_ASSERT(false); // TODO: not implemented
|
||||
} break;
|
||||
case GGML_OP_ARGSORT:
|
||||
{
|
||||
GGML_ASSERT(false); // TODO: not implemented
|
||||
@ -17368,6 +17551,14 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
||||
{
|
||||
n_tasks = n_threads;
|
||||
} break;
|
||||
case GGML_OP_ARANGE:
|
||||
{
|
||||
n_tasks = n_threads;
|
||||
} break;
|
||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
{
|
||||
n_tasks = n_threads;
|
||||
} break;
|
||||
case GGML_OP_ARGSORT:
|
||||
{
|
||||
n_tasks = n_threads;
|
||||
|
17
ggml.h
17
ggml.h
@ -454,6 +454,8 @@ extern "C" {
|
||||
GGML_OP_POOL_2D,
|
||||
GGML_OP_UPSCALE, // nearest interpolate
|
||||
GGML_OP_PAD,
|
||||
GGML_OP_ARANGE,
|
||||
GGML_OP_TIMESTEP_EMBEDDING,
|
||||
GGML_OP_ARGSORT,
|
||||
GGML_OP_LEAKY_RELU,
|
||||
|
||||
@ -1661,6 +1663,15 @@ extern "C" {
|
||||
int p2,
|
||||
int p3);
|
||||
|
||||
// Ref: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L151
|
||||
// timesteps: [N,]
|
||||
// return: [N, dim]
|
||||
GGML_API struct ggml_tensor * ggml_timestep_embedding(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * timesteps,
|
||||
int dim,
|
||||
int max_period);
|
||||
|
||||
// sort rows
|
||||
enum ggml_sort_order {
|
||||
GGML_SORT_ORDER_ASC,
|
||||
@ -1672,6 +1683,12 @@ extern "C" {
|
||||
struct ggml_tensor * a,
|
||||
enum ggml_sort_order order);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_arange(
|
||||
struct ggml_context * ctx,
|
||||
float start,
|
||||
float stop,
|
||||
float step);
|
||||
|
||||
// top k elements per row
|
||||
GGML_API struct ggml_tensor * ggml_top_k(
|
||||
struct ggml_context * ctx,
|
||||
|
@ -1412,6 +1412,50 @@ struct test_pad : public test_case {
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_ARANGE
|
||||
struct test_arange : public test_case {
|
||||
const ggml_type type;
|
||||
const float start;
|
||||
const float stop;
|
||||
const float step;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR4(type, start, stop, step);
|
||||
}
|
||||
|
||||
test_arange(ggml_type type = GGML_TYPE_F32,
|
||||
float start = 0.f, float stop = 10.f, float step = 1.f)
|
||||
: type(type), start(start), stop(stop), step(step) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * out = ggml_arange(ctx, start, stop, step);
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_TIMESTEP_EMBEDDING
|
||||
struct test_timestep_embedding : public test_case {
|
||||
const ggml_type type;
|
||||
const std::array<int64_t, 4> ne_a;
|
||||
const int dim;
|
||||
const int max_period;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR4(type, ne_a, dim, max_period);
|
||||
}
|
||||
|
||||
test_timestep_embedding(ggml_type type = GGML_TYPE_F32,
|
||||
std::array<int64_t, 4> ne_a = {2, 1, 1, 1},
|
||||
int dim = 320, int max_period=10000)
|
||||
: type(type), ne_a(ne_a), dim(dim), max_period(max_period) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
|
||||
ggml_tensor * out = ggml_timestep_embedding(ctx, a, dim, max_period);
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_LEAKY_RELU
|
||||
struct test_leaky_relu : public test_case {
|
||||
const ggml_type type;
|
||||
@ -2126,6 +2170,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
||||
test_cases.emplace_back(new test_group_norm());
|
||||
test_cases.emplace_back(new test_acc());
|
||||
test_cases.emplace_back(new test_pad());
|
||||
test_cases.emplace_back(new test_arange());
|
||||
test_cases.emplace_back(new test_timestep_embedding());
|
||||
test_cases.emplace_back(new test_leaky_relu());
|
||||
|
||||
// these tests are disabled to save execution time, but they can be handy for debugging
|
||||
|
Loading…
Reference in New Issue
Block a user