mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-24 10:24:35 +00:00
perplexity : add support for batch size to --perplexity
(#407)
* Add support to batch size for perplexity
* Revert "Fix memory allocation issues and seg faults"
This reverts commit 4870e455b3
.
* update from merge
* Remove perplexity from main
* updates
* Update batch size for efficiency
This commit is contained in:
parent
0e07e6a839
commit
be87b6ed20
@ -27,20 +27,27 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
|
|||||||
|
|
||||||
int count = 0;
|
int count = 0;
|
||||||
int seq_count = tokens.size() / params.n_ctx;
|
int seq_count = tokens.size() / params.n_ctx;
|
||||||
|
int n_vocab = llama_n_vocab(ctx);
|
||||||
|
|
||||||
double nll = 0.0;
|
double nll = 0.0;
|
||||||
|
fprintf(stderr, "%s : calculating perplexity over %d chunks, batch_size=%d\n", __func__, seq_count, params.n_batch);
|
||||||
fprintf(stderr, "%s : calculating perplexity over %d chunks\n", __func__, seq_count);
|
|
||||||
|
|
||||||
for (int i = 0; i < seq_count; ++i) {
|
for (int i = 0; i < seq_count; ++i) {
|
||||||
int start = i * params.n_ctx;
|
int start = i * params.n_ctx;
|
||||||
int end = start + params.n_ctx - 1; // TODO: this is not optimal, e.g. it makes the batch 511 instead of 512
|
int end = start + params.n_ctx;
|
||||||
// it is better to always be power of 2 for better performance
|
|
||||||
std::vector<llama_token> embd(tokens.begin() + start, tokens.begin() + end);
|
std::vector<float> logits;
|
||||||
|
int num_batches = (params.n_ctx + params.n_batch - 1) / params.n_batch;
|
||||||
auto start_t = std::chrono::high_resolution_clock::now();
|
auto start_t = std::chrono::high_resolution_clock::now();
|
||||||
if (llama_eval(ctx, embd.data(), embd.size(), 0, params.n_threads)) {
|
for (int j = 0; j < num_batches; ++j) {
|
||||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
int batch_start = start + j * params.n_batch;
|
||||||
return;
|
int batch_size = std::min(end - batch_start, params.n_batch);
|
||||||
|
if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * params.n_batch, params.n_threads)) {
|
||||||
|
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto batch_logits = llama_get_logits(ctx);
|
||||||
|
logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
|
||||||
}
|
}
|
||||||
auto end_t = std::chrono::high_resolution_clock::now();
|
auto end_t = std::chrono::high_resolution_clock::now();
|
||||||
if (i == 0) {
|
if (i == 0) {
|
||||||
@ -59,15 +66,12 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
|
|||||||
// Example, we have a context window of 512, we will compute perplexity for each of the
|
// Example, we have a context window of 512, we will compute perplexity for each of the
|
||||||
// last 256 tokens. Then, we split the input up into context window size chunks to
|
// last 256 tokens. Then, we split the input up into context window size chunks to
|
||||||
// process the entire prompt.
|
// process the entire prompt.
|
||||||
|
for (int j = std::min(512, params.n_ctx / 2); j < params.n_ctx - 1; ++j) {
|
||||||
auto logits = llama_get_logits(ctx);
|
|
||||||
for (int j = params.n_ctx / 2; j < params.n_ctx - 1; ++j) {
|
|
||||||
// Calculate probability of next token, given the previous ones.
|
// Calculate probability of next token, given the previous ones.
|
||||||
int n_vocab = llama_n_vocab(ctx);
|
|
||||||
std::vector<float> tok_logits(
|
std::vector<float> tok_logits(
|
||||||
logits + j * n_vocab,
|
logits.begin() + j * n_vocab,
|
||||||
logits + (j + 1) * n_vocab);
|
logits.begin() + (j + 1) * n_vocab);
|
||||||
const float prob = softmax(tok_logits)[tokens[start + j + 1]];
|
float prob = softmax(tok_logits)[tokens[start + j + 1]];
|
||||||
nll += -std::log(prob);
|
nll += -std::log(prob);
|
||||||
++count;
|
++count;
|
||||||
}
|
}
|
||||||
@ -82,11 +86,13 @@ int main(int argc, char ** argv) {
|
|||||||
gpt_params params;
|
gpt_params params;
|
||||||
params.model = "models/llama-7B/ggml-model.bin";
|
params.model = "models/llama-7B/ggml-model.bin";
|
||||||
|
|
||||||
|
params.n_batch = 512;
|
||||||
if (gpt_params_parse(argc, argv, params) == false) {
|
if (gpt_params_parse(argc, argv, params) == false) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
params.perplexity = true;
|
params.perplexity = true;
|
||||||
|
params.n_batch = std::min(params.n_batch, params.n_ctx);
|
||||||
|
|
||||||
if (params.n_ctx > 2048) {
|
if (params.n_ctx > 2048) {
|
||||||
fprintf(stderr, "%s: warning: model does not support context sizes greater than 2048 tokens (%d specified);"
|
fprintf(stderr, "%s: warning: model does not support context sizes greater than 2048 tokens (%d specified);"
|
||||||
|
Loading…
Reference in New Issue
Block a user