mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-11-14 06:49:54 +00:00
Merge 8ceda95327
into 61408e7fad
This commit is contained in:
commit
8378989ad1
@ -541,7 +541,7 @@ std::string string_from(const struct llama_context * ctx, const struct llama_bat
|
|||||||
<< ":pos " << std::to_string(batch.pos[i])
|
<< ":pos " << std::to_string(batch.pos[i])
|
||||||
<< ":n_seq_id " << std::to_string(batch.n_seq_id[i])
|
<< ":n_seq_id " << std::to_string(batch.n_seq_id[i])
|
||||||
<< ":seq_id " << std::to_string(batch.seq_id[i][0])
|
<< ":seq_id " << std::to_string(batch.seq_id[i][0])
|
||||||
<< ":logits " << std::to_string(batch.logits[i]);
|
<< ":output " << std::to_string(batch.output[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
buf << " ]";
|
buf << " ]";
|
||||||
@ -1467,7 +1467,7 @@ void common_batch_add(
|
|||||||
llama_token id,
|
llama_token id,
|
||||||
llama_pos pos,
|
llama_pos pos,
|
||||||
const std::vector<llama_seq_id> & seq_ids,
|
const std::vector<llama_seq_id> & seq_ids,
|
||||||
bool logits) {
|
bool output) {
|
||||||
GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded");
|
GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded");
|
||||||
|
|
||||||
batch.token [batch.n_tokens] = id;
|
batch.token [batch.n_tokens] = id;
|
||||||
@ -1476,7 +1476,7 @@ void common_batch_add(
|
|||||||
for (size_t i = 0; i < seq_ids.size(); ++i) {
|
for (size_t i = 0; i < seq_ids.size(); ++i) {
|
||||||
batch.seq_id[batch.n_tokens][i] = seq_ids[i];
|
batch.seq_id[batch.n_tokens][i] = seq_ids[i];
|
||||||
}
|
}
|
||||||
batch.logits [batch.n_tokens] = logits;
|
batch.output [batch.n_tokens] = output;
|
||||||
|
|
||||||
batch.n_tokens++;
|
batch.n_tokens++;
|
||||||
}
|
}
|
||||||
|
@ -73,7 +73,7 @@ int main(int argc, char ** argv) {
|
|||||||
batch.pos + i,
|
batch.pos + i,
|
||||||
batch.n_seq_id + i,
|
batch.n_seq_id + i,
|
||||||
batch.seq_id + i,
|
batch.seq_id + i,
|
||||||
batch.logits + i,
|
batch.output + i,
|
||||||
};
|
};
|
||||||
|
|
||||||
const int ret = llama_decode(ctx, batch_view);
|
const int ret = llama_decode(ctx, batch_view);
|
||||||
@ -128,7 +128,7 @@ int main(int argc, char ** argv) {
|
|||||||
common_batch_add(batch, 0, i, { j }, false);
|
common_batch_add(batch, 0, i, { j }, false);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
batch.logits[batch.n_tokens - 1] = true;
|
batch.output[batch.n_tokens - 1] = true;
|
||||||
|
|
||||||
const auto t_pp_start = ggml_time_us();
|
const auto t_pp_start = ggml_time_us();
|
||||||
|
|
||||||
|
@ -99,11 +99,11 @@ for (i, token) in tokens.enumerated() {
|
|||||||
if let seq_id = batch.seq_id[i] {
|
if let seq_id = batch.seq_id[i] {
|
||||||
seq_id[0] = 0
|
seq_id[0] = 0
|
||||||
}
|
}
|
||||||
batch.logits[i] = 0
|
batch.output[i] = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
// llama_decode will output logits only for the last token of the prompt
|
// llama_decode will output logits only for the last token of the prompt
|
||||||
batch.logits[Int(batch.n_tokens) - 1] = 1
|
batch.output[Int(batch.n_tokens) - 1] = 1
|
||||||
|
|
||||||
if llama_decode(context, batch) != 0 {
|
if llama_decode(context, batch) != 0 {
|
||||||
print("llama_decode() failed")
|
print("llama_decode() failed")
|
||||||
@ -166,7 +166,7 @@ while n_cur <= n_len {
|
|||||||
if let seq_id = batch.seq_id[Int(batch.n_tokens)] {
|
if let seq_id = batch.seq_id[Int(batch.n_tokens)] {
|
||||||
seq_id[0] = Int32(i)
|
seq_id[0] = Int32(i)
|
||||||
}
|
}
|
||||||
batch.logits[Int(batch.n_tokens)] = 1
|
batch.output[Int(batch.n_tokens)] = 1
|
||||||
|
|
||||||
i_batch[i] = batch.n_tokens
|
i_batch[i] = batch.n_tokens
|
||||||
|
|
||||||
|
@ -128,7 +128,7 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// llama_decode will output logits only for the last token of the prompt
|
// llama_decode will output logits only for the last token of the prompt
|
||||||
batch.logits[batch.n_tokens - 1] = true;
|
batch.output[batch.n_tokens - 1] = true;
|
||||||
|
|
||||||
if (llama_decode(ctx, batch) != 0) {
|
if (llama_decode(ctx, batch) != 0) {
|
||||||
LOG_ERR("%s: llama_decode() failed\n", __func__);
|
LOG_ERR("%s: llama_decode() failed\n", __func__);
|
||||||
|
@ -54,7 +54,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (int i = 0; i < batch.n_tokens; i++) {
|
for (int i = 0; i < batch.n_tokens; i++) {
|
||||||
if (!batch.logits[i]) {
|
if (!batch.output[i]) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -193,7 +193,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
|
|||||||
common_batch_add(*batch, 0, i, { 0 }, false);
|
common_batch_add(*batch, 0, i, { 0 }, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
batch->logits[batch->n_tokens - 1] = true;
|
batch->output[batch->n_tokens - 1] = true;
|
||||||
llama_kv_cache_clear(context);
|
llama_kv_cache_clear(context);
|
||||||
|
|
||||||
const auto t_pp_start = ggml_time_us();
|
const auto t_pp_start = ggml_time_us();
|
||||||
@ -297,7 +297,7 @@ Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens,
|
|||||||
for (int i = 0; i < n_tokens; ++i) {
|
for (int i = 0; i < n_tokens; ++i) {
|
||||||
batch->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
|
batch->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
|
||||||
}
|
}
|
||||||
batch->logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
|
batch->output = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
|
||||||
|
|
||||||
return reinterpret_cast<jlong>(batch);
|
return reinterpret_cast<jlong>(batch);
|
||||||
}
|
}
|
||||||
@ -377,7 +377,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// llama_decode will output logits only for the last token of the prompt
|
// llama_decode will output logits only for the last token of the prompt
|
||||||
batch->logits[batch->n_tokens - 1] = true;
|
batch->output[batch->n_tokens - 1] = true;
|
||||||
|
|
||||||
if (llama_decode(context, *batch) != 0) {
|
if (llama_decode(context, *batch) != 0) {
|
||||||
LOGe("llama_decode() failed");
|
LOGe("llama_decode() failed");
|
||||||
|
@ -16,7 +16,7 @@ func llama_batch_add(_ batch: inout llama_batch, _ id: llama_token, _ pos: llama
|
|||||||
for i in 0..<seq_ids.count {
|
for i in 0..<seq_ids.count {
|
||||||
batch.seq_id[Int(batch.n_tokens)]![Int(i)] = seq_ids[i]
|
batch.seq_id[Int(batch.n_tokens)]![Int(i)] = seq_ids[i]
|
||||||
}
|
}
|
||||||
batch.logits [Int(batch.n_tokens)] = logits ? 1 : 0
|
batch.output [Int(batch.n_tokens)] = logits ? 1 : 0
|
||||||
|
|
||||||
batch.n_tokens += 1
|
batch.n_tokens += 1
|
||||||
}
|
}
|
||||||
@ -137,7 +137,7 @@ actor LlamaContext {
|
|||||||
let i = Int(i1)
|
let i = Int(i1)
|
||||||
llama_batch_add(&batch, tokens_list[i], Int32(i), [0], false)
|
llama_batch_add(&batch, tokens_list[i], Int32(i), [0], false)
|
||||||
}
|
}
|
||||||
batch.logits[Int(batch.n_tokens) - 1] = 1 // true
|
batch.output[Int(batch.n_tokens) - 1] = 1 // true
|
||||||
|
|
||||||
if llama_decode(context, batch) != 0 {
|
if llama_decode(context, batch) != 0 {
|
||||||
print("llama_decode() failed")
|
print("llama_decode() failed")
|
||||||
@ -206,7 +206,7 @@ actor LlamaContext {
|
|||||||
for i in 0..<n_tokens {
|
for i in 0..<n_tokens {
|
||||||
llama_batch_add(&batch, 0, Int32(i), [0], false)
|
llama_batch_add(&batch, 0, Int32(i), [0], false)
|
||||||
}
|
}
|
||||||
batch.logits[Int(batch.n_tokens) - 1] = 1 // true
|
batch.output[Int(batch.n_tokens) - 1] = 1 // true
|
||||||
|
|
||||||
llama_kv_cache_clear(context)
|
llama_kv_cache_clear(context)
|
||||||
|
|
||||||
|
@ -406,13 +406,13 @@ struct llava_embd_batch {
|
|||||||
std::vector<int32_t> n_seq_id;
|
std::vector<int32_t> n_seq_id;
|
||||||
std::vector<llama_seq_id> seq_id_0;
|
std::vector<llama_seq_id> seq_id_0;
|
||||||
std::vector<llama_seq_id *> seq_ids;
|
std::vector<llama_seq_id *> seq_ids;
|
||||||
std::vector<int8_t> logits;
|
std::vector<int8_t> outputs;
|
||||||
llama_batch batch;
|
llama_batch batch;
|
||||||
llava_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
|
llava_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
|
||||||
pos .resize(n_tokens);
|
pos .resize(n_tokens);
|
||||||
n_seq_id.resize(n_tokens);
|
n_seq_id.resize(n_tokens);
|
||||||
seq_ids .resize(n_tokens + 1);
|
seq_ids .resize(n_tokens + 1);
|
||||||
logits .resize(n_tokens);
|
outputs .resize(n_tokens);
|
||||||
seq_id_0.resize(1);
|
seq_id_0.resize(1);
|
||||||
seq_id_0[0] = seq_id;
|
seq_id_0[0] = seq_id;
|
||||||
seq_ids [n_tokens] = nullptr;
|
seq_ids [n_tokens] = nullptr;
|
||||||
@ -423,13 +423,13 @@ struct llava_embd_batch {
|
|||||||
/*pos =*/ pos.data(),
|
/*pos =*/ pos.data(),
|
||||||
/*n_seq_id =*/ n_seq_id.data(),
|
/*n_seq_id =*/ n_seq_id.data(),
|
||||||
/*seq_id =*/ seq_ids.data(),
|
/*seq_id =*/ seq_ids.data(),
|
||||||
/*logits =*/ logits.data(),
|
/*output =*/ outputs.data(),
|
||||||
};
|
};
|
||||||
for (int i = 0; i < n_tokens; i++) {
|
for (int i = 0; i < n_tokens; i++) {
|
||||||
batch.pos [i] = pos_0 + i;
|
batch.pos [i] = pos_0 + i;
|
||||||
batch.n_seq_id[i] = 1;
|
batch.n_seq_id[i] = 1;
|
||||||
batch.seq_id [i] = seq_id_0.data();
|
batch.seq_id [i] = seq_id_0.data();
|
||||||
batch.logits [i] = false;
|
batch.output [i] = false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -264,7 +264,7 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
// extract the logits only for the last token
|
// extract the logits only for the last token
|
||||||
if (batch.n_tokens > 0) {
|
if (batch.n_tokens > 0) {
|
||||||
batch.logits[batch.n_tokens - 1] = true;
|
batch.output[batch.n_tokens - 1] = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
client.n_prompt = tokens_prompt.size();
|
client.n_prompt = tokens_prompt.size();
|
||||||
@ -307,7 +307,7 @@ int main(int argc, char ** argv) {
|
|||||||
batch.pos + i,
|
batch.pos + i,
|
||||||
batch.n_seq_id + i,
|
batch.n_seq_id + i,
|
||||||
batch.seq_id + i,
|
batch.seq_id + i,
|
||||||
batch.logits + i,
|
batch.output + i,
|
||||||
};
|
};
|
||||||
|
|
||||||
const int ret = llama_decode(ctx, batch_view);
|
const int ret = llama_decode(ctx, batch_view);
|
||||||
|
@ -144,7 +144,7 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (i + n_batch >= n_tokens_all) {
|
if (i + n_batch >= n_tokens_all) {
|
||||||
batch.logits[batch.n_tokens - 1] = true;
|
batch.output[batch.n_tokens - 1] = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (llama_decode(ctx, batch) != 0) {
|
if (llama_decode(ctx, batch) != 0) {
|
||||||
@ -178,7 +178,7 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (i + n_batch >= n_tokens_all) {
|
if (i + n_batch >= n_tokens_all) {
|
||||||
batch.logits[batch.n_tokens - 1] = true;
|
batch.output[batch.n_tokens - 1] = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (llama_decode(ctx, batch) != 0) {
|
if (llama_decode(ctx, batch) != 0) {
|
||||||
|
@ -615,9 +615,9 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
|
|||||||
batch.pos [idx] = j*n_batch + k;
|
batch.pos [idx] = j*n_batch + k;
|
||||||
batch.n_seq_id[idx] = 1;
|
batch.n_seq_id[idx] = 1;
|
||||||
batch.seq_id [idx][0] = seq;
|
batch.seq_id [idx][0] = seq;
|
||||||
batch.logits [idx] = batch.pos[idx] >= first ? 1 : 0;
|
batch.output [idx] = batch.pos[idx] >= first ? 1 : 0;
|
||||||
|
|
||||||
n_outputs += batch.logits[idx] != 0;
|
n_outputs += batch.output[idx] != 0;
|
||||||
}
|
}
|
||||||
batch.n_tokens += batch_size;
|
batch.n_tokens += batch_size;
|
||||||
|
|
||||||
@ -712,7 +712,7 @@ static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<
|
|||||||
batch.pos + i,
|
batch.pos + i,
|
||||||
batch.n_seq_id + i,
|
batch.n_seq_id + i,
|
||||||
batch.seq_id + i,
|
batch.seq_id + i,
|
||||||
batch.logits + i,
|
batch.output + i,
|
||||||
};
|
};
|
||||||
|
|
||||||
const int ret = llama_decode(ctx, batch_view);
|
const int ret = llama_decode(ctx, batch_view);
|
||||||
@ -723,7 +723,7 @@ static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<
|
|||||||
|
|
||||||
int n_outputs = 0;
|
int n_outputs = 0;
|
||||||
for (int i = 0; i < n_tokens; ++i) {
|
for (int i = 0; i < n_tokens; ++i) {
|
||||||
n_outputs += batch_view.logits[i] != 0;
|
n_outputs += batch_view.output[i] != 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
memcpy(batch_logits.data() + size_t(prev_outputs)*n_vocab, llama_get_logits(ctx), size_t(n_outputs)*n_vocab*sizeof(float));
|
memcpy(batch_logits.data() + size_t(prev_outputs)*n_vocab, llama_get_logits(ctx), size_t(n_outputs)*n_vocab*sizeof(float));
|
||||||
@ -936,7 +936,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
|
|||||||
for (size_t i = 0; i < hs_cur.common_prefix; ++i) {
|
for (size_t i = 0; i < hs_cur.common_prefix; ++i) {
|
||||||
common_batch_add(batch, hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3 }, false);
|
common_batch_add(batch, hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3 }, false);
|
||||||
}
|
}
|
||||||
batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
|
batch.output[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
|
||||||
n_logits += 1;
|
n_logits += 1;
|
||||||
|
|
||||||
for (int s = 0; s < 4; ++s) {
|
for (int s = 0; s < 4; ++s) {
|
||||||
@ -1215,7 +1215,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
|
|||||||
for (size_t i = 0; i < data[i1].common_prefix; ++i) {
|
for (size_t i = 0; i < data[i1].common_prefix; ++i) {
|
||||||
common_batch_add(batch, data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1 }, false);
|
common_batch_add(batch, data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1 }, false);
|
||||||
}
|
}
|
||||||
batch.logits[batch.n_tokens - 1] = true;
|
batch.output[batch.n_tokens - 1] = true;
|
||||||
n_logits += 1;
|
n_logits += 1;
|
||||||
|
|
||||||
for (int s = 0; s < 2; ++s) {
|
for (int s = 0; s < 2; ++s) {
|
||||||
@ -1581,7 +1581,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
|
|||||||
//llama_batch_add(batch, cur_task.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3}, false);
|
//llama_batch_add(batch, cur_task.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3}, false);
|
||||||
common_batch_add(batch, cur_task.seq_tokens[0][i], i, batch_indeces, false);
|
common_batch_add(batch, cur_task.seq_tokens[0][i], i, batch_indeces, false);
|
||||||
}
|
}
|
||||||
batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
|
batch.output[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
|
||||||
n_logits += 1;
|
n_logits += 1;
|
||||||
|
|
||||||
for (int s = 0; s < int(cur_task.seq_tokens.size()); ++s) {
|
for (int s = 0; s < int(cur_task.seq_tokens.size()); ++s) {
|
||||||
|
@ -92,7 +92,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (int i = 0; i < batch.n_tokens; i++) {
|
for (int i = 0; i < batch.n_tokens; i++) {
|
||||||
if (!batch.logits[i]) {
|
if (!batch.output[i]) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -52,7 +52,7 @@ int main(int argc, char ** argv) {
|
|||||||
for (size_t i = 0; i < tokens.size(); i++) {
|
for (size_t i = 0; i < tokens.size(); i++) {
|
||||||
common_batch_add(batch, tokens[i], i, {0}, false);
|
common_batch_add(batch, tokens[i], i, {0}, false);
|
||||||
}
|
}
|
||||||
batch.logits[batch.n_tokens - 1] = true; // generate next token
|
batch.output[batch.n_tokens - 1] = true; // generate next token
|
||||||
|
|
||||||
// evaluate prompt
|
// evaluate prompt
|
||||||
llama_decode(ctx, batch);
|
llama_decode(ctx, batch);
|
||||||
|
@ -1295,7 +1295,7 @@ struct server_context {
|
|||||||
std::vector<float> embd_res(n_embd, 0.0f);
|
std::vector<float> embd_res(n_embd, 0.0f);
|
||||||
|
|
||||||
for (int i = 0; i < batch.n_tokens; ++i) {
|
for (int i = 0; i < batch.n_tokens; ++i) {
|
||||||
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id + 1) {
|
if (!batch.output[i] || batch.seq_id[i][0] != slot.id + 1) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1335,7 +1335,7 @@ struct server_context {
|
|||||||
res.stop = true;
|
res.stop = true;
|
||||||
|
|
||||||
for (int i = 0; i < batch.n_tokens; ++i) {
|
for (int i = 0; i < batch.n_tokens; ++i) {
|
||||||
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id + 1) {
|
if (!batch.output[i] || batch.seq_id[i][0] != slot.id + 1) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2083,7 +2083,7 @@ struct server_context {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// extract the logits only for the last token
|
// extract the logits only for the last token
|
||||||
batch.logits[batch.n_tokens - 1] = true;
|
batch.output[batch.n_tokens - 1] = true;
|
||||||
|
|
||||||
slot.n_decoded = 0;
|
slot.n_decoded = 0;
|
||||||
slot.i_batch = batch.n_tokens - 1;
|
slot.i_batch = batch.n_tokens - 1;
|
||||||
@ -2119,7 +2119,7 @@ struct server_context {
|
|||||||
batch.pos + i,
|
batch.pos + i,
|
||||||
batch.n_seq_id + i,
|
batch.n_seq_id + i,
|
||||||
batch.seq_id + i,
|
batch.seq_id + i,
|
||||||
batch.logits + i,
|
batch.output + i,
|
||||||
};
|
};
|
||||||
|
|
||||||
const int ret = llama_decode(ctx, batch_view);
|
const int ret = llama_decode(ctx, batch_view);
|
||||||
|
@ -247,7 +247,7 @@ extern "C" {
|
|||||||
llama_pos * pos;
|
llama_pos * pos;
|
||||||
int32_t * n_seq_id;
|
int32_t * n_seq_id;
|
||||||
llama_seq_id ** seq_id;
|
llama_seq_id ** seq_id;
|
||||||
int8_t * logits; // TODO: rename this to "output"
|
int8_t * output;
|
||||||
} llama_batch;
|
} llama_batch;
|
||||||
|
|
||||||
enum llama_model_kv_override_type {
|
enum llama_model_kv_override_type {
|
||||||
|
@ -3072,17 +3072,17 @@ struct llama_sbatch {
|
|||||||
ubatch.output[ubatch.n_tokens + i] = 1;
|
ubatch.output[ubatch.n_tokens + i] = 1;
|
||||||
out_ids.push_back(ids[seq.offset + i]);
|
out_ids.push_back(ids[seq.offset + i]);
|
||||||
}
|
}
|
||||||
} else if (batch->logits) {
|
} else if (batch->output) {
|
||||||
if (ubatch.equal_seqs) {
|
if (ubatch.equal_seqs) {
|
||||||
for (size_t i = 0; i < length; ++i) {
|
for (size_t i = 0; i < length; ++i) {
|
||||||
size_t id = ids[seq.offset + i];
|
size_t id = ids[seq.offset + i];
|
||||||
int8_t is_output = batch->logits[id];
|
int8_t is_output = batch->output[id];
|
||||||
ubatch.output[ubatch.n_tokens + i] = is_output;
|
ubatch.output[ubatch.n_tokens + i] = is_output;
|
||||||
if (is_output) { out_ids.push_back(id); }
|
if (is_output) { out_ids.push_back(id); }
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// simple split
|
// simple split
|
||||||
ubatch.output = batch->logits + seq.offset;
|
ubatch.output = batch->output + seq.offset;
|
||||||
for (size_t i = 0; i < length; ++i) {
|
for (size_t i = 0; i < length; ++i) {
|
||||||
if (ubatch.output[i] != 0) { out_ids.push_back(seq.offset + i); }
|
if (ubatch.output[i] != 0) { out_ids.push_back(seq.offset + i); }
|
||||||
}
|
}
|
||||||
@ -5139,7 +5139,7 @@ struct llama_batch_allocr {
|
|||||||
std::vector<llama_pos> pos;
|
std::vector<llama_pos> pos;
|
||||||
std::vector<int32_t> n_seq_id;
|
std::vector<int32_t> n_seq_id;
|
||||||
std::vector<llama_seq_id *> seq_id;
|
std::vector<llama_seq_id *> seq_id;
|
||||||
std::vector<int8_t> logits;
|
std::vector<int8_t> outputs;
|
||||||
struct llama_batch batch;
|
struct llama_batch batch;
|
||||||
// optionally fulfill the batch returned by llama_batch_get_one
|
// optionally fulfill the batch returned by llama_batch_get_one
|
||||||
llama_batch_allocr(llama_context & ctx, struct llama_batch in_batch) {
|
llama_batch_allocr(llama_context & ctx, struct llama_batch in_batch) {
|
||||||
@ -5175,10 +5175,10 @@ struct llama_batch_allocr {
|
|||||||
}
|
}
|
||||||
batch.seq_id = seq_id.data();
|
batch.seq_id = seq_id.data();
|
||||||
}
|
}
|
||||||
if (!batch.logits) {
|
if (!batch.output) {
|
||||||
logits.resize(batch.n_tokens);
|
outputs.resize(batch.n_tokens);
|
||||||
logits[logits.size() - 1] = true;
|
outputs[outputs.size() - 1] = true;
|
||||||
batch.logits = logits.data();
|
batch.output = outputs.data();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -17307,9 +17307,9 @@ static int llama_decode_internal(
|
|||||||
lctx.embd_seq.clear();
|
lctx.embd_seq.clear();
|
||||||
|
|
||||||
// count outputs
|
// count outputs
|
||||||
if (batch.logits && !embd_pooled) {
|
if (batch.output && !embd_pooled) {
|
||||||
for (uint32_t i = 0; i < n_tokens_all; ++i) {
|
for (uint32_t i = 0; i < n_tokens_all; ++i) {
|
||||||
n_outputs += batch.logits[i] != 0;
|
n_outputs += batch.output[i] != 0;
|
||||||
}
|
}
|
||||||
} else if (lctx.logits_all || embd_pooled) {
|
} else if (lctx.logits_all || embd_pooled) {
|
||||||
n_outputs = n_tokens_all;
|
n_outputs = n_tokens_all;
|
||||||
@ -21234,7 +21234,7 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_
|
|||||||
}
|
}
|
||||||
batch.seq_id[n_tokens_alloc] = nullptr;
|
batch.seq_id[n_tokens_alloc] = nullptr;
|
||||||
|
|
||||||
batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens_alloc);
|
batch.output = (int8_t *) malloc(sizeof(int8_t) * n_tokens_alloc);
|
||||||
|
|
||||||
return batch;
|
return batch;
|
||||||
}
|
}
|
||||||
@ -21250,7 +21250,7 @@ void llama_batch_free(struct llama_batch batch) {
|
|||||||
}
|
}
|
||||||
free(batch.seq_id);
|
free(batch.seq_id);
|
||||||
}
|
}
|
||||||
if (batch.logits) free(batch.logits);
|
if (batch.output) free(batch.output);
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t llama_encode(
|
int32_t llama_encode(
|
||||||
|
Loading…
Reference in New Issue
Block a user