mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 02:44:36 +00:00
tests : sync test-grad0 from ggml
This commit is contained in:
parent
fdd1860911
commit
65bdd52a86
@ -1,3 +1,4 @@
|
|||||||
|
#define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnigns on Windows
|
||||||
#include "ggml.h"
|
#include "ggml.h"
|
||||||
|
|
||||||
#include <math.h>
|
#include <math.h>
|
||||||
@ -5,6 +6,10 @@
|
|||||||
#include <stdlib.h>
|
#include <stdlib.h>
|
||||||
#include <assert.h>
|
#include <assert.h>
|
||||||
|
|
||||||
|
#if defined(_MSC_VER)
|
||||||
|
#pragma warning(disable: 4244 4267) // possible loss of data
|
||||||
|
#endif
|
||||||
|
|
||||||
#define MAX_NARGS 3
|
#define MAX_NARGS 3
|
||||||
|
|
||||||
#undef MIN
|
#undef MIN
|
||||||
@ -197,8 +202,23 @@ bool check_gradient(
|
|||||||
float max_error_abs,
|
float max_error_abs,
|
||||||
float max_error_rel) {
|
float max_error_rel) {
|
||||||
|
|
||||||
|
static int n_threads = -1;
|
||||||
|
if (n_threads < 0) {
|
||||||
|
n_threads = GGML_DEFAULT_N_THREADS;
|
||||||
|
|
||||||
|
const char *env = getenv("GGML_N_THREADS");
|
||||||
|
if (env) {
|
||||||
|
n_threads = atoi(env);
|
||||||
|
}
|
||||||
|
|
||||||
|
printf("GGML_N_THREADS = %d\n", n_threads);
|
||||||
|
}
|
||||||
|
|
||||||
struct ggml_cgraph gf = ggml_build_forward (f);
|
struct ggml_cgraph gf = ggml_build_forward (f);
|
||||||
|
gf.n_threads = n_threads;
|
||||||
|
|
||||||
struct ggml_cgraph gb = ggml_build_backward(ctx0, &gf, false);
|
struct ggml_cgraph gb = ggml_build_backward(ctx0, &gf, false);
|
||||||
|
gb.n_threads = n_threads;
|
||||||
|
|
||||||
ggml_graph_compute(ctx0, &gf);
|
ggml_graph_compute(ctx0, &gf);
|
||||||
ggml_graph_reset (&gf);
|
ggml_graph_reset (&gf);
|
||||||
|
Loading…
Reference in New Issue
Block a user