mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-24 18:34:36 +00:00
tests: add gradient tests for all backends (ggml/932)
* tests: add gradient checking to test-backend-ops * remove old comment * reorder includes * adjust SIN/COS parameters * add documentation, use supports_op if possible
This commit is contained in:
parent
dbbebcab33
commit
202084d31d
@ -1272,7 +1272,7 @@ extern "C" {
|
||||
size_t nb1,
|
||||
size_t nb2,
|
||||
size_t nb3,
|
||||
size_t offset);
|
||||
size_t offset); // in bytes
|
||||
|
||||
// b -> view(a,offset,nb1,nb2,3), return view(a)
|
||||
GGML_API struct ggml_tensor * ggml_set_inplace(
|
||||
@ -1282,19 +1282,19 @@ extern "C" {
|
||||
size_t nb1,
|
||||
size_t nb2,
|
||||
size_t nb3,
|
||||
size_t offset);
|
||||
size_t offset); // in bytes
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_set_1d(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
size_t offset);
|
||||
size_t offset); // in bytes
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_set_1d_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
size_t offset);
|
||||
size_t offset); // in bytes
|
||||
|
||||
// b -> view(a,offset,nb1,nb2,3), return modified a
|
||||
GGML_API struct ggml_tensor * ggml_set_2d(
|
||||
@ -1302,7 +1302,7 @@ extern "C" {
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
size_t nb1,
|
||||
size_t offset);
|
||||
size_t offset); // in bytes
|
||||
|
||||
// b -> view(a,offset,nb1,nb2,3), return view(a)
|
||||
GGML_API struct ggml_tensor * ggml_set_2d_inplace(
|
||||
@ -1310,7 +1310,7 @@ extern "C" {
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
size_t nb1,
|
||||
size_t offset);
|
||||
size_t offset); // in bytes
|
||||
|
||||
// a -> b, return view(b)
|
||||
GGML_API struct ggml_tensor * ggml_cpy(
|
||||
|
@ -827,6 +827,10 @@ GGML_CALL static bool ggml_backend_cpu_supports_op(ggml_backend_t backend, const
|
||||
op->type != GGML_TYPE_IQ1_M; // missing type_traits.from_float
|
||||
case GGML_OP_MUL_MAT:
|
||||
return op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == ggml_internal_get_type_traits(op->src[0]->type).vec_dot_type;
|
||||
case GGML_OP_ROPE_BACK:
|
||||
return op->src[2] == NULL && (op->op_params[2] & 4) == 0;
|
||||
case GGML_OP_IM2COL_BACK:
|
||||
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
|
||||
default:
|
||||
return true;
|
||||
}
|
||||
|
@ -27,6 +27,7 @@
|
||||
#include "ggml-cuda/rope.cuh"
|
||||
#include "ggml-cuda/scale.cuh"
|
||||
#include "ggml-cuda/softmax.cuh"
|
||||
#include "ggml-cuda/sum.cuh"
|
||||
#include "ggml-cuda/sumrows.cuh"
|
||||
#include "ggml-cuda/tsembd.cuh"
|
||||
#include "ggml-cuda/unary.cuh"
|
||||
@ -2180,6 +2181,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
ggml_cuda_dup(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_ADD1: // TODO: more efficient implementation
|
||||
ggml_cuda_op_add(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_SUB:
|
||||
@ -2196,6 +2198,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
break;
|
||||
case GGML_OP_UNARY:
|
||||
switch (ggml_get_unary_op(dst)) {
|
||||
case GGML_UNARY_OP_NEG:
|
||||
ggml_cuda_op_neg(ctx, dst);
|
||||
break;
|
||||
case GGML_UNARY_OP_GELU:
|
||||
ggml_cuda_op_gelu(ctx, dst);
|
||||
break;
|
||||
@ -2304,6 +2309,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
case GGML_OP_POOL_2D:
|
||||
ggml_cuda_op_pool2d(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_SUM:
|
||||
ggml_cuda_op_sum(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_SUM_ROWS:
|
||||
ggml_cuda_op_sum_rows(ctx, dst);
|
||||
break;
|
||||
@ -2748,6 +2756,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
||||
switch (op->op) {
|
||||
case GGML_OP_UNARY:
|
||||
switch (ggml_get_unary_op(op)) {
|
||||
case GGML_UNARY_OP_NEG:
|
||||
case GGML_UNARY_OP_GELU:
|
||||
case GGML_UNARY_OP_SILU:
|
||||
case GGML_UNARY_OP_RELU:
|
||||
@ -2877,6 +2886,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
||||
case GGML_OP_TRANSPOSE:
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_ADD1:
|
||||
case GGML_OP_SUB:
|
||||
case GGML_OP_MUL:
|
||||
case GGML_OP_DIV:
|
||||
@ -2896,7 +2906,9 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
||||
case GGML_OP_ROPE:
|
||||
return ggml_is_contiguous(op->src[0]);
|
||||
case GGML_OP_IM2COL:
|
||||
return op->src[0]->type == GGML_TYPE_F16;
|
||||
case GGML_OP_POOL_2D:
|
||||
case GGML_OP_SUM:
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_ARGSORT:
|
||||
case GGML_OP_ACC:
|
||||
|
@ -1,6 +1,6 @@
|
||||
#include "common.cuh"
|
||||
#include "cross-entropy-loss.cuh"
|
||||
#include "sumrows.cuh"
|
||||
#include "sum.cuh"
|
||||
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
@ -102,5 +102,5 @@ void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor *
|
||||
cross_entropy_loss_f32<<<blocks_num, blocks_dim, shmem, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
|
||||
|
||||
// Combine results from individual blocks:
|
||||
sum_rows_f32_cuda(dst_tmp.ptr, dst_d, blocks_num.x, 1, stream);
|
||||
sum_f32_cuda(pool, dst_tmp.ptr, dst_d, blocks_num.x, stream);
|
||||
}
|
||||
|
41
ggml/src/ggml-cuda/sum.cu
Normal file
41
ggml/src/ggml-cuda/sum.cu
Normal file
@ -0,0 +1,41 @@
|
||||
#include "sumrows.cuh"
|
||||
#include "sum.cuh"
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
|
||||
#include <cub/cub.cuh>
|
||||
using namespace cub;
|
||||
#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
|
||||
|
||||
void sum_f32_cuda(ggml_cuda_pool & pool, const float * x, float * dst, const int64_t ne, cudaStream_t stream) {
|
||||
#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
|
||||
size_t tmp_size = 0;
|
||||
DeviceReduce::Sum(nullptr, tmp_size, x, dst, ne, stream);
|
||||
ggml_cuda_pool_alloc<uint8_t> tmp_alloc(pool, tmp_size);
|
||||
DeviceReduce::Sum(tmp_alloc.ptr, tmp_size, x, dst, ne, stream);
|
||||
#else
|
||||
// Use (inefficient) sum_rows implementation as a fallback.
|
||||
// For AMD there is rocPRIM which could be used as a drop-in replacement via hipcub but this would require C++11 -> C++14.
|
||||
sum_rows_f32_cuda(x, dst, ne, 1, stream);
|
||||
GGML_UNUSED(pool);
|
||||
#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
|
||||
}
|
||||
|
||||
void ggml_cuda_op_sum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
|
||||
const float * src0_d = (const float *) src0->data;
|
||||
float * dst_d = (float *) dst->data;
|
||||
|
||||
const int64_t ne = ggml_nelements(src0);
|
||||
|
||||
ggml_cuda_pool & pool = ctx.pool();
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
sum_f32_cuda(pool, src0_d, dst_d, ne, stream);
|
||||
}
|
5
ggml/src/ggml-cuda/sum.cuh
Normal file
5
ggml/src/ggml-cuda/sum.cuh
Normal file
@ -0,0 +1,5 @@
|
||||
#include "common.cuh"
|
||||
|
||||
void sum_f32_cuda(ggml_cuda_pool & pool, const float * x, float * dst, const int64_t ne, cudaStream_t stream);
|
||||
|
||||
void ggml_cuda_op_sum(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
@ -1,5 +1,15 @@
|
||||
#include "unary.cuh"
|
||||
|
||||
static __global__ void neg_f32(const float * x, float * dst, const int k) {
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i >= k) {
|
||||
return;
|
||||
}
|
||||
|
||||
dst[i] = -x[i];
|
||||
}
|
||||
|
||||
static __global__ void gelu_f32(const float * x, float * dst, const int k) {
|
||||
const float GELU_COEF_A = 0.044715f;
|
||||
const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
||||
@ -119,6 +129,11 @@ static __global__ void cos_f32(const float * x, float * dst, const int k) {
|
||||
dst[i] = cosf(x[i]);
|
||||
}
|
||||
|
||||
static void neg_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
||||
const int num_blocks = (k + CUDA_NEG_BLOCK_SIZE - 1) / CUDA_NEG_BLOCK_SIZE;
|
||||
neg_f32<<<num_blocks, CUDA_NEG_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
||||
}
|
||||
|
||||
static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
|
||||
const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
|
||||
gelu_f32<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
||||
@ -184,6 +199,20 @@ static void cos_f32_cuda(const float * x, float * dst, const int k, cudaStream_t
|
||||
cos_f32<<<num_blocks, CUDA_COS_BLOCK_SIZE, 0, stream>>>(x, dst, k);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_neg(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const float * src0_d = (const float *)src0->data;
|
||||
float * dst_d = (float *)dst->data;
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
|
||||
neg_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const float * src0_d = (const float *)src0->data;
|
||||
|
@ -1,5 +1,6 @@
|
||||
#include "common.cuh"
|
||||
|
||||
#define CUDA_NEG_BLOCK_SIZE 256
|
||||
#define CUDA_GELU_BLOCK_SIZE 256
|
||||
#define CUDA_SILU_BLOCK_SIZE 256
|
||||
#define CUDA_TANH_BLOCK_SIZE 256
|
||||
@ -12,6 +13,8 @@
|
||||
#define CUDA_SIN_BLOCK_SIZE 256
|
||||
#define CUDA_COS_BLOCK_SIZE 256
|
||||
|
||||
void ggml_cuda_op_neg(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
@ -5267,6 +5267,7 @@ struct ggml_tensor * ggml_concat(
|
||||
bool is_node = false;
|
||||
|
||||
if (a->grad || b->grad) {
|
||||
GGML_ABORT("fatal error"); // TODO: implement
|
||||
is_node = true;
|
||||
}
|
||||
|
||||
@ -5388,6 +5389,7 @@ struct ggml_tensor * ggml_leaky_relu(
|
||||
bool is_node = false;
|
||||
|
||||
if (!inplace && (a->grad)) {
|
||||
GGML_ABORT("fatal error"); // TODO: not implemented
|
||||
is_node = true;
|
||||
}
|
||||
|
||||
@ -5826,6 +5828,7 @@ static struct ggml_tensor * ggml_set_impl(
|
||||
// make a view of the destination
|
||||
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||
|
||||
GGML_ASSERT(offset < (size_t)(1 << 30));
|
||||
int32_t params[] = { nb1, nb2, nb3, offset, inplace ? 1 : 0 };
|
||||
ggml_set_op_params(result, params, sizeof(params));
|
||||
|
||||
@ -6783,14 +6786,12 @@ struct ggml_tensor * ggml_rope_back(
|
||||
GGML_ASSERT(ggml_is_vector(b));
|
||||
GGML_ASSERT(b->type == GGML_TYPE_I32);
|
||||
GGML_ASSERT(a->ne[2] == b->ne[0]);
|
||||
GGML_ASSERT(c == NULL && "freq factors not implemented yet");
|
||||
|
||||
GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet");
|
||||
|
||||
bool is_node = false;
|
||||
|
||||
if (a->grad) {
|
||||
is_node = false; // TODO: implement backward
|
||||
GGML_ASSERT(false && "backwards pass not implemented");
|
||||
is_node = false;
|
||||
}
|
||||
|
||||
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
|
||||
@ -6808,6 +6809,7 @@ struct ggml_tensor * ggml_rope_back(
|
||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||
result->src[0] = a;
|
||||
result->src[1] = b;
|
||||
result->src[2] = c;
|
||||
|
||||
return result;
|
||||
}
|
||||
@ -7361,6 +7363,11 @@ struct ggml_tensor * ggml_argsort(
|
||||
enum ggml_sort_order order) {
|
||||
bool is_node = false;
|
||||
|
||||
if (a->grad) {
|
||||
GGML_ABORT("fatal error"); // TODO: not implemented
|
||||
is_node = true;
|
||||
}
|
||||
|
||||
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_I32, GGML_MAX_DIMS, a->ne);
|
||||
|
||||
ggml_set_op_params_i32(result, 0, (int32_t) order);
|
||||
@ -10953,9 +10960,6 @@ static void ggml_compute_forward_sum_f32(
|
||||
return;
|
||||
}
|
||||
|
||||
assert(ggml_is_scalar(dst));
|
||||
|
||||
|
||||
assert(ggml_is_scalar(dst));
|
||||
assert(src0->nb[0] == sizeof(float));
|
||||
|
||||
@ -18356,14 +18360,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
||||
if (src0->grad || src1->grad) {
|
||||
GGML_ASSERT(src0->type == tensor->type);
|
||||
GGML_ASSERT(tensor->grad->type == tensor->type);
|
||||
GGML_ASSERT(tensor->grad->type == src1->grad->type);
|
||||
GGML_ASSERT(!src1->grad || src1->grad->type == tensor->grad->type);
|
||||
|
||||
tensor_grad_view = ggml_view_4d(ctx,
|
||||
tensor->grad,
|
||||
src1->grad->ne[0],
|
||||
src1->grad->ne[1],
|
||||
src1->grad->ne[2],
|
||||
src1->grad->ne[3],
|
||||
tensor->grad, src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
|
||||
nb1, nb2, nb3, offset);
|
||||
}
|
||||
|
||||
@ -18432,9 +18432,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
||||
|
||||
memcpy(&offset, tensor->op_params, sizeof(offset));
|
||||
|
||||
size_t nb1 = tensor->nb[1];
|
||||
size_t nb2 = tensor->nb[2];
|
||||
size_t nb3 = tensor->nb[3];
|
||||
size_t nb1 = tensor->nb[1];
|
||||
size_t nb2 = tensor->nb[2];
|
||||
size_t nb3 = tensor->nb[3];
|
||||
|
||||
if (src0->type != src0->grad->type) {
|
||||
// gradient is typically F32, but src0 could be other type
|
||||
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user