mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-11-11 13:30:35 +00:00
sync : ggml
This commit is contained in:
parent
3246fe84d7
commit
231cff5f6f
@ -63,6 +63,7 @@ extern "C" {
|
||||
GGML_API void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
|
||||
GGML_API void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
|
||||
|
||||
// "offset" refers to the offset of the tensor data for setting/getting data
|
||||
GGML_API GGML_CALL void ggml_backend_tensor_set( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
|
||||
GGML_API GGML_CALL void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
|
||||
|
||||
|
@ -220,7 +220,7 @@
|
||||
#include <stdio.h>
|
||||
|
||||
#define GGML_FILE_MAGIC 0x67676d6c // "ggml"
|
||||
#define GGML_FILE_VERSION 1
|
||||
#define GGML_FILE_VERSION 2
|
||||
|
||||
#define GGML_QNT_VERSION 2 // bump this on quantization format changes
|
||||
#define GGML_QNT_VERSION_FACTOR 1000 // do not change this
|
||||
@ -453,6 +453,8 @@ extern "C" {
|
||||
GGML_OP_SQR,
|
||||
GGML_OP_SQRT,
|
||||
GGML_OP_LOG,
|
||||
GGML_OP_SIN,
|
||||
GGML_OP_COS,
|
||||
GGML_OP_SUM,
|
||||
GGML_OP_SUM_ROWS,
|
||||
GGML_OP_MEAN,
|
||||
@ -490,9 +492,11 @@ extern "C" {
|
||||
GGML_OP_CLAMP,
|
||||
GGML_OP_CONV_TRANSPOSE_1D,
|
||||
GGML_OP_IM2COL,
|
||||
GGML_OP_IM2COL_BACK,
|
||||
GGML_OP_CONV_TRANSPOSE_2D,
|
||||
GGML_OP_POOL_1D,
|
||||
GGML_OP_POOL_2D,
|
||||
GGML_OP_POOL_2D_BACK,
|
||||
GGML_OP_UPSCALE, // nearest interpolate
|
||||
GGML_OP_PAD,
|
||||
GGML_OP_ARANGE,
|
||||
@ -969,6 +973,22 @@ extern "C" {
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_sin(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_sin_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_cos(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_cos_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
// return scalar
|
||||
GGML_API struct ggml_tensor * ggml_sum(
|
||||
struct ggml_context * ctx,
|
||||
@ -1566,34 +1586,49 @@ extern "C" {
|
||||
float min,
|
||||
float max);
|
||||
|
||||
// im2col
|
||||
// converts data into a format that effectively results in a convolution when combined with matrix multiplication
|
||||
GGML_API struct ggml_tensor * ggml_im2col(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
int s0,
|
||||
int s1,
|
||||
int p0,
|
||||
int p1,
|
||||
int d0,
|
||||
int d1,
|
||||
bool is_2D,
|
||||
enum ggml_type dst_type);
|
||||
struct ggml_tensor * a, // convolution kernel
|
||||
struct ggml_tensor * b, // data
|
||||
int s0, // stride dimension 0
|
||||
int s1, // stride dimension 1
|
||||
int p0, // padding dimension 0
|
||||
int p1, // padding dimension 1
|
||||
int d0, // dilation dimension 0
|
||||
int d1, // dilation dimension 1
|
||||
bool is_2D,
|
||||
enum ggml_type dst_type);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_im2col_back(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a, // convolution kernel
|
||||
struct ggml_tensor * b, // gradient of im2col output
|
||||
int64_t * ne, // shape of im2col input
|
||||
int s0, // stride dimension 0
|
||||
int s1, // stride dimension 1
|
||||
int p0, // padding dimension 0
|
||||
int p1, // padding dimension 1
|
||||
int d0, // dilation dimension 0
|
||||
int d1, // dilation dimension 1
|
||||
bool is_2D);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_conv_depthwise_2d(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
int s0,
|
||||
int s1,
|
||||
int p0,
|
||||
int p1,
|
||||
int d0,
|
||||
int d1);
|
||||
struct ggml_tensor * a, // convolution kernel
|
||||
struct ggml_tensor * b, // data
|
||||
int s0, // stride dimension 0
|
||||
int s1, // stride dimension 1
|
||||
int p0, // padding dimension 0
|
||||
int p1, // padding dimension 1
|
||||
int d0, // dilation dimension 0
|
||||
int d1); // dilation dimension 1
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_conv_1d(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
struct ggml_tensor * a, // convolution kernel
|
||||
struct ggml_tensor * b, // data
|
||||
int s0, // stride
|
||||
int p0, // padding
|
||||
int d0); // dilation
|
||||
@ -1602,29 +1637,29 @@ extern "C" {
|
||||
// alias for ggml_conv_1d(a, b, s, a->ne[0]/2, d)
|
||||
GGML_API struct ggml_tensor* ggml_conv_1d_ph(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
int s,
|
||||
int d);
|
||||
struct ggml_tensor * a, // convolution kernel
|
||||
struct ggml_tensor * b, // data
|
||||
int s, // stride
|
||||
int d); // dilation
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_conv_transpose_1d(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
int s0,
|
||||
int p0,
|
||||
int d0);
|
||||
struct ggml_tensor * a, // convolution kernel
|
||||
struct ggml_tensor * b, // data
|
||||
int s0, // stride
|
||||
int p0, // padding
|
||||
int d0); // dilation
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_conv_2d(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
int s0,
|
||||
int s1,
|
||||
int p0,
|
||||
int p1,
|
||||
int d0,
|
||||
int d1);
|
||||
struct ggml_tensor * a, // convolution kernel
|
||||
struct ggml_tensor * b, // data
|
||||
int s0, // stride dimension 0
|
||||
int s1, // stride dimension 1
|
||||
int p0, // padding dimension 0
|
||||
int p1, // padding dimension 1
|
||||
int d0, // dilation dimension 0
|
||||
int d1); // dilation dimension 1
|
||||
|
||||
|
||||
// kernel size is a->ne[0] x a->ne[1]
|
||||
@ -1686,6 +1721,18 @@ extern "C" {
|
||||
float p0,
|
||||
float p1);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_pool_2d_back(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * af, // "a"/input used in forward pass
|
||||
enum ggml_op_pool op,
|
||||
int k0,
|
||||
int k1,
|
||||
int s0,
|
||||
int s1,
|
||||
float p0,
|
||||
float p1);
|
||||
|
||||
// nearest interpolate
|
||||
// multiplies ne0 and ne1 by scale factor
|
||||
// used in stable-diffusion
|
||||
|
@ -9,8 +9,10 @@
|
||||
#include "ggml-cuda/binbcast.cuh"
|
||||
#include "ggml-cuda/clamp.cuh"
|
||||
#include "ggml-cuda/concat.cuh"
|
||||
#include "ggml-cuda/conv-transpose-1d.cuh"
|
||||
#include "ggml-cuda/convert.cuh"
|
||||
#include "ggml-cuda/cpy.cuh"
|
||||
#include "ggml-cuda/cross-entropy-loss.cuh"
|
||||
#include "ggml-cuda/diagmask.cuh"
|
||||
#include "ggml-cuda/dmmv.cuh"
|
||||
#include "ggml-cuda/fattn.cuh"
|
||||
@ -29,7 +31,6 @@
|
||||
#include "ggml-cuda/tsembd.cuh"
|
||||
#include "ggml-cuda/unary.cuh"
|
||||
#include "ggml-cuda/upscale.cuh"
|
||||
#include "ggml-cuda/conv-transpose-1d.cuh"
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
@ -2181,6 +2182,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
case GGML_OP_ADD:
|
||||
ggml_cuda_op_add(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_SUB:
|
||||
ggml_cuda_op_sub(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_ACC:
|
||||
ggml_cuda_op_acc(ctx, dst);
|
||||
break;
|
||||
@ -2267,6 +2271,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
case GGML_OP_SQRT:
|
||||
ggml_cuda_op_sqrt(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_SIN:
|
||||
ggml_cuda_op_sin(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_COS:
|
||||
ggml_cuda_op_cos(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_CLAMP:
|
||||
ggml_cuda_op_clamp(ctx, dst);
|
||||
break;
|
||||
@ -2303,6 +2313,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
ggml_cuda_flash_attn_ext(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_CROSS_ENTROPY_LOSS:
|
||||
ggml_cuda_cross_entropy_loss(ctx, dst);
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
@ -2610,6 +2623,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
|
||||
assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));
|
||||
for (int j = 0; j < GGML_MAX_SRC; j++) {
|
||||
if (node->src[j] != nullptr) {
|
||||
assert(node->src[j]->buffer);
|
||||
assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) || ggml_backend_buffer_is_cuda_split(node->src[j]->buffer));
|
||||
}
|
||||
}
|
||||
@ -2853,12 +2867,15 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
||||
case GGML_OP_TRANSPOSE:
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_SUB:
|
||||
case GGML_OP_MUL:
|
||||
case GGML_OP_DIV:
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_SCALE:
|
||||
case GGML_OP_SQR:
|
||||
case GGML_OP_SQRT:
|
||||
case GGML_OP_SIN:
|
||||
case GGML_OP_COS:
|
||||
case GGML_OP_CLAMP:
|
||||
case GGML_OP_CONT:
|
||||
case GGML_OP_DIAG_MASK_INF:
|
||||
@ -2890,6 +2907,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
||||
}
|
||||
return ggml_cuda_info().devices[cuda_ctx->device].cc >= CC_VOLTA &&
|
||||
op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
|
||||
case GGML_OP_CROSS_ENTROPY_LOSS:
|
||||
return true;
|
||||
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||
default:
|
||||
return false;
|
||||
|
@ -9,6 +9,10 @@ static __device__ __forceinline__ float op_add(const float a, const float b) {
|
||||
return a + b;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float op_sub(const float a, const float b) {
|
||||
return a - b;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float op_mul(const float a, const float b) {
|
||||
return a * b;
|
||||
}
|
||||
@ -271,6 +275,10 @@ void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_add>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
|
||||
}
|
||||
|
||||
void ggml_cuda_op_sub(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_sub>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
|
||||
}
|
||||
|
||||
void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_mul>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
|
||||
}
|
||||
|
@ -2,5 +2,6 @@
|
||||
|
||||
void ggml_cuda_op_repeat(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
void ggml_cuda_op_sub(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
106
ggml/src/ggml-cuda/cross-entropy-loss.cu
Normal file
106
ggml/src/ggml-cuda/cross-entropy-loss.cu
Normal file
@ -0,0 +1,106 @@
|
||||
#include "common.cuh"
|
||||
#include "cross-entropy-loss.cuh"
|
||||
#include "sumrows.cuh"
|
||||
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
|
||||
static __global__ void cross_entropy_loss_f32(const float * logits, const float * labels, float * dst, const int nclasses, const int k) {
|
||||
const int warp_id = threadIdx.x / WARP_SIZE;
|
||||
const int lane_id = threadIdx.x % WARP_SIZE;
|
||||
const int i0 = blockDim.x*blockIdx.x + warp_id*WARP_SIZE;
|
||||
|
||||
const int ne_tmp = WARP_SIZE*nclasses;
|
||||
|
||||
extern __shared__ float tmp_all[];
|
||||
float * tmp_logits = tmp_all + (2*warp_id + 0)*ne_tmp;
|
||||
float * tmp_labels = tmp_all + (2*warp_id + 1)*ne_tmp;
|
||||
|
||||
// Each warp first loads ne_tmp logits/labels into shared memory:
|
||||
for (int i = lane_id; i < ne_tmp; i += WARP_SIZE) {
|
||||
const int ig = i0*nclasses + i; // ig == i global
|
||||
|
||||
tmp_logits[i] = ig < k*nclasses ? logits[ig] : 0.0f;
|
||||
tmp_labels[i] = ig < k*nclasses ? labels[ig] : 0.0f;
|
||||
}
|
||||
|
||||
// Each thread in the warp then calculates the cross entropy loss for a single row.
|
||||
// TODO: pad in order to avoid shared memory bank conflicts.
|
||||
|
||||
// Find maximum for softmax:
|
||||
float max = -INFINITY;
|
||||
for (int i = 0; i < nclasses; ++i) {
|
||||
max = fmaxf(max, tmp_logits[lane_id*nclasses + i]);
|
||||
}
|
||||
|
||||
// Calculate log(softmax(logits)) which is just logits - max:
|
||||
float sum = 0.0f;
|
||||
for (int i = 0; i < nclasses; ++i) {
|
||||
float val = tmp_logits[lane_id*nclasses + i] - max;
|
||||
sum += expf(val);
|
||||
tmp_logits[lane_id*nclasses + i] = val;
|
||||
}
|
||||
sum = logf(sum);
|
||||
|
||||
// log(exp(logits - max) / sum) = (logits - max) - log(sum)
|
||||
float loss = 0.0f;
|
||||
for (int i = 0; i < nclasses; ++i) {
|
||||
loss += (tmp_logits[lane_id*nclasses + i] - sum) * tmp_labels[lane_id*nclasses + i];
|
||||
}
|
||||
loss = -warp_reduce_sum(loss) / (float)k;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (lane_id == 0) {
|
||||
tmp_all[warp_id] = loss;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (warp_id != 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
loss = lane_id < CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE/WARP_SIZE ? tmp_all[lane_id] : 0.0f;
|
||||
loss = warp_reduce_sum(loss);
|
||||
|
||||
if (lane_id != 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
dst[blockIdx.x] = loss;
|
||||
}
|
||||
|
||||
void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
GGML_ASSERT(ggml_is_contiguous(src1));
|
||||
GGML_ASSERT(ggml_is_contiguous(dst));
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t nrows = ggml_nrows(src0);
|
||||
|
||||
const float * src0_d = (const float *) src0->data;
|
||||
const float * src1_d = (const float *) src1->data;
|
||||
float * dst_d = (float *) dst->data;
|
||||
|
||||
ggml_cuda_pool & pool = ctx.pool();
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
const dim3 blocks_dim(CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE, 1, 1);
|
||||
const dim3 blocks_num((nrows + CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE - 1) / CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE, 1, 1);
|
||||
const int shmem = 2*CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE*ne00*sizeof(float);
|
||||
|
||||
ggml_cuda_pool_alloc<float> dst_tmp(pool, blocks_num.x);
|
||||
|
||||
cross_entropy_loss_f32<<<blocks_num, blocks_dim, shmem, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
|
||||
|
||||
// Combine results from individual blocks:
|
||||
sum_rows_f32_cuda(dst_tmp.ptr, dst_d, blocks_num.x, 1, stream);
|
||||
}
|
5
ggml/src/ggml-cuda/cross-entropy-loss.cuh
Normal file
5
ggml/src/ggml-cuda/cross-entropy-loss.cuh
Normal file
@ -0,0 +1,5 @@
|
||||
#include "common.cuh"
|
||||
|
||||
#define CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE 256
|
||||
|
||||
void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
@ -16,7 +16,7 @@ static __global__ void k_sum_rows_f32(const float * x, float * dst, const int nc
|
||||
}
|
||||
}
|
||||
|
||||
static void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
const dim3 block_nums(nrows, 1, 1);
|
||||
k_sum_rows_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
|
||||
@ -32,7 +32,6 @@ void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
|
||||
|
||||
const int64_t ncols = src0->ne[0];
|
||||
const int64_t nrows = ggml_nrows(src0);
|
||||
|
||||
|
@ -1,3 +1,5 @@
|
||||
#include "common.cuh"
|
||||
|
||||
void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream);
|
||||
|
||||
void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
@ -101,6 +101,24 @@ static __global__ void sqrt_f32(const float * x, float * dst, const int k) {
|
||||
dst[i] = sqrtf(x[i]);
|
||||
}
|
||||
|
||||
static __global__ void sin_f32(const float * x, float * dst, const int k) {
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i >= k) {
|
||||
return;
|
||||
}
|
||||
dst[i] = sinf(x[i]);
|
||||
}
|
||||
|
||||
static __global__ void cos_f32(const float * x, float * dst, const int k) {
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i >= k) {
|
||||
return;
|
||||
}
|
||||
dst[i] = cosf(x[i]);
|
||||
}
|
||||
|
||||
static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
||||
const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
|
||||
gelu_f32<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
||||
@ -156,6 +174,16 @@ static void sqrt_f32_cuda(const float * x, float * dst, const int k, cudaStream_
|
||||
sqrt_f32<<<num_blocks, CUDA_SQRT_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
||||
}
|
||||
|
||||
static void sin_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
||||
const int num_blocks = (k + CUDA_SIN_BLOCK_SIZE - 1) / CUDA_SIN_BLOCK_SIZE;
|
||||
sin_f32<<<num_blocks, CUDA_SIN_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
||||
}
|
||||
|
||||
static void cos_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
||||
const int num_blocks = (k + CUDA_COS_BLOCK_SIZE - 1) / CUDA_COS_BLOCK_SIZE;
|
||||
cos_f32<<<num_blocks, CUDA_COS_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const float * src0_d = (const float *)src0->data;
|
||||
@ -312,3 +340,31 @@ void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
|
||||
sqrt_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_sin(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const float * src0_d = (const float *)src0->data;
|
||||
float * dst_d = (float *)dst->data;
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
|
||||
sin_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const float * src0_d = (const float *)src0->data;
|
||||
float * dst_d = (float *)dst->data;
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
|
||||
cos_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
|
||||
}
|
||||
|
@ -9,6 +9,8 @@
|
||||
#define CUDA_HARDSWISH_BLOCK_SIZE 256
|
||||
#define CUDA_SQR_BLOCK_SIZE 256
|
||||
#define CUDA_SQRT_BLOCK_SIZE 256
|
||||
#define CUDA_SIN_BLOCK_SIZE 256
|
||||
#define CUDA_COS_BLOCK_SIZE 256
|
||||
|
||||
void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
@ -31,3 +33,7 @@ void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
|
||||
void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_sin(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
@ -31,6 +31,8 @@ struct ggml_metal_kernel {
|
||||
enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_ADD,
|
||||
GGML_METAL_KERNEL_TYPE_ADD_ROW,
|
||||
GGML_METAL_KERNEL_TYPE_SUB,
|
||||
GGML_METAL_KERNEL_TYPE_SUB_ROW,
|
||||
GGML_METAL_KERNEL_TYPE_MUL,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_ROW,
|
||||
GGML_METAL_KERNEL_TYPE_DIV,
|
||||
@ -207,6 +209,9 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
|
||||
GGML_METAL_KERNEL_TYPE_CONCAT,
|
||||
GGML_METAL_KERNEL_TYPE_SQR,
|
||||
GGML_METAL_KERNEL_TYPE_SQRT,
|
||||
GGML_METAL_KERNEL_TYPE_SIN,
|
||||
GGML_METAL_KERNEL_TYPE_COS,
|
||||
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
||||
|
||||
GGML_METAL_KERNEL_TYPE_COUNT
|
||||
@ -493,6 +498,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
|
||||
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB, sub, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB_ROW, sub_row, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
|
||||
@ -669,6 +676,9 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT, sqrt, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
||||
}
|
||||
|
||||
@ -769,15 +779,20 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
|
||||
case GGML_OP_PERMUTE:
|
||||
case GGML_OP_CONCAT:
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_SUB:
|
||||
case GGML_OP_ACC:
|
||||
case GGML_OP_MUL:
|
||||
case GGML_OP_DIV:
|
||||
case GGML_OP_REPEAT:
|
||||
case GGML_OP_SCALE:
|
||||
case GGML_OP_CLAMP:
|
||||
case GGML_OP_SQR:
|
||||
case GGML_OP_SUM_ROWS:
|
||||
return true;
|
||||
case GGML_OP_SQR:
|
||||
case GGML_OP_SQRT:
|
||||
case GGML_OP_SIN:
|
||||
case GGML_OP_COS:
|
||||
return ggml_is_contiguous(op->src[0]);
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_SOFT_MAX:
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_GROUP_NORM:
|
||||
@ -1057,6 +1072,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_SUB:
|
||||
case GGML_OP_MUL:
|
||||
case GGML_OP_DIV:
|
||||
{
|
||||
@ -1080,6 +1096,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||
nb = ne00 / 4;
|
||||
switch (dst->op) {
|
||||
case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break;
|
||||
case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB_ROW].pipeline; break;
|
||||
case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break;
|
||||
case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break;
|
||||
default: GGML_ABORT("fatal error");
|
||||
@ -1089,6 +1106,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||
} else {
|
||||
switch (dst->op) {
|
||||
case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; break;
|
||||
case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB].pipeline; break;
|
||||
case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break;
|
||||
case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break;
|
||||
default: GGML_ABORT("fatal error");
|
||||
@ -1416,6 +1434,48 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||
|
||||
const int64_t n = ggml_nelements(dst);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_SQRT:
|
||||
{
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SQRT].pipeline;
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
|
||||
const int64_t n = ggml_nelements(dst);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_SIN:
|
||||
{
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SIN].pipeline;
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
|
||||
const int64_t n = ggml_nelements(dst);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_COS:
|
||||
{
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_COS].pipeline;
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
|
||||
const int64_t n = ggml_nelements(dst);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_SUM_ROWS:
|
||||
|
@ -17,7 +17,7 @@ enum ggml_sort_order {
|
||||
GGML_SORT_ORDER_DESC,
|
||||
};
|
||||
|
||||
// general-purpose kernel for addition, multiplication and division of two tensors
|
||||
// general-purpose kernel for addition, subtraction, multiplication and division of two tensors
|
||||
// pros: works for non-contiguous tensors, supports broadcast across all dims
|
||||
// cons: not very efficient
|
||||
kernel void kernel_add(
|
||||
@ -70,6 +70,56 @@ kernel void kernel_add(
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_sub(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant int64_t & ne03,
|
||||
constant uint64_t & nb00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant uint64_t & nb03,
|
||||
constant int64_t & ne10,
|
||||
constant int64_t & ne11,
|
||||
constant int64_t & ne12,
|
||||
constant int64_t & ne13,
|
||||
constant uint64_t & nb10,
|
||||
constant uint64_t & nb11,
|
||||
constant uint64_t & nb12,
|
||||
constant uint64_t & nb13,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant int64_t & ne2,
|
||||
constant int64_t & ne3,
|
||||
constant uint64_t & nb0,
|
||||
constant uint64_t & nb1,
|
||||
constant uint64_t & nb2,
|
||||
constant uint64_t & nb3,
|
||||
constant int64_t & offs,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
const int64_t i03 = tgpig.z;
|
||||
const int64_t i02 = tgpig.y;
|
||||
const int64_t i01 = tgpig.x;
|
||||
|
||||
const int64_t i13 = i03 % ne13;
|
||||
const int64_t i12 = i02 % ne12;
|
||||
const int64_t i11 = i01 % ne11;
|
||||
|
||||
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs;
|
||||
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
|
||||
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + offs;
|
||||
|
||||
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
||||
const int i10 = i0 % ne10;
|
||||
*((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) - *((device float *)(src1_ptr + i10*nb10));
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_mul(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
@ -226,6 +276,15 @@ kernel void kernel_add_row(
|
||||
dst[tpig] = src0[tpig] + src1[tpig % nb];
|
||||
}
|
||||
|
||||
kernel void kernel_sub_row(
|
||||
device const float4 * src0,
|
||||
device const float4 * src1,
|
||||
device float4 * dst,
|
||||
constant uint64_t & nb [[buffer(28)]],
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = src0[tpig] - src1[tpig % nb];
|
||||
}
|
||||
|
||||
kernel void kernel_mul_row(
|
||||
device const float4 * src0,
|
||||
device const float4 * src1,
|
||||
@ -358,6 +417,27 @@ kernel void kernel_sqr(
|
||||
dst[tpig] = src0[tpig] * src0[tpig];
|
||||
}
|
||||
|
||||
kernel void kernel_sqrt(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = sqrt(src0[tpig]);
|
||||
}
|
||||
|
||||
kernel void kernel_sin(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = sin(src0[tpig]);
|
||||
}
|
||||
|
||||
kernel void kernel_cos(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = cos(src0[tpig]);
|
||||
}
|
||||
|
||||
kernel void kernel_sum_rows(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
|
@ -3644,7 +3644,7 @@ void quantize_row_q8_K(const float * restrict x, void * restrict y, int64_t k) {
|
||||
quantize_row_q8_K_ref(x, y, k);
|
||||
}
|
||||
|
||||
//===================================== Dot ptoducts =================================
|
||||
//===================================== Dot products =================================
|
||||
|
||||
//
|
||||
// Helper functions
|
||||
|
@ -188,6 +188,8 @@ struct vk_device_struct {
|
||||
vk_pipeline pipeline_upscale_f32;
|
||||
vk_pipeline pipeline_scale_f32;
|
||||
vk_pipeline pipeline_sqr_f32;
|
||||
vk_pipeline pipeline_sin_f32;
|
||||
vk_pipeline pipeline_cos_f32;
|
||||
vk_pipeline pipeline_clamp_f32;
|
||||
vk_pipeline pipeline_pad_f32;
|
||||
vk_pipeline pipeline_repeat_f32;
|
||||
@ -1702,6 +1704,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_sqr_f32, "sqr_f32", sqr_f32_len, sqr_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_sin_f32, "sin_f32", sin_f32_len, sin_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cos_f32, "cos_f32", cos_f32_len, cos_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
@ -4023,6 +4027,16 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||
return ctx->device->pipeline_sqr_f32;
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_SIN:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_sin_f32;
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_COS:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_cos_f32;
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_CLAMP:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_clamp_f32;
|
||||
@ -4171,6 +4185,8 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
|
||||
case GGML_OP_UPSCALE:
|
||||
case GGML_OP_SCALE:
|
||||
case GGML_OP_SQR:
|
||||
case GGML_OP_SIN:
|
||||
case GGML_OP_COS:
|
||||
case GGML_OP_CLAMP:
|
||||
case GGML_OP_PAD:
|
||||
case GGML_OP_REPEAT:
|
||||
@ -4381,6 +4397,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
||||
case GGML_OP_MUL:
|
||||
case GGML_OP_SCALE:
|
||||
case GGML_OP_SQR:
|
||||
case GGML_OP_SIN:
|
||||
case GGML_OP_COS:
|
||||
case GGML_OP_CLAMP:
|
||||
case GGML_OP_PAD:
|
||||
case GGML_OP_REPEAT:
|
||||
@ -4598,6 +4616,32 @@ static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context& subctx, const
|
||||
}, dryrun);
|
||||
}
|
||||
|
||||
static void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
||||
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
||||
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
||||
|
||||
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SIN, {
|
||||
(uint32_t)ggml_nelements(src0),
|
||||
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
||||
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
||||
0,
|
||||
0.0f, 0.0f,
|
||||
});
|
||||
}
|
||||
|
||||
static void ggml_vk_cos(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
||||
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
||||
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
||||
|
||||
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_COS, {
|
||||
(uint32_t)ggml_nelements(src0),
|
||||
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
||||
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
||||
0,
|
||||
0.0f, 0.0f,
|
||||
});
|
||||
}
|
||||
|
||||
static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
||||
float * op_params = (float *)dst->op_params;
|
||||
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
||||
@ -5658,6 +5702,8 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
||||
case GGML_OP_UPSCALE:
|
||||
case GGML_OP_SCALE:
|
||||
case GGML_OP_SQR:
|
||||
case GGML_OP_SIN:
|
||||
case GGML_OP_COS:
|
||||
case GGML_OP_CLAMP:
|
||||
case GGML_OP_PAD:
|
||||
case GGML_OP_CPY:
|
||||
@ -5735,6 +5781,14 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
||||
case GGML_OP_SQR:
|
||||
ggml_vk_sqr(ctx, compute_ctx, src0, node, dryrun);
|
||||
|
||||
break;
|
||||
case GGML_OP_SIN:
|
||||
ggml_vk_sin(ctx, compute_ctx, src0, node);
|
||||
|
||||
break;
|
||||
case GGML_OP_COS:
|
||||
ggml_vk_cos(ctx, compute_ctx, src0, node);
|
||||
|
||||
break;
|
||||
case GGML_OP_CLAMP:
|
||||
ggml_vk_clamp(ctx, compute_ctx, src0, node, dryrun);
|
||||
@ -5851,6 +5905,8 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
|
||||
case GGML_OP_UPSCALE:
|
||||
case GGML_OP_SCALE:
|
||||
case GGML_OP_SQR:
|
||||
case GGML_OP_SIN:
|
||||
case GGML_OP_COS:
|
||||
case GGML_OP_CLAMP:
|
||||
case GGML_OP_PAD:
|
||||
case GGML_OP_CPY:
|
||||
@ -6582,6 +6638,8 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const
|
||||
case GGML_OP_UPSCALE:
|
||||
case GGML_OP_SCALE:
|
||||
case GGML_OP_SQR:
|
||||
case GGML_OP_SIN:
|
||||
case GGML_OP_COS:
|
||||
case GGML_OP_CLAMP:
|
||||
case GGML_OP_PAD:
|
||||
case GGML_OP_CONT:
|
||||
@ -7024,6 +7082,10 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
||||
tensor_clone = ggml_scale(ggml_ctx, src0_clone, ((float *)tensor->op_params)[0]);
|
||||
} else if (tensor->op == GGML_OP_SQR) {
|
||||
tensor_clone = ggml_sqr(ggml_ctx, src0_clone);
|
||||
} else if (tensor->op == GGML_OP_SIN) {
|
||||
tensor_clone = ggml_sin(ggml_ctx, src0_clone);
|
||||
} else if (tensor->op == GGML_OP_COS) {
|
||||
tensor_clone = ggml_cos(ggml_ctx, src0_clone);
|
||||
} else if (tensor->op == GGML_OP_CLAMP) {
|
||||
tensor_clone = ggml_clamp(ggml_ctx, src0_clone, ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
|
||||
} else if (tensor->op == GGML_OP_PAD) {
|
||||
|
704
ggml/src/ggml.c
704
ggml/src/ggml.c
File diff suppressed because it is too large
Load Diff
15
ggml/src/vulkan-shaders/cos.comp
Normal file
15
ggml/src/vulkan-shaders/cos.comp
Normal file
@ -0,0 +1,15 @@
|
||||
#version 450
|
||||
|
||||
#include "types.comp"
|
||||
#include "generic_unary_head.comp"
|
||||
|
||||
void main() {
|
||||
const uint idx = get_idx();
|
||||
|
||||
if (idx >= p.ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
const FLOAT_TYPE val = FLOAT_TYPE(data_a[src0_idx(idx)]);
|
||||
data_d[p.d_offset + dst_idx(idx)] = D_TYPE(cos(val));
|
||||
}
|
15
ggml/src/vulkan-shaders/sin.comp
Normal file
15
ggml/src/vulkan-shaders/sin.comp
Normal file
@ -0,0 +1,15 @@
|
||||
#version 450
|
||||
|
||||
#include "types.comp"
|
||||
#include "generic_unary_head.comp"
|
||||
|
||||
void main() {
|
||||
const uint idx = get_idx();
|
||||
|
||||
if (idx >= p.ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
const FLOAT_TYPE val = FLOAT_TYPE(data_a[src0_idx(idx)]);
|
||||
data_d[p.d_offset + dst_idx(idx)] = D_TYPE(sin(val));
|
||||
}
|
@ -1 +1 @@
|
||||
797faa25af14126eb30134d4033139ae3c5428ed
|
||||
28b7633d733bbeef0026570fbc61c79c5e9aa5ae
|
||||
|
@ -1160,6 +1160,58 @@ struct test_sqrt : public test_case {
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_SIN
|
||||
struct test_sin : public test_case {
|
||||
const ggml_type type;
|
||||
const std::array<int64_t, 4> ne;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR2(type, ne);
|
||||
}
|
||||
|
||||
test_sin(ggml_type type = GGML_TYPE_F32,
|
||||
std::array<int64_t, 4> ne = {10, 10, 10, 10})
|
||||
: type(type), ne(ne) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||
ggml_tensor * out = ggml_sin(ctx, a);
|
||||
return out;
|
||||
}
|
||||
|
||||
void initialize_tensors(ggml_context * ctx) override {
|
||||
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
||||
init_tensor_uniform(t, -100.0f, 100.0f);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_COS
|
||||
struct test_cos : public test_case {
|
||||
const ggml_type type;
|
||||
const std::array<int64_t, 4> ne;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR2(type, ne);
|
||||
}
|
||||
|
||||
test_cos(ggml_type type = GGML_TYPE_F32,
|
||||
std::array<int64_t, 4> ne = {10, 10, 10, 10})
|
||||
: type(type), ne(ne) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||
ggml_tensor * out = ggml_cos(ctx, a);
|
||||
return out;
|
||||
}
|
||||
|
||||
void initialize_tensors(ggml_context * ctx) override {
|
||||
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
||||
init_tensor_uniform(t, -100.0f, 100.0f);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_CLAMP
|
||||
struct test_clamp : public test_case {
|
||||
const ggml_type type;
|
||||
@ -1731,6 +1783,27 @@ struct test_flash_attn_ext : public test_case {
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_CROSS_ENTROPY_LOSS
|
||||
struct test_cross_entropy_loss : public test_case {
|
||||
const ggml_type type;
|
||||
const std::array<int64_t, 4> ne;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR2(type, ne);
|
||||
}
|
||||
|
||||
test_cross_entropy_loss(ggml_type type = GGML_TYPE_F32,
|
||||
std::array<int64_t, 4> ne = {10, 10, 10, 10})
|
||||
: type(type), ne(ne) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * logits = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||
ggml_tensor * labels = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||
ggml_tensor * out = ggml_cross_entropy_loss(ctx, logits, labels);
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
enum llm_norm_type {
|
||||
LLM_NORM,
|
||||
LLM_NORM_RMS,
|
||||
@ -2393,6 +2466,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
||||
|
||||
test_cases.emplace_back(new test_sqr());
|
||||
test_cases.emplace_back(new test_sqrt());
|
||||
test_cases.emplace_back(new test_sin());
|
||||
test_cases.emplace_back(new test_cos());
|
||||
test_cases.emplace_back(new test_clamp());
|
||||
|
||||
test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 1, 1}, 5));
|
||||
@ -2512,6 +2587,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
||||
}
|
||||
}
|
||||
|
||||
test_cases.emplace_back(new test_cross_entropy_loss());
|
||||
|
||||
// these tests are disabled to save execution time, but they can be handy for debugging
|
||||
#if 0
|
||||
test_cases.emplace_back(new test_llama(1));
|
||||
|
@ -1,10 +1,14 @@
|
||||
#define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnings on Windows
|
||||
#include "ggml.h"
|
||||
|
||||
#include <cfloat>
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <cassert>
|
||||
#include <initializer_list>
|
||||
#include <vector>
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
#pragma warning(disable: 4244 4267) // possible loss of data
|
||||
@ -217,7 +221,8 @@ static bool check_gradient(
|
||||
int nargs,
|
||||
float eps,
|
||||
float max_error_abs,
|
||||
float max_error_rel) {
|
||||
float max_error_rel,
|
||||
std::vector<double> expected_vals) {
|
||||
|
||||
static int n_threads = -1;
|
||||
if (n_threads < 0) {
|
||||
@ -248,9 +253,10 @@ static bool check_gradient(
|
||||
// ggml_graph_dump_dot(gb, gf, "test-grad0-backward.dot");
|
||||
|
||||
for (int i = 0; i < nargs; ++i) {
|
||||
bool all_g0_bad = true;
|
||||
const int nelements = ggml_nelements(x[i]);
|
||||
for (int k = 0; k < nelements; ++k) {
|
||||
// compute gradient using finite differences
|
||||
// Calculate gradient numerically:
|
||||
const float x0 = ggml_get_f32_1d(x[i], k);
|
||||
const float xm = x0 - eps;
|
||||
const float xp = x0 + eps;
|
||||
@ -267,6 +273,28 @@ static bool check_gradient(
|
||||
const double f1 = ggml_get_f32_1d(f, 0);
|
||||
const double g0 = (f0 - f1)/(2.0*(double) eps);
|
||||
|
||||
// The numerical calculation of the gradient fails around noncontinuities (e.g. 0 for ReLU).
|
||||
// In such cases, provide a vector of expected values and skip the comparison for failed calculations.
|
||||
if (!expected_vals.empty()) {
|
||||
bool matches_any = false;
|
||||
for (const double & ev : expected_vals) {
|
||||
const double error_abs = std::fabs(g0 - ev);
|
||||
if (error_abs > max_error_abs) {
|
||||
continue;
|
||||
}
|
||||
const double error_rel = g0 != 0.0 ? fabs(g0 - ev)/fabs(g0) : 0.0;
|
||||
if (error_rel > max_error_rel) {
|
||||
continue;
|
||||
}
|
||||
matches_any = true;
|
||||
break;
|
||||
}
|
||||
if (!matches_any) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
all_g0_bad = false;
|
||||
|
||||
ggml_set_f32_1d(x[i], k, x0);
|
||||
|
||||
// compute gradient using backward graph
|
||||
@ -278,7 +306,7 @@ static bool check_gradient(
|
||||
const double g1 = ggml_get_f32_1d(x[i]->grad, k);
|
||||
|
||||
const double error_abs = fabs(g0 - g1);
|
||||
const double error_rel = g0 != 0 ? fabs(g0 - g1)/fabs(g0) : 0;
|
||||
const double error_rel = g0 != 0.0 ? fabs(g0 - g1)/fabs(g0) : 0.0;
|
||||
|
||||
if (error_abs > max_error_abs || error_rel > max_error_rel) {
|
||||
printf("%s: ndims=%d, i=%d, k=%d, x0=%f, xm=%f, xp=%f, f0=%f, f1=%f, g0=%f, g1=%f, eps=%f, error_abs=%f, error_rel=%f\n",
|
||||
@ -287,6 +315,10 @@ static bool check_gradient(
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (all_g0_bad) {
|
||||
printf("%s: numerical calculation of the gradient failed for all values\n", op_name);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
@ -404,7 +436,7 @@ int main(int argc, const char ** argv) {
|
||||
seed_iter = rand();
|
||||
unsigned seed = rand();
|
||||
|
||||
printf("test-grad0: iter:%d/%d\n", iter, niter);
|
||||
printf("test-grad0: iter:%d/%d\n", (iter+1), niter);
|
||||
struct ggml_context * ctx0 = ggml_init(params);
|
||||
|
||||
get_random_dims(ne, 4);
|
||||
@ -424,7 +456,7 @@ int main(int argc, const char ** argv) {
|
||||
|
||||
struct ggml_tensor * f = ggml_sum(ctx0, ggml_add(ctx0, x[0], x[1]));
|
||||
|
||||
check_gradient("add f32", ctx0, x, f, ndims, nargs, 1e-3f, 2e-3f, 2e-3f);
|
||||
check_gradient("add f32", ctx0, x, f, ndims, nargs, 1e-3f, 2e-3f, 2e-3f, {});
|
||||
}
|
||||
}
|
||||
|
||||
@ -441,7 +473,7 @@ int main(int argc, const char ** argv) {
|
||||
|
||||
struct ggml_tensor * f = ggml_sum(ctx0, ggml_add(ctx0, x[0], x[1]));
|
||||
|
||||
check_gradient("add f16", ctx0, x, f, ndims, nargs, 1e-1f, 2e-1f, 2e-1f);
|
||||
check_gradient("add f16", ctx0, x, f, ndims, nargs, 1e-1f, 2e-1f, 2e-1f, {});
|
||||
}
|
||||
}
|
||||
|
||||
@ -458,7 +490,7 @@ int main(int argc, const char ** argv) {
|
||||
|
||||
struct ggml_tensor * f = ggml_sum(ctx0, ggml_sub(ctx0, x[0], x[1]));
|
||||
|
||||
check_gradient("sub", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
|
||||
check_gradient("sub", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {});
|
||||
}
|
||||
}
|
||||
|
||||
@ -475,7 +507,7 @@ int main(int argc, const char ** argv) {
|
||||
|
||||
struct ggml_tensor * f = ggml_sum(ctx0, ggml_mul(ctx0, x[0], x[1]));
|
||||
|
||||
check_gradient("mul", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
|
||||
check_gradient("mul", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
|
||||
}
|
||||
}
|
||||
|
||||
@ -492,7 +524,7 @@ int main(int argc, const char ** argv) {
|
||||
|
||||
struct ggml_tensor * f = ggml_sum(ctx0, ggml_div(ctx0, x[0], x[1]));
|
||||
|
||||
check_gradient("div", ctx0, x, f, ndims, nargs, 1e-3f, 1e-1f, 1e-1f);
|
||||
check_gradient("div", ctx0, x, f, ndims, nargs, 1e-3f, 1e-1f, 1e-1f, {});
|
||||
}
|
||||
}
|
||||
|
||||
@ -509,7 +541,7 @@ int main(int argc, const char ** argv) {
|
||||
|
||||
struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqr(ctx0, x[0]));
|
||||
|
||||
check_gradient("sqr", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
|
||||
check_gradient("sqr", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
|
||||
}
|
||||
}
|
||||
|
||||
@ -526,7 +558,7 @@ int main(int argc, const char ** argv) {
|
||||
|
||||
struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqrt(ctx0, x[0]));
|
||||
|
||||
check_gradient("sqrt", ctx0, x, f, ndims, nargs, 1e-3f, 2e-2f, 1e-1f);
|
||||
check_gradient("sqrt", ctx0, x, f, ndims, nargs, 1e-3f, 2e-2f, 1e-1f, {});
|
||||
}
|
||||
}
|
||||
|
||||
@ -543,7 +575,7 @@ int main(int argc, const char ** argv) {
|
||||
|
||||
struct ggml_tensor * f = ggml_sum(ctx0, ggml_log(ctx0, x[0]));
|
||||
|
||||
check_gradient("log", ctx0, x, f, ndims, nargs, 1e-3f, INFINITY, 1e-1f);
|
||||
check_gradient("log", ctx0, x, f, ndims, nargs, 1e-3f, INFINITY, 1e-1f, {});
|
||||
}
|
||||
}
|
||||
|
||||
@ -560,7 +592,7 @@ int main(int argc, const char ** argv) {
|
||||
|
||||
struct ggml_tensor * f = ggml_sum(ctx0, x[0]);
|
||||
|
||||
check_gradient("sum", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
|
||||
check_gradient("sum", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {});
|
||||
}
|
||||
}
|
||||
|
||||
@ -578,7 +610,7 @@ int main(int argc, const char ** argv) {
|
||||
|
||||
struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqr(ctx0, ggml_sum_rows(ctx0, x[0])));
|
||||
|
||||
check_gradient("sum_rows", ctx0, x, f, ndims, nargs, 1e-3f, 1e-2f, INFINITY);
|
||||
check_gradient("sum_rows", ctx0, x, f, ndims, nargs, 1e-3f, 1e-2f, INFINITY, {});
|
||||
}
|
||||
}
|
||||
|
||||
@ -596,7 +628,7 @@ int main(int argc, const char ** argv) {
|
||||
|
||||
struct ggml_tensor * f = ggml_sum(ctx0, ggml_mean(ctx0, x[0]));
|
||||
|
||||
check_gradient("mean", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
|
||||
check_gradient("mean", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {});
|
||||
}
|
||||
}
|
||||
|
||||
@ -614,7 +646,7 @@ int main(int argc, const char ** argv) {
|
||||
|
||||
struct ggml_tensor * f = ggml_sum(ctx0, ggml_argmax(ctx0, x[0]));
|
||||
|
||||
check_gradient("argmax", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
|
||||
check_gradient("argmax", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {});
|
||||
}
|
||||
}
|
||||
|
||||
@ -637,7 +669,7 @@ int main(int argc, const char ** argv) {
|
||||
|
||||
struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqr(ctx0, ggml_sub(ctx0, x[1], ggml_repeat(ctx0, x[0], x[1]))));
|
||||
|
||||
check_gradient("repeat", ctx0, x, f, ndims, nargs, 1e-3f, 1e-2f, INFINITY);
|
||||
check_gradient("repeat", ctx0, x, f, ndims, nargs, 1e-3f, 1e-2f, INFINITY, {});
|
||||
}
|
||||
}
|
||||
|
||||
@ -660,25 +692,25 @@ int main(int argc, const char ** argv) {
|
||||
|
||||
struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqr(ctx0, ggml_sub(ctx0, x[0], ggml_repeat_back(ctx0, x[1], x[0]))));
|
||||
|
||||
check_gradient("repeat back", ctx0, x, f, ndims, nargs, 1e-3f, 1e-2f, INFINITY);
|
||||
check_gradient("repeat back", ctx0, x, f, ndims, nargs, 1e-3f, 1e-2f, INFINITY, {});
|
||||
}
|
||||
}
|
||||
|
||||
// abs (finite differences do not work)
|
||||
//{
|
||||
// const int nargs = 1;
|
||||
// abs
|
||||
{
|
||||
const int nargs = 1;
|
||||
|
||||
// for (int ndims = 1; ndims <= 2; ++ndims) {
|
||||
// for (int i = 0; i < nargs; ++i) {
|
||||
// x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
|
||||
// ggml_set_param(ctx0, x[i]);
|
||||
// }
|
||||
for (int ndims = 1; ndims <= 4; ++ndims) {
|
||||
for (int i = 0; i < nargs; ++i) {
|
||||
x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
|
||||
ggml_set_param(ctx0, x[i]);
|
||||
}
|
||||
|
||||
// struct ggml_tensor * f = ggml_sum(ctx0, ggml_abs(ctx0, x[0]));
|
||||
struct ggml_tensor * f = ggml_sum(ctx0, ggml_abs(ctx0, x[0]));
|
||||
|
||||
// check_gradient("abs", ctx0, x, f, ndims, nargs, 1e-3f, INFINITY, 1e-3f);
|
||||
// }
|
||||
//}
|
||||
check_gradient("abs", ctx0, x, f, ndims, nargs, 1e-3f, INFINITY, 1e-3f, {-1.0, 1.0});
|
||||
}
|
||||
}
|
||||
|
||||
// sgn
|
||||
{
|
||||
@ -693,7 +725,7 @@ int main(int argc, const char ** argv) {
|
||||
|
||||
struct ggml_tensor* f = ggml_sum(ctx0, ggml_sgn(ctx0, x[0]));
|
||||
|
||||
check_gradient("sgn", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
|
||||
check_gradient("sgn", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {0.0});
|
||||
}
|
||||
}
|
||||
|
||||
@ -710,7 +742,7 @@ int main(int argc, const char ** argv) {
|
||||
|
||||
struct ggml_tensor* f = ggml_sum(ctx0, ggml_neg(ctx0, x[0]));
|
||||
|
||||
check_gradient("neg", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
|
||||
check_gradient("neg", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {});
|
||||
}
|
||||
}
|
||||
|
||||
@ -727,7 +759,7 @@ int main(int argc, const char ** argv) {
|
||||
|
||||
struct ggml_tensor* f = ggml_sum(ctx0, ggml_step(ctx0, x[0]));
|
||||
|
||||
check_gradient("step", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
|
||||
check_gradient("step", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {0.0});
|
||||
}
|
||||
}
|
||||
|
||||
@ -745,7 +777,7 @@ int main(int argc, const char ** argv) {
|
||||
|
||||
struct ggml_tensor* f = ggml_sum(ctx0, ggml_tanh(ctx0, x[0]));
|
||||
|
||||
check_gradient("tanh", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
|
||||
check_gradient("tanh", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {});
|
||||
}
|
||||
}
|
||||
|
||||
@ -776,7 +808,7 @@ int main(int argc, const char ** argv) {
|
||||
|
||||
GGML_PRINT_DEBUG("testing: mul_mat, [%lld, %lld] (%d) * [%lld, %lld] (%d)\n", x[1]->ne[0], x[1]->ne[1], x[1]->n_dims, x[0]->ne[0], x[0]->ne[1], x[0]->n_dims);
|
||||
|
||||
check_gradient("mul_mat", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
|
||||
check_gradient("mul_mat", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
|
||||
if (ndims == 2) {
|
||||
// check_mat_mul does not support ndims > 2
|
||||
check_mat_mul(m, x[1], x[0]);
|
||||
@ -800,7 +832,7 @@ int main(int argc, const char ** argv) {
|
||||
|
||||
struct ggml_tensor* f = ggml_sum(ctx0, ggml_elu(ctx0, x[0]));
|
||||
|
||||
check_gradient("elu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
|
||||
check_gradient("elu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {});
|
||||
}
|
||||
}
|
||||
|
||||
@ -817,7 +849,7 @@ int main(int argc, const char ** argv) {
|
||||
|
||||
struct ggml_tensor* f = ggml_sum(ctx0, ggml_relu(ctx0, x[0]));
|
||||
|
||||
check_gradient("relu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
|
||||
check_gradient("relu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {0.0, 1.0});
|
||||
}
|
||||
}
|
||||
|
||||
@ -835,7 +867,7 @@ int main(int argc, const char ** argv) {
|
||||
|
||||
struct ggml_tensor* f = ggml_sum(ctx0, ggml_gelu(ctx0, x[0]));
|
||||
|
||||
check_gradient("gelu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
|
||||
check_gradient("gelu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {});
|
||||
}
|
||||
}
|
||||
|
||||
@ -854,9 +886,9 @@ int main(int argc, const char ** argv) {
|
||||
|
||||
#ifdef GGML_SILU_FP16
|
||||
// due to GGML_SILU_FP16 the finite difference method will be slightly wrong -> increase error bounds.
|
||||
check_gradient("silu", ctx0, x, f, ndims, nargs, 1e-3f, 0.5, INFINITY);
|
||||
check_gradient("silu", ctx0, x, f, ndims, nargs, 1e-3f, 0.5, INFINITY, {});
|
||||
#else
|
||||
check_gradient("silu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
|
||||
check_gradient("silu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
|
||||
#endif
|
||||
}
|
||||
}
|
||||
@ -874,7 +906,7 @@ int main(int argc, const char ** argv) {
|
||||
|
||||
struct ggml_tensor * f = ggml_sum(ctx0, ggml_rms_norm(ctx0, x[0], 1e-6f));
|
||||
|
||||
check_gradient("rms_norm", ctx0, x, f, ndims, nargs, 1e-4f, 1.0f, INFINITY);
|
||||
check_gradient("rms_norm", ctx0, x, f, ndims, nargs, 1e-4f, 1.0f, INFINITY, {});
|
||||
}
|
||||
}
|
||||
|
||||
@ -892,7 +924,7 @@ int main(int argc, const char ** argv) {
|
||||
|
||||
struct ggml_tensor * f = ggml_sum(ctx0, ggml_scale(ctx0, x[0], s));
|
||||
|
||||
check_gradient("scale", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
|
||||
check_gradient("scale", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
|
||||
}
|
||||
}
|
||||
|
||||
@ -910,7 +942,7 @@ int main(int argc, const char ** argv) {
|
||||
|
||||
struct ggml_tensor * f = ggml_sum(ctx0, ggml_cpy(ctx0, x[0], x[1]));
|
||||
|
||||
check_gradient("cpy f32", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
|
||||
check_gradient("cpy f32", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
|
||||
}
|
||||
}
|
||||
|
||||
@ -928,7 +960,7 @@ int main(int argc, const char ** argv) {
|
||||
|
||||
struct ggml_tensor * f = ggml_sum(ctx0, ggml_cpy(ctx0, x[0], x[1]));
|
||||
|
||||
check_gradient("cpy f16", ctx0, x, f, ndims, nargs, 1e-1f, 1e-1f, INFINITY);
|
||||
check_gradient("cpy f16", ctx0, x, f, ndims, nargs, 1e-1f, 1e-1f, INFINITY, {});
|
||||
}
|
||||
}
|
||||
|
||||
@ -952,7 +984,7 @@ int main(int argc, const char ** argv) {
|
||||
|
||||
|
||||
struct ggml_tensor * f = ggml_sum(ctx0, ggml_reshape(ctx0, x[0], x[1]));
|
||||
check_gradient("reshape", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
|
||||
check_gradient("reshape", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
|
||||
}
|
||||
}
|
||||
|
||||
@ -976,7 +1008,7 @@ int main(int argc, const char ** argv) {
|
||||
|
||||
|
||||
struct ggml_tensor * f = ggml_sum(ctx0, ggml_reshape(ctx0, x[0], x[1]));
|
||||
check_gradient("reshape", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
|
||||
check_gradient("reshape", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
|
||||
}
|
||||
}
|
||||
|
||||
@ -1004,7 +1036,7 @@ int main(int argc, const char ** argv) {
|
||||
|
||||
struct ggml_tensor * f = ggml_sum(ctx0, ggml_acc(ctx0, x[0], x[1], x[0]->nb[1], x[0]->nb[2], x[0]->nb[3], offset));
|
||||
|
||||
check_gradient("acc 1d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
|
||||
check_gradient("acc 1d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
|
||||
}
|
||||
}
|
||||
|
||||
@ -1037,7 +1069,7 @@ int main(int argc, const char ** argv) {
|
||||
|
||||
struct ggml_tensor * f = ggml_sum(ctx0, ggml_acc(ctx0, x[0], x[1], x[0]->nb[1], x[0]->nb[2], x[0]->nb[3], offset));
|
||||
|
||||
check_gradient("acc 2d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
|
||||
check_gradient("acc 2d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
|
||||
}
|
||||
}
|
||||
|
||||
@ -1072,7 +1104,7 @@ int main(int argc, const char ** argv) {
|
||||
|
||||
struct ggml_tensor * f = ggml_sum(ctx0, ggml_acc(ctx0, x[0], x[1], x[0]->nb[1], x[0]->nb[2], x[0]->nb[3], offset));
|
||||
|
||||
check_gradient("acc 3d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
|
||||
check_gradient("acc 3d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
|
||||
}
|
||||
}
|
||||
|
||||
@ -1109,7 +1141,7 @@ int main(int argc, const char ** argv) {
|
||||
|
||||
struct ggml_tensor * f = ggml_sum(ctx0, ggml_acc(ctx0, x[0], x[1], x[0]->nb[1], x[0]->nb[2], x[0]->nb[3], offset));
|
||||
|
||||
check_gradient("acc 4d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
|
||||
check_gradient("acc 4d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
|
||||
}
|
||||
}
|
||||
|
||||
@ -1137,7 +1169,7 @@ int main(int argc, const char ** argv) {
|
||||
|
||||
struct ggml_tensor * f = ggml_sum(ctx0, ggml_set_1d(ctx0, x[0], x[1], offset));
|
||||
|
||||
check_gradient("set_1d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
|
||||
check_gradient("set_1d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
|
||||
}
|
||||
}
|
||||
|
||||
@ -1170,7 +1202,7 @@ int main(int argc, const char ** argv) {
|
||||
|
||||
struct ggml_tensor * f = ggml_sum(ctx0, ggml_set_2d(ctx0, x[0], x[1], x[1]->nb[1], offset));
|
||||
|
||||
check_gradient("set_2d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
|
||||
check_gradient("set_2d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
|
||||
}
|
||||
}
|
||||
|
||||
@ -1194,7 +1226,7 @@ int main(int argc, const char ** argv) {
|
||||
|
||||
struct ggml_tensor * f = ggml_sum(ctx0, ggml_view_1d(ctx0, x[0], nelem, offset));
|
||||
|
||||
check_gradient("view_1d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
|
||||
check_gradient("view_1d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
|
||||
}
|
||||
}
|
||||
|
||||
@ -1225,7 +1257,7 @@ int main(int argc, const char ** argv) {
|
||||
|
||||
struct ggml_tensor * f = ggml_sum(ctx0, ggml_view_2d(ctx0, x[0], ne2[0], ne2[1], nb2[1], offset));
|
||||
|
||||
check_gradient("view_2d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
|
||||
check_gradient("view_2d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
|
||||
}
|
||||
}
|
||||
|
||||
@ -1257,7 +1289,7 @@ int main(int argc, const char ** argv) {
|
||||
|
||||
struct ggml_tensor * f = ggml_sum(ctx0, ggml_view_3d(ctx0, x[0], ne2[0], ne2[1], ne2[2], nb2[1], nb2[2], offset));
|
||||
|
||||
check_gradient("view_3d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
|
||||
check_gradient("view_3d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
|
||||
}
|
||||
}
|
||||
|
||||
@ -1291,7 +1323,7 @@ int main(int argc, const char ** argv) {
|
||||
// sum requires contiguous tensor rows
|
||||
struct ggml_tensor * f = ggml_sum(ctx0, ggml_cont(ctx0, ggml_permute(ctx0, x[0], ax0, ax1, ax2, ax3)));
|
||||
|
||||
check_gradient("permute", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
|
||||
check_gradient("permute", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
|
||||
}
|
||||
}
|
||||
|
||||
@ -1319,7 +1351,7 @@ int main(int argc, const char ** argv) {
|
||||
// sum requires contiguous tensor rows
|
||||
struct ggml_tensor * f = ggml_sum(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, x[0])));
|
||||
|
||||
check_gradient("transpose", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
|
||||
check_gradient("transpose", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
|
||||
}
|
||||
}
|
||||
|
||||
@ -1337,7 +1369,7 @@ int main(int argc, const char ** argv) {
|
||||
|
||||
struct ggml_tensor * f = ggml_sum(ctx0, ggml_get_rows(ctx0, x[0], x[1]));
|
||||
|
||||
check_gradient("get_rows", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
|
||||
check_gradient("get_rows", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
|
||||
}
|
||||
|
||||
// diag_mask_inf
|
||||
@ -1353,7 +1385,7 @@ int main(int argc, const char ** argv) {
|
||||
|
||||
struct ggml_tensor * f = ggml_sum(ctx0, ggml_diag_mask_inf(ctx0, x[0], n_past));
|
||||
|
||||
check_gradient("diag_mask_inf", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
|
||||
check_gradient("diag_mask_inf", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
|
||||
}
|
||||
|
||||
// diag_mask_zero
|
||||
@ -1369,7 +1401,7 @@ int main(int argc, const char ** argv) {
|
||||
|
||||
struct ggml_tensor * f = ggml_sum(ctx0, ggml_diag_mask_zero(ctx0, x[0], n_past));
|
||||
|
||||
check_gradient("diag_mask_zero", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
|
||||
check_gradient("diag_mask_zero", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
|
||||
}
|
||||
|
||||
// softmax
|
||||
@ -1395,7 +1427,7 @@ int main(int argc, const char ** argv) {
|
||||
1.0f - eps),
|
||||
ggml_new_f32(ctx0, eps))));
|
||||
|
||||
check_gradient("softmax", ctx0, x, f, ndims, nargs, 1e-3f, 2e-1f, INFINITY);
|
||||
check_gradient("softmax", ctx0, x, f, ndims, nargs, 1e-3f, 2e-1f, INFINITY, {});
|
||||
// NOTE: softmax forward is computed using f16 table lookup instead of using actual expf, but backward assumes actual expf.
|
||||
// this may result in different gradients too finite differences.
|
||||
// when this test reports errors, first try to replace the table lookup with actual expf and test again to see if just that was the cause.
|
||||
@ -1412,7 +1444,7 @@ int main(int argc, const char ** argv) {
|
||||
get_random_dims(ne2, 4);
|
||||
|
||||
for (int ndims = 1; ndims <= 4; ++ndims) {
|
||||
x[0] = get_random_tensor_f32(ctx0, ndims, ne2, -0.1f, 0.1f);
|
||||
x[0] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
|
||||
x[1] = get_random_tensor_f32(ctx0, ndims, ne2, 0.0f, 1.0f);
|
||||
// the second argument to cross_entropy_loss must sum up to 1 for each row
|
||||
int nr = ggml_nrows(x[1]);
|
||||
@ -1430,7 +1462,7 @@ int main(int argc, const char ** argv) {
|
||||
|
||||
struct ggml_tensor * f = ggml_cross_entropy_loss(ctx0, x[0], x[1]);
|
||||
|
||||
check_gradient("cross_entropy_loss", ctx0, x, f, ndims, nargs, 1e-4f, 1e-3f, INFINITY);
|
||||
check_gradient("cross_entropy_loss", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
|
||||
}
|
||||
}
|
||||
|
||||
@ -1468,7 +1500,7 @@ int main(int argc, const char ** argv) {
|
||||
struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], p, n_rot, mode));
|
||||
|
||||
GGML_PRINT_DEBUG("rope f32: n_past: %d n_rot: %d mode: %d\n", n_past, n_rot, mode);
|
||||
check_gradient("rope f32", ctx0, x, f, ndims, nargs, 1e-2f, 1e-3f, INFINITY);
|
||||
check_gradient("rope f32", ctx0, x, f, ndims, nargs, 1e-2f, 1e-3f, INFINITY, {});
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1508,12 +1540,93 @@ int main(int argc, const char ** argv) {
|
||||
struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], p, n_rot, mode));
|
||||
|
||||
GGML_PRINT_DEBUG("rope f16: n_past: %d n_rot: %d mode: %d\n", n_past, n_rot, mode);
|
||||
check_gradient("rope f16", ctx0, x, f, ndims, nargs, 1e-1f, 1e-1f, INFINITY);
|
||||
check_gradient("rope f16", ctx0, x, f, ndims, nargs, 1e-1f, 1e-1f, INFINITY, {});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// im2col f32
|
||||
{
|
||||
srand(seed);
|
||||
const int nargs = 1;
|
||||
const int ndims = 4;
|
||||
|
||||
for (const bool is_2D : {false, true}) {
|
||||
int64_t ne0[ndims];
|
||||
int64_t ne1[ndims];
|
||||
get_random_dims(ne0, ndims);
|
||||
get_random_dims(ne1, ndims);
|
||||
|
||||
// // Ensure that the output is not zero-sized:
|
||||
ne1[0] += 8;
|
||||
ne1[1] += 8;
|
||||
|
||||
if (is_2D) {
|
||||
ne1[2] = ne0[2];
|
||||
} else {
|
||||
ne1[1] = ne0[1];
|
||||
ne0[3] = 1;
|
||||
ne1[3] = 1;
|
||||
}
|
||||
|
||||
// The order of arguments is swapped because the first tensor is only used for its shape.
|
||||
x[1] = get_random_tensor_f16(ctx0, ndims, ne0, -1.0f, 1.0f);
|
||||
x[0] = get_random_tensor_f32(ctx0, ndims, ne1, -1.0f, 1.0f);
|
||||
|
||||
ggml_set_param(ctx0, x[0]);
|
||||
|
||||
const int s0 = 1 + irand(2);
|
||||
const int s1 = is_2D ? 1 + irand(2) : 0;
|
||||
const int p0 = 0 + irand(2);
|
||||
const int p1 = is_2D ? 0 + irand(2) : 0;
|
||||
const int d0 = 1 + irand(2);
|
||||
const int d1 = is_2D ? 1 + irand(2) : 0;
|
||||
|
||||
struct ggml_tensor * f = ggml_sum(ctx0, ggml_im2col(ctx0, x[1], x[0], s0, s1, p0, p1, d0, d1, is_2D, GGML_TYPE_F32));
|
||||
|
||||
GGML_PRINT_DEBUG("im2col f32: is_2D=%s, s0=%d, s1=%d, p0=%d, p1=%d, d0=%d, d1=%d\n", is_2D ? "yes" : "no", s0, s1, p0, p1, d0, d1);
|
||||
check_gradient("im2col f32", ctx0, x, f, ndims, nargs, 1e-2f, 1e-3f, INFINITY, {});
|
||||
}
|
||||
}
|
||||
|
||||
// pool_2d f32
|
||||
{
|
||||
srand(seed);
|
||||
const int nargs = 1;
|
||||
const int ndims = 4;
|
||||
|
||||
for (const enum ggml_op_pool op : {GGML_OP_POOL_AVG, GGML_OP_POOL_MAX}) {
|
||||
int64_t ne0[ndims];
|
||||
get_random_dims(ne0, ndims);
|
||||
|
||||
ne0[0] += 8;
|
||||
ne0[1] += 8;
|
||||
|
||||
x[0] = get_random_tensor_f32(ctx0, ndims, ne0, -1.0f, 1.0f);
|
||||
|
||||
ggml_set_param(ctx0, x[0]);
|
||||
|
||||
const int k0 = 2 + irand(2);
|
||||
const int k1 = 2 + irand(2);
|
||||
const int s0 = 2 + irand(2);
|
||||
const int s1 = 2 + irand(2);
|
||||
const int p0 = 0 + irand(2);
|
||||
const int p1 = 0 + irand(2);
|
||||
|
||||
struct ggml_tensor * f = ggml_sum(ctx0, ggml_pool_2d(ctx0, x[0], op, k0, k1, s0, s1, p0, p1));
|
||||
|
||||
GGML_PRINT_DEBUG("ggml_pool_2d f32: op=%s k0=%d, k1=%d, s0=%d, s1=%d, p0=%d, p1=%d\n",
|
||||
op == GGML_OP_POOL_MAX ? "max" : "avg", k0, k1, s0, s1, p0, p1);
|
||||
std::vector<double> expected_vals;
|
||||
if (op == GGML_OP_POOL_MAX) {
|
||||
expected_vals.push_back(0.0);
|
||||
expected_vals.push_back(1.0);
|
||||
}
|
||||
check_gradient("ggml_pool_2d f32", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, expected_vals);
|
||||
}
|
||||
}
|
||||
|
||||
// flash_attn f32
|
||||
// TODO: adapt to ggml_flash_attn_ext() changes
|
||||
//{
|
||||
@ -1553,7 +1666,7 @@ int main(int argc, const char ** argv) {
|
||||
|
||||
// struct ggml_tensor * f = ggml_sum(ctx0, ggml_flash_attn(ctx0, x[0], x[1], x[2], (masked == 0)));
|
||||
|
||||
// check_gradient("flash_attn f32", ctx0, x, f, ndims, nargs, 1.5e-4f, 1e-3f, INFINITY);
|
||||
// check_gradient("flash_attn f32", ctx0, x, f, ndims, nargs, 1.5e-4f, 1e-3f, INFINITY, {});
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
Loading…
Reference in New Issue
Block a user