llama.cpp/examples/imatrix/imatrix.cpp

649 lines
22 KiB
C++
Raw Normal View History

#include "common.h"
#include "llama.h"
#include <cmath>
#include <cstdio>
#include <cstring>
#include <ctime>
#include <sstream>
#include <thread>
#include <mutex>
#include <vector>
#include <fstream>
#include <unordered_map>
#include <algorithm>
#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
#endif
static void print_usage(int, char ** argv) {
LOG_TEE("\nexample usage:\n");
LOG_TEE("\n %s \\\n"
" -m model.gguf -f some-text.txt [-o imatrix.dat] [--process-output] [--verbosity 1] \\\n"
" [--no-ppl] [--chunk 123] [--output-frequency 10] [--save-frequency 0] \\\n"
" [--in-file imatrix-prev-0.dat --in-file imatrix-prev-1.dat ...]\n" , argv[0]);
LOG_TEE("\n");
}
struct Stats {
std::vector<float> values;
std::vector<int> counts;
int ncall = 0;
};
class IMatrixCollector {
public:
IMatrixCollector() = default;
void set_params(gpt_params params) { m_params = std::move(params); }
bool collect_imatrix(struct ggml_tensor * t, bool ask, void * user_data);
void save_imatrix(int ncall = -1) const;
bool load_imatrix(const char * file_name);
private:
std::unordered_map<std::string, Stats> m_stats;
gpt_params m_params;
std::mutex m_mutex;
int m_last_call = 0;
std::vector<float> m_src1_data;
std::vector<char> m_ids; // the expert ids from ggml_mul_mat_id
};
// remove any prefix and suffixes from the name
// CUDA0#blk.0.attn_k.weight#0 => blk.0.attn_k.weight
static std::string filter_tensor_name(const char * name) {
std::string wname;
const char * p = strchr(name, '#');
if (p != NULL) {
p = p + 1;
const char * q = strchr(p, '#');
if (q != NULL) {
wname = std::string(p, q - p);
} else {
wname = p;
}
} else {
wname = name;
}
return wname;
}
bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * user_data) {
GGML_UNUSED(user_data);
const struct ggml_tensor * src0 = t->src[0];
const struct ggml_tensor * src1 = t->src[1];
std::string wname = filter_tensor_name(src0->name);
// when ask is true, the scheduler wants to know if we are interested in data from this tensor
// if we return true, a follow-up call will be made with ask=false in which we can do the actual collection
if (ask) {
if (t->op == GGML_OP_MUL_MAT_ID) return true; // collect all indirect matrix multiplications
if (t->op != GGML_OP_MUL_MAT) return false;
// why are small batches ignored (<16 tokens)?
if (src1->ne[1] < 16 || src1->type != GGML_TYPE_F32) return false;
if (!(wname.substr(0, 4) == "blk." || (m_params.process_output && wname == "output.weight"))) return false;
return true;
}
std::lock_guard<std::mutex> lock(m_mutex);
// copy the data from the GPU memory if needed
const bool is_host = ggml_backend_buffer_is_host(src1->buffer);
if (!is_host) {
m_src1_data.resize(ggml_nelements(src1));
ggml_backend_tensor_get(src1, m_src1_data.data(), 0, ggml_nbytes(src1));
}
const float * data = is_host ? (const float *) src1->data : m_src1_data.data();
ggml : mul_mat_id use the same tensor for all the experts (#6387) * ggml : update mul_mat_id to use the same tensor for all the experts * update cuda * minor * update metal * update test-backend-ops * fix cuda * Update ggml-metal.m Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * update convert.py * update convert-hf-to-gguf.py * update convert.py for mixtral hf models * Update convert-hf-to-gguf.py Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * cuda : support non-pow-2 number of experts * allow quantize to work for split and merged experts models in the same way * cleanup + disable mmap automatically with split tensors models * update imatrix * test-backend-ops : test qwen argsort * update grok model loading * llama : add merged experts tensors to the grok tensor map * minor * gguf : bump version * fix quantizing of merged experts * convert-hf-to-gguf.py : update grok (untested) * make linter happy * cuda/argsort : use shared memory instead of pool memory * convert : fix grok tensor names * metal : add support for non-pow-2 argsort * llama : more loader cleanup, better error checking * cuda : fix warning * llama : still use mmap for loading old models, but copy the data to a host buffer * add review note * llama : remove ffn tensor counting + add sanity check ggml-ci * convert : fix handling of n_experts == None ggml-ci * imatrix : fix ncall counters * llama : produce error if imatrix size does not match * quantize : terminate on errors + trace logs ggml-ci * metal : pad shared memory to 16 bytes --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2024-04-03 13:07:05 +00:00
// this has been adapted to the new format of storing merged experts in a single 3d tensor
// ref: https://github.com/ggerganov/llama.cpp/pull/6387
if (t->op == GGML_OP_MUL_MAT_ID) {
// ids -> [n_experts_used, n_tokens]
// src1 -> [cols, n_expert_used, n_tokens]
ggml : mul_mat_id use the same tensor for all the experts (#6387) * ggml : update mul_mat_id to use the same tensor for all the experts * update cuda * minor * update metal * update test-backend-ops * fix cuda * Update ggml-metal.m Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * update convert.py * update convert-hf-to-gguf.py * update convert.py for mixtral hf models * Update convert-hf-to-gguf.py Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * cuda : support non-pow-2 number of experts * allow quantize to work for split and merged experts models in the same way * cleanup + disable mmap automatically with split tensors models * update imatrix * test-backend-ops : test qwen argsort * update grok model loading * llama : add merged experts tensors to the grok tensor map * minor * gguf : bump version * fix quantizing of merged experts * convert-hf-to-gguf.py : update grok (untested) * make linter happy * cuda/argsort : use shared memory instead of pool memory * convert : fix grok tensor names * metal : add support for non-pow-2 argsort * llama : more loader cleanup, better error checking * cuda : fix warning * llama : still use mmap for loading old models, but copy the data to a host buffer * add review note * llama : remove ffn tensor counting + add sanity check ggml-ci * convert : fix handling of n_experts == None ggml-ci * imatrix : fix ncall counters * llama : produce error if imatrix size does not match * quantize : terminate on errors + trace logs ggml-ci * metal : pad shared memory to 16 bytes --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2024-04-03 13:07:05 +00:00
const ggml_tensor * ids = t->src[2];
const int n_as = src0->ne[2];
const int n_ids = ids->ne[0];
ggml : mul_mat_id use the same tensor for all the experts (#6387) * ggml : update mul_mat_id to use the same tensor for all the experts * update cuda * minor * update metal * update test-backend-ops * fix cuda * Update ggml-metal.m Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * update convert.py * update convert-hf-to-gguf.py * update convert.py for mixtral hf models * Update convert-hf-to-gguf.py Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * cuda : support non-pow-2 number of experts * allow quantize to work for split and merged experts models in the same way * cleanup + disable mmap automatically with split tensors models * update imatrix * test-backend-ops : test qwen argsort * update grok model loading * llama : add merged experts tensors to the grok tensor map * minor * gguf : bump version * fix quantizing of merged experts * convert-hf-to-gguf.py : update grok (untested) * make linter happy * cuda/argsort : use shared memory instead of pool memory * convert : fix grok tensor names * metal : add support for non-pow-2 argsort * llama : more loader cleanup, better error checking * cuda : fix warning * llama : still use mmap for loading old models, but copy the data to a host buffer * add review note * llama : remove ffn tensor counting + add sanity check ggml-ci * convert : fix handling of n_experts == None ggml-ci * imatrix : fix ncall counters * llama : produce error if imatrix size does not match * quantize : terminate on errors + trace logs ggml-ci * metal : pad shared memory to 16 bytes --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2024-04-03 13:07:05 +00:00
// the top-k selected expert ids are stored in the ids tensor
// for simplicity, always copy ids to host, because it is small
// take into account that ids is not contiguous!
GGML_ASSERT(ids->ne[1] == src1->ne[2]);
m_ids.resize(ggml_nbytes(ids));
ggml : mul_mat_id use the same tensor for all the experts (#6387) * ggml : update mul_mat_id to use the same tensor for all the experts * update cuda * minor * update metal * update test-backend-ops * fix cuda * Update ggml-metal.m Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * update convert.py * update convert-hf-to-gguf.py * update convert.py for mixtral hf models * Update convert-hf-to-gguf.py Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * cuda : support non-pow-2 number of experts * allow quantize to work for split and merged experts models in the same way * cleanup + disable mmap automatically with split tensors models * update imatrix * test-backend-ops : test qwen argsort * update grok model loading * llama : add merged experts tensors to the grok tensor map * minor * gguf : bump version * fix quantizing of merged experts * convert-hf-to-gguf.py : update grok (untested) * make linter happy * cuda/argsort : use shared memory instead of pool memory * convert : fix grok tensor names * metal : add support for non-pow-2 argsort * llama : more loader cleanup, better error checking * cuda : fix warning * llama : still use mmap for loading old models, but copy the data to a host buffer * add review note * llama : remove ffn tensor counting + add sanity check ggml-ci * convert : fix handling of n_experts == None ggml-ci * imatrix : fix ncall counters * llama : produce error if imatrix size does not match * quantize : terminate on errors + trace logs ggml-ci * metal : pad shared memory to 16 bytes --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2024-04-03 13:07:05 +00:00
ggml_backend_tensor_get(ids, m_ids.data(), 0, ggml_nbytes(ids));
auto & e = m_stats[wname];
++e.ncall;
if (e.values.empty()) {
e.values.resize(src1->ne[0]*n_as, 0);
e.counts.resize(src1->ne[0]*n_as, 0);
}
else if (e.values.size() != (size_t)src1->ne[0]*n_as) {
fprintf(stderr, "Oops: inconsistent size for %s (%d vs %d)\n", wname.c_str(), (int)e.values.size(), (int)src1->ne[0]*n_as);
exit(1); //GGML_ABORT("fatal error");
}
if (m_params.verbosity > 1) {
printf("%s[%d]: %32s, %s, %5d x %5d, %d\n", __func__, m_last_call, wname.c_str(), ggml_op_name(t->op), (int)src1->ne[0], (int)src1->ne[2], (int)src1->type);
}
// loop over all possible experts, regardless if they are used or not in the batch
for (int ex = 0; ex < n_as; ++ex) {
ggml : mul_mat_id use the same tensor for all the experts (#6387) * ggml : update mul_mat_id to use the same tensor for all the experts * update cuda * minor * update metal * update test-backend-ops * fix cuda * Update ggml-metal.m Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * update convert.py * update convert-hf-to-gguf.py * update convert.py for mixtral hf models * Update convert-hf-to-gguf.py Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * cuda : support non-pow-2 number of experts * allow quantize to work for split and merged experts models in the same way * cleanup + disable mmap automatically with split tensors models * update imatrix * test-backend-ops : test qwen argsort * update grok model loading * llama : add merged experts tensors to the grok tensor map * minor * gguf : bump version * fix quantizing of merged experts * convert-hf-to-gguf.py : update grok (untested) * make linter happy * cuda/argsort : use shared memory instead of pool memory * convert : fix grok tensor names * metal : add support for non-pow-2 argsort * llama : more loader cleanup, better error checking * cuda : fix warning * llama : still use mmap for loading old models, but copy the data to a host buffer * add review note * llama : remove ffn tensor counting + add sanity check ggml-ci * convert : fix handling of n_experts == None ggml-ci * imatrix : fix ncall counters * llama : produce error if imatrix size does not match * quantize : terminate on errors + trace logs ggml-ci * metal : pad shared memory to 16 bytes --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2024-04-03 13:07:05 +00:00
size_t e_start = ex*src1->ne[0];
for (int idx = 0; idx < n_ids; ++idx) {
for (int row = 0; row < (int)src1->ne[2]; ++row) {
const int excur = *(const int32_t *) (m_ids.data() + row*ids->nb[1] + idx*ids->nb[0]);
GGML_ASSERT(excur >= 0 && excur < n_as); // sanity check
if (excur != ex) continue;
const int64_t i11 = idx % src1->ne[1];
const int64_t i12 = row;
const float * x = (const float *)((const char *)data + i11*src1->nb[1] + i12*src1->nb[2]);
for (int j = 0; j < (int)src1->ne[0]; ++j) {
e.values[e_start + j] += x[j]*x[j];
e.counts[e_start + j]++;
if (!std::isfinite(e.values[e_start + j])) {
fprintf(stderr, "%f detected in %s\n", e.values[e_start + j], wname.c_str());
exit(1);
}
}
}
}
if (e.ncall > m_last_call) {
m_last_call = e.ncall;
if (m_last_call % m_params.n_out_freq == 0) {
save_imatrix();
}
if (m_params.n_save_freq > 0 && m_last_call%m_params.n_save_freq == 0) {
save_imatrix(m_last_call);
}
}
}
} else {
auto & e = m_stats[wname];
if (e.values.empty()) {
e.values.resize(src1->ne[0], 0);
e.counts.resize(src1->ne[0], 0);
}
else if (e.values.size() != (size_t)src1->ne[0]) {
fprintf(stderr, "Oops: inconsistent size for %s (%d vs %d)\n", wname.c_str(), (int)e.values.size(), (int)src1->ne[0]);
exit(1); //GGML_ABORT("fatal error");
}
++e.ncall;
if (m_params.verbosity > 1) {
printf("%s[%d]: %32s, %s, %5d x %5d, %d\n", __func__, m_last_call, wname.c_str(), ggml_op_name(t->op), (int)src1->ne[0], (int)src1->ne[1], (int)src1->type);
}
for (int row = 0; row < (int)src1->ne[1]; ++row) {
const float * x = data + row * src1->ne[0];
for (int j = 0; j < (int)src1->ne[0]; ++j) {
e.values[j] += x[j]*x[j];
e.counts[j]++;
if (!std::isfinite(e.values[j])) {
fprintf(stderr, "%f detected in %s\n", e.values[j], wname.c_str());
exit(1);
}
}
}
if (e.ncall > m_last_call) {
m_last_call = e.ncall;
if (m_last_call % m_params.n_out_freq == 0) {
save_imatrix();
}
if (m_params.n_save_freq > 0 && m_last_call%m_params.n_save_freq == 0) {
save_imatrix(m_last_call);
}
}
}
return true;
}
void IMatrixCollector::save_imatrix(int ncall) const {
auto fname = m_params.out_file;
if (fname.empty()) {
fname = "imatrix.dat";
}
if (ncall > 0) {
fname += ".at_";
fname += std::to_string(ncall);
}
// avoid writing imatrix entries that do not have full data
// this can happen with MoE models where some of the experts end up not being exercised by the provided training data
int n_entries = 0;
std::vector<std::string> to_store;
bool is_first = true; // for printing
for (const auto & kv : m_stats) {
const int n_all = kv.second.counts.size();
if (n_all == 0) {
continue;
}
int n_zeros = 0;
for (const int c : kv.second.counts) {
if (c == 0) {
n_zeros++;
}
}
if (n_zeros != 0 && is_first) {
fprintf(stderr, "\n");
is_first = false;
}
if (n_zeros == n_all) {
fprintf(stderr, "%s: entry '%40s' has no data - skipping\n", __func__, kv.first.c_str());
continue;
}
if (n_zeros > 0) {
fprintf(stderr, "%s: entry '%40s' has partial data (%.2f%%) - skipping\n", __func__, kv.first.c_str(), 100.0f * (n_all - n_zeros) / n_all);
continue;
}
n_entries++;
to_store.push_back(kv.first);
}
if (to_store.size() < m_stats.size()) {
fprintf(stderr, "%s: warning: storing only %zu out of %zu entries\n", __func__, to_store.size(), m_stats.size());
}
std::ofstream out(fname, std::ios::binary);
out.write((const char *) &n_entries, sizeof(n_entries));
for (const auto & name : to_store) {
const auto & stat = m_stats.at(name);
int len = name.size();
out.write((const char *) &len, sizeof(len));
out.write(name.c_str(), len);
out.write((const char *) &stat.ncall, sizeof(stat.ncall));
int nval = stat.values.size();
out.write((const char *) &nval, sizeof(nval));
if (nval > 0) {
std::vector<float> tmp(nval);
for (int i = 0; i < nval; i++) {
tmp[i] = (stat.values[i] / static_cast<float>(stat.counts[i])) * static_cast<float>(stat.ncall);
}
out.write((const char*)tmp.data(), nval*sizeof(float));
}
}
// Write the number of call the matrix was computed with
out.write((const char *) &m_last_call, sizeof(m_last_call));
// Write the input filename at the end of the file to later on specify it in quantize
{
int len = m_params.prompt_file.size();
out.write((const char *) &len, sizeof(len));
out.write(m_params.prompt_file.c_str(), len);
}
if (m_params.verbosity > 0) {
fprintf(stderr, "\n%s: stored collected data after %d chunks in %s\n", __func__, m_last_call, fname.c_str());
}
}
bool IMatrixCollector::load_imatrix(const char * fname) {
std::ifstream in(fname, std::ios::binary);
if (!in) {
printf("%s: failed to open %s\n",__func__, fname);
return false;
}
int n_entries;
in.read((char*)&n_entries, sizeof(n_entries));
if (in.fail() || n_entries < 1) {
printf("%s: no data in file %s\n", __func__, fname);
return false;
}
for (int i = 0; i < n_entries; ++i) {
int len; in.read((char *)&len, sizeof(len));
std::vector<char> name_as_vec(len+1);
in.read((char *)name_as_vec.data(), len);
if (in.fail()) {
printf("%s: failed reading name for entry %d from %s\n",__func__,i+1, fname);
return false;
}
name_as_vec[len] = 0;
std::string name{name_as_vec.data()};
auto & e = m_stats[std::move(name)];
int ncall;
in.read((char*)&ncall, sizeof(ncall));
int nval;
in.read((char *)&nval, sizeof(nval));
if (in.fail() || nval < 1) {
printf("%s: failed reading number of values for entry %d\n",__func__,i);
m_stats = {};
return false;
}
if (e.values.empty()) {
e.values.resize(nval, 0);
e.counts.resize(nval, 0);
}
std::vector<float> tmp(nval);
in.read((char*)tmp.data(), nval*sizeof(float));
if (in.fail()) {
printf("%s: failed reading data for entry %d\n",__func__,i);
m_stats = {};
return false;
}
// Recreate the state as expected by save_imatrix(), and corerct for weighted sum.
for (int i = 0; i < nval; i++) {
e.values[i] += tmp[i];
e.counts[i] += ncall;
}
e.ncall += ncall;
}
return true;
}
static IMatrixCollector g_collector;
static bool ik_collect_imatrix(struct ggml_tensor * t, bool ask, void * user_data) {
return g_collector.collect_imatrix(t, ask, user_data);
}
struct results_log_softmax {
double log_softmax;
float logit;
float prob;
};
static std::vector<float> softmax(const std::vector<float> & logits) {
std::vector<float> probs(logits.size());
float max_logit = logits[0];
for (float v : logits) {
max_logit = std::max(max_logit, v);
}
double sum_exp = 0.0;
for (size_t i = 0; i < logits.size(); i++) {
// Subtract the maximum logit value from the current logit value for numerical stability
const float logit = logits[i] - max_logit;
const float exp_logit = expf(logit);
sum_exp += exp_logit;
probs[i] = exp_logit;
}
for (size_t i = 0; i < probs.size(); i++) {
probs[i] /= sum_exp;
}
return probs;
}
static results_log_softmax log_softmax(int n_vocab, const float * logits, int tok) {
float max_logit = logits[0];
for (int i = 1; i < n_vocab; ++i) {
max_logit = std::max(max_logit, logits[i]);
}
double sum_exp = 0.0;
for (int i = 0; i < n_vocab; ++i) {
sum_exp += expf(logits[i] - max_logit);
}
return {logits[tok] - max_logit - log(sum_exp), logits[tok], expf(logits[tok] - max_logit) / (float) sum_exp};
}
static void process_logits(
int n_vocab, const float * logits, const int * tokens, int n_token, std::vector<std::thread> & workers,
double & nll, double & nll2, float * logit_history, float * prob_history) {
std::mutex mutex;
int counter = 0;
auto compute = [&mutex, &counter, &nll, &nll2, logit_history, prob_history, n_vocab, logits, tokens, n_token] () {
double local_nll = 0;
double local_nll2 = 0;
while (true) {
std::unique_lock<std::mutex> lock(mutex);
int i = counter++;
if (i >= n_token) {
nll += local_nll; nll2 += local_nll2;
break;
}
lock.unlock();
const results_log_softmax results = log_softmax(n_vocab, logits + i*n_vocab, tokens[i+1]);
const double v = -results.log_softmax;
local_nll += v;
local_nll2 += v*v;
logit_history[i] = results.logit;
prob_history[i] = results.prob;
}
};
for (auto & w : workers) {
w = std::thread(compute);
}
compute();
for (auto & w : workers) {
w.join();
}
}
static bool compute_imatrix(llama_context * ctx, const gpt_params & params) {
const bool add_bos = llama_add_bos_token(llama_get_model(ctx));
GGML_ASSERT(!llama_add_eos_token(llama_get_model(ctx)));
const int n_ctx = llama_n_ctx(ctx);
auto tim1 = std::chrono::high_resolution_clock::now();
fprintf(stderr, "%s: tokenizing the input ..\n", __func__);
std::vector<llama_token> tokens = ::llama_tokenize(ctx, params.prompt, true);
auto tim2 = std::chrono::high_resolution_clock::now();
fprintf(stderr, "%s: tokenization took %g ms\n",__func__,1e-3*std::chrono::duration_cast<std::chrono::microseconds>(tim2-tim1).count());
if (params.i_chunk > 0) {
if (size_t((params.i_chunk + 2)*n_ctx) >= tokens.size()) {
fprintf(stderr, "%s: there will be not enough tokens left after removing %d chunks\n", __func__, params.i_chunk);
return false;
}
fprintf(stderr, "%s: removing initial %d chunks (%d tokens)\n", __func__, params.i_chunk, params.i_chunk*n_ctx);
tokens.erase(tokens.begin(), tokens.begin() + params.i_chunk*n_ctx);
}
if (int(tokens.size()) < 2*n_ctx) {
fprintf(stderr, "%s: you need at least %d tokens for a context of %d tokens\n",__func__,2*n_ctx,
n_ctx);
fprintf(stderr, "%s: the data file you provided tokenizes to only %zu tokens\n",__func__,tokens.size());
return false;
}
std::vector<float> logit_history;
std::vector<float> prob_history;
if (params.compute_ppl) {
logit_history.resize(tokens.size());
prob_history.resize(tokens.size());
}
const int n_chunk_max = tokens.size() / n_ctx;
const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
const int n_batch = params.n_batch;
int count = 0;
double nll = 0.0;
double nll2 = 0.0;
fprintf(stderr, "%s: computing over %d chunks with batch_size %d\n", __func__, n_chunk, n_batch);
std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
const int num_batches = (n_ctx + n_batch - 1) / n_batch;
std::vector<float> logits;
if (params.compute_ppl && num_batches > 1) {
logits.reserve((size_t)n_ctx * n_vocab);
}
for (int i = 0; i < n_chunk; ++i) {
const int start = i * n_ctx;
const int end = start + n_ctx;
std::vector<float> logits;
const auto t_start = std::chrono::high_resolution_clock::now();
// clear the KV cache
llama_kv_cache_clear(ctx);
for (int j = 0; j < num_batches; ++j) {
const int batch_start = start + j * n_batch;
const int batch_size = std::min(end - batch_start, n_batch);
// save original token and restore it after eval
const auto token_org = tokens[batch_start];
// add BOS token for the first batch of each chunk
if (add_bos && j == 0) {
tokens[batch_start] = llama_token_bos(llama_get_model(ctx));
}
llama : greatly reduce output buffer memory usage (#6122) * llama : greatly reduce logits memory usage * llama : more compact state saving and reloading * llama : fix lctx.n_outputs not being set before building graph * perplexity : adapt to the logits API changes * perplexity : fix Winogrande, use correct logits for second choice start The first logits used to evaluate the second choice were not from the end of the common prefix; instead, they were the logits from the end of the first choice. This has been corrected. The previous implementation sometimes had outliers in the scores of choices for some tasks, and the logic to skip choices words in the log-likelihood evaluation probably was an attempt to reduce those, but it was complex and didn't quite seem to be the right thing. This is simpler now, and the outlier scores aren't there anymore. * perplexity : normalize spaces and punctuation in Winogrande sentences * llama : fix embedding conditions * llama : fix llama_get_embeddings_ith when the resulting id is 0 * llama : fix wrong n_outputs in llama_set_inputs A mismatch happened when using a smaller n_ubatch than n_batch and then using llama_batch_get_one(). The decision of what n_outputs should be now almost fully depends on how lctx.n_outputs is set in llama_decode_internal. The conditions are simpler this way. * llama : when saving the state, recalculate n_outputs This ensures the correct number of outputs for the entire previous batch is stored in the session file, even when n_ubatch is smaller than n_batch. * llama : fix not-skipping outputs of non-causal models * llama : fix running a batch with n_outputs == 0 It previously worked because lctx.inp_out_ids was not initialized, so it pointed to some garbage address which was somehow still valid when I ran my tests. * llama : keep same graph topology even when n_outputs == 0 * ggml : saner ggml_can_repeat with empty tensors * ggml : future-proof ggml_is_empty by using GGML_MAX_DIMS - 1 * ggml : do not multi-thread ops returning empty tensors * ggml : make ggml_is_empty public and work with views * llama : use a vector for ctx->output_ids * llama : rework reallocation logic for llama_output_reserve Now comparing the actual size with the new total size of the output buffer to allow more efficient enabling and disabling of the embeddings and/or logits output in the future. * ggml : skip empty tensors in all backends * llama : fix llama_output_reserve nullptr deref when new_size is 0 * perplexity : make Winogrande work as it does on master The problems with the Winogrande implementation will need to be fixed in a separate PR to ease review. * llama : clearer error messages for invalid logits or embeddings ids * llama : assert all models that can have inp_out_ids Since the graph topology is now constant, this presence check can be done even when there are no outputs. * llama : assert logits and embd buffers exist before writing to them * llama : handle errors from llama_output_reserve at call sites * perplexity : make hellaswag and multiple-choice outputs identical to master Due to how the KV cache is updated, the logprobs for tokens in a batch are very slightly affected by the other tokens present in the batch, so to make hellaswag and multiple-choice return exactly the same results as on master, the last token of each sequence needs to be evaluated even though its output is not used at all. This will probably be changed back in the future to make these benchmarks a tiny bit faster. * perplexity : fix division by zero when using less than 100 multiple-choice tasks * llama : allow loading state saved with a different ctx size When loading a session file, the context size is now only required to be at least enough to load the KV cells contained in that session file, instead of requiring to use exactly the same context size as when saving. Doing this enables the use-case of extending or shrinking the context size of a saved session. This breaks existing session files because the meaning of kv_buf_size is slightly changed (previously it was the size of the whole KV cache, now it's only the size of the saved part of it). This allows for finer-grained sanity checks when loading in an effort to keep kv_buf_size useful even when the kv_size is changed. * llama : minor ggml-ci * readme : update recent API changes, and warn about Vulkan --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2024-03-26 14:46:41 +00:00
// TODO: use batch.logits to save computations instead of relying on logits_all == true
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return false;
}
// restore the original token in case it was set to BOS
tokens[batch_start] = token_org;
if (params.compute_ppl && num_batches > 1) {
const auto * batch_logits = llama_get_logits(ctx);
logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
}
}
const auto t_end = std::chrono::high_resolution_clock::now();
if (i == 0) {
const float t_total = std::chrono::duration<float>(t_end - t_start).count();
fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, t_total);
int total_seconds = (int)(t_total * n_chunk);
if (total_seconds >= 60*60) {
fprintf(stderr, "%d hours ", total_seconds / (60*60));
total_seconds = total_seconds % (60*60);
}
fprintf(stderr, "%.2f minutes\n", total_seconds / 60.0);
}
if (params.compute_ppl) {
const int first = n_ctx/2;
const auto all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx);
process_logits(n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
workers, nll, nll2, logit_history.data() + start + first, prob_history.data() + start + first);
count += n_ctx - first - 1;
printf("[%d]%.4lf,", i + 1, std::exp(nll / count));
fflush(stdout);
logits.clear();
}
}
printf("\n");
if (params.compute_ppl) {
nll2 /= count;
nll /= count;
const double ppl = exp(nll);
nll2 -= nll * nll;
if (nll2 > 0) {
nll2 = sqrt(nll2/(count-1));
printf("Final estimate: PPL = %.4lf +/- %.5lf\n", ppl, nll2*ppl);
} else {
printf("Unexpected negative standard deviation of log(prob)\n");
}
}
return true;
}
int main(int argc, char ** argv) {
gpt_params params;
params.n_ctx = 512;
params.logits_all = true;
params.verbosity = 1;
auto options = gpt_params_parser_init(params, LLAMA_EXAMPLE_COMMON, print_usage);
if (!gpt_params_parse(argc, argv, params, options)) {
return 1;
}
params.n_batch = std::min(params.n_batch, params.n_ctx);
g_collector.set_params(params);
for (const auto & in_file : params.in_files) {
printf("%s : loading imatrix from '%s'\n", __func__, in_file.c_str());
if (!g_collector.load_imatrix(in_file.c_str())) {
fprintf(stderr, "%s : failed to load %s\n", __func__, in_file.c_str());
return 1;
}
}
if (params.in_files.size() > 1) {
printf("%s : saving combined imatrix to '%s'\n", __func__, params.out_file.c_str());
g_collector.save_imatrix();
}
ggml : add numa options (#5377) * Added numa options to allow finer grained control as well as plumbing for a new mirror mode that will require numa.h * Reverted Makefile * Fixed include * Removed sched.h from ggml.h, moved ggml_get_numa_affinity into ggml.c, removed trailing whitespace and fixed up a few inconsistent variables * removed trailing whitespace * Added numa options to allow finer grained control as well as plumbing for a new mirror mode that will require numa.h * Reverting Makefile * Fixed a number of issues with the move from BOOL to ggml_numa_strategies. Added a note about mirror mode note being implemented yet * Removing MIRROR_MODE code for this PR * Removing last bit of MIRROR_MODE code for this PR * Removing unneeded branch in server.cpp example and moving get_numa_affinity and making it static * Fixed lingering init_llama_backend() bool calls in tests and examples * Remote enum llama_numa_strategies * Revert bad merge with dynatemp flags * add missing enum ggml_numa_strategies declaration and revert sync problem with master * add missing enum ggml_numa_strategies declaration * fixed ggml_init_numa variable * Update ggml.h Co-authored-by: Jared Van Bortel <cebtenzzre@gmail.com> * Update READMEs with info about numa flags, change INTERLEAVE strategy name to DISTRIBUTE everywhere, implement the improved distribution strategy from @rankaiyx, fix a spelling mistake and un-merge some bad merges * split numa init out from llama_backend_init and created llama_numa_init. Updated all code paths and samples * Fix up some boolean vs enum comparisons * Added #ifdefs for non-Linux OS that don't have cpu_set_t datatype * Update ggml.h Align enum values Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Update ggml.c Remove whitespace Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Update ggml.c align paremeters Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Update examples/server/server.cpp remove whitespace and align brace Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Update common/common.cpp Remove whitespace and align brace Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * unified ggml_numa_strategy enum and fixed text alignment in server.cpp example * Update ggml.c simplified return for platforms without NUMA support Co-authored-by: Jared Van Bortel <cebtenzzre@gmail.com> * removed redundant else from cli argument processing of --numa * whitespace --------- Co-authored-by: root <root@nenya.lothlorien.ca> Co-authored-by: Jared Van Bortel <cebtenzzre@gmail.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> Co-authored-by: Jared Van Bortel <jared@nomic.ai>
2024-02-16 09:31:07 +00:00
llama_backend_init();
llama_numa_init(params.numa);
// pass the callback to the backend scheduler
// it will be executed for each node during the graph computation
params.cb_eval = ik_collect_imatrix;
params.cb_eval_user_data = NULL;
params.warmup = false;
// init
llama_init_result llama_init = llama_init_from_gpt_params(params);
llama_model * model = llama_init.model;
llama_context * ctx = llama_init.context;
if (model == nullptr || ctx == nullptr) {
fprintf(stderr, "%s : failed to init\n", __func__);
return 1;
}
const int n_ctx_train = llama_n_ctx_train(model);
if (params.n_ctx > n_ctx_train) {
fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n",
__func__, n_ctx_train, params.n_ctx);
}
// print system information
{
fprintf(stderr, "\n");
fprintf(stderr, "%s\n", gpt_params_get_system_info(params).c_str());
}
if (!compute_imatrix(ctx, params)) {
return 1;
}
g_collector.save_imatrix();
LOG_TEE("\n");
llama_perf_print(ctx, LLAMA_PERF_TYPE_CONTEXT);
llama_free(ctx);
llama_free_model(model);
llama_backend_free();
return 0;
}