sampling : refactor init to use llama_sampling_params (#3696)

* sampling : refactor init to use llama_sampling_params

* llama : combine repetition, frequency and presence penalties in 1 call

* examples : remove embd-input and gptneox-wip

* sampling : rename penalty params + reduce size of "prev" vector

* sampling : add llama_sampling_print helper

* sampling : hide prev behind API and apply #3661

ggml-ci
This commit is contained in:
Georgi Gerganov 2023-10-20 21:07:23 +03:00 committed by GitHub
parent 8cf19d60dc
commit d1031cf49c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
30 changed files with 365 additions and 4502 deletions

View File

@ -1,7 +1,7 @@
# Define the default target now so that it is always the first target # Define the default target now so that it is always the first target
BUILD_TARGETS = \ BUILD_TARGETS = \
main quantize quantize-stats perplexity embedding vdot q8dot train-text-from-scratch convert-llama2c-to-ggml \ main quantize quantize-stats perplexity embedding vdot q8dot train-text-from-scratch convert-llama2c-to-ggml \
simple batched batched-bench save-load-state server embd-input-test gguf llama-bench llava baby-llama beam-search \ simple batched batched-bench save-load-state server gguf llama-bench llava baby-llama beam-search \
speculative infill benchmark-matmult parallel finetune export-lora tests/test-c.o speculative infill benchmark-matmult parallel finetune export-lora tests/test-c.o
# Binaries only useful for tests # Binaries only useful for tests
@ -608,13 +608,6 @@ save-load-state: examples/save-load-state/save-load-state.cpp build-info.h ggml.
server: examples/server/server.cpp examples/server/httplib.h examples/server/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp build-info.h ggml.o llama.o $(COMMON_DEPS) grammar-parser.o $(OBJS) server: examples/server/server.cpp examples/server/httplib.h examples/server/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp build-info.h ggml.o llama.o $(COMMON_DEPS) grammar-parser.o $(OBJS)
$(CXX) $(CXXFLAGS) -Iexamples/server $(filter-out %.h,$(filter-out %.hpp,$^)) -o $@ $(LDFLAGS) $(LWINSOCK2) $(CXX) $(CXXFLAGS) -Iexamples/server $(filter-out %.h,$(filter-out %.hpp,$^)) -o $@ $(LDFLAGS) $(LWINSOCK2)
$(LIB_PRE)embdinput$(DSO_EXT): examples/embd-input/embd-input.h examples/embd-input/embd-input-lib.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) --shared $(CXXFLAGS) $(filter-out %.h,$(filter-out %.hpp,$^)) -o $@ $(LDFLAGS)
embd-input-test: $(LIB_PRE)embdinput$(DSO_EXT) examples/embd-input/embd-input-test.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %$(DSO_EXT),$(filter-out %.h,$(filter-out %.hpp,$^))) -o $@ $(LDFLAGS) -L. -lembdinput
gguf: examples/gguf/gguf.cpp ggml.o llama.o $(OBJS) gguf: examples/gguf/gguf.cpp ggml.o llama.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)

View File

@ -962,7 +962,6 @@ docker run --gpus all -v /path/to/models:/models local/llama.cpp:light-cuda -m /
- [main](./examples/main/README.md) - [main](./examples/main/README.md)
- [server](./examples/server/README.md) - [server](./examples/server/README.md)
- [embd-input](./examples/embd-input/README.md)
- [jeopardy](./examples/jeopardy/README.md) - [jeopardy](./examples/jeopardy/README.md)
- [BLIS](./docs/BLIS.md) - [BLIS](./docs/BLIS.md)
- [Performance troubleshooting](./docs/token_generation_performance_tips.md) - [Performance troubleshooting](./docs/token_generation_performance_tips.md)

View File

@ -107,7 +107,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
std::string arg; std::string arg;
gpt_params default_params; gpt_params default_params;
const std::string arg_prefix = "--"; const std::string arg_prefix = "--";
llama_sampling_params & sparams = params.sampling_params; llama_sampling_params & sparams = params.sparams;
for (int i = 1; i < argc; i++) { for (int i = 1; i < argc; i++) {
arg = argv[i]; arg = argv[i];
@ -241,25 +241,26 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
invalid_param = true; invalid_param = true;
break; break;
} }
sparams.repeat_last_n = std::stoi(argv[i]); sparams.penalty_last_n = std::stoi(argv[i]);
sparams.n_prev = std::max(sparams.n_prev, sparams.penalty_last_n);
} else if (arg == "--repeat-penalty") { } else if (arg == "--repeat-penalty") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
sparams.repeat_penalty = std::stof(argv[i]); sparams.penalty_repeat = std::stof(argv[i]);
} else if (arg == "--frequency-penalty") { } else if (arg == "--frequency-penalty") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
sparams.frequency_penalty = std::stof(argv[i]); sparams.penalty_freq = std::stof(argv[i]);
} else if (arg == "--presence-penalty") { } else if (arg == "--presence-penalty") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
sparams.presence_penalty = std::stof(argv[i]); sparams.penalty_present = std::stof(argv[i]);
} else if (arg == "--mirostat") { } else if (arg == "--mirostat") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
@ -572,7 +573,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.grammar = argv[i]; sparams.grammar = argv[i];
} else if (arg == "--grammar-file") { } else if (arg == "--grammar-file") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
@ -587,7 +588,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
std::copy( std::copy(
std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>(file),
std::istreambuf_iterator<char>(), std::istreambuf_iterator<char>(),
std::back_inserter(params.grammar) std::back_inserter(sparams.grammar)
); );
#ifndef LOG_DISABLE_LOGS #ifndef LOG_DISABLE_LOGS
// Parse args for logging parameters // Parse args for logging parameters
@ -640,7 +641,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
} }
void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
const llama_sampling_params & sparams = params.sampling_params; const llama_sampling_params & sparams = params.sparams;
printf("usage: %s [options]\n", argv[0]); printf("usage: %s [options]\n", argv[0]);
printf("\n"); printf("\n");
@ -678,10 +679,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)sparams.top_p); printf(" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)sparams.top_p);
printf(" --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)sparams.tfs_z); printf(" --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)sparams.tfs_z);
printf(" --typical N locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)\n", (double)sparams.typical_p); printf(" --typical N locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)\n", (double)sparams.typical_p);
printf(" --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", sparams.repeat_last_n); printf(" --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", sparams.penalty_last_n);
printf(" --repeat-penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)sparams.repeat_penalty); printf(" --repeat-penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)sparams.penalty_repeat);
printf(" --presence-penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.presence_penalty); printf(" --presence-penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.penalty_present);
printf(" --frequency-penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.frequency_penalty); printf(" --frequency-penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.penalty_freq);
printf(" --mirostat N use Mirostat sampling.\n"); printf(" --mirostat N use Mirostat sampling.\n");
printf(" Top K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n"); printf(" Top K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n");
printf(" (default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)\n", sparams.mirostat); printf(" (default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)\n", sparams.mirostat);
@ -878,7 +879,7 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
} }
if (params.ignore_eos) { if (params.ignore_eos) {
params.sampling_params.logit_bias[llama_token_eos(lctx)] = -INFINITY; params.sparams.logit_bias[llama_token_eos(lctx)] = -INFINITY;
} }
{ {
@ -1123,28 +1124,28 @@ std::string get_sortable_timestamp() {
void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const llama_context * lctx, void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const llama_context * lctx,
const std::string & timestamp, const std::vector<int> & prompt_tokens, const char * model_desc) { const std::string & timestamp, const std::vector<int> & prompt_tokens, const char * model_desc) {
const llama_sampling_params & sparams = params.sampling_params; const llama_sampling_params & sparams = params.sparams;
fprintf(stream, "build_commit: %s\n", BUILD_COMMIT); fprintf(stream, "build_commit: %s\n", BUILD_COMMIT);
fprintf(stream, "build_number: %d\n", BUILD_NUMBER); fprintf(stream, "build_number: %d\n", BUILD_NUMBER);
fprintf(stream, "cpu_has_arm_fma: %s\n", ggml_cpu_has_arm_fma() ? "true" : "false"); fprintf(stream, "cpu_has_arm_fma: %s\n", ggml_cpu_has_arm_fma() ? "true" : "false");
fprintf(stream, "cpu_has_avx: %s\n", ggml_cpu_has_avx() ? "true" : "false"); fprintf(stream, "cpu_has_avx: %s\n", ggml_cpu_has_avx() ? "true" : "false");
fprintf(stream, "cpu_has_avx2: %s\n", ggml_cpu_has_avx2() ? "true" : "false"); fprintf(stream, "cpu_has_avx2: %s\n", ggml_cpu_has_avx2() ? "true" : "false");
fprintf(stream, "cpu_has_avx512: %s\n", ggml_cpu_has_avx512() ? "true" : "false"); fprintf(stream, "cpu_has_avx512: %s\n", ggml_cpu_has_avx512() ? "true" : "false");
fprintf(stream, "cpu_has_avx512_vbmi: %s\n", ggml_cpu_has_avx512_vbmi() ? "true" : "false"); fprintf(stream, "cpu_has_avx512_vbmi: %s\n", ggml_cpu_has_avx512_vbmi() ? "true" : "false");
fprintf(stream, "cpu_has_avx512_vnni: %s\n", ggml_cpu_has_avx512_vnni() ? "true" : "false"); fprintf(stream, "cpu_has_avx512_vnni: %s\n", ggml_cpu_has_avx512_vnni() ? "true" : "false");
fprintf(stream, "cpu_has_blas: %s\n", ggml_cpu_has_blas() ? "true" : "false"); fprintf(stream, "cpu_has_blas: %s\n", ggml_cpu_has_blas() ? "true" : "false");
fprintf(stream, "cpu_has_cublas: %s\n", ggml_cpu_has_cublas() ? "true" : "false"); fprintf(stream, "cpu_has_cublas: %s\n", ggml_cpu_has_cublas() ? "true" : "false");
fprintf(stream, "cpu_has_clblast: %s\n", ggml_cpu_has_clblast() ? "true" : "false"); fprintf(stream, "cpu_has_clblast: %s\n", ggml_cpu_has_clblast() ? "true" : "false");
fprintf(stream, "cpu_has_fma: %s\n", ggml_cpu_has_fma() ? "true" : "false"); fprintf(stream, "cpu_has_fma: %s\n", ggml_cpu_has_fma() ? "true" : "false");
fprintf(stream, "cpu_has_gpublas: %s\n", ggml_cpu_has_gpublas() ? "true" : "false"); fprintf(stream, "cpu_has_gpublas: %s\n", ggml_cpu_has_gpublas() ? "true" : "false");
fprintf(stream, "cpu_has_neon: %s\n", ggml_cpu_has_neon() ? "true" : "false"); fprintf(stream, "cpu_has_neon: %s\n", ggml_cpu_has_neon() ? "true" : "false");
fprintf(stream, "cpu_has_f16c: %s\n", ggml_cpu_has_f16c() ? "true" : "false"); fprintf(stream, "cpu_has_f16c: %s\n", ggml_cpu_has_f16c() ? "true" : "false");
fprintf(stream, "cpu_has_fp16_va: %s\n", ggml_cpu_has_fp16_va() ? "true" : "false"); fprintf(stream, "cpu_has_fp16_va: %s\n", ggml_cpu_has_fp16_va() ? "true" : "false");
fprintf(stream, "cpu_has_wasm_simd: %s\n", ggml_cpu_has_wasm_simd() ? "true" : "false"); fprintf(stream, "cpu_has_wasm_simd: %s\n", ggml_cpu_has_wasm_simd() ? "true" : "false");
fprintf(stream, "cpu_has_blas: %s\n", ggml_cpu_has_blas() ? "true" : "false"); fprintf(stream, "cpu_has_blas: %s\n", ggml_cpu_has_blas() ? "true" : "false");
fprintf(stream, "cpu_has_sse3: %s\n", ggml_cpu_has_sse3() ? "true" : "false"); fprintf(stream, "cpu_has_sse3: %s\n", ggml_cpu_has_sse3() ? "true" : "false");
fprintf(stream, "cpu_has_vsx: %s\n", ggml_cpu_has_vsx() ? "true" : "false"); fprintf(stream, "cpu_has_vsx: %s\n", ggml_cpu_has_vsx() ? "true" : "false");
#ifdef NDEBUG #ifdef NDEBUG
fprintf(stream, "debug: false\n"); fprintf(stream, "debug: false\n");
@ -1178,8 +1179,8 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
fprintf(stream, "ctx_size: %d # default: 512\n", params.n_ctx); fprintf(stream, "ctx_size: %d # default: 512\n", params.n_ctx);
fprintf(stream, "escape: %s # default: false\n", params.escape ? "true" : "false"); fprintf(stream, "escape: %s # default: false\n", params.escape ? "true" : "false");
fprintf(stream, "file: # never logged, see prompt instead. Can still be specified for input.\n"); fprintf(stream, "file: # never logged, see prompt instead. Can still be specified for input.\n");
fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", sparams.frequency_penalty); fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", sparams.penalty_freq);
dump_string_yaml_multiline(stream, "grammar", params.grammar.c_str()); dump_string_yaml_multiline(stream, "grammar", sparams.grammar.c_str());
fprintf(stream, "grammar-file: # never logged, see grammar instead. Can still be specified for input.\n"); fprintf(stream, "grammar-file: # never logged, see grammar instead. Can still be specified for input.\n");
fprintf(stream, "hellaswag: %s # default: false\n", params.hellaswag ? "true" : "false"); fprintf(stream, "hellaswag: %s # default: false\n", params.hellaswag ? "true" : "false");
fprintf(stream, "hellaswag_tasks: %zu # default: 400\n", params.hellaswag_tasks); fprintf(stream, "hellaswag_tasks: %zu # default: 400\n", params.hellaswag_tasks);
@ -1238,14 +1239,14 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
fprintf(stream, "numa: %s # default: false\n", params.numa ? "true" : "false"); fprintf(stream, "numa: %s # default: false\n", params.numa ? "true" : "false");
fprintf(stream, "ppl_output_type: %d # default: 0\n", params.ppl_output_type); fprintf(stream, "ppl_output_type: %d # default: 0\n", params.ppl_output_type);
fprintf(stream, "ppl_stride: %d # default: 0\n", params.ppl_stride); fprintf(stream, "ppl_stride: %d # default: 0\n", params.ppl_stride);
fprintf(stream, "presence_penalty: %f # default: 0.0\n", sparams.presence_penalty); fprintf(stream, "presence_penalty: %f # default: 0.0\n", sparams.penalty_present);
dump_string_yaml_multiline(stream, "prompt", params.prompt.c_str()); dump_string_yaml_multiline(stream, "prompt", params.prompt.c_str());
fprintf(stream, "prompt_cache: %s\n", params.path_prompt_cache.c_str()); fprintf(stream, "prompt_cache: %s\n", params.path_prompt_cache.c_str());
fprintf(stream, "prompt_cache_all: %s # default: false\n", params.prompt_cache_all ? "true" : "false"); fprintf(stream, "prompt_cache_all: %s # default: false\n", params.prompt_cache_all ? "true" : "false");
fprintf(stream, "prompt_cache_ro: %s # default: false\n", params.prompt_cache_ro ? "true" : "false"); fprintf(stream, "prompt_cache_ro: %s # default: false\n", params.prompt_cache_ro ? "true" : "false");
dump_vector_int_yaml(stream, "prompt_tokens", prompt_tokens); dump_vector_int_yaml(stream, "prompt_tokens", prompt_tokens);
fprintf(stream, "random_prompt: %s # default: false\n", params.random_prompt ? "true" : "false"); fprintf(stream, "random_prompt: %s # default: false\n", params.random_prompt ? "true" : "false");
fprintf(stream, "repeat_penalty: %f # default: 1.1\n", sparams.repeat_penalty); fprintf(stream, "repeat_penalty: %f # default: 1.1\n", sparams.penalty_repeat);
fprintf(stream, "reverse_prompt:\n"); fprintf(stream, "reverse_prompt:\n");
for (std::string ap : params.antiprompt) { for (std::string ap : params.antiprompt) {

View File

@ -56,7 +56,7 @@ struct gpt_params {
float rope_freq_scale = 0.0f; // RoPE frequency scaling factor float rope_freq_scale = 0.0f; // RoPE frequency scaling factor
// // sampling parameters // // sampling parameters
struct llama_sampling_params sampling_params; struct llama_sampling_params sparams;
std::string model = "models/7B/ggml-model-f16.gguf"; // model path std::string model = "models/7B/ggml-model-f16.gguf"; // model path
std::string model_draft = ""; // draft model for speculative decoding std::string model_draft = ""; // draft model for speculative decoding
@ -66,7 +66,6 @@ struct gpt_params {
std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state
std::string input_prefix = ""; // string to prefix user inputs with std::string input_prefix = ""; // string to prefix user inputs with
std::string input_suffix = ""; // string to suffix user inputs with std::string input_suffix = ""; // string to suffix user inputs with
std::string grammar = ""; // optional BNF-like grammar to constrain sampling
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
std::string logdir = ""; // directory in which to save YAML log files std::string logdir = ""; // directory in which to save YAML log files

View File

@ -1,9 +1,9 @@
#include "sampling.h" #include "sampling.h"
struct llama_sampling_context * llama_sampling_init(const struct gpt_params & params) { struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) {
struct llama_sampling_context * result = new llama_sampling_context(); struct llama_sampling_context * result = new llama_sampling_context();
result->params = params.sampling_params; result->params = params;
result->grammar = nullptr; result->grammar = nullptr;
// if there is a grammar, parse it // if there is a grammar, parse it
@ -23,7 +23,7 @@ struct llama_sampling_context * llama_sampling_init(const struct gpt_params & pa
grammar_rules.size(), result->parsed_grammar.symbol_ids.at("root")); grammar_rules.size(), result->parsed_grammar.symbol_ids.at("root"));
} }
result->prev.resize(params.n_ctx); result->prev.resize(params.n_prev);
return result; return result;
} }
@ -66,25 +66,56 @@ void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * ds
dst->prev = src->prev; dst->prev = src->prev;
} }
llama_token llama_sampling_last(llama_sampling_context * ctx) {
return ctx->prev.back();
}
std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama_context * ctx_main, int n) {
const int size = ctx_sampling->prev.size();
n = std::min(n, size);
std::string result;
for (int i = size - n; i < size; i++) {
result += llama_token_to_piece(ctx_main, ctx_sampling->prev[i]);
}
return result;
}
std::string llama_sampling_print(const llama_sampling_params & params) {
char result[1024];
snprintf(result, sizeof(result),
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
"\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, typical_p = %.3f, temp = %.3f\n"
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present,
params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp,
params.mirostat, params.mirostat_eta, params.mirostat_tau);
return std::string(result);
}
llama_token llama_sampling_sample( llama_token llama_sampling_sample(
struct llama_sampling_context * ctx_sampling, struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main, struct llama_context * ctx_main,
struct llama_context * ctx_cfg, struct llama_context * ctx_cfg,
const int idx) { const int idx) {
const int n_ctx = llama_n_ctx(ctx_main);
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
const llama_sampling_params & params = ctx_sampling->params; const llama_sampling_params & params = ctx_sampling->params;
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
const float temp = params.temp; const float temp = params.temp;
const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k; const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k;
const float top_p = params.top_p; const float top_p = params.top_p;
const float tfs_z = params.tfs_z; const float tfs_z = params.tfs_z;
const float typical_p = params.typical_p; const float typical_p = params.typical_p;
const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n; const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n;
const float repeat_penalty = params.repeat_penalty; const float penalty_repeat = params.penalty_repeat;
const float alpha_presence = params.presence_penalty; const float penalty_freq = params.penalty_freq;
const float alpha_frequency = params.frequency_penalty; const float penalty_present = params.penalty_present;
const int mirostat = params.mirostat; const int mirostat = params.mirostat;
const float mirostat_tau = params.mirostat_tau; const float mirostat_tau = params.mirostat_tau;
const float mirostat_eta = params.mirostat_eta; const float mirostat_eta = params.mirostat_eta;
@ -97,7 +128,7 @@ llama_token llama_sampling_sample(
float * logits = llama_get_logits_ith(ctx_main, idx); float * logits = llama_get_logits_ith(ctx_main, idx);
// Apply params.logit_bias map // apply params.logit_bias map
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) { for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
logits[it->first] += it->second; logits[it->first] += it->second;
} }
@ -117,14 +148,10 @@ llama_token llama_sampling_sample(
// apply penalties // apply penalties
if (!prev.empty()) { if (!prev.empty()) {
const float nl_logit = logits[llama_token_nl(ctx_main)]; const float nl_logit = logits[llama_token_nl(ctx_main)];
const int last_n_repeat = std::min(std::min((int)prev.size(), repeat_last_n), n_ctx);
llama_sample_repetition_penalty(ctx_main, &cur_p, llama_sample_repetition_penalties(ctx_main, &cur_p,
prev.data() + prev.size() - last_n_repeat, prev.data() + prev.size() - penalty_last_n,
last_n_repeat, repeat_penalty); penalty_last_n, penalty_repeat, penalty_freq, penalty_present);
llama_sample_frequency_and_presence_penalties(ctx_main, &cur_p,
prev.data() + prev.size() - last_n_repeat,
last_n_repeat, alpha_frequency, alpha_presence);
if (!penalize_nl) { if (!penalize_nl) {
for (size_t idx = 0; idx < cur_p.size; idx++) { for (size_t idx = 0; idx < cur_p.size; idx++) {
@ -141,7 +168,7 @@ llama_token llama_sampling_sample(
} }
if (temp <= 0) { if (temp <= 0) {
// Greedy sampling // greedy sampling
id = llama_sample_token_greedy(ctx_main, &cur_p); id = llama_sample_token_greedy(ctx_main, &cur_p);
} else { } else {
if (mirostat == 1) { if (mirostat == 1) {
@ -152,8 +179,9 @@ llama_token llama_sampling_sample(
llama_sample_temp(ctx_main, &cur_p, temp); llama_sample_temp(ctx_main, &cur_p, temp);
id = llama_sample_token_mirostat_v2(ctx_main, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling->mirostat_mu); id = llama_sample_token_mirostat_v2(ctx_main, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling->mirostat_mu);
} else { } else {
// Temperature sampling // temperature sampling
size_t min_keep = std::max(1, params.n_probs); size_t min_keep = std::max(1, params.n_probs);
llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep); llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep);
llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep); llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep);
llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep);
@ -183,11 +211,12 @@ llama_token llama_sampling_sample(
void llama_sampling_accept( void llama_sampling_accept(
struct llama_sampling_context * ctx_sampling, struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main, struct llama_context * ctx_main,
llama_token id) { llama_token id,
bool apply_grammar) {
ctx_sampling->prev.erase(ctx_sampling->prev.begin()); ctx_sampling->prev.erase(ctx_sampling->prev.begin());
ctx_sampling->prev.push_back(id); ctx_sampling->prev.push_back(id);
if (ctx_sampling->grammar != NULL) { if (ctx_sampling->grammar != NULL && apply_grammar) {
llama_grammar_accept_token(ctx_main, ctx_sampling->grammar, id); llama_grammar_accept_token(ctx_main, ctx_sampling->grammar, id);
} }
} }

View File

@ -10,30 +10,30 @@
// sampling parameters // sampling parameters
typedef struct llama_sampling_params { typedef struct llama_sampling_params {
int32_t n_prev = 64; // number of previous tokens to remember
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
int32_t top_k = 40; // <= 0 to use vocab size int32_t top_k = 40; // <= 0 to use vocab size
float top_p = 0.95f; // 1.0 = disabled float top_p = 0.95f; // 1.0 = disabled
float tfs_z = 1.00f; // 1.0 = disabled float tfs_z = 1.00f; // 1.0 = disabled
float typical_p = 1.00f; // 1.0 = disabled float typical_p = 1.00f; // 1.0 = disabled
float temp = 0.80f; // 1.0 = disabled float temp = 0.80f; // 1.0 = disabled
float repeat_penalty = 1.10f; // 1.0 = disabled int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
int32_t repeat_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size) float penalty_repeat = 1.10f; // 1.0 = disabled
float frequency_penalty = 0.00f; // 0.0 = disabled float penalty_freq = 0.00f; // 0.0 = disabled
float presence_penalty = 0.00f; // 0.0 = disabled float penalty_present = 0.00f; // 0.0 = disabled
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
float mirostat_tau = 5.00f; // target entropy float mirostat_tau = 5.00f; // target entropy
float mirostat_eta = 0.10f; // learning rate float mirostat_eta = 0.10f; // learning rate
bool penalize_nl = true; // consider newlines as a repeatable token bool penalize_nl = true; // consider newlines as a repeatable token
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. std::string grammar; // optional BNF-like grammar to constrain sampling
// Classifier-Free Guidance // Classifier-Free Guidance
// https://arxiv.org/abs/2306.17806 // https://arxiv.org/abs/2306.17806
std::string cfg_negative_prompt; // string to help guidance std::string cfg_negative_prompt; // string to help guidance
float cfg_scale = 1.f; // How strong is guidance float cfg_scale = 1.f; // how strong is guidance
std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens
} llama_sampling_params; } llama_sampling_params;
// general sampler context // general sampler context
@ -58,7 +58,7 @@ struct llama_sampling_context {
#include "common.h" #include "common.h"
// Create a new sampling context instance. // Create a new sampling context instance.
struct llama_sampling_context * llama_sampling_init(const struct gpt_params & params); struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params);
void llama_sampling_free(struct llama_sampling_context * ctx); void llama_sampling_free(struct llama_sampling_context * ctx);
@ -70,6 +70,15 @@ void llama_sampling_reset(llama_sampling_context * ctx);
// Copy the sampler context // Copy the sampler context
void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst); void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst);
// Get the last sampled token
llama_token llama_sampling_last(llama_sampling_context * ctx);
// Get a string representation of the last sampled tokens
std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama_context * ctx_main, int n);
// Print sampling parameters into a string
std::string llama_sampling_print(const llama_sampling_params & params);
// this is a common sampling function used across the examples for convenience // this is a common sampling function used across the examples for convenience
// it can serve as a starting point for implementing your own sampling function // it can serve as a starting point for implementing your own sampling function
// Note: When using multiple sequences, it is the caller's responsibility to call // Note: When using multiple sequences, it is the caller's responsibility to call
@ -96,4 +105,5 @@ llama_token llama_sampling_sample(
void llama_sampling_accept( void llama_sampling_accept(
struct llama_sampling_context * ctx_sampling, struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main, struct llama_context * ctx_main,
llama_token id); llama_token id,
bool apply_grammar);

View File

@ -12,26 +12,26 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR})
if (EMSCRIPTEN) if (EMSCRIPTEN)
else() else()
add_subdirectory(main)
add_subdirectory(quantize)
add_subdirectory(quantize-stats)
add_subdirectory(perplexity)
add_subdirectory(embedding)
add_subdirectory(save-load-state)
add_subdirectory(benchmark)
add_subdirectory(baby-llama) add_subdirectory(baby-llama)
add_subdirectory(train-text-from-scratch)
add_subdirectory(finetune)
add_subdirectory(convert-llama2c-to-ggml)
add_subdirectory(simple)
add_subdirectory(batched) add_subdirectory(batched)
add_subdirectory(batched-bench) add_subdirectory(batched-bench)
add_subdirectory(speculative)
add_subdirectory(parallel)
add_subdirectory(embd-input)
add_subdirectory(llava)
add_subdirectory(llama-bench)
add_subdirectory(beam-search) add_subdirectory(beam-search)
add_subdirectory(benchmark)
add_subdirectory(convert-llama2c-to-ggml)
add_subdirectory(embedding)
add_subdirectory(finetune)
add_subdirectory(infill)
add_subdirectory(llama-bench)
add_subdirectory(llava)
add_subdirectory(main)
add_subdirectory(parallel)
add_subdirectory(perplexity)
add_subdirectory(quantize)
add_subdirectory(quantize-stats)
add_subdirectory(save-load-state)
add_subdirectory(simple)
add_subdirectory(speculative)
add_subdirectory(train-text-from-scratch)
if (LLAMA_METAL) if (LLAMA_METAL)
add_subdirectory(metal) add_subdirectory(metal)
endif() endif()

View File

@ -1,4 +0,0 @@
PandaGPT
MiniGPT-4
*.pth

View File

@ -1,17 +0,0 @@
set(TARGET embdinput)
add_library(${TARGET} embd-input-lib.cpp embd-input.h)
install(TARGETS ${TARGET} LIBRARY)
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_11)
if(TARGET BUILD_INFO)
add_dependencies(${TARGET} BUILD_INFO)
endif()
set(TARGET embd-input-test)
add_executable(${TARGET} embd-input-test.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE common llama embdinput ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_11)
if(TARGET BUILD_INFO)
add_dependencies(${TARGET} BUILD_INFO)
endif()

View File

@ -1,63 +0,0 @@
### Examples for input embedding directly
## Requirement
build `libembdinput.so`
run the following comman in main dir (../../).
```
make
```
## [LLaVA](https://github.com/haotian-liu/LLaVA/) example (llava.py)
1. Obtian LLaVA model (following https://github.com/haotian-liu/LLaVA/ , use https://huggingface.co/liuhaotian/LLaVA-13b-delta-v1-1/).
2. Convert it to ggml format.
3. `llava_projection.pth` is [pytorch_model-00003-of-00003.bin](https://huggingface.co/liuhaotian/LLaVA-13b-delta-v1-1/blob/main/pytorch_model-00003-of-00003.bin).
```
import torch
bin_path = "../LLaVA-13b-delta-v1-1/pytorch_model-00003-of-00003.bin"
pth_path = "./examples/embd-input/llava_projection.pth"
dic = torch.load(bin_path)
used_key = ["model.mm_projector.weight","model.mm_projector.bias"]
torch.save({k: dic[k] for k in used_key}, pth_path)
```
4. Check the path of LLaVA model and `llava_projection.pth` in `llava.py`.
## [PandaGPT](https://github.com/yxuansu/PandaGPT) example (panda_gpt.py)
1. Obtian PandaGPT lora model from https://github.com/yxuansu/PandaGPT. Rename the file to `adapter_model.bin`. Use [convert-lora-to-ggml.py](../../convert-lora-to-ggml.py) to convert it to ggml format.
The `adapter_config.json` is
```
{
"peft_type": "LORA",
"fan_in_fan_out": false,
"bias": null,
"modules_to_save": null,
"r": 32,
"lora_alpha": 32,
"lora_dropout": 0.1,
"target_modules": ["q_proj", "k_proj", "v_proj", "o_proj"]
}
```
2. Papare the `vicuna` v0 model.
3. Obtain the [ImageBind](https://dl.fbaipublicfiles.com/imagebind/imagebind_huge.pth) model.
4. Clone the PandaGPT source.
```
git clone https://github.com/yxuansu/PandaGPT
```
5. Install the requirement of PandaGPT.
6. Check the path of PandaGPT source, ImageBind model, lora model and vicuna model in panda_gpt.py.
## [MiniGPT-4](https://github.com/Vision-CAIR/MiniGPT-4/) example (minigpt4.py)
1. Obtain MiniGPT-4 model from https://github.com/Vision-CAIR/MiniGPT-4/ and put it in `embd-input`.
2. Clone the MiniGPT-4 source.
```
git clone https://github.com/Vision-CAIR/MiniGPT-4/
```
3. Install the requirement of PandaGPT.
4. Papare the `vicuna` v0 model.
5. Check the path of MiniGPT-4 source, MiniGPT-4 model and vicuna model in `minigpt4.py`.

View File

@ -1,221 +0,0 @@
#include "build-info.h"
#include "common.h"
#include "embd-input.h"
#include <cassert>
#include <cinttypes>
#include <cmath>
#include <cstdio>
#include <cstring>
#include <ctime>
#include <fstream>
#include <iostream>
#include <string>
#include <vector>
static llama_context ** g_ctx;
extern "C" {
struct MyModel* create_mymodel(int argc, char ** argv) {
gpt_params params;
if (!gpt_params_parse(argc, argv, params)) {
return nullptr;
}
print_build_info();
if (params.seed == LLAMA_DEFAULT_SEED) {
params.seed = uint32_t(time(NULL));
}
fprintf(stderr, "%s: seed = %d\n", __func__, params.seed);
llama_backend_init(params.numa);
llama_model * model;
llama_context * ctx;
g_ctx = &ctx;
// load the model and apply lora adapter, if any
std::tie(model, ctx) = llama_init_from_gpt_params(params);
if (model == NULL) {
fprintf(stderr, "%s: error: unable to load model\n", __func__);
return nullptr;
}
// print system information
{
fprintf(stderr, "\n");
fprintf(stderr, "%s\n", get_system_info(params).c_str());
}
struct MyModel * ret = new MyModel();
ret->ctx = ctx;
ret->params = params;
ret->n_past = 0;
// printf("ctx: %d\n", ret->ctx);
return ret;
}
void free_mymodel(struct MyModel * mymodel) {
llama_context * ctx = mymodel->ctx;
llama_print_timings(ctx);
llama_free(ctx);
delete mymodel;
}
bool eval_float(void * model, float * input, int N){
MyModel * mymodel = (MyModel*)model;
llama_context * ctx = mymodel->ctx;
gpt_params params = mymodel->params;
int n_emb = llama_n_embd(llama_get_model(ctx));
int n_past = mymodel->n_past;
int n_batch = N; // params.n_batch;
for (int i = 0; i < (int) N; i += n_batch) {
int n_eval = (int) N - i;
if (n_eval > n_batch) {
n_eval = n_batch;
}
llama_batch batch = { int32_t(n_eval), nullptr, (input+i*n_emb), nullptr, nullptr, nullptr, nullptr, n_past, 1, 0, };
if (llama_decode(ctx, batch)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return false;
}
n_past += n_eval;
}
mymodel->n_past = n_past;
return true;
}
bool eval_tokens(void * model, std::vector<llama_token> tokens) {
MyModel * mymodel = (MyModel* )model;
llama_context * ctx;
ctx = mymodel->ctx;
gpt_params params = mymodel->params;
int n_past = mymodel->n_past;
for (int i = 0; i < (int) tokens.size(); i += params.n_batch) {
int n_eval = (int) tokens.size() - i;
if (n_eval > params.n_batch) {
n_eval = params.n_batch;
}
if (llama_decode(ctx, llama_batch_get_one(&tokens[i], n_eval, n_past, 0))) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return false;
}
n_past += n_eval;
}
mymodel->n_past = n_past;
return true;
}
bool eval_id(struct MyModel* mymodel, int id) {
std::vector<llama_token> tokens;
tokens.push_back(id);
return eval_tokens(mymodel, tokens);
}
bool eval_string(struct MyModel * mymodel,const char* str){
llama_context * ctx = mymodel->ctx;
std::string str2 = str;
std::vector<llama_token> embd_inp = ::llama_tokenize(ctx, str2, true);
eval_tokens(mymodel, embd_inp);
return true;
}
llama_token sampling_id(struct MyModel* mymodel) {
llama_context* ctx = mymodel->ctx;
gpt_params params = mymodel->params;
llama_sampling_params & sparams = params.sampling_params;
// int n_ctx = llama_n_ctx(ctx);
// out of user input, sample next token
const float temp = sparams.temp;
const int32_t top_k = sparams.top_k <= 0 ? llama_n_vocab(llama_get_model(ctx)) : sparams.top_k;
const float top_p = sparams.top_p;
const float tfs_z = sparams.tfs_z;
const float typical_p = sparams.typical_p;
// const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n;
// const float repeat_penalty = params.repeat_penalty;
// const float alpha_presence = params.presence_penalty;
// const float alpha_frequency = params.frequency_penalty;
const int mirostat = sparams.mirostat;
const float mirostat_tau = sparams.mirostat_tau;
const float mirostat_eta = sparams.mirostat_eta;
// const bool penalize_nl = params.penalize_nl;
llama_token id = 0;
{
auto logits = llama_get_logits(ctx);
auto n_vocab = llama_n_vocab(llama_get_model(ctx));
// Apply params.logit_bias map
for (auto it = sparams.logit_bias.begin(); it != sparams.logit_bias.end(); it++) {
logits[it->first] += it->second;
}
std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
}
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
// TODO: Apply penalties
// float nl_logit = logits[llama_token_nl(ctx)];
// auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx);
// llama_sample_repetition_penalty(ctx, &candidates_p,
// last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
// last_n_repeat, repeat_penalty);
// llama_sample_frequency_and_presence_penalties(ctx, &candidates_p,
// last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
// last_n_repeat, alpha_frequency, alpha_presence);
// if (!penalize_nl) {
// logits[llama_token_nl(ctx)] = nl_logit;
// }
if (temp <= 0) {
// Greedy sampling
id = llama_sample_token_greedy(ctx, &candidates_p);
} else {
if (mirostat == 1) {
static float mirostat_mu = 2.0f * mirostat_tau;
const int mirostat_m = 100;
llama_sample_temp(ctx, &candidates_p, temp);
id = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
} else if (mirostat == 2) {
static float mirostat_mu = 2.0f * mirostat_tau;
llama_sample_temp(ctx, &candidates_p, temp);
id = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu);
} else {
// Temperature sampling
llama_sample_top_k(ctx, &candidates_p, top_k, 1);
llama_sample_tail_free(ctx, &candidates_p, tfs_z, 1);
llama_sample_typical(ctx, &candidates_p, typical_p, 1);
llama_sample_top_p(ctx, &candidates_p, top_p, 1);
llama_sample_temp(ctx, &candidates_p, temp);
id = llama_sample_token(ctx, &candidates_p);
}
}
}
return id;
}
const char * sampling(struct MyModel * mymodel) {
llama_context * ctx = mymodel->ctx;
int id = sampling_id(mymodel);
static std::string ret;
if (id == llama_token_eos(ctx)) {
ret = "</s>";
} else {
ret = llama_token_to_piece(ctx, id);
}
eval_id(mymodel, id);
return ret.c_str();
}
}

View File

@ -1,35 +0,0 @@
#include "embd-input.h"
#include <stdlib.h>
#include <random>
#include <string.h>
int main(int argc, char** argv) {
auto mymodel = create_mymodel(argc, argv);
int N = 10;
int max_tgt_len = 500;
int n_embd = llama_n_embd(llama_get_model(mymodel->ctx));
// add random float embd to test evaluation
float * data = new float[N*n_embd];
std::default_random_engine e;
std::uniform_real_distribution<float> u(0,1);
for (int i=0;i<N*n_embd;i++) {
data[i] = u(e);
}
eval_string(mymodel, "user: what is the color of the flag of UN?");
eval_float(mymodel, data, N);
eval_string(mymodel, "assistant:");
eval_string(mymodel, mymodel->params.prompt.c_str());
const char* tmp;
for (int i=0; i<max_tgt_len; i++) {
tmp = sampling(mymodel);
if (strcmp(tmp, "</s>")==0) break;
printf("%s", tmp);
fflush(stdout);
}
printf("\n");
free_mymodel(mymodel);
return 0;
}

View File

@ -1,27 +0,0 @@
#ifndef _EMBD_INPUT_H_
#define _EMBD_INPUT_H_ 1
#include "common.h"
#include "llama.h"
extern "C" {
typedef struct MyModel {
llama_context* ctx;
gpt_params params;
int n_past = 0;
} MyModel;
struct MyModel* create_mymodel(int argc, char ** argv);
bool eval_float(void* model, float* input, int N);
bool eval_tokens(void* model, std::vector<llama_token> tokens);
bool eval_id(struct MyModel* mymodel, int id);
bool eval_string(struct MyModel* mymodel, const char* str);
const char * sampling(struct MyModel* mymodel);
llama_token sampling_id(struct MyModel* mymodel);
void free_mymodel(struct MyModel* mymodel);
}
#endif

View File

@ -1,72 +0,0 @@
#!/usr/bin/env python3
import ctypes
from ctypes import cdll, c_char_p, c_void_p, POINTER, c_float, c_int
import numpy as np
import os
libc = cdll.LoadLibrary("./libembdinput.so")
libc.sampling.restype=c_char_p
libc.create_mymodel.restype=c_void_p
libc.eval_string.argtypes=[c_void_p, c_char_p]
libc.sampling.argtypes=[c_void_p]
libc.eval_float.argtypes=[c_void_p, POINTER(c_float), c_int]
class MyModel:
def __init__(self, args):
argc = len(args)
c_str = [c_char_p(i.encode()) for i in args]
args_c = (c_char_p * argc)(*c_str)
self.model = c_void_p(libc.create_mymodel(argc, args_c))
self.max_tgt_len = 512
self.print_string_eval = True
def __del__(self):
libc.free_mymodel(self.model)
def eval_float(self, x):
libc.eval_float(self.model, x.astype(np.float32).ctypes.data_as(POINTER(c_float)), x.shape[1])
def eval_string(self, x):
libc.eval_string(self.model, x.encode()) # c_char_p(x.encode()))
if self.print_string_eval:
print(x)
def eval_token(self, x):
libc.eval_id(self.model, x)
def sampling(self):
s = libc.sampling(self.model)
return s
def stream_generate(self, end="</s>"):
ret = b""
end = end.encode()
for _ in range(self.max_tgt_len):
tmp = self.sampling()
ret += tmp
yield tmp
if ret.endswith(end):
break
def generate_with_print(self, end="</s>"):
ret = b""
for i in self.stream_generate(end=end):
ret += i
print(i.decode(errors="replace"), end="", flush=True)
print("")
return ret.decode(errors="replace")
def generate(self, end="</s>"):
text = b"".join(self.stream_generate(end=end))
return text.decode(errors="replace")
if __name__ == "__main__":
model = MyModel(["main", "--model", "../llama.cpp/models/ggml-vic13b-q4_1.bin", "-c", "2048"])
model.eval_string("""user: what is the color of the flag of UN?""")
x = np.random.random((5120,10))# , dtype=np.float32)
model.eval_float(x)
model.eval_string("""assistant:""")
for i in model.generate():
print(i.decode(errors="replace"), end="", flush=True)

View File

@ -1,71 +0,0 @@
#!/usr/bin/env python3
import sys
import os
sys.path.insert(0, os.path.dirname(__file__))
from embd_input import MyModel
import numpy as np
from torch import nn
import torch
from transformers import CLIPVisionModel, CLIPImageProcessor
from PIL import Image
# model parameters from 'liuhaotian/LLaVA-13b-delta-v1-1'
vision_tower = "openai/clip-vit-large-patch14"
select_hidden_state_layer = -2
# (vision_config.image_size // vision_config.patch_size) ** 2
image_token_len = (224//14)**2
class Llava:
def __init__(self, args):
self.image_processor = CLIPImageProcessor.from_pretrained(vision_tower)
self.vision_tower = CLIPVisionModel.from_pretrained(vision_tower)
self.mm_projector = nn.Linear(1024, 5120)
self.model = MyModel(["main", *args])
def load_projection(self, path):
state = torch.load(path)
self.mm_projector.load_state_dict({
"weight": state["model.mm_projector.weight"],
"bias": state["model.mm_projector.bias"]})
def chat(self, question):
self.model.eval_string("user: ")
self.model.eval_string(question)
self.model.eval_string("\nassistant: ")
return self.model.generate_with_print()
def chat_with_image(self, image, question):
with torch.no_grad():
embd_image = self.image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
image_forward_out = self.vision_tower(embd_image.unsqueeze(0), output_hidden_states=True)
select_hidden_state = image_forward_out.hidden_states[select_hidden_state_layer]
image_feature = select_hidden_state[:, 1:]
embd_image = self.mm_projector(image_feature)
embd_image = embd_image.cpu().numpy()[0]
self.model.eval_string("user: ")
self.model.eval_token(32003-2) # im_start
self.model.eval_float(embd_image.T)
for i in range(image_token_len-embd_image.shape[0]):
self.model.eval_token(32003-3) # im_patch
self.model.eval_token(32003-1) # im_end
self.model.eval_string(question)
self.model.eval_string("\nassistant: ")
return self.model.generate_with_print()
if __name__=="__main__":
# model form liuhaotian/LLaVA-13b-delta-v1-1
a = Llava(["--model", "./models/ggml-llava-13b-v1.1.bin", "-c", "2048"])
# Extract from https://huggingface.co/liuhaotian/LLaVA-13b-delta-v1-1/blob/main/pytorch_model-00003-of-00003.bin.
# Also here can use pytorch_model-00003-of-00003.bin directly.
a.load_projection(os.path.join(
os.path.dirname(__file__) ,
"llava_projection.pth"))
respose = a.chat_with_image(
Image.open("./media/llama1-logo.png").convert('RGB'),
"what is the text in the picture?")
respose
a.chat("what is the color of it?")

View File

@ -1,129 +0,0 @@
#!/usr/bin/env python3
import sys
import os
sys.path.insert(0, os.path.dirname(__file__))
from embd_input import MyModel
import numpy as np
from torch import nn
import torch
from PIL import Image
minigpt4_path = os.path.join(os.path.dirname(__file__), "MiniGPT-4")
sys.path.insert(0, minigpt4_path)
from minigpt4.models.blip2 import Blip2Base
from minigpt4.processors.blip_processors import Blip2ImageEvalProcessor
class MiniGPT4(Blip2Base):
"""
MiniGPT4 model from https://github.com/Vision-CAIR/MiniGPT-4
"""
def __init__(self,
args,
vit_model="eva_clip_g",
q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth",
img_size=224,
drop_path_rate=0,
use_grad_checkpoint=False,
vit_precision="fp32",
freeze_vit=True,
freeze_qformer=True,
num_query_token=32,
llama_model="",
prompt_path="",
prompt_template="",
max_txt_len=32,
end_sym='\n',
low_resource=False, # use 8 bit and put vit in cpu
device_8bit=0
):
super().__init__()
self.img_size = img_size
self.low_resource = low_resource
self.preprocessor = Blip2ImageEvalProcessor(img_size)
print('Loading VIT')
self.visual_encoder, self.ln_vision = self.init_vision_encoder(
vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision
)
print('Loading VIT Done')
print('Loading Q-Former')
self.Qformer, self.query_tokens = self.init_Qformer(
num_query_token, self.visual_encoder.num_features
)
self.Qformer.cls = None
self.Qformer.bert.embeddings.word_embeddings = None
self.Qformer.bert.embeddings.position_embeddings = None
for layer in self.Qformer.bert.encoder.layer:
layer.output = None
layer.intermediate = None
self.load_from_pretrained(url_or_filename=q_former_model)
print('Loading Q-Former Done')
self.llama_proj = nn.Linear(
self.Qformer.config.hidden_size, 5120 # self.llama_model.config.hidden_size
)
self.max_txt_len = max_txt_len
self.end_sym = end_sym
self.model = MyModel(["main", *args])
# system prompt
self.model.eval_string("Give the following image: <Img>ImageContent</Img>. "
"You will be able to see the image once I provide it to you. Please answer my questions."
"###")
def encode_img(self, image):
image = self.preprocessor(image)
image = image.unsqueeze(0)
device = image.device
if self.low_resource:
self.vit_to_cpu()
image = image.to("cpu")
with self.maybe_autocast():
image_embeds = self.ln_vision(self.visual_encoder(image)).to(device)
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
query_output = self.Qformer.bert(
query_embeds=query_tokens,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_atts,
return_dict=True,
)
inputs_llama = self.llama_proj(query_output.last_hidden_state)
# atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device)
return inputs_llama
def load_projection(self, path):
state = torch.load(path)["model"]
self.llama_proj.load_state_dict({
"weight": state["llama_proj.weight"],
"bias": state["llama_proj.bias"]})
def chat(self, question):
self.model.eval_string("Human: ")
self.model.eval_string(question)
self.model.eval_string("\n### Assistant:")
return self.model.generate_with_print(end="###")
def chat_with_image(self, image, question):
with torch.no_grad():
embd_image = self.encode_img(image)
embd_image = embd_image.cpu().numpy()[0]
self.model.eval_string("Human: <Img>")
self.model.eval_float(embd_image.T)
self.model.eval_string("</Img> ")
self.model.eval_string(question)
self.model.eval_string("\n### Assistant:")
return self.model.generate_with_print(end="###")
if __name__=="__main__":
a = MiniGPT4(["--model", "./models/ggml-vicuna-13b-v0-q4_1.bin", "-c", "2048"])
a.load_projection(os.path.join(
os.path.dirname(__file__) ,
"pretrained_minigpt4.pth"))
respose = a.chat_with_image(
Image.open("./media/llama1-logo.png").convert('RGB'),
"what is the text in the picture?")
a.chat("what is the color of it?")

View File

@ -1,99 +0,0 @@
#!/usr/bin/env python3
import sys
import os
sys.path.insert(0, os.path.dirname(__file__))
from embd_input import MyModel
import numpy as np
from torch import nn
import torch
# use PandaGPT path
panda_gpt_path = os.path.join(os.path.dirname(__file__), "PandaGPT")
imagebind_ckpt_path = "./models/panda_gpt/"
sys.path.insert(0, os.path.join(panda_gpt_path,"code","model"))
from ImageBind.models import imagebind_model
from ImageBind import data
ModalityType = imagebind_model.ModalityType
max_tgt_len = 400
class PandaGPT:
def __init__(self, args):
self.visual_encoder,_ = imagebind_model.imagebind_huge(pretrained=True, store_path=imagebind_ckpt_path)
self.visual_encoder.eval()
self.llama_proj = nn.Linear(1024, 5120) # self.visual_hidden_size, 5120)
self.max_tgt_len = max_tgt_len
self.model = MyModel(["main", *args])
self.generated_text = ""
self.device = "cpu"
def load_projection(self, path):
state = torch.load(path, map_location="cpu")
self.llama_proj.load_state_dict({
"weight": state["llama_proj.weight"],
"bias": state["llama_proj.bias"]})
def eval_inputs(self, inputs):
self.model.eval_string("<Img>")
embds = self.extract_multimoal_feature(inputs)
for i in embds:
self.model.eval_float(i.T)
self.model.eval_string("</Img> ")
def chat(self, question):
return self.chat_with_image(None, question)
def chat_with_image(self, inputs, question):
if self.generated_text == "":
self.model.eval_string("###")
self.model.eval_string(" Human: ")
if inputs:
self.eval_inputs(inputs)
self.model.eval_string(question)
self.model.eval_string("\n### Assistant:")
ret = self.model.generate_with_print(end="###")
self.generated_text += ret
return ret
def extract_multimoal_feature(self, inputs):
features = []
for key in ["image", "audio", "video", "thermal"]:
if key + "_paths" in inputs:
embeds = self.encode_data(key, inputs[key+"_paths"])
features.append(embeds)
return features
def encode_data(self, data_type, data_paths):
type_map = {
"image": ModalityType.VISION,
"audio": ModalityType.AUDIO,
"video": ModalityType.VISION,
"thermal": ModalityType.THERMAL,
}
load_map = {
"image": data.load_and_transform_vision_data,
"audio": data.load_and_transform_audio_data,
"video": data.load_and_transform_video_data,
"thermal": data.load_and_transform_thermal_data
}
load_function = load_map[data_type]
key = type_map[data_type]
inputs = {key: load_function(data_paths, self.device)}
with torch.no_grad():
embeddings = self.visual_encoder(inputs)
embeds = embeddings[key]
embeds = self.llama_proj(embeds).cpu().numpy()
return embeds
if __name__=="__main__":
a = PandaGPT(["--model", "./models/ggml-vicuna-13b-v0-q4_1.bin", "-c", "2048", "--lora", "./models/panda_gpt/ggml-adapter-model.bin","--temp", "0"])
a.load_projection("./models/panda_gpt/adapter_model.bin")
a.chat_with_image(
{"image_paths": ["./media/llama1-logo.png"]},
"what is the text in the picture? 'llama' or 'lambda'?")
a.chat("what is the color of it?")

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -4,5 +4,5 @@ install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_11) target_compile_features(${TARGET} PRIVATE cxx_std_11)
if(TARGET BUILD_INFO) if(TARGET BUILD_INFO)
add_dependencies(${TARGET} BUILD_INFO) add_dependencies(${TARGET} BUILD_INFO)
endif() endif()

View File

@ -39,8 +39,8 @@ static gpt_params * g_params;
static std::vector<llama_token> * g_input_tokens; static std::vector<llama_token> * g_input_tokens;
static std::ostringstream * g_output_ss; static std::ostringstream * g_output_ss;
static std::vector<llama_token> * g_output_tokens; static std::vector<llama_token> * g_output_tokens;
static bool is_interacting = false;
static bool is_interacting = false;
static void write_logfile( static void write_logfile(
const llama_context * ctx, const gpt_params & params, const llama_model * model, const llama_context * ctx, const gpt_params & params, const llama_model * model,
@ -104,7 +104,7 @@ static void sigint_handler(int signo) {
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
gpt_params params; gpt_params params;
llama_sampling_params & sparams = params.sampling_params; llama_sampling_params & sparams = params.sparams;
g_params = &params; g_params = &params;
if (!gpt_params_parse(argc, argv, params)) { if (!gpt_params_parse(argc, argv, params)) {
@ -358,36 +358,10 @@ int main(int argc, char ** argv) {
LOG_TEE("Input suffix: '%s'\n", params.input_suffix.c_str()); LOG_TEE("Input suffix: '%s'\n", params.input_suffix.c_str());
} }
} }
LOG_TEE("sampling: repeat_last_n = %d, repeat_penalty = %f, presence_penalty = %f, frequency_penalty = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_lr = %f, mirostat_ent = %f\n", LOG_TEE("sampling: \n%s\n", llama_sampling_print(sparams).c_str());
sparams.repeat_last_n, sparams.repeat_penalty, sparams.presence_penalty, sparams.frequency_penalty, sparams.top_k, sparams.tfs_z, sparams.top_p, sparams.typical_p, sparams.temp, sparams.mirostat, sparams.mirostat_eta, sparams.mirostat_tau);
LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
LOG_TEE("\n\n"); LOG_TEE("\n\n");
struct llama_grammar * grammar = NULL;
grammar_parser::parse_state parsed_grammar;
if (!params.grammar.empty()) {
parsed_grammar = grammar_parser::parse(params.grammar.c_str());
// will be empty (default) if there are parse errors
if (parsed_grammar.rules.empty()) {
return 1;
}
LOG_TEE("%s: grammar:\n", __func__);
grammar_parser::print_grammar(stderr, parsed_grammar);
LOG_TEE("\n");
{
auto it = sparams.logit_bias.find(llama_token_eos(ctx));
if (it != sparams.logit_bias.end() && it->second == -INFINITY) {
LOG_TEE("%s: warning: EOS token is disabled, which will cause most grammars to fail\n", __func__);
}
}
std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
grammar = llama_grammar_init(
grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
}
LOG_TEE("\n##### Infill mode #####\n\n"); LOG_TEE("\n##### Infill mode #####\n\n");
if (params.infill) { if (params.infill) {
printf("\n************\n"); printf("\n************\n");
@ -430,7 +404,7 @@ int main(int argc, char ** argv) {
std::vector<llama_token> embd; std::vector<llama_token> embd;
std::vector<llama_token> embd_guidance; std::vector<llama_token> embd_guidance;
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params); struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams);
while (n_remain != 0 || params.interactive) { while (n_remain != 0 || params.interactive) {
// predict // predict
@ -549,7 +523,7 @@ int main(int argc, char ** argv) {
const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance); const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance);
llama_sampling_accept(ctx_sampling, ctx, id); llama_sampling_accept(ctx_sampling, ctx, id, true);
LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str()); LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str());
@ -567,8 +541,11 @@ int main(int argc, char ** argv) {
LOG("embd_inp.size(): %d, n_consumed: %d\n", (int) embd_inp.size(), n_consumed); LOG("embd_inp.size(): %d, n_consumed: %d\n", (int) embd_inp.size(), n_consumed);
while ((int) embd_inp.size() > n_consumed) { while ((int) embd_inp.size() > n_consumed) {
embd.push_back(embd_inp[n_consumed]); embd.push_back(embd_inp[n_consumed]);
ctx_sampling->prev.erase(ctx_sampling->prev.begin());
ctx_sampling->prev.push_back(embd_inp[n_consumed]); // push the prompt in the sampling context in order to apply repetition penalties later
// for the prompt, we don't apply grammar rules
llama_sampling_accept(ctx_sampling, ctx, embd_inp[n_consumed], false);
++n_consumed; ++n_consumed;
if ((int) embd.size() >= params.n_batch) { if ((int) embd.size() >= params.n_batch) {
break; break;
@ -600,7 +577,7 @@ int main(int argc, char ** argv) {
if ((int) embd_inp.size() <= n_consumed) { if ((int) embd_inp.size() <= n_consumed) {
// deal with eot token in infill mode // deal with eot token in infill mode
if ((ctx_sampling->prev.back() == llama_token_eot(ctx) || is_interacting) && params.interactive){ if ((llama_sampling_last(ctx_sampling) == llama_token_eot(ctx) || is_interacting) && params.interactive){
if(is_interacting && !params.interactive_first) { if(is_interacting && !params.interactive_first) {
// print an eot token // print an eot token
printf("%s", llama_token_to_piece(ctx, llama_token_eot(ctx)).c_str()); printf("%s", llama_token_to_piece(ctx, llama_token_eot(ctx)).c_str());
@ -617,7 +594,7 @@ int main(int argc, char ** argv) {
buffer += line; buffer += line;
} while (another_line); } while (another_line);
// check if we got an empty line, if so we use the old input // check if we got an empty line, if so we use the old input
if(!buffer.empty() && !(buffer.length() == 1 && buffer[0] == '\n')) { if (!buffer.empty() && !(buffer.length() == 1 && buffer[0] == '\n')) {
params.input_prefix = buffer; params.input_prefix = buffer;
} }
buffer.clear(); buffer.clear();
@ -627,7 +604,7 @@ int main(int argc, char ** argv) {
buffer += line; buffer += line;
} while (another_line); } while (another_line);
// check if we got an empty line // check if we got an empty line
if(!buffer.empty() && !(buffer.length() == 1 && buffer[0] == '\n')) { if (!buffer.empty() && !(buffer.length() == 1 && buffer[0] == '\n')) {
params.input_suffix = buffer; params.input_suffix = buffer;
} }
buffer.clear(); buffer.clear();
@ -640,7 +617,7 @@ int main(int argc, char ** argv) {
process_escapes(params.input_suffix); process_escapes(params.input_suffix);
} }
suff_rm_leading_spc = params.escape; suff_rm_leading_spc = params.escape;
if (suff_rm_leading_spc && params.input_suffix.find_first_of(" ") == 0 && params.input_suffix.size() > 1) { if (suff_rm_leading_spc && params.input_suffix.find_first_of(' ') == 0 && params.input_suffix.size() > 1) {
params.input_suffix.erase(0, 1); params.input_suffix.erase(0, 1);
suff_rm_leading_spc = false; suff_rm_leading_spc = false;
} }
@ -667,7 +644,7 @@ int main(int argc, char ** argv) {
is_interacting = false; is_interacting = false;
} }
// deal with end of text token in interactive mode // deal with end of text token in interactive mode
else if (ctx_sampling->prev.back() == llama_token_eos(ctx)) { else if (llama_sampling_last(ctx_sampling) == llama_token_eos(ctx)) {
LOG("found EOS token\n"); LOG("found EOS token\n");
if (params.interactive) { if (params.interactive) {
@ -740,15 +717,7 @@ int main(int argc, char ** argv) {
if (n_past > 0) { if (n_past > 0) {
if (is_interacting) { if (is_interacting) {
// reset grammar state if we're restarting generation llama_sampling_reset(ctx_sampling);
if (grammar != NULL) {
llama_grammar_free(grammar);
std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
grammar = llama_grammar_init(
grammar_rules.data(), grammar_rules.size(),
parsed_grammar.symbol_ids.at("root"));
}
} }
is_interacting = false; is_interacting = false;
} }
@ -778,9 +747,7 @@ int main(int argc, char ** argv) {
llama_free(ctx); llama_free(ctx);
llama_free_model(model); llama_free_model(model);
if (grammar != NULL) { llama_sampling_free(ctx_sampling);
llama_grammar_free(grammar);
}
llama_backend_free(); llama_backend_free();
#ifndef LOG_DISABLE_LOGS #ifndef LOG_DISABLE_LOGS

View File

@ -58,28 +58,30 @@ inline bool eval_string(struct llama_context * ctx_llama, const char* str, int n
// TODO: use common/sampling.h // TODO: use common/sampling.h
inline llama_token sample_id(llama_context * ctx_llama, gpt_params & params) { inline llama_token sample_id(llama_context * ctx_llama, gpt_params & params) {
// out of user input, sample next token auto & sparams = params.sparams;
const float temp = params.sampling_params.temp;
const int32_t top_k = params.sampling_params.top_k <= 0 ? llama_n_vocab(llama_get_model(ctx_llama)) : params.sampling_params.top_k; // out of user input, sample next token
const float top_p = params.sampling_params.top_p; const float temp = sparams.temp;
const float tfs_z = params.sampling_params.tfs_z; const int32_t top_k = sparams.top_k <= 0 ? llama_n_vocab(llama_get_model(ctx_llama)) : sparams.top_k;
const float typical_p = params.sampling_params.typical_p; const float top_p = sparams.top_p;
// const int32_t repeat_last_n = params.sampling_params.repeat_last_n < 0 ? n_ctx : params.sampling_params.repeat_last_n; const float tfs_z = sparams.tfs_z;
// const float repeat_penalty = params.sampling_params.repeat_penalty; const float typical_p = sparams.typical_p;
// const float alpha_presence = params.sampling_params.presence_penalty; // const int32_t repeat_last_n = sparams.repeat_last_n < 0 ? n_ctx : sparams.repeat_last_n;
// const float alpha_frequency = params.sampling_params.frequency_penalty; // const float repeat_penalty = sparams.repeat_penalty;
const int mirostat = params.sampling_params.mirostat; // const float alpha_presence = sparams.presence_penalty;
const float mirostat_tau = params.sampling_params.mirostat_tau; // const float alpha_frequency = sparams.frequency_penalty;
const float mirostat_eta = params.sampling_params.mirostat_eta; const int mirostat = sparams.mirostat;
// const bool penalize_nl = params.sampling_params.penalize_nl; const float mirostat_tau = sparams.mirostat_tau;
const float mirostat_eta = sparams.mirostat_eta;
// const bool penalize_nl = sparams.penalize_nl;
llama_token id = 0; llama_token id = 0;
{ {
auto logits = llama_get_logits(ctx_llama); auto logits = llama_get_logits(ctx_llama);
auto n_vocab = llama_n_vocab(llama_get_model(ctx_llama)); auto n_vocab = llama_n_vocab(llama_get_model(ctx_llama));
// Apply params.logit_bias map // Apply params.logit_bias map
for (auto it = params.sampling_params.logit_bias.begin(); it != params.sampling_params.logit_bias.end(); it++) { for (auto it = sparams.logit_bias.begin(); it != sparams.logit_bias.end(); it++) {
logits[it->first] += it->second; logits[it->first] += it->second;
} }
@ -91,18 +93,18 @@ inline llama_token sample_id(llama_context * ctx_llama, gpt_params & params) {
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
// TODO: Apply penalties // TODO: Apply penalties
// float nl_logit = logits[llama_token_nl(ctx)]; // float nl_logit = logits[llama_token_nl(ctx)];
// auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx); // auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx);
// llama_sample_repetition_penalty(ctx, &candidates_p, // llama_sample_repetition_penalty(ctx, &candidates_p,
// last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, // last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
// last_n_repeat, repeat_penalty); // last_n_repeat, repeat_penalty);
// llama_sample_frequency_and_presence_penalties(ctx, &candidates_p, // llama_sample_frequency_and_presence_penalties(ctx, &candidates_p,
// last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, // last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
// last_n_repeat, alpha_frequency, alpha_presence); // last_n_repeat, alpha_frequency, alpha_presence);
// if (!penalize_nl) { // if (!penalize_nl) {
// logits[llama_token_nl(ctx)] = nl_logit; // logits[llama_token_nl(ctx)] = nl_logit;
// } // }
if (temp <= 0) { if (temp <= 0) {
// Greedy sampling // Greedy sampling

View File

@ -108,7 +108,7 @@ int main(int argc, char ** argv) {
if (!gpt_params_parse(argc, argv, params)) { if (!gpt_params_parse(argc, argv, params)) {
return 1; return 1;
} }
llama_sampling_params & sparams = params.sampling_params; llama_sampling_params & sparams = params.sparams;
#ifndef LOG_DISABLE_LOGS #ifndef LOG_DISABLE_LOGS
log_set_target(log_filename_generator("main", "log")); log_set_target(log_filename_generator("main", "log"));
@ -415,8 +415,7 @@ int main(int argc, char ** argv) {
} }
} }
} }
LOG_TEE("sampling: repeat_last_n = %d, repeat_penalty = %f, presence_penalty = %f, frequency_penalty = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_lr = %f, mirostat_ent = %f\n", LOG_TEE("sampling: \n%s\n", llama_sampling_print(sparams).c_str());
sparams.repeat_last_n, sparams.repeat_penalty, sparams.presence_penalty, sparams.frequency_penalty, sparams.top_k, sparams.tfs_z, sparams.top_p, sparams.typical_p, sparams.temp, sparams.mirostat, sparams.mirostat_eta, sparams.mirostat_tau);
LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
LOG_TEE("\n\n"); LOG_TEE("\n\n");
@ -459,7 +458,7 @@ int main(int argc, char ** argv) {
std::vector<llama_token> embd; std::vector<llama_token> embd;
std::vector<llama_token> embd_guidance; std::vector<llama_token> embd_guidance;
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params); struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams);
while ((n_remain != 0 && !is_antiprompt) || params.interactive) { while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
// predict // predict
@ -612,7 +611,7 @@ int main(int argc, char ** argv) {
const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance); const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance);
llama_sampling_accept(ctx_sampling, ctx, id); llama_sampling_accept(ctx_sampling, ctx, id, true);
LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str()); LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str());
@ -631,12 +630,9 @@ int main(int argc, char ** argv) {
while ((int) embd_inp.size() > n_consumed) { while ((int) embd_inp.size() > n_consumed) {
embd.push_back(embd_inp[n_consumed]); embd.push_back(embd_inp[n_consumed]);
// GG: I'm not sure it's a good idea to push the prompt tokens into the sampling context // push the prompt in the sampling context in order to apply repetition penalties later
// Most likely will remove this in the future to avoid exposing "prev" // for the prompt, we don't apply grammar rules
// Same thing is done in "server". If we stop pushing the prompt tokens, then the repetition llama_sampling_accept(ctx_sampling, ctx, embd_inp[n_consumed], false);
// penalty will be applied only based on the tokens generated by the model.
ctx_sampling->prev.erase(ctx_sampling->prev.begin());
ctx_sampling->prev.push_back(embd_inp[n_consumed]);
++n_consumed; ++n_consumed;
if ((int) embd.size() >= params.n_batch) { if ((int) embd.size() >= params.n_batch) {
@ -667,12 +663,10 @@ int main(int argc, char ** argv) {
// if not currently processing queued inputs; // if not currently processing queued inputs;
if ((int) embd_inp.size() <= n_consumed) { if ((int) embd_inp.size() <= n_consumed) {
// check for reverse prompt // check for reverse prompt in the last n_prev tokens
if (!params.antiprompt.empty()) { if (!params.antiprompt.empty()) {
std::string last_output; const int n_prev = 32;
for (auto id : ctx_sampling->prev) { const std::string last_output = llama_sampling_prev_str(ctx_sampling, ctx, n_prev);
last_output += llama_token_to_piece(ctx, id);
}
is_antiprompt = false; is_antiprompt = false;
// Check if each of the reverse prompts appears at the end of the output. // Check if each of the reverse prompts appears at the end of the output.
@ -699,7 +693,7 @@ int main(int argc, char ** argv) {
} }
// deal with end of text token in interactive mode // deal with end of text token in interactive mode
if (ctx_sampling->prev.back() == llama_token_eos(ctx)) { if (llama_sampling_last(ctx_sampling) == llama_token_eos(ctx)) {
LOG("found EOS token\n"); LOG("found EOS token\n");
if (params.interactive) { if (params.interactive) {

View File

@ -157,7 +157,7 @@ int main(int argc, char ** argv) {
for (size_t i = 0; i < clients.size(); ++i) { for (size_t i = 0; i < clients.size(); ++i) {
auto & client = clients[i]; auto & client = clients[i];
client.id = i; client.id = i;
client.ctx_sampling = llama_sampling_init(params); client.ctx_sampling = llama_sampling_init(params.sparams);
} }
std::vector<llama_token> tokens_system; std::vector<llama_token> tokens_system;
@ -330,7 +330,7 @@ int main(int argc, char ** argv) {
const llama_token id = llama_sampling_sample(client.ctx_sampling, ctx, NULL, client.i_batch - i); const llama_token id = llama_sampling_sample(client.ctx_sampling, ctx, NULL, client.i_batch - i);
llama_sampling_accept(client.ctx_sampling, ctx, id); llama_sampling_accept(client.ctx_sampling, ctx, id, true);
if (client.n_decoded == 1) { if (client.n_decoded == 1) {
// start measuring generation time after the first token to make sure all concurrent clients // start measuring generation time after the first token to make sure all concurrent clients

View File

@ -195,10 +195,12 @@ struct llama_server_context
json prompt; json prompt;
std::vector<llama_token> embd; std::vector<llama_token> embd;
gpt_params params;
llama_model *model = nullptr; llama_model *model = nullptr;
llama_context *ctx = nullptr; llama_context *ctx = nullptr;
gpt_params params;
llama_sampling_context *ctx_sampling = nullptr; llama_sampling_context *ctx_sampling = nullptr;
int n_ctx; int n_ctx;
bool truncated = false; bool truncated = false;
@ -232,7 +234,7 @@ struct llama_server_context
void rewind() void rewind()
{ {
params.antiprompt.clear(); params.antiprompt.clear();
params.grammar.clear(); params.sparams.grammar.clear();
num_prompt_tokens = 0; num_prompt_tokens = 0;
num_tokens_predicted = 0; num_tokens_predicted = 0;
generated_text = ""; generated_text = "";
@ -246,11 +248,14 @@ struct llama_server_context
multibyte_pending = 0; multibyte_pending = 0;
n_remain = 0; n_remain = 0;
n_past = 0; n_past = 0;
params.sparams.n_prev = n_ctx;
}
void initSampling() {
if (ctx_sampling != nullptr) { if (ctx_sampling != nullptr) {
llama_sampling_free(ctx_sampling); llama_sampling_free(ctx_sampling);
} }
ctx_sampling = llama_sampling_init(params); ctx_sampling = llama_sampling_init(params.sparams);
} }
bool loadModel(const gpt_params &params_) bool loadModel(const gpt_params &params_)
@ -311,16 +316,32 @@ struct llama_server_context
return prompt_tokens; return prompt_tokens;
} }
bool loadGrammar() void truncatePrompt(std::vector<llama_token> &prompt_tokens) {
{ const int n_left = n_ctx - params.n_keep;
ctx_sampling = llama_sampling_init(params); const int n_block_size = n_left / 2;
return true; const int erased_blocks = (prompt_tokens.size() - params.n_keep - n_block_size) / n_block_size;
// Keep n_keep tokens at start of prompt (at most n_ctx - 4)
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep);
new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_block_size, prompt_tokens.end());
LOG_VERBOSE("input truncated", {
{"n_ctx", n_ctx},
{"n_keep", params.n_keep},
{"n_left", n_left},
{"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
{"num_prompt_tokens", new_tokens.size()}
});
truncated = true;
prompt_tokens = new_tokens;
} }
void loadInfill() void loadInfill()
{ {
bool suff_rm_leading_spc = true; bool suff_rm_leading_spc = true;
if (params.input_suffix.find_first_of(" ") == 0 && params.input_suffix.size() > 1) { if (params.input_suffix.find_first_of(' ') == 0 && params.input_suffix.size() > 1) {
params.input_suffix.erase(0, 1); params.input_suffix.erase(0, 1);
suff_rm_leading_spc = false; suff_rm_leading_spc = false;
} }
@ -336,6 +357,7 @@ struct llama_server_context
prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(ctx)); prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(ctx));
prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end()); prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end());
prefix_tokens.push_back(llama_token_middle(ctx)); prefix_tokens.push_back(llama_token_middle(ctx));
auto prompt_tokens = prefix_tokens; auto prompt_tokens = prefix_tokens;
num_prompt_tokens = prompt_tokens.size(); num_prompt_tokens = prompt_tokens.size();
@ -347,31 +369,18 @@ struct llama_server_context
params.n_keep = std::min(params.n_ctx - 4, params.n_keep); params.n_keep = std::min(params.n_ctx - 4, params.n_keep);
// if input prompt is too big, truncate like normal // if input prompt is too big, truncate like normal
if (num_prompt_tokens >= (size_t)params.n_ctx) if (num_prompt_tokens >= (size_t) n_ctx)
{ {
printf("Input prompt is too big, truncating. Can only take %d tokens but got %zu\n", params.n_ctx, num_prompt_tokens); truncatePrompt(prompt_tokens);
// todo we probably want to cut from both sides num_prompt_tokens = prompt_tokens.size();
const int n_left = (params.n_ctx - params.n_keep) / 2;
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep);
const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left;
new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end());
std::copy(prompt_tokens.end() - params.n_ctx, prompt_tokens.end(), ctx_sampling->prev.begin());
LOG_VERBOSE("input truncated", { GGML_ASSERT(num_prompt_tokens < (size_t)n_ctx);
{"n_ctx", params.n_ctx},
{"n_keep", params.n_keep},
{"n_left", n_left},
{"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
});
truncated = true;
prompt_tokens = new_tokens;
} }
else
// push the prompt into the sampling context (do not apply grammar)
for (auto & token : prompt_tokens)
{ {
const size_t ps = num_prompt_tokens; llama_sampling_accept(ctx_sampling, ctx, token, false);
std::fill(ctx_sampling->prev.begin(), ctx_sampling->prev.end() - ps, 0);
std::copy(prompt_tokens.begin(), prompt_tokens.end(), ctx_sampling->prev.end() - ps);
} }
// compare the evaluated prompt with the new prompt // compare the evaluated prompt with the new prompt
@ -409,29 +418,18 @@ struct llama_server_context
params.n_keep = std::min(n_ctx - 4, params.n_keep); params.n_keep = std::min(n_ctx - 4, params.n_keep);
// if input prompt is too big, truncate like normal // if input prompt is too big, truncate like normal
if (num_prompt_tokens >= (size_t)n_ctx) if (num_prompt_tokens >= (size_t) n_ctx)
{ {
const int n_left = (n_ctx - params.n_keep) / 2; truncatePrompt(prompt_tokens);
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep); num_prompt_tokens = prompt_tokens.size();
const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left;
new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end());
std::copy(prompt_tokens.end() - n_ctx, prompt_tokens.end(), ctx_sampling->prev.begin());
LOG_VERBOSE("input truncated", { GGML_ASSERT(num_prompt_tokens < (size_t)n_ctx);
{"n_ctx", n_ctx},
{"n_keep", params.n_keep},
{"n_left", n_left},
{"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
});
truncated = true;
prompt_tokens = new_tokens;
} }
else
// push the prompt into the sampling context (do not apply grammar)
for (auto & token : prompt_tokens)
{ {
const size_t ps = num_prompt_tokens; llama_sampling_accept(ctx_sampling, ctx, token, false);
std::fill(ctx_sampling->prev.begin(), ctx_sampling->prev.end() - ps, 0);
std::copy(prompt_tokens.begin(), prompt_tokens.end(), ctx_sampling->prev.end() - ps);
} }
// compare the evaluated prompt with the new prompt // compare the evaluated prompt with the new prompt
@ -530,8 +528,8 @@ struct llama_server_context
llama_token_data_array cur_p = { ctx_sampling->cur.data(), ctx_sampling->cur.size(), false }; llama_token_data_array cur_p = { ctx_sampling->cur.data(), ctx_sampling->cur.size(), false };
const int32_t n_probs = params.sampling_params.n_probs; const int32_t n_probs = params.sparams.n_probs;
if (params.sampling_params.temp <= 0 && n_probs > 0) if (params.sparams.temp <= 0 && n_probs > 0)
{ {
// For llama_sample_token_greedy we need to sort candidates // For llama_sample_token_greedy we need to sort candidates
llama_sample_softmax(ctx, &cur_p); llama_sample_softmax(ctx, &cur_p);
@ -542,7 +540,7 @@ struct llama_server_context
result.probs.push_back({cur_p.data[i].id, cur_p.data[i].p}); result.probs.push_back({cur_p.data[i].id, cur_p.data[i].p});
} }
llama_sampling_accept(ctx_sampling, ctx, result.tok); llama_sampling_accept(ctx_sampling, ctx, result.tok, true);
if (tg) { if (tg) {
num_tokens_predicted++; num_tokens_predicted++;
@ -606,7 +604,7 @@ struct llama_server_context
const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_piece(ctx, token_with_probs.tok); const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_piece(ctx, token_with_probs.tok);
generated_text += token_text; generated_text += token_text;
if (params.sampling_params.n_probs > 0) if (params.sparams.n_probs > 0)
{ {
generated_token_probs.push_back(token_with_probs); generated_token_probs.push_back(token_with_probs);
} }
@ -1004,36 +1002,36 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
static json format_generation_settings(llama_server_context &llama) static json format_generation_settings(llama_server_context &llama)
{ {
const auto & sparams = llama.params.sampling_params; const auto & sparams = llama.params.sparams;
const auto eos_bias = sparams.logit_bias.find(llama_token_eos(llama.ctx)); const auto eos_bias = sparams.logit_bias.find(llama_token_eos(llama.ctx));
const bool ignore_eos = eos_bias != sparams.logit_bias.end() && const bool ignore_eos = eos_bias != sparams.logit_bias.end() &&
eos_bias->second < 0.0f && std::isinf(eos_bias->second); eos_bias->second < 0.0f && std::isinf(eos_bias->second);
return json{ return json{
{"n_ctx", llama.n_ctx}, {"n_ctx", llama.n_ctx},
{"model", llama.params.model_alias}, {"model", llama.params.model_alias},
{"seed", llama.params.seed}, {"seed", llama.params.seed},
{"temp", sparams.temp}, {"temp", sparams.temp},
{"top_k", sparams.top_k}, {"top_k", sparams.top_k},
{"top_p", sparams.top_p}, {"top_p", sparams.top_p},
{"tfs_z", sparams.tfs_z}, {"tfs_z", sparams.tfs_z},
{"typical_p", sparams.typical_p}, {"typical_p", sparams.typical_p},
{"repeat_last_n", sparams.repeat_last_n}, {"repeat_last_n", sparams.penalty_last_n},
{"repeat_penalty", sparams.repeat_penalty}, {"repeat_penalty", sparams.penalty_repeat},
{"presence_penalty", sparams.presence_penalty}, {"frequency_penalty", sparams.penalty_freq},
{"frequency_penalty", sparams.frequency_penalty}, {"presence_penalty", sparams.penalty_present},
{"mirostat", sparams.mirostat}, {"mirostat", sparams.mirostat},
{"mirostat_tau", sparams.mirostat_tau}, {"mirostat_tau", sparams.mirostat_tau},
{"mirostat_eta", sparams.mirostat_eta}, {"mirostat_eta", sparams.mirostat_eta},
{"penalize_nl", sparams.penalize_nl}, {"penalize_nl", sparams.penalize_nl},
{"stop", llama.params.antiprompt}, {"stop", llama.params.antiprompt},
{"n_predict", llama.params.n_predict}, {"n_predict", llama.params.n_predict},
{"n_keep", llama.params.n_keep}, {"n_keep", llama.params.n_keep},
{"ignore_eos", ignore_eos}, {"ignore_eos", ignore_eos},
{"stream", llama.stream}, {"stream", llama.stream},
{"logit_bias", sparams.logit_bias}, {"logit_bias", sparams.logit_bias},
{"n_probs", sparams.n_probs}, {"n_probs", sparams.n_probs},
{"grammar", llama.params.grammar}, {"grammar", llama.params.sparams.grammar},
}; };
} }
@ -1081,7 +1079,7 @@ static json format_final_response(llama_server_context &llama, const std::string
{"timings", format_timings(llama)}, {"timings", format_timings(llama)},
}; };
if (llama.params.sampling_params.n_probs > 0) if (llama.params.sparams.n_probs > 0)
{ {
res["completion_probabilities"] = probs_vector_to_json(llama.ctx, probs); res["completion_probabilities"] = probs_vector_to_json(llama.ctx, probs);
} }
@ -1097,7 +1095,7 @@ static json format_partial_response(
{"stop", false}, {"stop", false},
}; };
if (llama.params.sampling_params.n_probs > 0) if (llama.params.sparams.n_probs > 0)
{ {
res["completion_probabilities"] = probs_vector_to_json(llama.ctx, probs); res["completion_probabilities"] = probs_vector_to_json(llama.ctx, probs);
} }
@ -1129,28 +1127,30 @@ static T json_value(const json &body, const std::string &key, const T &default_v
static void parse_options_completion(const json &body, llama_server_context &llama) static void parse_options_completion(const json &body, llama_server_context &llama)
{ {
gpt_params default_params; gpt_params default_params;
const auto & default_sparams = default_params.sampling_params; const auto & default_sparams = default_params.sparams;
auto & sparams = llama.params.sampling_params;
llama.stream = json_value(body, "stream", false); auto & params = llama.params;
llama.params.n_predict = json_value(body, "n_predict", default_params.n_predict); auto & sparams = llama.params.sparams;
sparams.top_k = json_value(body, "top_k", default_sparams.top_k);
sparams.top_p = json_value(body, "top_p", default_sparams.top_p); llama.stream = json_value(body, "stream", false);
sparams.tfs_z = json_value(body, "tfs_z", default_sparams.tfs_z); params.n_predict = json_value(body, "n_predict", default_params.n_predict);
sparams.typical_p = json_value(body, "typical_p", default_sparams.typical_p); sparams.top_k = json_value(body, "top_k", default_sparams.top_k);
sparams.repeat_last_n = json_value(body, "repeat_last_n", default_sparams.repeat_last_n); sparams.top_p = json_value(body, "top_p", default_sparams.top_p);
sparams.temp = json_value(body, "temperature", default_sparams.temp); sparams.tfs_z = json_value(body, "tfs_z", default_sparams.tfs_z);
sparams.repeat_penalty = json_value(body, "repeat_penalty", default_sparams.repeat_penalty); sparams.typical_p = json_value(body, "typical_p", default_sparams.typical_p);
sparams.presence_penalty = json_value(body, "presence_penalty", default_sparams.presence_penalty); sparams.temp = json_value(body, "temperature", default_sparams.temp);
sparams.frequency_penalty = json_value(body, "frequency_penalty", default_sparams.frequency_penalty); sparams.penalty_last_n = json_value(body, "repeat_last_n", default_sparams.penalty_last_n);
sparams.mirostat = json_value(body, "mirostat", default_sparams.mirostat); sparams.penalty_repeat = json_value(body, "repeat_penalty", default_sparams.penalty_repeat);
sparams.mirostat_tau = json_value(body, "mirostat_tau", default_sparams.mirostat_tau); sparams.penalty_freq = json_value(body, "frequency_penalty", default_sparams.penalty_freq);
sparams.mirostat_eta = json_value(body, "mirostat_eta", default_sparams.mirostat_eta); sparams.penalty_present = json_value(body, "presence_penalty", default_sparams.penalty_present);
sparams.penalize_nl = json_value(body, "penalize_nl", default_sparams.penalize_nl); sparams.mirostat = json_value(body, "mirostat", default_sparams.mirostat);
llama.params.n_keep = json_value(body, "n_keep", default_params.n_keep); sparams.mirostat_tau = json_value(body, "mirostat_tau", default_sparams.mirostat_tau);
llama.params.seed = json_value(body, "seed", default_params.seed); sparams.mirostat_eta = json_value(body, "mirostat_eta", default_sparams.mirostat_eta);
llama.params.grammar = json_value(body, "grammar", default_params.grammar); sparams.penalize_nl = json_value(body, "penalize_nl", default_sparams.penalize_nl);
sparams.n_probs = json_value(body, "n_probs", default_sparams.n_probs); params.n_keep = json_value(body, "n_keep", default_params.n_keep);
params.seed = json_value(body, "seed", default_params.seed);
sparams.grammar = json_value(body, "grammar", default_sparams.grammar);
sparams.n_probs = json_value(body, "n_probs", default_sparams.n_probs);
if (body.count("prompt") != 0) if (body.count("prompt") != 0)
{ {
@ -1204,8 +1204,6 @@ static void parse_options_completion(const json &body, llama_server_context &lla
} }
} }
llama.ctx_sampling = llama_sampling_init(llama.params);
LOG_VERBOSE("completion parameters parsed", format_generation_settings(llama)); LOG_VERBOSE("completion parameters parsed", format_generation_settings(llama));
} }
@ -1374,15 +1372,9 @@ int main(int argc, char **argv)
llama.rewind(); llama.rewind();
llama_reset_timings(llama.ctx); llama_reset_timings(llama.ctx);
parse_options_completion(json::parse(req.body), llama); parse_options_completion(json::parse(req.body), llama);
if (!llama.loadGrammar()) llama.initSampling();
{
res.status = 400;
return;
}
llama.loadPrompt(); llama.loadPrompt();
llama.beginCompletion(); llama.beginCompletion();
@ -1414,7 +1406,7 @@ int main(int argc, char **argv)
} }
auto probs = llama.generated_token_probs; auto probs = llama.generated_token_probs;
if (llama.params.sampling_params.n_probs > 0 && llama.stopped_word) { if (llama.params.sparams.n_probs > 0 && llama.stopped_word) {
const std::vector<llama_token> stop_word_toks = llama_tokenize(llama.ctx, llama.stopping_word, false); const std::vector<llama_token> stop_word_toks = llama_tokenize(llama.ctx, llama.stopping_word, false);
probs = std::vector<completion_token_output>(llama.generated_token_probs.begin(), llama.generated_token_probs.end() - stop_word_toks.size()); probs = std::vector<completion_token_output>(llama.generated_token_probs.begin(), llama.generated_token_probs.end() - stop_word_toks.size());
} }
@ -1466,7 +1458,7 @@ int main(int argc, char **argv)
std::vector<completion_token_output> probs_output = {}; std::vector<completion_token_output> probs_output = {};
if (llama.params.sampling_params.n_probs > 0) { if (llama.params.sparams.n_probs > 0) {
const std::vector<llama_token> to_send_toks = llama_tokenize(llama.ctx, to_send, false); const std::vector<llama_token> to_send_toks = llama_tokenize(llama.ctx, to_send, false);
size_t probs_pos = std::min(sent_token_probs_index, llama.generated_token_probs.size()); size_t probs_pos = std::min(sent_token_probs_index, llama.generated_token_probs.size());
size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama.generated_token_probs.size()); size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama.generated_token_probs.size());
@ -1537,14 +1529,9 @@ int main(int argc, char **argv)
llama.rewind(); llama.rewind();
llama_reset_timings(llama.ctx); llama_reset_timings(llama.ctx);
parse_options_infill(json::parse(req.body), llama); parse_options_infill(json::parse(req.body), llama);
if (!llama.loadGrammar()) llama.initSampling();
{
res.status = 400;
return;
}
llama.loadInfill(); llama.loadInfill();
llama.beginCompletion(); llama.beginCompletion();
const auto chunked_content_provider = [&](size_t, DataSink & sink) { const auto chunked_content_provider = [&](size_t, DataSink & sink) {
@ -1587,7 +1574,7 @@ int main(int argc, char **argv)
std::vector<completion_token_output> probs_output = {}; std::vector<completion_token_output> probs_output = {};
if (llama.params.sampling_params.n_probs > 0) { if (llama.params.sparams.n_probs > 0) {
const std::vector<llama_token> to_send_toks = llama_tokenize(llama.ctx, to_send, false); const std::vector<llama_token> to_send_toks = llama_tokenize(llama.ctx, to_send, false);
size_t probs_pos = std::min(sent_token_probs_index, llama.generated_token_probs.size()); size_t probs_pos = std::min(sent_token_probs_index, llama.generated_token_probs.size());
size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama.generated_token_probs.size()); size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama.generated_token_probs.size());
@ -1694,7 +1681,9 @@ int main(int argc, char **argv)
const json body = json::parse(req.body); const json body = json::parse(req.body);
llama.rewind(); llama.rewind();
llama_reset_timings(llama.ctx); llama_reset_timings(llama.ctx);
if (body.count("content") != 0) if (body.count("content") != 0)
{ {
llama.prompt = body["content"]; llama.prompt = body["content"];
@ -1704,6 +1693,8 @@ int main(int argc, char **argv)
llama.prompt = ""; llama.prompt = "";
} }
llama.params.n_predict = 0; llama.params.n_predict = 0;
llama.initSampling();
llama.loadPrompt(); llama.loadPrompt();
llama.beginCompletion(); llama.beginCompletion();
llama.doCompletion(); llama.doCompletion();

View File

@ -112,16 +112,16 @@ int main(int argc, char ** argv) {
bool has_eos = false; bool has_eos = false;
// target model sampling context // target model sampling context
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params); struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams);
// draft sequence data // draft sequence data
std::vector<seq_draft> drafts(n_seq_dft); std::vector<seq_draft> drafts(n_seq_dft);
params.grammar.clear(); // the draft samplers will copy the target sampler's grammar params.sparams.grammar.clear(); // the draft samplers will copy the target sampler's grammar
params.sampling_params.temp = std::max(0.01f, params.sampling_params.temp); params.sparams.temp = std::max(0.01f, params.sparams.temp);
for (int s = 0; s < n_seq_dft; ++s) { for (int s = 0; s < n_seq_dft; ++s) {
drafts[s].ctx_sampling = llama_sampling_init(params); drafts[s].ctx_sampling = llama_sampling_init(params.sparams);
} }
llama_batch batch_dft = llama_batch_init(params.n_ctx, 0, 1); llama_batch batch_dft = llama_batch_init(params.n_ctx, 0, 1);
@ -154,7 +154,7 @@ int main(int argc, char ** argv) {
// sample from the target model // sample from the target model
llama_token id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]); llama_token id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]);
llama_sampling_accept(ctx_sampling, ctx_tgt, id); llama_sampling_accept(ctx_sampling, ctx_tgt, id, true);
//LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, ctx_sampling->prev).c_str()); //LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, ctx_sampling->prev).c_str());
@ -328,7 +328,7 @@ int main(int argc, char ** argv) {
const int s = sa[is]; const int s = sa[is];
llama_sampling_accept(drafts[s].ctx_sampling, ctx_dft, id); llama_sampling_accept(drafts[s].ctx_sampling, ctx_dft, id, true);
drafts[s].tokens.push_back(id); drafts[s].tokens.push_back(id);

View File

@ -1018,8 +1018,8 @@ enum e_model {
}; };
static const size_t kB = 1024; static const size_t kB = 1024;
static const size_t MB = kB*kB; static const size_t MB = 1024*kB;
static const size_t GB = kB*kB*kB; static const size_t GB = 1024*MB;
struct llama_hparams { struct llama_hparams {
bool vocab_only; bool vocab_only;
@ -1042,21 +1042,21 @@ struct llama_hparams {
float f_max_alibi_bias; float f_max_alibi_bias;
bool operator!=(const llama_hparams & other) const { bool operator!=(const llama_hparams & other) const {
if (this->vocab_only != other.vocab_only) return true; if (this->vocab_only != other.vocab_only) return true;
if (this->n_vocab != other.n_vocab) return true; if (this->n_vocab != other.n_vocab) return true;
if (this->n_ctx_train != other.n_ctx_train) return true; if (this->n_ctx_train != other.n_ctx_train) return true;
if (this->n_embd != other.n_embd) return true; if (this->n_embd != other.n_embd) return true;
if (this->n_head != other.n_head) return true; if (this->n_head != other.n_head) return true;
if (this->n_head_kv != other.n_head_kv) return true; if (this->n_head_kv != other.n_head_kv) return true;
if (this->n_layer != other.n_layer) return true; if (this->n_layer != other.n_layer) return true;
if (this->n_rot != other.n_rot) return true; if (this->n_rot != other.n_rot) return true;
if (this->n_ff != other.n_ff) return true; if (this->n_ff != other.n_ff) return true;
const float EPSILON = 1e-9; const float EPSILON = 1e-9;
if (!is_float_close(this->f_norm_eps, other.f_norm_eps, EPSILON)) return true; if (!is_float_close(this->f_norm_eps, other.f_norm_eps, EPSILON)) return true;
if (!is_float_close(this->f_norm_rms_eps, other.f_norm_rms_eps, EPSILON)) return true; if (!is_float_close(this->f_norm_rms_eps, other.f_norm_rms_eps, EPSILON)) return true;
if (!is_float_close(this->rope_freq_base_train, other.rope_freq_base_train, EPSILON)) return true; if (!is_float_close(this->rope_freq_base_train, other.rope_freq_base_train, EPSILON)) return true;
if (!is_float_close(this->rope_freq_scale_train, other.rope_freq_scale_train, EPSILON)) return true; if (!is_float_close(this->rope_freq_scale_train, other.rope_freq_scale_train, EPSILON)) return true;
return false; return false;
@ -1195,11 +1195,11 @@ struct llama_vocab {
id special_sep_id = -1; id special_sep_id = -1;
id special_pad_id = -1; id special_pad_id = -1;
id linefeed_id = 13; id linefeed_id = 13;
id special_prefix_id = 32007; id special_prefix_id = 32007;
id special_middle_id = 32009; id special_middle_id = 32009;
id special_suffix_id = 32008; id special_suffix_id = 32008;
id special_eot_id = 32010; id special_eot_id = 32010;
int find_bpe_rank(std::string token_left, std::string token_right) const { int find_bpe_rank(std::string token_left, std::string token_right) const {
replace_all(token_left, " ", "\u0120"); replace_all(token_left, " ", "\u0120");
@ -1359,10 +1359,7 @@ static bool llama_kv_cache_init(
cache.cells.clear(); cache.cells.clear();
cache.cells.resize(n_ctx); cache.cells.resize(n_ctx);
// TODO: this should be: cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*ggml_tensor_overhead());
// cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*ggml_tensor_overhead());
// change it and test that it works
cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*MB);
memset(cache.buf.data, 0, cache.buf.size); memset(cache.buf.data, 0, cache.buf.size);
struct ggml_init_params params; struct ggml_init_params params;
@ -7417,37 +7414,15 @@ void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array
llama_sample_temp(ctx, candidates_p, temp); llama_sample_temp(ctx, candidates_p, temp);
} }
void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float penalty) { void llama_sample_repetition_penalties(
if (last_tokens_size == 0 || penalty == 1.0f) { struct llama_context * ctx,
return; llama_token_data_array * candidates,
} const llama_token * last_tokens,
size_t penalty_last_n,
const int64_t t_start_sample_us = ggml_time_us(); float penalty_repeat,
float penalty_freq,
for (size_t i = 0; i < candidates->size; ++i) { float penalty_present) {
const auto * token_iter = std::find(last_tokens, last_tokens + last_tokens_size, candidates->data[i].id); if (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)) {
if (token_iter == last_tokens + last_tokens_size) {
continue;
}
// The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
// This is common fix for this problem, which is to multiply by the penalty instead of dividing.
if (candidates->data[i].logit <= 0) {
candidates->data[i].logit *= penalty;
} else {
candidates->data[i].logit /= penalty;
}
}
candidates->sorted = false;
if (ctx) {
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
}
}
void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens_p, size_t last_tokens_size, float alpha_frequency, float alpha_presence) {
if (last_tokens_size == 0 || (alpha_frequency == 0.0f && alpha_presence == 0.0f)) {
return; return;
} }
@ -7455,19 +7430,28 @@ void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, l
// Create a frequency map to count occurrences of each token in last_tokens // Create a frequency map to count occurrences of each token in last_tokens
std::unordered_map<llama_token, int> token_count; std::unordered_map<llama_token, int> token_count;
for (size_t i = 0; i < last_tokens_size; ++i) { for (size_t i = 0; i < penalty_last_n; ++i) {
token_count[last_tokens_p[i]]++; token_count[last_tokens[i]]++;
} }
// Apply frequency and presence penalties to the candidates // Apply frequency and presence penalties to the candidates
for (size_t i = 0; i < candidates->size; ++i) { for (size_t i = 0; i < candidates->size; ++i) {
auto token_iter = token_count.find(candidates->data[i].id); const auto token_iter = token_count.find(candidates->data[i].id);
if (token_iter == token_count.end()) { if (token_iter == token_count.end()) {
continue; continue;
} }
int count = token_iter->second; const int count = token_iter->second;
candidates->data[i].logit -= float(count) * alpha_frequency + float(count > 0) * alpha_presence;
// The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
// This is common fix for this problem, which is to multiply by the penalty instead of dividing.
if (candidates->data[i].logit <= 0) {
candidates->data[i].logit *= penalty_repeat;
} else {
candidates->data[i].logit /= penalty_repeat;
}
candidates->data[i].logit -= float(count) * penalty_freq + float(count > 0) * penalty_present;
} }
candidates->sorted = false; candidates->sorted = false;

16
llama.h
View File

@ -560,21 +560,15 @@ extern "C" {
LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed); LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed);
/// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. /// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
LLAMA_API void llama_sample_repetition_penalty(
struct llama_context * ctx,
llama_token_data_array * candidates,
const llama_token * last_tokens,
size_t last_tokens_size,
float penalty);
/// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. /// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
LLAMA_API void llama_sample_frequency_and_presence_penalties( LLAMA_API void llama_sample_repetition_penalties(
struct llama_context * ctx, struct llama_context * ctx,
llama_token_data_array * candidates, llama_token_data_array * candidates,
const llama_token * last_tokens, const llama_token * last_tokens,
size_t last_tokens_size, size_t penalty_last_n,
float alpha_frequency, float penalty_repeat,
float alpha_presence); float penalty_freq,
float penalty_present);
/// @details Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806 /// @details Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, the logits must be directly extracted from the original generation context without being sorted. /// @param candidates A vector of `llama_token_data` containing the candidate tokens, the logits must be directly extracted from the original generation context without being sorted.

View File

@ -8,11 +8,9 @@
#include <cmath> #include <cmath>
#include <numeric> #include <numeric>
#include <cassert> #include <cassert>
#include <iostream>
#include <vector> #include <vector>
#include <algorithm> #include <algorithm>
static void dump(const llama_token_data_array * candidates) { static void dump(const llama_token_data_array * candidates) {
for (size_t i = 0; i < candidates->size; i++) { for (size_t i = 0; i < candidates->size; i++) {
printf("%d: %f (%f)\n", candidates->data[i].id, candidates->data[i].p, candidates->data[i].logit); printf("%d: %f (%f)\n", candidates->data[i].id, candidates->data[i].p, candidates->data[i].logit);
@ -21,7 +19,6 @@ static void dump(const llama_token_data_array * candidates) {
#define DUMP(__candidates) do { printf("%s:%d (%s)\n", __FILE__, __LINE__, __func__); dump((__candidates)); printf("-\n"); } while(0) #define DUMP(__candidates) do { printf("%s:%d (%s)\n", __FILE__, __LINE__, __func__); dump((__candidates)); printf("-\n"); } while(0)
static void test_top_k(const std::vector<float> & probs, const std::vector<float> & expected_probs, int k) { static void test_top_k(const std::vector<float> & probs, const std::vector<float> & expected_probs, int k) {
size_t n_vocab = probs.size(); size_t n_vocab = probs.size();
std::vector<llama_token_data> candidates; std::vector<llama_token_data> candidates;
@ -37,13 +34,12 @@ static void test_top_k(const std::vector<float> & probs, const std::vector<float
llama_sample_top_k(nullptr, &candidates_p, k, 1); llama_sample_top_k(nullptr, &candidates_p, k, 1);
DUMP(&candidates_p); DUMP(&candidates_p);
assert(candidates_p.size == expected_probs.size()); GGML_ASSERT(candidates_p.size == expected_probs.size());
for (size_t i = 0; i < candidates_p.size; i++) { for (size_t i = 0; i < candidates_p.size; i++) {
assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-5); GGML_ASSERT(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-5);
} }
} }
static void test_top_p(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) { static void test_top_p(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
size_t n_vocab = probs.size(); size_t n_vocab = probs.size();
std::vector<llama_token_data> candidates; std::vector<llama_token_data> candidates;
@ -59,13 +55,12 @@ static void test_top_p(const std::vector<float> & probs, const std::vector<float
llama_sample_top_p(nullptr, &candidates_p, p, 1); llama_sample_top_p(nullptr, &candidates_p, p, 1);
DUMP(&candidates_p); DUMP(&candidates_p);
assert(candidates_p.size == expected_probs.size()); GGML_ASSERT(candidates_p.size == expected_probs.size());
for (size_t i = 0; i < candidates_p.size; i++) { for (size_t i = 0; i < candidates_p.size; i++) {
assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3); GGML_ASSERT(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
} }
} }
static void test_tfs(const std::vector<float> & probs, const std::vector<float> & expected_probs, float z) { static void test_tfs(const std::vector<float> & probs, const std::vector<float> & expected_probs, float z) {
size_t n_vocab = probs.size(); size_t n_vocab = probs.size();
std::vector<llama_token_data> candidates; std::vector<llama_token_data> candidates;
@ -80,13 +75,12 @@ static void test_tfs(const std::vector<float> & probs, const std::vector<float>
llama_sample_tail_free(nullptr, &candidates_p, z, 1); llama_sample_tail_free(nullptr, &candidates_p, z, 1);
DUMP(&candidates_p); DUMP(&candidates_p);
assert(candidates_p.size == expected_probs.size()); GGML_ASSERT(candidates_p.size == expected_probs.size());
for (size_t i = 0; i < candidates_p.size; i++) { for (size_t i = 0; i < candidates_p.size; i++) {
assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3); GGML_ASSERT(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
} }
} }
static void test_typical(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) { static void test_typical(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
size_t n_vocab = probs.size(); size_t n_vocab = probs.size();
std::vector<llama_token_data> candidates; std::vector<llama_token_data> candidates;
@ -101,18 +95,17 @@ static void test_typical(const std::vector<float> & probs, const std::vector<flo
llama_sample_typical(nullptr, &candidates_p, p, 1); llama_sample_typical(nullptr, &candidates_p, p, 1);
DUMP(&candidates_p); DUMP(&candidates_p);
assert(candidates_p.size == expected_probs.size()); GGML_ASSERT(candidates_p.size == expected_probs.size());
for (size_t i = 0; i < candidates_p.size; i++) { for (size_t i = 0; i < candidates_p.size; i++) {
assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3); GGML_ASSERT(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
} }
} }
static void test_repetition_penalties(
static void test_repetition_penalty(
const std::vector<float> & probs, const std::vector<llama_token> & last_tokens, const std::vector<float> & probs, const std::vector<llama_token> & last_tokens,
const std::vector<float> & expected_probs, float penalty const std::vector<float> & expected_probs, float repeat_penalty, float alpha_frequency, float alpha_presence
) { ) {
assert(probs.size() == expected_probs.size()); GGML_ASSERT(probs.size() == expected_probs.size());
size_t n_vocab = probs.size(); size_t n_vocab = probs.size();
std::vector<llama_token_data> candidates; std::vector<llama_token_data> candidates;
@ -125,41 +118,13 @@ static void test_repetition_penalty(
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
llama_sample_softmax(nullptr, &candidates_p); llama_sample_softmax(nullptr, &candidates_p);
DUMP(&candidates_p); DUMP(&candidates_p);
llama_sample_repetition_penalty(nullptr, &candidates_p, (const llama_token *) last_tokens.data(), last_tokens.size(), penalty); llama_sample_repetition_penalties(nullptr, &candidates_p, (const llama_token *) last_tokens.data(), last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence);
llama_sample_softmax(nullptr, &candidates_p); llama_sample_softmax(nullptr, &candidates_p);
DUMP(&candidates_p); DUMP(&candidates_p);
assert(candidates_p.size == expected_probs.size()); GGML_ASSERT(candidates_p.size == expected_probs.size());
for (size_t i = 0; i < candidates_p.size; i++) { for (size_t i = 0; i < candidates_p.size; i++) {
assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-6); GGML_ASSERT(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
}
}
static void test_frequency_presence_penalty(
const std::vector<float> & probs, const std::vector<llama_token> & last_tokens,
const std::vector<float> & expected_probs, float alpha_frequency, float alpha_presence
) {
assert(probs.size() == expected_probs.size());
size_t n_vocab = probs.size();
std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
float logit = log(probs[token_id]);
candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
}
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
llama_sample_softmax(nullptr, &candidates_p);
// DUMP(&candidates_p);
llama_sample_frequency_and_presence_penalties(nullptr, &candidates_p, (const llama_token *) last_tokens.data(), last_tokens.size(), alpha_frequency, alpha_presence);
llama_sample_softmax(nullptr, &candidates_p);
// DUMP(&candidates_p);
assert(candidates_p.size == expected_probs.size());
for (size_t i = 0; i < candidates_p.size; i++) {
assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
} }
} }
@ -181,13 +146,13 @@ int main(void) {
test_typical({0.97f, 0.01f, 0.01f, 0.01f}, {0.97f}, 0.5f); test_typical({0.97f, 0.01f, 0.01f, 0.01f}, {0.97f}, 0.5f);
test_typical({0.4f, 0.2f, 0.2f, 0.2f}, {0.2f, 0.2f, 0.2f}, 0.5f); test_typical({0.4f, 0.2f, 0.2f, 0.2f}, {0.2f, 0.2f, 0.2f}, 0.5f);
test_repetition_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.25f, 0.25f, 0.25f, 0.25f, 0}, 50.0f); test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.25f, 0.25f, 0.25f, 0.25f, 0}, 50.0f, 0.0f, 0.0f);
test_repetition_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.5f, 0.5f, 0, 0, 0}, 50.0f); test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f);
test_repetition_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.5f, 0.5f, 0, 0, 0}, 50.0f); test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f);
test_frequency_presence_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.249997f, 0.249997f, 0.249997f, 0.249997f, 0.000011f}, 5.0f, 5.0f); test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.249997f, 0.249997f, 0.249997f, 0.249997f, 0.000011f}, 1.0f, 5.0f, 5.0f);
test_frequency_presence_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.499966f, 0.499966f, 0.000023f, 0.000023f, 0.000023f}, 5.0f, 5.0f); test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.499966f, 0.499966f, 0.000023f, 0.000023f, 0.000023f}, 1.0f, 5.0f, 5.0f);
test_frequency_presence_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.499977f, 0.499977f, 0.000023f, 0.000023f, 0.000000f}, 5.0f, 5.0f); test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.499977f, 0.499977f, 0.000023f, 0.000023f, 0.000000f}, 1.0f, 5.0f, 5.0f);
printf("OK\n"); printf("OK\n");