backend : add eval callback (#4935)

* backend : add eval callback

ggml-ci

* backend : group nodes in a single compute when user don't need them

* backend : clean-up the implementation

ggml-ci

* simple : do not perform tensor data copy if not needed

* simple : fix

* simple : no need for ggml_is_contiguous + fix bool parse

* llama : fix callback placement in llama_context_params

* backend : avoid double-ask callback calls

* simple : restore examples, imatrix will serve as a demo
This commit is contained in:
Georgi Gerganov 2024-01-17 18:39:41 +02:00 committed by GitHub
parent c918fe8dca
commit 44a1a4a41a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 64 additions and 2 deletions

View File

@ -802,6 +802,9 @@ struct ggml_backend_sched {
__attribute__((aligned(GGML_MEM_ALIGN))) __attribute__((aligned(GGML_MEM_ALIGN)))
#endif #endif
char context_buffer[GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS*sizeof(struct ggml_tensor) + sizeof(struct ggml_cgraph)]; char context_buffer[GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS*sizeof(struct ggml_tensor) + sizeof(struct ggml_cgraph)];
ggml_backend_sched_eval_callback callback_eval;
void * callback_eval_user_data;
}; };
#define hash_id(node) ggml_hash_find_or_insert(sched->hash_set, node) #define hash_id(node) ggml_hash_find_or_insert(sched->hash_set, node)
@ -1324,9 +1327,38 @@ static void sched_compute_splits(ggml_backend_sched_t sched) {
ggml_graph_dump_dot(split->graph, NULL, split_filename); ggml_graph_dump_dot(split->graph, NULL, split_filename);
#endif #endif
uint64_t compute_start_us = ggml_time_us(); uint64_t compute_start_us = ggml_time_us();
if (!sched->callback_eval) {
ggml_backend_graph_compute(split_backend, &split->graph); ggml_backend_graph_compute(split_backend, &split->graph);
//ggml_backend_synchronize(split_backend); // necessary to measure compute time //ggml_backend_synchronize(split_backend); // necessary to measure compute time
} else {
// similar to ggml_backend_compare_graph_backend
for (int j0 = 0; j0 < split->graph.n_nodes; j0++) {
struct ggml_tensor * t = split->graph.nodes[j0];
// check if the user needs data from this node
bool need = sched->callback_eval(t, true, sched->callback_eval_user_data);
int j1 = j0;
// determine the range [j0, j1] of nodes that can be computed together
while (!need && j1 < split->graph.n_nodes - 1) {
t = split->graph.nodes[++j1];
need = sched->callback_eval(t, true, sched->callback_eval_user_data);
}
struct ggml_cgraph gv = ggml_graph_view(&split->graph, j0, j1 + 1);
ggml_backend_graph_compute(split_backend, &gv);
if (need && !sched->callback_eval(t, false, sched->callback_eval_user_data)) {
break;
}
j0 = j1;
}
}
uint64_t compute_end_us = ggml_time_us(); uint64_t compute_end_us = ggml_time_us();
compute_us[split_backend_id] += compute_end_us - compute_start_us; compute_us[split_backend_id] += compute_end_us - compute_start_us;
} }
@ -1431,6 +1463,12 @@ void ggml_backend_sched_reset(ggml_backend_sched_t sched) {
sched_reset(sched); sched_reset(sched);
} }
void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data) {
sched->callback_eval = callback;
sched->callback_eval_user_data = user_data;
}
int ggml_backend_sched_get_n_splits(ggml_backend_sched_t sched) { int ggml_backend_sched_get_n_splits(ggml_backend_sched_t sched) {
return sched->n_splits; return sched->n_splits;
} }

View File

@ -148,6 +148,14 @@ extern "C" {
struct ggml_backend_sched; struct ggml_backend_sched;
typedef struct ggml_backend_sched * ggml_backend_sched_t; typedef struct ggml_backend_sched * ggml_backend_sched_t;
// when ask == true, the scheduler wants to know if the user wants to observe this node
// this allows the scheduler to batch nodes together in order to evaluate them in a single call
//
// when ask == false, the scheduler is passing the node tensor to the user for observation
// if the user returns false, the scheduler will cancel the graph compute
//
typedef bool (*ggml_backend_sched_eval_callback)(struct ggml_tensor * t, bool ask, void * user_data);
// Initialize a backend scheduler // Initialize a backend scheduler
GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size); GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size);
GGML_API void ggml_backend_sched_free(ggml_backend_sched_t sched); GGML_API void ggml_backend_sched_free(ggml_backend_sched_t sched);
@ -168,6 +176,9 @@ extern "C" {
// Reset all assignments and allocators - must be called before using the sched allocators to allocate inputs // Reset all assignments and allocators - must be called before using the sched allocators to allocate inputs
GGML_API void ggml_backend_sched_reset(ggml_backend_sched_t sched); GGML_API void ggml_backend_sched_reset(ggml_backend_sched_t sched);
// Set a callback to be called for each resulting node during graph compute
GGML_API void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data);
// //
// Utils // Utils
// //

View File

@ -1393,6 +1393,9 @@ struct llama_cparams {
bool mul_mat_q; bool mul_mat_q;
bool offload_kqv; bool offload_kqv;
ggml_backend_sched_eval_callback cb_eval;
void * cb_eval_user_data;
}; };
struct llama_layer { struct llama_layer {
@ -6254,6 +6257,7 @@ static int llama_decode_internal(
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head); //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
ggml_backend_sched_reset(lctx.sched); ggml_backend_sched_reset(lctx.sched);
ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
ggml_cgraph * gf = llama_build_graph(lctx, batch); ggml_cgraph * gf = llama_build_graph(lctx, batch);
@ -9276,6 +9280,8 @@ struct llama_context_params llama_context_default_params() {
/*.yarn_beta_fast =*/ 32.0f, /*.yarn_beta_fast =*/ 32.0f,
/*.yarn_beta_slow =*/ 1.0f, /*.yarn_beta_slow =*/ 1.0f,
/*.yarn_orig_ctx =*/ 0, /*.yarn_orig_ctx =*/ 0,
/*.cb_eval =*/ nullptr,
/*.cb_eval_user_data =*/ nullptr,
/*.type_k =*/ GGML_TYPE_F16, /*.type_k =*/ GGML_TYPE_F16,
/*.type_v =*/ GGML_TYPE_F16, /*.type_v =*/ GGML_TYPE_F16,
/*.mul_mat_q =*/ true, /*.mul_mat_q =*/ true,
@ -9416,6 +9422,9 @@ struct llama_context * llama_new_context_with_model(
hparams.n_yarn_orig_ctx != 0 ? hparams.n_yarn_orig_ctx : hparams.n_yarn_orig_ctx != 0 ? hparams.n_yarn_orig_ctx :
hparams.n_ctx_train; hparams.n_ctx_train;
cparams.cb_eval = params.cb_eval;
cparams.cb_eval_user_data = params.cb_eval_user_data;
auto rope_scaling_type = params.rope_scaling_type; auto rope_scaling_type = params.rope_scaling_type;
if (rope_scaling_type == LLAMA_ROPE_SCALING_UNSPECIFIED) { if (rope_scaling_type == LLAMA_ROPE_SCALING_UNSPECIFIED) {
rope_scaling_type = hparams.rope_scaling_type_train; rope_scaling_type = hparams.rope_scaling_type_train;

View File

@ -2,6 +2,7 @@
#define LLAMA_H #define LLAMA_H
#include "ggml.h" #include "ggml.h"
#include "ggml-backend.h"
#ifdef GGML_USE_CUBLAS #ifdef GGML_USE_CUBLAS
#include "ggml-cuda.h" #include "ggml-cuda.h"
#define LLAMA_MAX_DEVICES GGML_CUDA_MAX_DEVICES #define LLAMA_MAX_DEVICES GGML_CUDA_MAX_DEVICES
@ -231,6 +232,9 @@ extern "C" {
float yarn_beta_slow; // YaRN high correction dim float yarn_beta_slow; // YaRN high correction dim
uint32_t yarn_orig_ctx; // YaRN original context size uint32_t yarn_orig_ctx; // YaRN original context size
ggml_backend_sched_eval_callback cb_eval;
void * cb_eval_user_data;
enum ggml_type type_k; // data type for K cache enum ggml_type type_k; // data type for K cache
enum ggml_type type_v; // data type for V cache enum ggml_type type_v; // data type for V cache