mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-24 02:14:35 +00:00
ggml: fix gradient allocation logic (ggml/966)
* ggml: fix gradient allocation logic * gradient allocation in ggml_build_backward_expand * fixup * fix test-backend-ops grad * suggestions by slaren * fix test1.c * fix legacy opt API * fix test-grad0 * remove keep arg
This commit is contained in:
parent
cad341d889
commit
7254cdf7e8
@ -577,10 +577,10 @@ extern "C" {
|
||||
|
||||
// this tensor...
|
||||
enum ggml_tensor_flag {
|
||||
GGML_TENSOR_FLAG_INPUT = 1, // ...is an input for the GGML compute graph
|
||||
GGML_TENSOR_FLAG_OUTPUT = 2, // ...is an output for the GGML compute graph
|
||||
GGML_TENSOR_FLAG_PARAM = 4, // ...contains trainable parameters
|
||||
GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up)
|
||||
GGML_TENSOR_FLAG_INPUT = 1, // ...is an input for the GGML compute graph
|
||||
GGML_TENSOR_FLAG_OUTPUT = 2, // ...is an output for the GGML compute graph
|
||||
GGML_TENSOR_FLAG_PARAM = 4, // ...contains trainable parameters
|
||||
GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up)
|
||||
};
|
||||
|
||||
// n-dimensional tensor
|
||||
@ -1410,14 +1410,14 @@ extern "C" {
|
||||
// supports 3D: a->ne[2] == b->ne[1]
|
||||
GGML_API struct ggml_tensor * ggml_get_rows(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b);
|
||||
struct ggml_tensor * a, // data
|
||||
struct ggml_tensor * b); // row indices
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_get_rows_back(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
struct ggml_tensor * c);
|
||||
struct ggml_tensor * a, // gradients of ggml_get_rows result
|
||||
struct ggml_tensor * b, // row indices
|
||||
struct ggml_tensor * c); // data for ggml_get_rows, only used for its shape
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_diag(
|
||||
struct ggml_context * ctx,
|
||||
@ -1568,9 +1568,9 @@ extern "C" {
|
||||
// a - dy
|
||||
GGML_API struct ggml_tensor * ggml_rope_back(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
struct ggml_tensor * c,
|
||||
struct ggml_tensor * a, // gradients of ggml_rope result
|
||||
struct ggml_tensor * b, // positions
|
||||
struct ggml_tensor * c, // freq factors
|
||||
int n_dims,
|
||||
int mode,
|
||||
int n_ctx_orig,
|
||||
@ -2036,15 +2036,15 @@ extern "C" {
|
||||
// loss function
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_cross_entropy_loss(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b);
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a, // logits
|
||||
struct ggml_tensor * b); // labels
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_cross_entropy_loss_back(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
struct ggml_tensor * c);
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a, // logits
|
||||
struct ggml_tensor * b, // labels
|
||||
struct ggml_tensor * c); // gradients of cross_entropy_loss result
|
||||
|
||||
// AdamW optimizer step
|
||||
// Paper: https://arxiv.org/pdf/1711.05101v3.pdf
|
||||
@ -2066,7 +2066,7 @@ extern "C" {
|
||||
GGML_API void ggml_set_loss(struct ggml_tensor * tensor);
|
||||
|
||||
GGML_API void ggml_build_forward_expand (struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
|
||||
GGML_API void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool accumulate, bool keep);
|
||||
GGML_API void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool accumulate);
|
||||
|
||||
GGML_API void ggml_build_opt_adamw(
|
||||
struct ggml_context * ctx,
|
||||
|
1464
ggml/src/ggml.c
1464
ggml/src/ggml.c
File diff suppressed because it is too large
Load Diff
@ -1,6 +1,6 @@
|
||||
// This file defines tests for various GGML ops and backends.
|
||||
// For the forward pass it asserts that the results of multiple backends computing the same GGML ops are consistent.
|
||||
// For the backwards pass it asserts that the gradients from backpropagation are consistent
|
||||
// For the backward pass it asserts that the gradients from backpropagation are consistent
|
||||
// with the gradients obtained via the method of finite differences ("grad" mode, this is optional).
|
||||
// It is also possible to check the performance ("perf" mode).
|
||||
//
|
||||
@ -740,7 +740,7 @@ struct test_case {
|
||||
|
||||
ggml_tensor * out = build_graph(ctx);
|
||||
|
||||
if (op_name != nullptr && op_desc(out) != op_name) {
|
||||
if ((op_name != nullptr && op_desc(out) != op_name) || out->op == GGML_OP_OPT_STEP_ADAMW) {
|
||||
//printf(" %s: skipping\n", op_desc(out).c_str());
|
||||
ggml_free(ctx);
|
||||
return true;
|
||||
@ -749,11 +749,6 @@ struct test_case {
|
||||
printf(" %s(%s): ", op_desc(out).c_str(), vars().c_str());
|
||||
fflush(stdout);
|
||||
|
||||
if (out->grad == nullptr) {
|
||||
printf("backwards pass not supported \n");
|
||||
ggml_free(ctx);
|
||||
return true;
|
||||
}
|
||||
if (out->type != GGML_TYPE_F32) {
|
||||
ggml_free(ctx);
|
||||
printf("not supported [%s->type != FP32]\n", out->name);
|
||||
@ -762,18 +757,26 @@ struct test_case {
|
||||
|
||||
// check if the backend supports the ops
|
||||
bool supported = true;
|
||||
bool any_params = false;
|
||||
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
||||
if (!ggml_backend_supports_op(backend, t)) {
|
||||
printf("not supported [%s] ", ggml_backend_name(backend));
|
||||
supported = false;
|
||||
break;
|
||||
}
|
||||
if ((t->flags & GGML_TENSOR_FLAG_PARAM) && t->type != GGML_TYPE_F32) {
|
||||
printf("not supported [%s->type != FP32] ", t->name);
|
||||
supported = false;
|
||||
break;
|
||||
if ((t->flags & GGML_TENSOR_FLAG_PARAM)) {
|
||||
any_params = true;
|
||||
if (t->type != GGML_TYPE_F32) {
|
||||
printf("not supported [%s->type != FP32] ", t->name);
|
||||
supported = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!any_params) {
|
||||
printf("not supported [%s] \n", op_name);
|
||||
supported = false;
|
||||
}
|
||||
if (!supported) {
|
||||
printf("\n");
|
||||
ggml_free(ctx);
|
||||
@ -801,7 +804,7 @@ struct test_case {
|
||||
|
||||
ggml_build_forward_expand(gf, out);
|
||||
ggml_graph_cpy(gf, gb);
|
||||
ggml_build_backward_expand(ctx, gf, gb, false, false);
|
||||
ggml_build_backward_expand(ctx, gf, gb, false);
|
||||
if (expect.size() != 1 || expect[0] != 0.0f) {
|
||||
GGML_ASSERT(ggml_graph_n_nodes(gb) > ggml_graph_n_nodes(gf));
|
||||
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
||||
@ -984,7 +987,7 @@ struct test_example : public test_case {
|
||||
}
|
||||
// In order to also check the gradients for your op, add calls like ggml_set_param(ctx, a)
|
||||
// immediately after you create the tensors.
|
||||
// This is optional and only makes sense if a backwards pass has actually been implemented for the new op.
|
||||
// This is optional and only makes sense if a backward pass has actually been implemented for the new op.
|
||||
};
|
||||
|
||||
|
||||
@ -1223,7 +1226,7 @@ struct test_set : public test_case {
|
||||
offset += ((ne_dst[i] - ne[i])/2)*dst->nb[i];
|
||||
}
|
||||
ggml_tensor * out = ggml_set(ctx, dst, src,
|
||||
// The backwards pass requires setting a contiguous region:
|
||||
// The backward pass requires setting a contiguous region:
|
||||
src->nb[1], src->nb[2], src->nb[3], offset);
|
||||
ggml_set_name(out, "out");
|
||||
|
||||
@ -1335,7 +1338,7 @@ struct test_bin_bcast : public test_case {
|
||||
ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||
ggml_set_name(b, "b");
|
||||
|
||||
// The backwards pass supports broadcasting only for GGML_ADD:
|
||||
// The backward pass supports broadcasting only for GGML_ADD:
|
||||
const bool grad_supported = op == ggml_add || ggml_are_same_shape(a, b);
|
||||
if (grad_supported) {
|
||||
ggml_set_param(ctx, a);
|
||||
@ -1830,7 +1833,7 @@ struct test_log : public test_case {
|
||||
|
||||
void initialize_tensors(ggml_context * ctx) override {
|
||||
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
||||
// log(1) == 0, cluster values there to keep the sum low for better precision in the backwards pass:
|
||||
// log(1) == 0, cluster values there to keep the sum low for better precision in the backward pass:
|
||||
init_tensor_uniform(t, 0.9f, 1.1f);
|
||||
}
|
||||
}
|
||||
@ -3257,7 +3260,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,1,2,1}, 1, 0, 1));
|
||||
test_cases.emplace_back(new test_conv_transpose_1d({2,1,1,1}, {3,1,1,1}, 1, 0, 1));
|
||||
|
||||
for (int ne3 : {1, 3}) { // CUDA backwards pass only supports ne3 == 1
|
||||
for (int ne3 : {1, 3}) { // CUDA backward pass only supports ne3 == 1
|
||||
test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, ne3}, {1, 1, 1, 1}));
|
||||
test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, ne3}, {2, 1, 1, 1}));
|
||||
test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, ne3}, {1, 2, 1, 1}));
|
||||
|
@ -240,12 +240,14 @@ static bool check_gradient(
|
||||
struct ggml_cgraph * gb = ggml_new_graph_custom(ctx0, GGML_DEFAULT_GRAPH_SIZE, true);
|
||||
ggml_build_forward_expand(gf, f);
|
||||
ggml_graph_cpy(gf, gb);
|
||||
ggml_build_backward_expand(ctx0, gf, gb, false, false);
|
||||
ggml_build_backward_expand(ctx0, gf, gb, false);
|
||||
|
||||
ggml_graph_compute_with_ctx(ctx0, gf, n_threads);
|
||||
|
||||
ggml_graph_reset (gf);
|
||||
ggml_set_f32 (f->grad, 1.0f);
|
||||
ggml_graph_reset(gb);
|
||||
if (f->grad) {
|
||||
ggml_set_f32(f->grad, 1.0f);
|
||||
}
|
||||
|
||||
ggml_graph_compute_with_ctx(ctx0, gb, n_threads);
|
||||
|
||||
@ -298,8 +300,10 @@ static bool check_gradient(
|
||||
ggml_set_f32_1d(x[i], k, x0);
|
||||
|
||||
// compute gradient using backward graph
|
||||
ggml_graph_reset (gf);
|
||||
ggml_set_f32 (f->grad, 1.0f);
|
||||
ggml_graph_reset(gb);
|
||||
if (f->grad) {
|
||||
ggml_set_f32(f->grad, 1.0f);
|
||||
}
|
||||
|
||||
ggml_graph_compute_with_ctx(ctx0, gb, n_threads);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user