llama: propagating the results of graph_compute to the user interface

This commit is contained in:
Michael Podvitskiy 2024-09-17 21:43:01 +02:00
parent 23e0d70bac
commit 95ce058c2b

View File

@ -16533,7 +16533,7 @@ static void llama_output_reorder(struct llama_context * ctx) {
}
}
static void llama_graph_compute(
static enum ggml_status llama_graph_compute(
llama_context & lctx,
ggml_cgraph * gf,
int n_threads,
@ -16555,9 +16555,11 @@ static void llama_graph_compute(
}
#endif
ggml_backend_sched_graph_compute_async(lctx.sched, gf);
auto status = ggml_backend_sched_graph_compute_async(lctx.sched, gf);
// fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(lctx.sched));
return status;
}
// decode a batch of tokens by evaluating the transformer
@ -16739,7 +16741,18 @@ static int llama_decode_internal(
llama_set_inputs(lctx, ubatch);
llama_graph_compute(lctx, gf, n_threads, threadpool);
const auto compute_status = llama_graph_compute(lctx, gf, n_threads, threadpool);
switch (compute_status) {
case GGML_STATUS_SUCCESS:
break;
case GGML_STATUS_ABORTED:
return 2;
case GGML_STATUS_ALLOC_FAILED:
return -2;
case GGML_STATUS_FAILED:
default:
return -3;
}
// update the kv ring buffer
{
@ -16959,7 +16972,18 @@ static int llama_encode_internal(
llama_set_inputs(lctx, ubatch);
llama_graph_compute(lctx, gf, n_threads, threadpool);
const auto compute_status = llama_graph_compute(lctx, gf, n_threads, threadpool);
switch (compute_status) {
case GGML_STATUS_SUCCESS:
break;
case GGML_STATUS_ABORTED:
return 2;
case GGML_STATUS_ALLOC_FAILED:
return -2;
case GGML_STATUS_FAILED:
default:
return -3;
}
// extract embeddings
if (embd) {