mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 02:44:36 +00:00
6381d4e110
* gguf : first API pass
* gguf : read header + meta data
* gguf : read tensor info
* gguf : initial model loading - not tested
* gguf : add gguf_get_tensor_name()
* gguf : do not support passing existing ggml_context to gguf_init
* gguf : simplify gguf_get_val
* gguf : gguf.c is now part of ggml.c
* gguf : read / write sample models
* gguf : add comments
* refactor : reduce code duplication and better API (#2415)
* gguf : expose the gguf_type enum through the API for now
* gguf : add array support
* gguf.py : some code style changes
* convert.py : start a new simplified implementation by removing old stuff
* convert.py : remove GGML vocab + other obsolete stuff
* GGUF : write tensor (#2426)
* WIP: Write tensor
* GGUF : Support writing tensors in Python
* refactor : rm unused import and upd todos
* fix : fix errors upd writing example
* rm example.gguf
* gitignore *.gguf
* undo formatting
* gguf : add gguf_find_key (#2438)
* gguf.cpp : find key example
* ggml.h : add gguf_find_key
* ggml.c : add gguf_find_key
* gguf : fix writing tensors
* gguf : do not hardcode tensor names to read
* gguf : write sample tensors to read
* gguf : add tokenization constants
* quick and dirty conversion example
* gguf : fix writing gguf arrays
* gguf : write tensors one by one and code reuse
* gguf : fix writing gguf arrays
* gguf : write tensors one by one
* gguf : write tensors one by one
* gguf : write tokenizer data
* gguf : upd gguf conversion script
* Update convert-llama-h5-to-gguf.py
* gguf : handle already encoded string
* ggml.h : get array str and f32
* ggml.c : get arr str and f32
* gguf.py : support any type
* Update convert-llama-h5-to-gguf.py
* gguf : fix set is not subscriptable
* gguf : update convert-llama-h5-to-gguf.py
* constants.py : add layer norm eps
* gguf.py : add layer norm eps and merges
* ggml.h : increase GGML_MAX_NAME to 64
* ggml.c : add gguf_get_arr_n
* Update convert-llama-h5-to-gguf.py
* add gptneox gguf example
* Makefile : add gptneox gguf example
* Update convert-llama-h5-to-gguf.py
* add gptneox gguf example
* Update convert-llama-h5-to-gguf.py
* Update convert-gptneox-h5-to-gguf.py
* Update convert-gptneox-h5-to-gguf.py
* Update convert-llama-h5-to-gguf.py
* gguf : support custom alignment value
* gguf : fix typo in function call
* gguf : mmap tensor data example
* fix : update convert-llama-h5-to-gguf.py
* Update convert-llama-h5-to-gguf.py
* convert-gptneox-h5-to-gguf.py : Special tokens
* gptneox-main.cpp : special tokens
* Update gptneox-main.cpp
* constants.py : special tokens
* gguf.py : accumulate kv and tensor info data + special tokens
* convert-gptneox-h5-to-gguf.py : accumulate kv and ti + special tokens
* gguf : gguf counterpart of llama-util.h
* gguf-util.h : update note
* convert-llama-h5-to-gguf.py : accumulate kv / ti + special tokens
* convert-llama-h5-to-gguf.py : special tokens
* Delete gptneox-common.cpp
* Delete gptneox-common.h
* convert-gptneox-h5-to-gguf.py : gpt2bpe tokenizer
* gptneox-main.cpp : gpt2 bpe tokenizer
* gpt2 bpe tokenizer (handles merges and unicode)
* Makefile : remove gptneox-common
* gguf.py : bytesarray for gpt2bpe tokenizer
* cmpnct_gpt2bpe.hpp : comments
* gguf.py : use custom alignment if present
* gguf : minor stuff
* Update gptneox-main.cpp
* map tensor names
* convert-gptneox-h5-to-gguf.py : map tensor names
* convert-llama-h5-to-gguf.py : map tensor names
* gptneox-main.cpp : map tensor names
* gguf : start implementing libllama in GGUF (WIP)
* gguf : start implementing libllama in GGUF (WIP)
* rm binary commited by mistake
* upd .gitignore
* gguf : calculate n_mult
* gguf : inference with 7B model working (WIP)
* gguf : rm deprecated function
* gguf : start implementing gguf_file_saver (WIP)
* gguf : start implementing gguf_file_saver (WIP)
* gguf : start implementing gguf_file_saver (WIP)
* gguf : add gguf_get_kv_type
* gguf : add gguf_get_kv_type
* gguf : write metadata in gguf_file_saver (WIP)
* gguf : write metadata in gguf_file_saver (WIP)
* gguf : write metadata in gguf_file_saver
* gguf : rm references to old file formats
* gguf : shorter name for member variable
* gguf : rm redundant method
* gguf : get rid of n_mult, read n_ff from file
* Update gguf_tensor_map.py
* Update gptneox-main.cpp
* gguf : rm references to old file magics
* gguf : start implementing quantization (WIP)
* gguf : start implementing quantization (WIP)
* gguf : start implementing quantization (WIP)
* gguf : start implementing quantization (WIP)
* gguf : start implementing quantization (WIP)
* gguf : start implementing quantization (WIP)
* gguf : quantization is working
* gguf : roper closing of file
* gguf.py : no need to convert tensors twice
* convert-gptneox-h5-to-gguf.py : no need to convert tensors twice
* convert-llama-h5-to-gguf.py : no need to convert tensors twice
* convert-gptneox-h5-to-gguf.py : simplify nbytes
* convert-llama-h5-to-gguf.py : simplify nbytes
* gptneox-main.cpp : n_layer --> n_block
* constants.py : n_layer --> n_block
* gguf.py : n_layer --> n_block
* convert-gptneox-h5-to-gguf.py : n_layer --> n_block
* convert-llama-h5-to-gguf.py : n_layer --> n_block
* gptneox-main.cpp : n_layer --> n_block
* Update gguf_tensor_map.py
* convert-gptneox-h5-to-gguf.py : load model in parts to save memory
* convert-llama-h5-to-gguf.py : load model in parts to save memory
* convert : write more metadata for LLaMA
* convert : rm quantization version
* convert-gptneox-h5-to-gguf.py : add file_type key
* gptneox-main.cpp : add file_type key
* fix conflicts
* gguf : add todos and comments
* convert-gptneox-h5-to-gguf.py : tensor name map changes
* Create gguf_namemap.py : tensor name map changes
* Delete gguf_tensor_map.py
* gptneox-main.cpp : tensor name map changes
* convert-llama-h5-to-gguf.py : fixes
* gguf.py : dont add empty strings
* simple : minor style changes
* gguf : use UNIX line ending
* Create convert-llama-7b-pth-to-gguf.py
* llama : sync gguf-llama.cpp with latest llama.cpp (#2608)
* llama : sync gguf-llama.cpp with latest llama.cpp
* minor : indentation + assert
* llama : refactor gguf_buffer and gguf_ctx_buffer
* llama : minor
* gitignore : add gptneox-main
* llama : tokenizer fixes (#2549)
* Merge tokenizer fixes into the gguf branch.
* Add test vocabularies
* convert : update convert-new.py with tokenizer fixes (#2614)
* Merge tokenizer fixes into the gguf branch.
* Add test vocabularies
* Adapt convert-new.py (and fix a clang-cl compiler error on windows)
* llama : sync gguf-llama with llama (#2613)
* llama : sync gguf-llama with llama
* tests : fix build + warnings (test-tokenizer-1 still fails)
* tests : fix wstring_convert
* convert : fix layer names
* llama : sync gguf-llama.cpp
* convert : update HF converter to new tokenizer voodoo magics
* llama : update tokenizer style
* convert-llama-h5-to-gguf.py : add token types
* constants.py : add token types
* gguf.py : add token types
* convert-llama-7b-pth-to-gguf.py : add token types
* gguf-llama.cpp : fix n_head_kv
* convert-llama-h5-to-gguf.py : add 70b gqa support
* gguf.py : add tensor data layout
* convert-llama-h5-to-gguf.py : add tensor data layout
* convert-llama-7b-pth-to-gguf.py : add tensor data layout
* gptneox-main.cpp : add tensor data layout
* convert-llama-h5-to-gguf.py : clarify the reverse permute
* llama : refactor model loading code (#2620)
* llama : style formatting + remove helper methods
* llama : fix quantization using gguf tool
* llama : simplify gguf_file_saver
* llama : fix method names
* llama : simplify write_header()
* llama : no need to pass full file loader to the file saver
just gguf_ctx
* llama : gguf_file_saver write I32
* llama : refactor tensor names (#2622)
* gguf: update tensor names searched in quantization
* gguf : define tensor names as constants
* gguf : initial write API (not tested yet)
* gguf : write to file API (not tested)
* gguf : initial write API ready + example
* gguf : fix header write
* gguf : fixes + simplify example + add ggml_nbytes_pad()
* gguf : minor
* llama : replace gguf_file_saver with new gguf write API
* gguf : streaming support when writing files
* gguf : remove oboslete write methods
* gguf : remove obosolete gguf_get_arr_xxx API
* llama : simplify gguf_file_loader
* llama : move hparams and vocab from gguf_file_loader to llama_model_loader
* llama : merge gguf-util.h in llama.cpp
* llama : reorder definitions in .cpp to match .h
* llama : minor simplifications
* llama : refactor llama_model_loader (WIP)
wip : remove ggml_ctx from llama_model_loader
wip : merge gguf_file_loader in llama_model_loader
* llama : fix shape prints
* llama : fix Windows build + fix norm_rms_eps key
* llama : throw error on missing KV paris in model meta data
* llama : improve printing + log meta data
* llama : switch print order of meta data
---------
Co-authored-by: M. Yusuf Sarıgöz <yusufsarigoz@gmail.com>
* gguf : deduplicate (#2629)
* gguf : better type names
* dedup : CPU + Metal is working
* ggml : fix warnings about unused results
* llama.cpp : fix line feed and compiler warning
* llama : fix strncpy warning + note token_to_str does not write null
* llama : restore the original load/save session implementation
Will migrate this to GGUF in the future
* convert-llama-h5-to-gguf.py : support alt ctx param name
* ggml : assert when using ggml_mul with non-F32 src1
* examples : dedup simple
---------
Co-authored-by: klosax <131523366+klosax@users.noreply.github.com>
* gguf.py : merge all files in gguf.py
* convert-new.py : pick #2427 for HF 70B support
* examples/gguf : no need to keep q option for quantization any more
* llama.cpp : print actual model size
* llama.cpp : use ggml_elements()
* convert-new.py : output gguf (#2635)
* convert-new.py : output gguf (WIP)
* convert-new.py : add gguf key-value pairs
* llama : add hparams.ctx_train + no longer print ftype
* convert-new.py : minor fixes
* convert-new.py : vocab-only option should work now
* llama : fix tokenizer to use llama_char_to_byte
* tests : add new ggml-vocab-llama.gguf
* convert-new.py : tensor name mapping
* convert-new.py : add map for skipping tensor serialization
* convert-new.py : convert script now works
* gguf.py : pick some of the refactoring from #2644
* convert-new.py : minor fixes
* convert.py : update to support GGUF output
* Revert "ci : disable CI temporary to not waste energy"
This reverts commit 7e82d25f40
.
* convert.py : n_head_kv optional and .gguf file extension
* convert.py : better always have n_head_kv and default it to n_head
* llama : sync with recent PRs on master
* editorconfig : ignore models folder
ggml-ci
* ci : update ".bin" to ".gguf" extension
ggml-ci
* llama : fix llama_model_loader memory leak
* gptneox : move as a WIP example
* llama : fix lambda capture
ggml-ci
* ggml : fix bug in gguf_set_kv
ggml-ci
* common.h : .bin --> .gguf
* quantize-stats.cpp : .bin --> .gguf
* convert.py : fix HF tensor permuting / unpacking
ggml-ci
* llama.cpp : typo
* llama : throw error if gguf fails to init from file
ggml-ci
* llama : fix tensor name grepping during quantization
ggml-ci
* gguf.py : write tensors in a single pass (#2644)
* gguf : single pass for writing tensors + refactoring writer
* gguf : single pass for writing tensors + refactoring writer
* gguf : single pass for writing tensors + refactoring writer
* gguf : style fixes in simple conversion script
* gguf : refactor gptneox conversion script
* gguf : rename h5 to hf (for HuggingFace)
* gguf : refactor pth to gguf conversion script
* gguf : rm file_type key and method
* gguf.py : fix vertical alignment
* gguf.py : indentation
---------
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
* convert-gptneox-hf-to-gguf.py : fixes
* gguf.py : gptneox mapping
* convert-llama-hf-to-gguf.py : fixes
* convert-llama-7b-pth-to-gguf.py : fixes
* ggml.h : reverse GGUF_MAGIC
* gguf.py : reverse GGUF_MAGIC
* test-tokenizer-0.cpp : fix warning
* llama.cpp : print kv general.name
* llama.cpp : get special token kv and linefeed token id
* llama : print number of tensors per type + print arch + style
* tests : update vocab file with new magic
* editorconfig : fix whitespaces
* llama : re-order functions
* llama : remove C++ API + reorganize common source in /common dir
* llama : minor API updates
* llama : avoid hardcoded special tokens
* llama : fix MPI build
ggml-ci
* llama : introduce enum llama_vocab_type + remove hardcoded string constants
* convert-falcon-hf-to-gguf.py : falcon HF --> gguf conversion, not tested
* falcon-main.cpp : falcon inference example
* convert-falcon-hf-to-gguf.py : remove extra kv
* convert-gptneox-hf-to-gguf.py : remove extra kv
* convert-llama-7b-pth-to-gguf.py : remove extra kv
* convert-llama-hf-to-gguf.py : remove extra kv
* gguf.py : fix for falcon 40b
* falcon-main.cpp : fix for falcon 40b
* convert-falcon-hf-to-gguf.py : update ref
* convert-falcon-hf-to-gguf.py : add tensor data layout
* cmpnct_gpt2bpe.hpp : fixes
* falcon-main.cpp : fixes
* gptneox-main.cpp : fixes
* cmpnct_gpt2bpe.hpp : remove non-general stuff
* Update examples/server/README.md
Co-authored-by: slaren <slarengh@gmail.com>
* cmpnct_gpt2bpe.hpp : cleanup
* convert-llama-hf-to-gguf.py : special tokens
* convert-llama-7b-pth-to-gguf.py : special tokens
* convert-permute-debug.py : permute debug print
* convert-permute-debug-master.py : permute debug for master
* convert-permute-debug.py : change permute type of attn_q
* convert.py : 70b model working (change attn_q permute)
* Delete convert-permute-debug-master.py
* Delete convert-permute-debug.py
* convert-llama-hf-to-gguf.py : fix attn_q permute
* gguf.py : fix rope scale kv
* convert-llama-hf-to-gguf.py : rope scale and added tokens
* convert-llama-7b-pth-to-gguf.py : rope scale and added tokens
* llama.cpp : use rope scale kv
* convert-llama-7b-pth-to-gguf.py : rope scale fix
* convert-llama-hf-to-gguf.py : rope scale fix
* py : fix whitespace
* gguf : add Python script to convert GGMLv3 LLaMA models to GGUF (#2682)
* First pass at converting GGMLv3 LLaMA models to GGUF
* Cleanups, better output during conversion
* Fix vocab space conversion logic
* More vocab conversion fixes
* Add description to converted GGUF files
* Improve help text, expand warning
* Allow specifying name and description for output GGUF
* Allow overriding vocab and hyperparams from original model metadata
* Use correct params override var name
* Fix wrong type size for Q8_K
Better handling of original style metadata
* Set default value for gguf add_tensor raw_shape KW arg
* llama : improve token type support (#2668)
* Merge tokenizer fixes into the gguf branch.
* Add test vocabularies
* Adapt convert-new.py (and fix a clang-cl compiler error on windows)
* Improved tokenizer test
But does it work on MacOS?
* Improve token type support
- Added @klosax code to convert.py
- Improved token type support in vocabulary
* Exclude platform dependent tests
* More sentencepiece compatibility by eliminating magic numbers
* Restored accidentally removed comment
* llama : add API for token type
ggml-ci
* tests : use new tokenizer type API (#2692)
* Merge tokenizer fixes into the gguf branch.
* Add test vocabularies
* Adapt convert-new.py (and fix a clang-cl compiler error on windows)
* Improved tokenizer test
But does it work on MacOS?
* Improve token type support
- Added @klosax code to convert.py
- Improved token type support in vocabulary
* Exclude platform dependent tests
* More sentencepiece compatibility by eliminating magic numbers
* Restored accidentally removed comment
* Improve commentary
* Use token type API in test-tokenizer-1.cpp
* py : cosmetics
* readme : add notice about new file format
ggml-ci
---------
Co-authored-by: M. Yusuf Sarıgöz <yusufsarigoz@gmail.com>
Co-authored-by: klosax <131523366+klosax@users.noreply.github.com>
Co-authored-by: goerch <jhr.walter@t-online.de>
Co-authored-by: slaren <slarengh@gmail.com>
Co-authored-by: Kerfuffle <44031344+KerfuffleV2@users.noreply.github.com>
423 lines
16 KiB
C++
423 lines
16 KiB
C++
#include "common.h"
|
|
#include "llama.h"
|
|
#include "build-info.h"
|
|
|
|
#include <cmath>
|
|
#include <ctime>
|
|
#include <sstream>
|
|
#include <cstring>
|
|
|
|
#if defined(_MSC_VER)
|
|
#pragma warning(disable: 4244 4267) // possible loss of data
|
|
#endif
|
|
|
|
std::vector<float> softmax(const std::vector<float>& logits) {
|
|
std::vector<float> probs(logits.size());
|
|
float max_logit = logits[0];
|
|
for (float v : logits) max_logit = std::max(max_logit, v);
|
|
double sum_exp = 0.0;
|
|
for (size_t i = 0; i < logits.size(); i++) {
|
|
// Subtract the maximum logit value from the current logit value for numerical stability
|
|
const float logit = logits[i] - max_logit;
|
|
const float exp_logit = expf(logit);
|
|
sum_exp += exp_logit;
|
|
probs[i] = exp_logit;
|
|
}
|
|
for (size_t i = 0; i < probs.size(); i++) probs[i] /= sum_exp;
|
|
return probs;
|
|
}
|
|
|
|
void perplexity(llama_context * ctx, const gpt_params & params) {
|
|
// Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research
|
|
// Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
|
|
// Output: `perplexity: 13.5106 [114/114]`
|
|
// BOS tokens will be added for each chunk before eval
|
|
auto tokens = ::llama_tokenize(ctx, params.prompt, true);
|
|
|
|
const int n_chunk_max = tokens.size() / params.n_ctx;
|
|
|
|
const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
|
|
const int n_vocab = llama_n_vocab(ctx);
|
|
const int n_batch = params.n_batch;
|
|
|
|
int count = 0;
|
|
double nll = 0.0;
|
|
|
|
fprintf(stderr, "%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch);
|
|
|
|
for (int i = 0; i < n_chunk; ++i) {
|
|
const int start = i * params.n_ctx;
|
|
const int end = start + params.n_ctx;
|
|
|
|
const int num_batches = (params.n_ctx + n_batch - 1) / n_batch;
|
|
|
|
std::vector<float> logits;
|
|
|
|
const auto t_start = std::chrono::high_resolution_clock::now();
|
|
|
|
for (int j = 0; j < num_batches; ++j) {
|
|
const int batch_start = start + j * n_batch;
|
|
const int batch_size = std::min(end - batch_start, n_batch);
|
|
|
|
// save original token and restore it after eval
|
|
const auto token_org = tokens[batch_start];
|
|
|
|
// add BOS token for the first batch of each chunk
|
|
if (j == 0) {
|
|
tokens[batch_start] = llama_token_bos(ctx);
|
|
}
|
|
|
|
if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * n_batch, params.n_threads)) {
|
|
fprintf(stderr, "%s : failed to eval\n", __func__);
|
|
return;
|
|
}
|
|
|
|
// restore the original token in case it was set to BOS
|
|
tokens[batch_start] = token_org;
|
|
|
|
const auto batch_logits = llama_get_logits(ctx);
|
|
logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
|
|
}
|
|
|
|
const auto t_end = std::chrono::high_resolution_clock::now();
|
|
|
|
if (i == 0) {
|
|
const float t_total = std::chrono::duration<float>(t_end - t_start).count();
|
|
fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, t_total);
|
|
int total_seconds = (int)(t_total * n_chunk);
|
|
if (total_seconds >= 60*60) {
|
|
fprintf(stderr, "%d hours ", total_seconds / (60*60));
|
|
total_seconds = total_seconds % (60*60);
|
|
}
|
|
fprintf(stderr, "%.2f minutes\n", total_seconds / 60.0);
|
|
}
|
|
|
|
// We get the logits for all the tokens in the context window (params.n_ctx)
|
|
// from llama_eval above. Now, based on https://huggingface.co/docs/transformers/perplexity,
|
|
// calculate the perplexity over the last half of the window (so the model always has
|
|
// some context to predict the token).
|
|
//
|
|
// We rely on the fact that attention in the forward pass only looks at previous
|
|
// tokens here, so the logits returned for each token are an accurate representation
|
|
// of what the model would have predicted at that point.
|
|
//
|
|
// Example, we have a context window of 512, we will compute perplexity for each of the
|
|
// last 256 tokens. Then, we split the input up into context window size chunks to
|
|
// process the entire prompt.
|
|
for (int j = std::min(512, params.n_ctx / 2); j < params.n_ctx - 1; ++j) {
|
|
// Calculate probability of next token, given the previous ones.
|
|
const std::vector<float> tok_logits(
|
|
logits.begin() + (j + 0) * n_vocab,
|
|
logits.begin() + (j + 1) * n_vocab);
|
|
|
|
const float prob = softmax(tok_logits)[tokens[start + j + 1]];
|
|
|
|
nll += -std::log(prob);
|
|
++count;
|
|
}
|
|
// perplexity is e^(average negative log-likelihood)
|
|
printf("[%d]%.4lf,", i + 1, std::exp(nll / count));
|
|
fflush(stdout);
|
|
}
|
|
printf("\n");
|
|
}
|
|
|
|
std::vector<float> hellaswag_evaluate_tokens(llama_context * ctx, const std::vector<int>& tokens, int n_past, int n_batch,
|
|
int n_vocab, int n_thread) {
|
|
std::vector<float> result;
|
|
result.reserve(tokens.size() * n_vocab);
|
|
size_t n_chunk = (tokens.size() + n_batch - 1)/n_batch;
|
|
for (size_t i_chunk = 0; i_chunk < n_chunk; ++i_chunk) {
|
|
size_t n_tokens = tokens.size() - i_chunk * n_batch;
|
|
n_tokens = std::min(n_tokens, size_t(n_batch));
|
|
if (llama_eval(ctx, tokens.data() + i_chunk * n_batch, n_tokens, n_past, n_thread)) {
|
|
fprintf(stderr, "%s : failed to eval\n", __func__);
|
|
return {};
|
|
}
|
|
|
|
const auto logits = llama_get_logits(ctx);
|
|
result.insert(result.end(), logits, logits + n_tokens * n_vocab);
|
|
|
|
n_past += n_tokens;
|
|
}
|
|
return result;
|
|
}
|
|
|
|
void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
|
// Calculates hellaswag score (acc_norm) from prompt
|
|
//
|
|
// Data extracted from the HellaSwag validation dataset (MIT license) https://github.com/rowanz/hellaswag/blob/master/data/hellaswag_val.jsonl
|
|
// All used data fields are preprocessed as in https://github.com/EleutherAI/lm-evaluation-harness/blob/df3da98c5405deafd519c2ddca52bb7c3fe36bef/lm_eval/tasks/hellaswag.py#L62-L68
|
|
//
|
|
// All 10042 tasks should be extracted to keep the results standardized like other implementations.
|
|
//
|
|
// Datafile layout:
|
|
// ['??'] denotes json fields
|
|
// 6 lines per task:
|
|
// ['activity_label'] + ": " +['ctx'] - The first part of the query, the context
|
|
// ['label'] - The index the best common sense ending aka gold ending
|
|
// ['endings'][0] - Endings added to the first part of the query
|
|
// ['endings'][1]
|
|
// ['endings'][2]
|
|
// ['endings'][3]
|
|
|
|
std::vector<std::string> prompt_lines;
|
|
std::istringstream strstream(params.prompt);
|
|
std::string line;
|
|
|
|
while (std::getline(strstream,line,'\n')) {
|
|
prompt_lines.push_back(line);
|
|
}
|
|
|
|
if( prompt_lines.size() % 6 != 0) {
|
|
fprintf(stderr, "%s : number of lines in prompt not a multiple of 6.\n", __func__);
|
|
return;
|
|
}
|
|
|
|
size_t hs_task_count = prompt_lines.size()/6;
|
|
fprintf(stderr, "%s : loaded %zu tasks from prompt.\n", __func__, hs_task_count);
|
|
|
|
// This is needed as usual for LLaMA models
|
|
bool prepend_bos = true;
|
|
|
|
// Number of tasks to use when computing the score
|
|
if ( params.hellaswag_tasks < hs_task_count ) {
|
|
hs_task_count = params.hellaswag_tasks;
|
|
}
|
|
|
|
// The tasks should be randomized so the score stabilizes quickly.
|
|
bool randomize_tasks = true;
|
|
|
|
// The random seed should not impact the final result if the computation is done over enough tasks, so kept hardcoded for now
|
|
std::mt19937 rng(1);
|
|
|
|
// Dataholder for hellaswag tasks
|
|
struct hs_data_t {
|
|
std::string context;
|
|
size_t gold_ending_idx;
|
|
std::string ending[4];
|
|
size_t ending_logprob_count[4];
|
|
double ending_logprob[4];
|
|
};
|
|
|
|
fprintf(stderr, "%s : selecting %zu %s tasks.\n", __func__, hs_task_count, (randomize_tasks?"randomized":"the first") );
|
|
|
|
// Select and read data from prompt lines
|
|
hs_data_t *hs_data = new hs_data_t[hs_task_count];
|
|
for (size_t i=0; i < hs_task_count; i++) {
|
|
size_t idx = i;
|
|
|
|
// Select a random example of those left in the prompt
|
|
if (randomize_tasks) {
|
|
std::uniform_int_distribution<size_t> dist(0, prompt_lines.size()/6-1 ) ;
|
|
idx = dist(rng);
|
|
}
|
|
|
|
hs_data[i].context = prompt_lines[idx*6];
|
|
hs_data[i].gold_ending_idx = std::stoi( prompt_lines[idx*6+1] );
|
|
for (size_t j=0; j < 4; j++) {
|
|
hs_data[i].ending[j] = " " + prompt_lines[idx*6+2+j];
|
|
}
|
|
|
|
// Delete the selected random example from the prompt
|
|
if (randomize_tasks) {
|
|
prompt_lines.erase( std::next(prompt_lines.begin(),idx*6) , std::next(prompt_lines.begin(),idx*6+6) );
|
|
}
|
|
}
|
|
|
|
fprintf(stderr, "%s : calculating hellaswag score over selected tasks.\n", __func__);
|
|
printf("\ntask\tacc_norm\n");
|
|
|
|
double acc = 0.0f;
|
|
const int n_vocab = llama_n_vocab(ctx);
|
|
|
|
std::vector<float> tok_logits(n_vocab);
|
|
|
|
for (size_t task_idx = 0; task_idx < hs_task_count; task_idx++) {
|
|
|
|
// Tokenize the context to count tokens
|
|
std::vector<int> context_embd = ::llama_tokenize(ctx, hs_data[task_idx].context, prepend_bos);
|
|
size_t context_size = context_embd.size();
|
|
|
|
// Do the 1st ending
|
|
// In this case we include the context when evaluating
|
|
auto query_embd = ::llama_tokenize(ctx, hs_data[task_idx].context + hs_data[task_idx].ending[0], prepend_bos);
|
|
auto query_size = query_embd.size();
|
|
//printf("First query: %d\n",(int)query_size);
|
|
|
|
// Stop if query wont fit the ctx window
|
|
if (query_size > (size_t)params.n_ctx) {
|
|
fprintf(stderr, "%s : number of tokens in query %zu > n_ctxl\n", __func__, query_size);
|
|
return;
|
|
}
|
|
|
|
// Speedup small evaluations by evaluating atleast 32 tokens
|
|
if (query_size < 32) {
|
|
query_embd.resize(32);
|
|
}
|
|
|
|
auto logits = hellaswag_evaluate_tokens(ctx, query_embd, 0, params.n_batch, n_vocab, params.n_threads);
|
|
if (logits.empty()) {
|
|
fprintf(stderr, "%s : failed to eval\n", __func__);
|
|
return;
|
|
}
|
|
|
|
std::memcpy(tok_logits.data(), logits.data() + (context_size-1)*n_vocab, n_vocab*sizeof(float));
|
|
const auto first_probs = softmax(tok_logits);
|
|
|
|
hs_data[task_idx].ending_logprob_count[0] = 1;
|
|
hs_data[task_idx].ending_logprob[0] = std::log(first_probs[query_embd[context_size]]);
|
|
|
|
// Calculate the logprobs over the ending
|
|
for (size_t j = context_size; j < query_size - 1; j++) {
|
|
|
|
std::memcpy(tok_logits.data(), logits.data() + j*n_vocab, n_vocab*sizeof(float));
|
|
|
|
const float prob = softmax(tok_logits)[query_embd[j + 1]];
|
|
|
|
hs_data[task_idx].ending_logprob[0] += std::log(prob);
|
|
hs_data[task_idx].ending_logprob_count[0]++;
|
|
}
|
|
|
|
// Calculate the mean token logprob for acc_norm
|
|
hs_data[task_idx].ending_logprob[0] /= hs_data[task_idx].ending_logprob_count[0];
|
|
|
|
// Do the remaining endings
|
|
// For these, we use the bare ending with n_past = context_size
|
|
//
|
|
for (size_t ending_idx = 1; ending_idx < 4; ending_idx++) {
|
|
|
|
// Tokenize the query
|
|
query_embd = ::llama_tokenize(ctx, hs_data[task_idx].ending[ending_idx], false);
|
|
query_size = query_embd.size();
|
|
|
|
// Stop if query wont fit the ctx window
|
|
if (context_size + query_size > (size_t)params.n_ctx) {
|
|
fprintf(stderr, "%s : number of tokens in query %zu > n_ctxl\n", __func__, query_size);
|
|
return;
|
|
}
|
|
|
|
// Speedup small evaluations by evaluating atleast 32 tokens
|
|
// No, resizing to 32 is actually slightly slower (at least on CUDA)
|
|
//if (query_size < 32) {
|
|
// query_embd.resize(32);
|
|
//}
|
|
|
|
// Evaluate the query
|
|
logits = hellaswag_evaluate_tokens(ctx, query_embd, context_size, params.n_batch, n_vocab, params.n_threads);
|
|
if (logits.empty()) {
|
|
fprintf(stderr, "%s : failed to eval\n", __func__);
|
|
return;
|
|
}
|
|
|
|
hs_data[task_idx].ending_logprob_count[ending_idx] = 1;
|
|
hs_data[task_idx].ending_logprob[ending_idx] = std::log(first_probs[query_embd[0]]);
|
|
|
|
// Calculate the logprobs over the ending
|
|
for (size_t j = 0; j < query_size - 1; j++) {
|
|
std::memcpy(tok_logits.data(), logits.data() + j*n_vocab, n_vocab*sizeof(float));
|
|
|
|
const float prob = softmax(tok_logits)[query_embd[j + 1]];
|
|
|
|
hs_data[task_idx].ending_logprob[ending_idx] += std::log(prob);
|
|
hs_data[task_idx].ending_logprob_count[ending_idx]++;
|
|
}
|
|
|
|
// Calculate the mean token logprob for acc_norm
|
|
hs_data[task_idx].ending_logprob[ending_idx] /= hs_data[task_idx].ending_logprob_count[ending_idx];
|
|
|
|
|
|
// printf("task %lu, ending %lu, whole_len %lu, context_len %lu, ending_logprob_count %lu, ending_logprob %.4f\n",
|
|
// task_idx,ending_idx,whole_size,context_size, hs_data[task_idx].ending_logprob_count[ending_idx], hs_data[task_idx].ending_logprob[ending_idx] );
|
|
}
|
|
|
|
// Find the ending with maximum logprob
|
|
size_t ending_logprob_max_idx = 0;
|
|
double ending_logprob_max_val = hs_data[task_idx].ending_logprob[0];
|
|
for (size_t j = 1; j < 4; j++) {
|
|
if (hs_data[task_idx].ending_logprob[j] > ending_logprob_max_val) {
|
|
ending_logprob_max_idx = j;
|
|
ending_logprob_max_val = hs_data[task_idx].ending_logprob[j];
|
|
}
|
|
}
|
|
|
|
// printf("max logprob ending idx %lu, gold ending idx %lu\n", ending_logprob_max_idx, hs_data[task_idx].gold_ending_idx);
|
|
|
|
// If the gold ending got the maximum logprobe add one accuracy point
|
|
if (ending_logprob_max_idx == hs_data[task_idx].gold_ending_idx) {
|
|
acc += 1.0;
|
|
}
|
|
|
|
// Print the accumulated accuracy mean x 100
|
|
printf("%zu\t%.8lf\n",task_idx+1, acc/double(task_idx+1)*100.0);
|
|
fflush(stdout);
|
|
}
|
|
|
|
delete [] hs_data;
|
|
|
|
printf("\n");
|
|
}
|
|
|
|
int main(int argc, char ** argv) {
|
|
gpt_params params;
|
|
|
|
params.n_batch = 512;
|
|
if (gpt_params_parse(argc, argv, params) == false) {
|
|
return 1;
|
|
}
|
|
|
|
params.perplexity = true;
|
|
params.n_batch = std::min(params.n_batch, params.n_ctx);
|
|
|
|
if (params.n_ctx > 2048) {
|
|
fprintf(stderr, "%s: warning: model might not support context sizes greater than 2048 tokens (%d specified);"
|
|
"expect poor results\n", __func__, params.n_ctx);
|
|
}
|
|
|
|
fprintf(stderr, "%s: build = %d (%s)\n", __func__, BUILD_NUMBER, BUILD_COMMIT);
|
|
|
|
if (params.seed == LLAMA_DEFAULT_SEED) {
|
|
params.seed = time(NULL);
|
|
}
|
|
|
|
fprintf(stderr, "%s: seed = %u\n", __func__, params.seed);
|
|
|
|
std::mt19937 rng(params.seed);
|
|
if (params.random_prompt) {
|
|
params.prompt = gpt_random_prompt(rng);
|
|
}
|
|
|
|
llama_backend_init(params.numa);
|
|
|
|
llama_model * model;
|
|
llama_context * ctx;
|
|
|
|
// load the model and apply lora adapter, if any
|
|
std::tie(model, ctx) = llama_init_from_gpt_params(params);
|
|
if (model == NULL) {
|
|
fprintf(stderr, "%s: error: unable to load model\n", __func__);
|
|
return 1;
|
|
}
|
|
|
|
// print system information
|
|
{
|
|
fprintf(stderr, "\n");
|
|
fprintf(stderr, "system_info: n_threads = %d / %d | %s\n",
|
|
params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
|
|
}
|
|
|
|
if (params.hellaswag) {
|
|
hellaswag_score(ctx, params);
|
|
} else {
|
|
perplexity(ctx, params);
|
|
}
|
|
|
|
llama_print_timings(ctx);
|
|
llama_free(ctx);
|
|
llama_free_model(model);
|
|
|
|
llama_backend_free();
|
|
|
|
return 0;
|
|
}
|