llama : fix input allocation logic

This commit is contained in:
Georgi Gerganov 2023-10-31 08:23:43 +02:00
parent a3f80013ad
commit 2926ef63b1
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -4970,10 +4970,10 @@ static struct ggml_cgraph * llama_build_graph(
// allocate input tensors and set input data
//
if (batch.token && !alloc_inp_tokens && strcmp(name, "inp_tokens") == 0) {
if (!alloc_inp_tokens && strcmp(name, "inp_tokens") == 0) {
ggml_allocr_alloc(lctx.alloc, cur);
if (!ggml_allocr_is_measure(lctx.alloc)) {
if (!ggml_allocr_is_measure(lctx.alloc) && batch.token) {
const int64_t n_tokens = cur->ne[0];
memcpy(cur->data, batch.token, n_tokens*ggml_element_size(cur));
@ -4982,10 +4982,10 @@ static struct ggml_cgraph * llama_build_graph(
alloc_inp_tokens = true;
}
if (batch.embd && !alloc_inp_embd && strcmp(name, "inp_embd") == 0) {
if (!alloc_inp_embd && strcmp(name, "inp_embd") == 0) {
ggml_allocr_alloc(lctx.alloc, cur);
if (!ggml_allocr_is_measure(lctx.alloc)) {
if (!ggml_allocr_is_measure(lctx.alloc) && batch.embd) {
const int64_t n_embd = cur->ne[0];
const int64_t n_tokens = cur->ne[1];
@ -4995,10 +4995,10 @@ static struct ggml_cgraph * llama_build_graph(
alloc_inp_embd = true;
}
if (batch.pos && !alloc_inp_pos && strcmp(name, "inp_pos") == 0) {
if (!alloc_inp_pos && strcmp(name, "inp_pos") == 0) {
ggml_allocr_alloc(lctx.alloc, cur);
if (!ggml_allocr_is_measure(lctx.alloc)) {
if (!ggml_allocr_is_measure(lctx.alloc) && batch.pos) {
const int64_t n_tokens = cur->ne[0];
int32_t * data = (int32_t *) cur->data;