mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-12 19:50:17 +00:00
sampling : refactor init to use llama_sampling_params (#3696)
* sampling : refactor init to use llama_sampling_params * llama : combine repetition, frequency and presence penalties in 1 call * examples : remove embd-input and gptneox-wip * sampling : rename penalty params + reduce size of "prev" vector * sampling : add llama_sampling_print helper * sampling : hide prev behind API and apply #3661 ggml-ci
This commit is contained in:
parent
8cf19d60dc
commit
d1031cf49c
9
Makefile
9
Makefile
@ -1,7 +1,7 @@
|
|||||||
# Define the default target now so that it is always the first target
|
# Define the default target now so that it is always the first target
|
||||||
BUILD_TARGETS = \
|
BUILD_TARGETS = \
|
||||||
main quantize quantize-stats perplexity embedding vdot q8dot train-text-from-scratch convert-llama2c-to-ggml \
|
main quantize quantize-stats perplexity embedding vdot q8dot train-text-from-scratch convert-llama2c-to-ggml \
|
||||||
simple batched batched-bench save-load-state server embd-input-test gguf llama-bench llava baby-llama beam-search \
|
simple batched batched-bench save-load-state server gguf llama-bench llava baby-llama beam-search \
|
||||||
speculative infill benchmark-matmult parallel finetune export-lora tests/test-c.o
|
speculative infill benchmark-matmult parallel finetune export-lora tests/test-c.o
|
||||||
|
|
||||||
# Binaries only useful for tests
|
# Binaries only useful for tests
|
||||||
@ -608,13 +608,6 @@ save-load-state: examples/save-load-state/save-load-state.cpp build-info.h ggml.
|
|||||||
server: examples/server/server.cpp examples/server/httplib.h examples/server/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp build-info.h ggml.o llama.o $(COMMON_DEPS) grammar-parser.o $(OBJS)
|
server: examples/server/server.cpp examples/server/httplib.h examples/server/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp build-info.h ggml.o llama.o $(COMMON_DEPS) grammar-parser.o $(OBJS)
|
||||||
$(CXX) $(CXXFLAGS) -Iexamples/server $(filter-out %.h,$(filter-out %.hpp,$^)) -o $@ $(LDFLAGS) $(LWINSOCK2)
|
$(CXX) $(CXXFLAGS) -Iexamples/server $(filter-out %.h,$(filter-out %.hpp,$^)) -o $@ $(LDFLAGS) $(LWINSOCK2)
|
||||||
|
|
||||||
$(LIB_PRE)embdinput$(DSO_EXT): examples/embd-input/embd-input.h examples/embd-input/embd-input-lib.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) $(OBJS)
|
|
||||||
$(CXX) --shared $(CXXFLAGS) $(filter-out %.h,$(filter-out %.hpp,$^)) -o $@ $(LDFLAGS)
|
|
||||||
|
|
||||||
|
|
||||||
embd-input-test: $(LIB_PRE)embdinput$(DSO_EXT) examples/embd-input/embd-input-test.cpp build-info.h ggml.o llama.o $(COMMON_DEPS) $(OBJS)
|
|
||||||
$(CXX) $(CXXFLAGS) $(filter-out %$(DSO_EXT),$(filter-out %.h,$(filter-out %.hpp,$^))) -o $@ $(LDFLAGS) -L. -lembdinput
|
|
||||||
|
|
||||||
gguf: examples/gguf/gguf.cpp ggml.o llama.o $(OBJS)
|
gguf: examples/gguf/gguf.cpp ggml.o llama.o $(OBJS)
|
||||||
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
|
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
|
||||||
|
|
||||||
|
@ -962,7 +962,6 @@ docker run --gpus all -v /path/to/models:/models local/llama.cpp:light-cuda -m /
|
|||||||
|
|
||||||
- [main](./examples/main/README.md)
|
- [main](./examples/main/README.md)
|
||||||
- [server](./examples/server/README.md)
|
- [server](./examples/server/README.md)
|
||||||
- [embd-input](./examples/embd-input/README.md)
|
|
||||||
- [jeopardy](./examples/jeopardy/README.md)
|
- [jeopardy](./examples/jeopardy/README.md)
|
||||||
- [BLIS](./docs/BLIS.md)
|
- [BLIS](./docs/BLIS.md)
|
||||||
- [Performance troubleshooting](./docs/token_generation_performance_tips.md)
|
- [Performance troubleshooting](./docs/token_generation_performance_tips.md)
|
||||||
|
@ -107,7 +107,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
|||||||
std::string arg;
|
std::string arg;
|
||||||
gpt_params default_params;
|
gpt_params default_params;
|
||||||
const std::string arg_prefix = "--";
|
const std::string arg_prefix = "--";
|
||||||
llama_sampling_params & sparams = params.sampling_params;
|
llama_sampling_params & sparams = params.sparams;
|
||||||
|
|
||||||
for (int i = 1; i < argc; i++) {
|
for (int i = 1; i < argc; i++) {
|
||||||
arg = argv[i];
|
arg = argv[i];
|
||||||
@ -241,25 +241,26 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
|||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
sparams.repeat_last_n = std::stoi(argv[i]);
|
sparams.penalty_last_n = std::stoi(argv[i]);
|
||||||
|
sparams.n_prev = std::max(sparams.n_prev, sparams.penalty_last_n);
|
||||||
} else if (arg == "--repeat-penalty") {
|
} else if (arg == "--repeat-penalty") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
sparams.repeat_penalty = std::stof(argv[i]);
|
sparams.penalty_repeat = std::stof(argv[i]);
|
||||||
} else if (arg == "--frequency-penalty") {
|
} else if (arg == "--frequency-penalty") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
sparams.frequency_penalty = std::stof(argv[i]);
|
sparams.penalty_freq = std::stof(argv[i]);
|
||||||
} else if (arg == "--presence-penalty") {
|
} else if (arg == "--presence-penalty") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
sparams.presence_penalty = std::stof(argv[i]);
|
sparams.penalty_present = std::stof(argv[i]);
|
||||||
} else if (arg == "--mirostat") {
|
} else if (arg == "--mirostat") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
@ -572,7 +573,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
|||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
params.grammar = argv[i];
|
sparams.grammar = argv[i];
|
||||||
} else if (arg == "--grammar-file") {
|
} else if (arg == "--grammar-file") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
@ -587,7 +588,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
|||||||
std::copy(
|
std::copy(
|
||||||
std::istreambuf_iterator<char>(file),
|
std::istreambuf_iterator<char>(file),
|
||||||
std::istreambuf_iterator<char>(),
|
std::istreambuf_iterator<char>(),
|
||||||
std::back_inserter(params.grammar)
|
std::back_inserter(sparams.grammar)
|
||||||
);
|
);
|
||||||
#ifndef LOG_DISABLE_LOGS
|
#ifndef LOG_DISABLE_LOGS
|
||||||
// Parse args for logging parameters
|
// Parse args for logging parameters
|
||||||
@ -640,7 +641,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
||||||
const llama_sampling_params & sparams = params.sampling_params;
|
const llama_sampling_params & sparams = params.sparams;
|
||||||
|
|
||||||
printf("usage: %s [options]\n", argv[0]);
|
printf("usage: %s [options]\n", argv[0]);
|
||||||
printf("\n");
|
printf("\n");
|
||||||
@ -678,10 +679,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
|||||||
printf(" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)sparams.top_p);
|
printf(" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)sparams.top_p);
|
||||||
printf(" --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)sparams.tfs_z);
|
printf(" --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)sparams.tfs_z);
|
||||||
printf(" --typical N locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)\n", (double)sparams.typical_p);
|
printf(" --typical N locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)\n", (double)sparams.typical_p);
|
||||||
printf(" --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", sparams.repeat_last_n);
|
printf(" --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", sparams.penalty_last_n);
|
||||||
printf(" --repeat-penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)sparams.repeat_penalty);
|
printf(" --repeat-penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)sparams.penalty_repeat);
|
||||||
printf(" --presence-penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.presence_penalty);
|
printf(" --presence-penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.penalty_present);
|
||||||
printf(" --frequency-penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.frequency_penalty);
|
printf(" --frequency-penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.penalty_freq);
|
||||||
printf(" --mirostat N use Mirostat sampling.\n");
|
printf(" --mirostat N use Mirostat sampling.\n");
|
||||||
printf(" Top K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n");
|
printf(" Top K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n");
|
||||||
printf(" (default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)\n", sparams.mirostat);
|
printf(" (default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)\n", sparams.mirostat);
|
||||||
@ -878,7 +879,7 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (params.ignore_eos) {
|
if (params.ignore_eos) {
|
||||||
params.sampling_params.logit_bias[llama_token_eos(lctx)] = -INFINITY;
|
params.sparams.logit_bias[llama_token_eos(lctx)] = -INFINITY;
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
@ -1123,7 +1124,7 @@ std::string get_sortable_timestamp() {
|
|||||||
|
|
||||||
void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const llama_context * lctx,
|
void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const llama_context * lctx,
|
||||||
const std::string & timestamp, const std::vector<int> & prompt_tokens, const char * model_desc) {
|
const std::string & timestamp, const std::vector<int> & prompt_tokens, const char * model_desc) {
|
||||||
const llama_sampling_params & sparams = params.sampling_params;
|
const llama_sampling_params & sparams = params.sparams;
|
||||||
|
|
||||||
fprintf(stream, "build_commit: %s\n", BUILD_COMMIT);
|
fprintf(stream, "build_commit: %s\n", BUILD_COMMIT);
|
||||||
fprintf(stream, "build_number: %d\n", BUILD_NUMBER);
|
fprintf(stream, "build_number: %d\n", BUILD_NUMBER);
|
||||||
@ -1178,8 +1179,8 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
|
|||||||
fprintf(stream, "ctx_size: %d # default: 512\n", params.n_ctx);
|
fprintf(stream, "ctx_size: %d # default: 512\n", params.n_ctx);
|
||||||
fprintf(stream, "escape: %s # default: false\n", params.escape ? "true" : "false");
|
fprintf(stream, "escape: %s # default: false\n", params.escape ? "true" : "false");
|
||||||
fprintf(stream, "file: # never logged, see prompt instead. Can still be specified for input.\n");
|
fprintf(stream, "file: # never logged, see prompt instead. Can still be specified for input.\n");
|
||||||
fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", sparams.frequency_penalty);
|
fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", sparams.penalty_freq);
|
||||||
dump_string_yaml_multiline(stream, "grammar", params.grammar.c_str());
|
dump_string_yaml_multiline(stream, "grammar", sparams.grammar.c_str());
|
||||||
fprintf(stream, "grammar-file: # never logged, see grammar instead. Can still be specified for input.\n");
|
fprintf(stream, "grammar-file: # never logged, see grammar instead. Can still be specified for input.\n");
|
||||||
fprintf(stream, "hellaswag: %s # default: false\n", params.hellaswag ? "true" : "false");
|
fprintf(stream, "hellaswag: %s # default: false\n", params.hellaswag ? "true" : "false");
|
||||||
fprintf(stream, "hellaswag_tasks: %zu # default: 400\n", params.hellaswag_tasks);
|
fprintf(stream, "hellaswag_tasks: %zu # default: 400\n", params.hellaswag_tasks);
|
||||||
@ -1238,14 +1239,14 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
|
|||||||
fprintf(stream, "numa: %s # default: false\n", params.numa ? "true" : "false");
|
fprintf(stream, "numa: %s # default: false\n", params.numa ? "true" : "false");
|
||||||
fprintf(stream, "ppl_output_type: %d # default: 0\n", params.ppl_output_type);
|
fprintf(stream, "ppl_output_type: %d # default: 0\n", params.ppl_output_type);
|
||||||
fprintf(stream, "ppl_stride: %d # default: 0\n", params.ppl_stride);
|
fprintf(stream, "ppl_stride: %d # default: 0\n", params.ppl_stride);
|
||||||
fprintf(stream, "presence_penalty: %f # default: 0.0\n", sparams.presence_penalty);
|
fprintf(stream, "presence_penalty: %f # default: 0.0\n", sparams.penalty_present);
|
||||||
dump_string_yaml_multiline(stream, "prompt", params.prompt.c_str());
|
dump_string_yaml_multiline(stream, "prompt", params.prompt.c_str());
|
||||||
fprintf(stream, "prompt_cache: %s\n", params.path_prompt_cache.c_str());
|
fprintf(stream, "prompt_cache: %s\n", params.path_prompt_cache.c_str());
|
||||||
fprintf(stream, "prompt_cache_all: %s # default: false\n", params.prompt_cache_all ? "true" : "false");
|
fprintf(stream, "prompt_cache_all: %s # default: false\n", params.prompt_cache_all ? "true" : "false");
|
||||||
fprintf(stream, "prompt_cache_ro: %s # default: false\n", params.prompt_cache_ro ? "true" : "false");
|
fprintf(stream, "prompt_cache_ro: %s # default: false\n", params.prompt_cache_ro ? "true" : "false");
|
||||||
dump_vector_int_yaml(stream, "prompt_tokens", prompt_tokens);
|
dump_vector_int_yaml(stream, "prompt_tokens", prompt_tokens);
|
||||||
fprintf(stream, "random_prompt: %s # default: false\n", params.random_prompt ? "true" : "false");
|
fprintf(stream, "random_prompt: %s # default: false\n", params.random_prompt ? "true" : "false");
|
||||||
fprintf(stream, "repeat_penalty: %f # default: 1.1\n", sparams.repeat_penalty);
|
fprintf(stream, "repeat_penalty: %f # default: 1.1\n", sparams.penalty_repeat);
|
||||||
|
|
||||||
fprintf(stream, "reverse_prompt:\n");
|
fprintf(stream, "reverse_prompt:\n");
|
||||||
for (std::string ap : params.antiprompt) {
|
for (std::string ap : params.antiprompt) {
|
||||||
|
@ -56,7 +56,7 @@ struct gpt_params {
|
|||||||
float rope_freq_scale = 0.0f; // RoPE frequency scaling factor
|
float rope_freq_scale = 0.0f; // RoPE frequency scaling factor
|
||||||
|
|
||||||
// // sampling parameters
|
// // sampling parameters
|
||||||
struct llama_sampling_params sampling_params;
|
struct llama_sampling_params sparams;
|
||||||
|
|
||||||
std::string model = "models/7B/ggml-model-f16.gguf"; // model path
|
std::string model = "models/7B/ggml-model-f16.gguf"; // model path
|
||||||
std::string model_draft = ""; // draft model for speculative decoding
|
std::string model_draft = ""; // draft model for speculative decoding
|
||||||
@ -66,7 +66,6 @@ struct gpt_params {
|
|||||||
std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state
|
std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state
|
||||||
std::string input_prefix = ""; // string to prefix user inputs with
|
std::string input_prefix = ""; // string to prefix user inputs with
|
||||||
std::string input_suffix = ""; // string to suffix user inputs with
|
std::string input_suffix = ""; // string to suffix user inputs with
|
||||||
std::string grammar = ""; // optional BNF-like grammar to constrain sampling
|
|
||||||
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
|
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
|
||||||
std::string logdir = ""; // directory in which to save YAML log files
|
std::string logdir = ""; // directory in which to save YAML log files
|
||||||
|
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
#include "sampling.h"
|
#include "sampling.h"
|
||||||
|
|
||||||
struct llama_sampling_context * llama_sampling_init(const struct gpt_params & params) {
|
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) {
|
||||||
struct llama_sampling_context * result = new llama_sampling_context();
|
struct llama_sampling_context * result = new llama_sampling_context();
|
||||||
|
|
||||||
result->params = params.sampling_params;
|
result->params = params;
|
||||||
result->grammar = nullptr;
|
result->grammar = nullptr;
|
||||||
|
|
||||||
// if there is a grammar, parse it
|
// if there is a grammar, parse it
|
||||||
@ -23,7 +23,7 @@ struct llama_sampling_context * llama_sampling_init(const struct gpt_params & pa
|
|||||||
grammar_rules.size(), result->parsed_grammar.symbol_ids.at("root"));
|
grammar_rules.size(), result->parsed_grammar.symbol_ids.at("root"));
|
||||||
}
|
}
|
||||||
|
|
||||||
result->prev.resize(params.n_ctx);
|
result->prev.resize(params.n_prev);
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
@ -66,25 +66,56 @@ void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * ds
|
|||||||
dst->prev = src->prev;
|
dst->prev = src->prev;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
llama_token llama_sampling_last(llama_sampling_context * ctx) {
|
||||||
|
return ctx->prev.back();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama_context * ctx_main, int n) {
|
||||||
|
const int size = ctx_sampling->prev.size();
|
||||||
|
|
||||||
|
n = std::min(n, size);
|
||||||
|
|
||||||
|
std::string result;
|
||||||
|
|
||||||
|
for (int i = size - n; i < size; i++) {
|
||||||
|
result += llama_token_to_piece(ctx_main, ctx_sampling->prev[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string llama_sampling_print(const llama_sampling_params & params) {
|
||||||
|
char result[1024];
|
||||||
|
|
||||||
|
snprintf(result, sizeof(result),
|
||||||
|
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
|
||||||
|
"\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, typical_p = %.3f, temp = %.3f\n"
|
||||||
|
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
|
||||||
|
params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present,
|
||||||
|
params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp,
|
||||||
|
params.mirostat, params.mirostat_eta, params.mirostat_tau);
|
||||||
|
|
||||||
|
return std::string(result);
|
||||||
|
}
|
||||||
|
|
||||||
llama_token llama_sampling_sample(
|
llama_token llama_sampling_sample(
|
||||||
struct llama_sampling_context * ctx_sampling,
|
struct llama_sampling_context * ctx_sampling,
|
||||||
struct llama_context * ctx_main,
|
struct llama_context * ctx_main,
|
||||||
struct llama_context * ctx_cfg,
|
struct llama_context * ctx_cfg,
|
||||||
const int idx) {
|
const int idx) {
|
||||||
const int n_ctx = llama_n_ctx(ctx_main);
|
|
||||||
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
|
|
||||||
|
|
||||||
const llama_sampling_params & params = ctx_sampling->params;
|
const llama_sampling_params & params = ctx_sampling->params;
|
||||||
|
|
||||||
|
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
|
||||||
|
|
||||||
const float temp = params.temp;
|
const float temp = params.temp;
|
||||||
const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k;
|
const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k;
|
||||||
const float top_p = params.top_p;
|
const float top_p = params.top_p;
|
||||||
const float tfs_z = params.tfs_z;
|
const float tfs_z = params.tfs_z;
|
||||||
const float typical_p = params.typical_p;
|
const float typical_p = params.typical_p;
|
||||||
const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n;
|
const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n;
|
||||||
const float repeat_penalty = params.repeat_penalty;
|
const float penalty_repeat = params.penalty_repeat;
|
||||||
const float alpha_presence = params.presence_penalty;
|
const float penalty_freq = params.penalty_freq;
|
||||||
const float alpha_frequency = params.frequency_penalty;
|
const float penalty_present = params.penalty_present;
|
||||||
const int mirostat = params.mirostat;
|
const int mirostat = params.mirostat;
|
||||||
const float mirostat_tau = params.mirostat_tau;
|
const float mirostat_tau = params.mirostat_tau;
|
||||||
const float mirostat_eta = params.mirostat_eta;
|
const float mirostat_eta = params.mirostat_eta;
|
||||||
@ -97,7 +128,7 @@ llama_token llama_sampling_sample(
|
|||||||
|
|
||||||
float * logits = llama_get_logits_ith(ctx_main, idx);
|
float * logits = llama_get_logits_ith(ctx_main, idx);
|
||||||
|
|
||||||
// Apply params.logit_bias map
|
// apply params.logit_bias map
|
||||||
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
|
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
|
||||||
logits[it->first] += it->second;
|
logits[it->first] += it->second;
|
||||||
}
|
}
|
||||||
@ -117,14 +148,10 @@ llama_token llama_sampling_sample(
|
|||||||
// apply penalties
|
// apply penalties
|
||||||
if (!prev.empty()) {
|
if (!prev.empty()) {
|
||||||
const float nl_logit = logits[llama_token_nl(ctx_main)];
|
const float nl_logit = logits[llama_token_nl(ctx_main)];
|
||||||
const int last_n_repeat = std::min(std::min((int)prev.size(), repeat_last_n), n_ctx);
|
|
||||||
|
|
||||||
llama_sample_repetition_penalty(ctx_main, &cur_p,
|
llama_sample_repetition_penalties(ctx_main, &cur_p,
|
||||||
prev.data() + prev.size() - last_n_repeat,
|
prev.data() + prev.size() - penalty_last_n,
|
||||||
last_n_repeat, repeat_penalty);
|
penalty_last_n, penalty_repeat, penalty_freq, penalty_present);
|
||||||
llama_sample_frequency_and_presence_penalties(ctx_main, &cur_p,
|
|
||||||
prev.data() + prev.size() - last_n_repeat,
|
|
||||||
last_n_repeat, alpha_frequency, alpha_presence);
|
|
||||||
|
|
||||||
if (!penalize_nl) {
|
if (!penalize_nl) {
|
||||||
for (size_t idx = 0; idx < cur_p.size; idx++) {
|
for (size_t idx = 0; idx < cur_p.size; idx++) {
|
||||||
@ -141,7 +168,7 @@ llama_token llama_sampling_sample(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (temp <= 0) {
|
if (temp <= 0) {
|
||||||
// Greedy sampling
|
// greedy sampling
|
||||||
id = llama_sample_token_greedy(ctx_main, &cur_p);
|
id = llama_sample_token_greedy(ctx_main, &cur_p);
|
||||||
} else {
|
} else {
|
||||||
if (mirostat == 1) {
|
if (mirostat == 1) {
|
||||||
@ -152,8 +179,9 @@ llama_token llama_sampling_sample(
|
|||||||
llama_sample_temp(ctx_main, &cur_p, temp);
|
llama_sample_temp(ctx_main, &cur_p, temp);
|
||||||
id = llama_sample_token_mirostat_v2(ctx_main, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling->mirostat_mu);
|
id = llama_sample_token_mirostat_v2(ctx_main, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling->mirostat_mu);
|
||||||
} else {
|
} else {
|
||||||
// Temperature sampling
|
// temperature sampling
|
||||||
size_t min_keep = std::max(1, params.n_probs);
|
size_t min_keep = std::max(1, params.n_probs);
|
||||||
|
|
||||||
llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep);
|
llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep);
|
||||||
llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep);
|
llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep);
|
||||||
llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep);
|
llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep);
|
||||||
@ -183,11 +211,12 @@ llama_token llama_sampling_sample(
|
|||||||
void llama_sampling_accept(
|
void llama_sampling_accept(
|
||||||
struct llama_sampling_context * ctx_sampling,
|
struct llama_sampling_context * ctx_sampling,
|
||||||
struct llama_context * ctx_main,
|
struct llama_context * ctx_main,
|
||||||
llama_token id) {
|
llama_token id,
|
||||||
|
bool apply_grammar) {
|
||||||
ctx_sampling->prev.erase(ctx_sampling->prev.begin());
|
ctx_sampling->prev.erase(ctx_sampling->prev.begin());
|
||||||
ctx_sampling->prev.push_back(id);
|
ctx_sampling->prev.push_back(id);
|
||||||
|
|
||||||
if (ctx_sampling->grammar != NULL) {
|
if (ctx_sampling->grammar != NULL && apply_grammar) {
|
||||||
llama_grammar_accept_token(ctx_main, ctx_sampling->grammar, id);
|
llama_grammar_accept_token(ctx_main, ctx_sampling->grammar, id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -10,30 +10,30 @@
|
|||||||
|
|
||||||
// sampling parameters
|
// sampling parameters
|
||||||
typedef struct llama_sampling_params {
|
typedef struct llama_sampling_params {
|
||||||
|
int32_t n_prev = 64; // number of previous tokens to remember
|
||||||
|
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
|
||||||
int32_t top_k = 40; // <= 0 to use vocab size
|
int32_t top_k = 40; // <= 0 to use vocab size
|
||||||
float top_p = 0.95f; // 1.0 = disabled
|
float top_p = 0.95f; // 1.0 = disabled
|
||||||
float tfs_z = 1.00f; // 1.0 = disabled
|
float tfs_z = 1.00f; // 1.0 = disabled
|
||||||
float typical_p = 1.00f; // 1.0 = disabled
|
float typical_p = 1.00f; // 1.0 = disabled
|
||||||
float temp = 0.80f; // 1.0 = disabled
|
float temp = 0.80f; // 1.0 = disabled
|
||||||
float repeat_penalty = 1.10f; // 1.0 = disabled
|
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
|
||||||
int32_t repeat_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
|
float penalty_repeat = 1.10f; // 1.0 = disabled
|
||||||
float frequency_penalty = 0.00f; // 0.0 = disabled
|
float penalty_freq = 0.00f; // 0.0 = disabled
|
||||||
float presence_penalty = 0.00f; // 0.0 = disabled
|
float penalty_present = 0.00f; // 0.0 = disabled
|
||||||
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
|
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
|
||||||
float mirostat_tau = 5.00f; // target entropy
|
float mirostat_tau = 5.00f; // target entropy
|
||||||
float mirostat_eta = 0.10f; // learning rate
|
float mirostat_eta = 0.10f; // learning rate
|
||||||
|
|
||||||
bool penalize_nl = true; // consider newlines as a repeatable token
|
bool penalize_nl = true; // consider newlines as a repeatable token
|
||||||
|
|
||||||
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
|
std::string grammar; // optional BNF-like grammar to constrain sampling
|
||||||
|
|
||||||
// Classifier-Free Guidance
|
// Classifier-Free Guidance
|
||||||
// https://arxiv.org/abs/2306.17806
|
// https://arxiv.org/abs/2306.17806
|
||||||
std::string cfg_negative_prompt; // string to help guidance
|
std::string cfg_negative_prompt; // string to help guidance
|
||||||
float cfg_scale = 1.f; // How strong is guidance
|
float cfg_scale = 1.f; // how strong is guidance
|
||||||
|
|
||||||
std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens
|
std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens
|
||||||
|
|
||||||
} llama_sampling_params;
|
} llama_sampling_params;
|
||||||
|
|
||||||
// general sampler context
|
// general sampler context
|
||||||
@ -58,7 +58,7 @@ struct llama_sampling_context {
|
|||||||
#include "common.h"
|
#include "common.h"
|
||||||
|
|
||||||
// Create a new sampling context instance.
|
// Create a new sampling context instance.
|
||||||
struct llama_sampling_context * llama_sampling_init(const struct gpt_params & params);
|
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params);
|
||||||
|
|
||||||
void llama_sampling_free(struct llama_sampling_context * ctx);
|
void llama_sampling_free(struct llama_sampling_context * ctx);
|
||||||
|
|
||||||
@ -70,6 +70,15 @@ void llama_sampling_reset(llama_sampling_context * ctx);
|
|||||||
// Copy the sampler context
|
// Copy the sampler context
|
||||||
void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst);
|
void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst);
|
||||||
|
|
||||||
|
// Get the last sampled token
|
||||||
|
llama_token llama_sampling_last(llama_sampling_context * ctx);
|
||||||
|
|
||||||
|
// Get a string representation of the last sampled tokens
|
||||||
|
std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama_context * ctx_main, int n);
|
||||||
|
|
||||||
|
// Print sampling parameters into a string
|
||||||
|
std::string llama_sampling_print(const llama_sampling_params & params);
|
||||||
|
|
||||||
// this is a common sampling function used across the examples for convenience
|
// this is a common sampling function used across the examples for convenience
|
||||||
// it can serve as a starting point for implementing your own sampling function
|
// it can serve as a starting point for implementing your own sampling function
|
||||||
// Note: When using multiple sequences, it is the caller's responsibility to call
|
// Note: When using multiple sequences, it is the caller's responsibility to call
|
||||||
@ -96,4 +105,5 @@ llama_token llama_sampling_sample(
|
|||||||
void llama_sampling_accept(
|
void llama_sampling_accept(
|
||||||
struct llama_sampling_context * ctx_sampling,
|
struct llama_sampling_context * ctx_sampling,
|
||||||
struct llama_context * ctx_main,
|
struct llama_context * ctx_main,
|
||||||
llama_token id);
|
llama_token id,
|
||||||
|
bool apply_grammar);
|
||||||
|
@ -12,26 +12,26 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR})
|
|||||||
|
|
||||||
if (EMSCRIPTEN)
|
if (EMSCRIPTEN)
|
||||||
else()
|
else()
|
||||||
add_subdirectory(main)
|
|
||||||
add_subdirectory(quantize)
|
|
||||||
add_subdirectory(quantize-stats)
|
|
||||||
add_subdirectory(perplexity)
|
|
||||||
add_subdirectory(embedding)
|
|
||||||
add_subdirectory(save-load-state)
|
|
||||||
add_subdirectory(benchmark)
|
|
||||||
add_subdirectory(baby-llama)
|
add_subdirectory(baby-llama)
|
||||||
add_subdirectory(train-text-from-scratch)
|
|
||||||
add_subdirectory(finetune)
|
|
||||||
add_subdirectory(convert-llama2c-to-ggml)
|
|
||||||
add_subdirectory(simple)
|
|
||||||
add_subdirectory(batched)
|
add_subdirectory(batched)
|
||||||
add_subdirectory(batched-bench)
|
add_subdirectory(batched-bench)
|
||||||
add_subdirectory(speculative)
|
|
||||||
add_subdirectory(parallel)
|
|
||||||
add_subdirectory(embd-input)
|
|
||||||
add_subdirectory(llava)
|
|
||||||
add_subdirectory(llama-bench)
|
|
||||||
add_subdirectory(beam-search)
|
add_subdirectory(beam-search)
|
||||||
|
add_subdirectory(benchmark)
|
||||||
|
add_subdirectory(convert-llama2c-to-ggml)
|
||||||
|
add_subdirectory(embedding)
|
||||||
|
add_subdirectory(finetune)
|
||||||
|
add_subdirectory(infill)
|
||||||
|
add_subdirectory(llama-bench)
|
||||||
|
add_subdirectory(llava)
|
||||||
|
add_subdirectory(main)
|
||||||
|
add_subdirectory(parallel)
|
||||||
|
add_subdirectory(perplexity)
|
||||||
|
add_subdirectory(quantize)
|
||||||
|
add_subdirectory(quantize-stats)
|
||||||
|
add_subdirectory(save-load-state)
|
||||||
|
add_subdirectory(simple)
|
||||||
|
add_subdirectory(speculative)
|
||||||
|
add_subdirectory(train-text-from-scratch)
|
||||||
if (LLAMA_METAL)
|
if (LLAMA_METAL)
|
||||||
add_subdirectory(metal)
|
add_subdirectory(metal)
|
||||||
endif()
|
endif()
|
||||||
|
4
examples/embd-input/.gitignore
vendored
4
examples/embd-input/.gitignore
vendored
@ -1,4 +0,0 @@
|
|||||||
PandaGPT
|
|
||||||
MiniGPT-4
|
|
||||||
*.pth
|
|
||||||
|
|
@ -1,17 +0,0 @@
|
|||||||
set(TARGET embdinput)
|
|
||||||
add_library(${TARGET} embd-input-lib.cpp embd-input.h)
|
|
||||||
install(TARGETS ${TARGET} LIBRARY)
|
|
||||||
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
|
||||||
target_compile_features(${TARGET} PRIVATE cxx_std_11)
|
|
||||||
if(TARGET BUILD_INFO)
|
|
||||||
add_dependencies(${TARGET} BUILD_INFO)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
set(TARGET embd-input-test)
|
|
||||||
add_executable(${TARGET} embd-input-test.cpp)
|
|
||||||
install(TARGETS ${TARGET} RUNTIME)
|
|
||||||
target_link_libraries(${TARGET} PRIVATE common llama embdinput ${CMAKE_THREAD_LIBS_INIT})
|
|
||||||
target_compile_features(${TARGET} PRIVATE cxx_std_11)
|
|
||||||
if(TARGET BUILD_INFO)
|
|
||||||
add_dependencies(${TARGET} BUILD_INFO)
|
|
||||||
endif()
|
|
@ -1,63 +0,0 @@
|
|||||||
### Examples for input embedding directly
|
|
||||||
|
|
||||||
## Requirement
|
|
||||||
build `libembdinput.so`
|
|
||||||
run the following comman in main dir (../../).
|
|
||||||
```
|
|
||||||
make
|
|
||||||
```
|
|
||||||
|
|
||||||
## [LLaVA](https://github.com/haotian-liu/LLaVA/) example (llava.py)
|
|
||||||
|
|
||||||
1. Obtian LLaVA model (following https://github.com/haotian-liu/LLaVA/ , use https://huggingface.co/liuhaotian/LLaVA-13b-delta-v1-1/).
|
|
||||||
2. Convert it to ggml format.
|
|
||||||
3. `llava_projection.pth` is [pytorch_model-00003-of-00003.bin](https://huggingface.co/liuhaotian/LLaVA-13b-delta-v1-1/blob/main/pytorch_model-00003-of-00003.bin).
|
|
||||||
|
|
||||||
```
|
|
||||||
import torch
|
|
||||||
|
|
||||||
bin_path = "../LLaVA-13b-delta-v1-1/pytorch_model-00003-of-00003.bin"
|
|
||||||
pth_path = "./examples/embd-input/llava_projection.pth"
|
|
||||||
|
|
||||||
dic = torch.load(bin_path)
|
|
||||||
used_key = ["model.mm_projector.weight","model.mm_projector.bias"]
|
|
||||||
torch.save({k: dic[k] for k in used_key}, pth_path)
|
|
||||||
```
|
|
||||||
4. Check the path of LLaVA model and `llava_projection.pth` in `llava.py`.
|
|
||||||
|
|
||||||
|
|
||||||
## [PandaGPT](https://github.com/yxuansu/PandaGPT) example (panda_gpt.py)
|
|
||||||
|
|
||||||
1. Obtian PandaGPT lora model from https://github.com/yxuansu/PandaGPT. Rename the file to `adapter_model.bin`. Use [convert-lora-to-ggml.py](../../convert-lora-to-ggml.py) to convert it to ggml format.
|
|
||||||
The `adapter_config.json` is
|
|
||||||
```
|
|
||||||
{
|
|
||||||
"peft_type": "LORA",
|
|
||||||
"fan_in_fan_out": false,
|
|
||||||
"bias": null,
|
|
||||||
"modules_to_save": null,
|
|
||||||
"r": 32,
|
|
||||||
"lora_alpha": 32,
|
|
||||||
"lora_dropout": 0.1,
|
|
||||||
"target_modules": ["q_proj", "k_proj", "v_proj", "o_proj"]
|
|
||||||
}
|
|
||||||
```
|
|
||||||
2. Papare the `vicuna` v0 model.
|
|
||||||
3. Obtain the [ImageBind](https://dl.fbaipublicfiles.com/imagebind/imagebind_huge.pth) model.
|
|
||||||
4. Clone the PandaGPT source.
|
|
||||||
```
|
|
||||||
git clone https://github.com/yxuansu/PandaGPT
|
|
||||||
```
|
|
||||||
5. Install the requirement of PandaGPT.
|
|
||||||
6. Check the path of PandaGPT source, ImageBind model, lora model and vicuna model in panda_gpt.py.
|
|
||||||
|
|
||||||
## [MiniGPT-4](https://github.com/Vision-CAIR/MiniGPT-4/) example (minigpt4.py)
|
|
||||||
|
|
||||||
1. Obtain MiniGPT-4 model from https://github.com/Vision-CAIR/MiniGPT-4/ and put it in `embd-input`.
|
|
||||||
2. Clone the MiniGPT-4 source.
|
|
||||||
```
|
|
||||||
git clone https://github.com/Vision-CAIR/MiniGPT-4/
|
|
||||||
```
|
|
||||||
3. Install the requirement of PandaGPT.
|
|
||||||
4. Papare the `vicuna` v0 model.
|
|
||||||
5. Check the path of MiniGPT-4 source, MiniGPT-4 model and vicuna model in `minigpt4.py`.
|
|
@ -1,221 +0,0 @@
|
|||||||
#include "build-info.h"
|
|
||||||
#include "common.h"
|
|
||||||
#include "embd-input.h"
|
|
||||||
|
|
||||||
#include <cassert>
|
|
||||||
#include <cinttypes>
|
|
||||||
#include <cmath>
|
|
||||||
#include <cstdio>
|
|
||||||
#include <cstring>
|
|
||||||
#include <ctime>
|
|
||||||
#include <fstream>
|
|
||||||
#include <iostream>
|
|
||||||
#include <string>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
static llama_context ** g_ctx;
|
|
||||||
|
|
||||||
extern "C" {
|
|
||||||
|
|
||||||
struct MyModel* create_mymodel(int argc, char ** argv) {
|
|
||||||
gpt_params params;
|
|
||||||
|
|
||||||
if (!gpt_params_parse(argc, argv, params)) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
print_build_info();
|
|
||||||
|
|
||||||
if (params.seed == LLAMA_DEFAULT_SEED) {
|
|
||||||
params.seed = uint32_t(time(NULL));
|
|
||||||
}
|
|
||||||
fprintf(stderr, "%s: seed = %d\n", __func__, params.seed);
|
|
||||||
|
|
||||||
llama_backend_init(params.numa);
|
|
||||||
|
|
||||||
llama_model * model;
|
|
||||||
llama_context * ctx;
|
|
||||||
|
|
||||||
g_ctx = &ctx;
|
|
||||||
|
|
||||||
// load the model and apply lora adapter, if any
|
|
||||||
std::tie(model, ctx) = llama_init_from_gpt_params(params);
|
|
||||||
if (model == NULL) {
|
|
||||||
fprintf(stderr, "%s: error: unable to load model\n", __func__);
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
// print system information
|
|
||||||
{
|
|
||||||
fprintf(stderr, "\n");
|
|
||||||
fprintf(stderr, "%s\n", get_system_info(params).c_str());
|
|
||||||
}
|
|
||||||
struct MyModel * ret = new MyModel();
|
|
||||||
ret->ctx = ctx;
|
|
||||||
ret->params = params;
|
|
||||||
ret->n_past = 0;
|
|
||||||
// printf("ctx: %d\n", ret->ctx);
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
void free_mymodel(struct MyModel * mymodel) {
|
|
||||||
llama_context * ctx = mymodel->ctx;
|
|
||||||
llama_print_timings(ctx);
|
|
||||||
llama_free(ctx);
|
|
||||||
delete mymodel;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
bool eval_float(void * model, float * input, int N){
|
|
||||||
MyModel * mymodel = (MyModel*)model;
|
|
||||||
llama_context * ctx = mymodel->ctx;
|
|
||||||
gpt_params params = mymodel->params;
|
|
||||||
int n_emb = llama_n_embd(llama_get_model(ctx));
|
|
||||||
int n_past = mymodel->n_past;
|
|
||||||
int n_batch = N; // params.n_batch;
|
|
||||||
|
|
||||||
for (int i = 0; i < (int) N; i += n_batch) {
|
|
||||||
int n_eval = (int) N - i;
|
|
||||||
if (n_eval > n_batch) {
|
|
||||||
n_eval = n_batch;
|
|
||||||
}
|
|
||||||
llama_batch batch = { int32_t(n_eval), nullptr, (input+i*n_emb), nullptr, nullptr, nullptr, nullptr, n_past, 1, 0, };
|
|
||||||
if (llama_decode(ctx, batch)) {
|
|
||||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
n_past += n_eval;
|
|
||||||
}
|
|
||||||
mymodel->n_past = n_past;
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool eval_tokens(void * model, std::vector<llama_token> tokens) {
|
|
||||||
MyModel * mymodel = (MyModel* )model;
|
|
||||||
llama_context * ctx;
|
|
||||||
ctx = mymodel->ctx;
|
|
||||||
gpt_params params = mymodel->params;
|
|
||||||
int n_past = mymodel->n_past;
|
|
||||||
for (int i = 0; i < (int) tokens.size(); i += params.n_batch) {
|
|
||||||
int n_eval = (int) tokens.size() - i;
|
|
||||||
if (n_eval > params.n_batch) {
|
|
||||||
n_eval = params.n_batch;
|
|
||||||
}
|
|
||||||
if (llama_decode(ctx, llama_batch_get_one(&tokens[i], n_eval, n_past, 0))) {
|
|
||||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
n_past += n_eval;
|
|
||||||
}
|
|
||||||
mymodel->n_past = n_past;
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool eval_id(struct MyModel* mymodel, int id) {
|
|
||||||
std::vector<llama_token> tokens;
|
|
||||||
tokens.push_back(id);
|
|
||||||
return eval_tokens(mymodel, tokens);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool eval_string(struct MyModel * mymodel,const char* str){
|
|
||||||
llama_context * ctx = mymodel->ctx;
|
|
||||||
std::string str2 = str;
|
|
||||||
std::vector<llama_token> embd_inp = ::llama_tokenize(ctx, str2, true);
|
|
||||||
eval_tokens(mymodel, embd_inp);
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
llama_token sampling_id(struct MyModel* mymodel) {
|
|
||||||
llama_context* ctx = mymodel->ctx;
|
|
||||||
gpt_params params = mymodel->params;
|
|
||||||
llama_sampling_params & sparams = params.sampling_params;
|
|
||||||
// int n_ctx = llama_n_ctx(ctx);
|
|
||||||
|
|
||||||
// out of user input, sample next token
|
|
||||||
const float temp = sparams.temp;
|
|
||||||
const int32_t top_k = sparams.top_k <= 0 ? llama_n_vocab(llama_get_model(ctx)) : sparams.top_k;
|
|
||||||
const float top_p = sparams.top_p;
|
|
||||||
const float tfs_z = sparams.tfs_z;
|
|
||||||
const float typical_p = sparams.typical_p;
|
|
||||||
// const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n;
|
|
||||||
// const float repeat_penalty = params.repeat_penalty;
|
|
||||||
// const float alpha_presence = params.presence_penalty;
|
|
||||||
// const float alpha_frequency = params.frequency_penalty;
|
|
||||||
const int mirostat = sparams.mirostat;
|
|
||||||
const float mirostat_tau = sparams.mirostat_tau;
|
|
||||||
const float mirostat_eta = sparams.mirostat_eta;
|
|
||||||
// const bool penalize_nl = params.penalize_nl;
|
|
||||||
|
|
||||||
llama_token id = 0;
|
|
||||||
{
|
|
||||||
auto logits = llama_get_logits(ctx);
|
|
||||||
auto n_vocab = llama_n_vocab(llama_get_model(ctx));
|
|
||||||
|
|
||||||
// Apply params.logit_bias map
|
|
||||||
for (auto it = sparams.logit_bias.begin(); it != sparams.logit_bias.end(); it++) {
|
|
||||||
logits[it->first] += it->second;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<llama_token_data> candidates;
|
|
||||||
candidates.reserve(n_vocab);
|
|
||||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
|
||||||
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
|
|
||||||
}
|
|
||||||
|
|
||||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
|
||||||
|
|
||||||
// TODO: Apply penalties
|
|
||||||
// float nl_logit = logits[llama_token_nl(ctx)];
|
|
||||||
// auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx);
|
|
||||||
// llama_sample_repetition_penalty(ctx, &candidates_p,
|
|
||||||
// last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
|
|
||||||
// last_n_repeat, repeat_penalty);
|
|
||||||
// llama_sample_frequency_and_presence_penalties(ctx, &candidates_p,
|
|
||||||
// last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
|
|
||||||
// last_n_repeat, alpha_frequency, alpha_presence);
|
|
||||||
// if (!penalize_nl) {
|
|
||||||
// logits[llama_token_nl(ctx)] = nl_logit;
|
|
||||||
// }
|
|
||||||
|
|
||||||
if (temp <= 0) {
|
|
||||||
// Greedy sampling
|
|
||||||
id = llama_sample_token_greedy(ctx, &candidates_p);
|
|
||||||
} else {
|
|
||||||
if (mirostat == 1) {
|
|
||||||
static float mirostat_mu = 2.0f * mirostat_tau;
|
|
||||||
const int mirostat_m = 100;
|
|
||||||
llama_sample_temp(ctx, &candidates_p, temp);
|
|
||||||
id = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
|
|
||||||
} else if (mirostat == 2) {
|
|
||||||
static float mirostat_mu = 2.0f * mirostat_tau;
|
|
||||||
llama_sample_temp(ctx, &candidates_p, temp);
|
|
||||||
id = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu);
|
|
||||||
} else {
|
|
||||||
// Temperature sampling
|
|
||||||
llama_sample_top_k(ctx, &candidates_p, top_k, 1);
|
|
||||||
llama_sample_tail_free(ctx, &candidates_p, tfs_z, 1);
|
|
||||||
llama_sample_typical(ctx, &candidates_p, typical_p, 1);
|
|
||||||
llama_sample_top_p(ctx, &candidates_p, top_p, 1);
|
|
||||||
llama_sample_temp(ctx, &candidates_p, temp);
|
|
||||||
id = llama_sample_token(ctx, &candidates_p);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return id;
|
|
||||||
}
|
|
||||||
|
|
||||||
const char * sampling(struct MyModel * mymodel) {
|
|
||||||
llama_context * ctx = mymodel->ctx;
|
|
||||||
int id = sampling_id(mymodel);
|
|
||||||
static std::string ret;
|
|
||||||
if (id == llama_token_eos(ctx)) {
|
|
||||||
ret = "</s>";
|
|
||||||
} else {
|
|
||||||
ret = llama_token_to_piece(ctx, id);
|
|
||||||
}
|
|
||||||
eval_id(mymodel, id);
|
|
||||||
return ret.c_str();
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
@ -1,35 +0,0 @@
|
|||||||
#include "embd-input.h"
|
|
||||||
#include <stdlib.h>
|
|
||||||
#include <random>
|
|
||||||
#include <string.h>
|
|
||||||
|
|
||||||
int main(int argc, char** argv) {
|
|
||||||
|
|
||||||
auto mymodel = create_mymodel(argc, argv);
|
|
||||||
int N = 10;
|
|
||||||
int max_tgt_len = 500;
|
|
||||||
int n_embd = llama_n_embd(llama_get_model(mymodel->ctx));
|
|
||||||
|
|
||||||
// add random float embd to test evaluation
|
|
||||||
float * data = new float[N*n_embd];
|
|
||||||
std::default_random_engine e;
|
|
||||||
std::uniform_real_distribution<float> u(0,1);
|
|
||||||
for (int i=0;i<N*n_embd;i++) {
|
|
||||||
data[i] = u(e);
|
|
||||||
}
|
|
||||||
|
|
||||||
eval_string(mymodel, "user: what is the color of the flag of UN?");
|
|
||||||
eval_float(mymodel, data, N);
|
|
||||||
eval_string(mymodel, "assistant:");
|
|
||||||
eval_string(mymodel, mymodel->params.prompt.c_str());
|
|
||||||
const char* tmp;
|
|
||||||
for (int i=0; i<max_tgt_len; i++) {
|
|
||||||
tmp = sampling(mymodel);
|
|
||||||
if (strcmp(tmp, "</s>")==0) break;
|
|
||||||
printf("%s", tmp);
|
|
||||||
fflush(stdout);
|
|
||||||
}
|
|
||||||
printf("\n");
|
|
||||||
free_mymodel(mymodel);
|
|
||||||
return 0;
|
|
||||||
}
|
|
@ -1,27 +0,0 @@
|
|||||||
#ifndef _EMBD_INPUT_H_
|
|
||||||
#define _EMBD_INPUT_H_ 1
|
|
||||||
|
|
||||||
#include "common.h"
|
|
||||||
#include "llama.h"
|
|
||||||
|
|
||||||
extern "C" {
|
|
||||||
|
|
||||||
typedef struct MyModel {
|
|
||||||
llama_context* ctx;
|
|
||||||
gpt_params params;
|
|
||||||
int n_past = 0;
|
|
||||||
} MyModel;
|
|
||||||
|
|
||||||
struct MyModel* create_mymodel(int argc, char ** argv);
|
|
||||||
|
|
||||||
bool eval_float(void* model, float* input, int N);
|
|
||||||
bool eval_tokens(void* model, std::vector<llama_token> tokens);
|
|
||||||
bool eval_id(struct MyModel* mymodel, int id);
|
|
||||||
bool eval_string(struct MyModel* mymodel, const char* str);
|
|
||||||
const char * sampling(struct MyModel* mymodel);
|
|
||||||
llama_token sampling_id(struct MyModel* mymodel);
|
|
||||||
void free_mymodel(struct MyModel* mymodel);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif
|
|
@ -1,72 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
import ctypes
|
|
||||||
from ctypes import cdll, c_char_p, c_void_p, POINTER, c_float, c_int
|
|
||||||
import numpy as np
|
|
||||||
import os
|
|
||||||
|
|
||||||
libc = cdll.LoadLibrary("./libembdinput.so")
|
|
||||||
libc.sampling.restype=c_char_p
|
|
||||||
libc.create_mymodel.restype=c_void_p
|
|
||||||
libc.eval_string.argtypes=[c_void_p, c_char_p]
|
|
||||||
libc.sampling.argtypes=[c_void_p]
|
|
||||||
libc.eval_float.argtypes=[c_void_p, POINTER(c_float), c_int]
|
|
||||||
|
|
||||||
|
|
||||||
class MyModel:
|
|
||||||
def __init__(self, args):
|
|
||||||
argc = len(args)
|
|
||||||
c_str = [c_char_p(i.encode()) for i in args]
|
|
||||||
args_c = (c_char_p * argc)(*c_str)
|
|
||||||
self.model = c_void_p(libc.create_mymodel(argc, args_c))
|
|
||||||
self.max_tgt_len = 512
|
|
||||||
self.print_string_eval = True
|
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
libc.free_mymodel(self.model)
|
|
||||||
|
|
||||||
def eval_float(self, x):
|
|
||||||
libc.eval_float(self.model, x.astype(np.float32).ctypes.data_as(POINTER(c_float)), x.shape[1])
|
|
||||||
|
|
||||||
def eval_string(self, x):
|
|
||||||
libc.eval_string(self.model, x.encode()) # c_char_p(x.encode()))
|
|
||||||
if self.print_string_eval:
|
|
||||||
print(x)
|
|
||||||
|
|
||||||
def eval_token(self, x):
|
|
||||||
libc.eval_id(self.model, x)
|
|
||||||
|
|
||||||
def sampling(self):
|
|
||||||
s = libc.sampling(self.model)
|
|
||||||
return s
|
|
||||||
|
|
||||||
def stream_generate(self, end="</s>"):
|
|
||||||
ret = b""
|
|
||||||
end = end.encode()
|
|
||||||
for _ in range(self.max_tgt_len):
|
|
||||||
tmp = self.sampling()
|
|
||||||
ret += tmp
|
|
||||||
yield tmp
|
|
||||||
if ret.endswith(end):
|
|
||||||
break
|
|
||||||
|
|
||||||
def generate_with_print(self, end="</s>"):
|
|
||||||
ret = b""
|
|
||||||
for i in self.stream_generate(end=end):
|
|
||||||
ret += i
|
|
||||||
print(i.decode(errors="replace"), end="", flush=True)
|
|
||||||
print("")
|
|
||||||
return ret.decode(errors="replace")
|
|
||||||
|
|
||||||
|
|
||||||
def generate(self, end="</s>"):
|
|
||||||
text = b"".join(self.stream_generate(end=end))
|
|
||||||
return text.decode(errors="replace")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
model = MyModel(["main", "--model", "../llama.cpp/models/ggml-vic13b-q4_1.bin", "-c", "2048"])
|
|
||||||
model.eval_string("""user: what is the color of the flag of UN?""")
|
|
||||||
x = np.random.random((5120,10))# , dtype=np.float32)
|
|
||||||
model.eval_float(x)
|
|
||||||
model.eval_string("""assistant:""")
|
|
||||||
for i in model.generate():
|
|
||||||
print(i.decode(errors="replace"), end="", flush=True)
|
|
@ -1,71 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
sys.path.insert(0, os.path.dirname(__file__))
|
|
||||||
from embd_input import MyModel
|
|
||||||
import numpy as np
|
|
||||||
from torch import nn
|
|
||||||
import torch
|
|
||||||
from transformers import CLIPVisionModel, CLIPImageProcessor
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
# model parameters from 'liuhaotian/LLaVA-13b-delta-v1-1'
|
|
||||||
vision_tower = "openai/clip-vit-large-patch14"
|
|
||||||
select_hidden_state_layer = -2
|
|
||||||
# (vision_config.image_size // vision_config.patch_size) ** 2
|
|
||||||
image_token_len = (224//14)**2
|
|
||||||
|
|
||||||
class Llava:
|
|
||||||
def __init__(self, args):
|
|
||||||
self.image_processor = CLIPImageProcessor.from_pretrained(vision_tower)
|
|
||||||
self.vision_tower = CLIPVisionModel.from_pretrained(vision_tower)
|
|
||||||
self.mm_projector = nn.Linear(1024, 5120)
|
|
||||||
self.model = MyModel(["main", *args])
|
|
||||||
|
|
||||||
def load_projection(self, path):
|
|
||||||
state = torch.load(path)
|
|
||||||
self.mm_projector.load_state_dict({
|
|
||||||
"weight": state["model.mm_projector.weight"],
|
|
||||||
"bias": state["model.mm_projector.bias"]})
|
|
||||||
|
|
||||||
def chat(self, question):
|
|
||||||
self.model.eval_string("user: ")
|
|
||||||
self.model.eval_string(question)
|
|
||||||
self.model.eval_string("\nassistant: ")
|
|
||||||
return self.model.generate_with_print()
|
|
||||||
|
|
||||||
def chat_with_image(self, image, question):
|
|
||||||
with torch.no_grad():
|
|
||||||
embd_image = self.image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
|
|
||||||
image_forward_out = self.vision_tower(embd_image.unsqueeze(0), output_hidden_states=True)
|
|
||||||
select_hidden_state = image_forward_out.hidden_states[select_hidden_state_layer]
|
|
||||||
image_feature = select_hidden_state[:, 1:]
|
|
||||||
embd_image = self.mm_projector(image_feature)
|
|
||||||
embd_image = embd_image.cpu().numpy()[0]
|
|
||||||
self.model.eval_string("user: ")
|
|
||||||
self.model.eval_token(32003-2) # im_start
|
|
||||||
self.model.eval_float(embd_image.T)
|
|
||||||
for i in range(image_token_len-embd_image.shape[0]):
|
|
||||||
self.model.eval_token(32003-3) # im_patch
|
|
||||||
self.model.eval_token(32003-1) # im_end
|
|
||||||
self.model.eval_string(question)
|
|
||||||
self.model.eval_string("\nassistant: ")
|
|
||||||
return self.model.generate_with_print()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__=="__main__":
|
|
||||||
# model form liuhaotian/LLaVA-13b-delta-v1-1
|
|
||||||
a = Llava(["--model", "./models/ggml-llava-13b-v1.1.bin", "-c", "2048"])
|
|
||||||
# Extract from https://huggingface.co/liuhaotian/LLaVA-13b-delta-v1-1/blob/main/pytorch_model-00003-of-00003.bin.
|
|
||||||
# Also here can use pytorch_model-00003-of-00003.bin directly.
|
|
||||||
a.load_projection(os.path.join(
|
|
||||||
os.path.dirname(__file__) ,
|
|
||||||
"llava_projection.pth"))
|
|
||||||
respose = a.chat_with_image(
|
|
||||||
Image.open("./media/llama1-logo.png").convert('RGB'),
|
|
||||||
"what is the text in the picture?")
|
|
||||||
respose
|
|
||||||
a.chat("what is the color of it?")
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1,129 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
sys.path.insert(0, os.path.dirname(__file__))
|
|
||||||
from embd_input import MyModel
|
|
||||||
import numpy as np
|
|
||||||
from torch import nn
|
|
||||||
import torch
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
minigpt4_path = os.path.join(os.path.dirname(__file__), "MiniGPT-4")
|
|
||||||
sys.path.insert(0, minigpt4_path)
|
|
||||||
from minigpt4.models.blip2 import Blip2Base
|
|
||||||
from minigpt4.processors.blip_processors import Blip2ImageEvalProcessor
|
|
||||||
|
|
||||||
|
|
||||||
class MiniGPT4(Blip2Base):
|
|
||||||
"""
|
|
||||||
MiniGPT4 model from https://github.com/Vision-CAIR/MiniGPT-4
|
|
||||||
"""
|
|
||||||
def __init__(self,
|
|
||||||
args,
|
|
||||||
vit_model="eva_clip_g",
|
|
||||||
q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth",
|
|
||||||
img_size=224,
|
|
||||||
drop_path_rate=0,
|
|
||||||
use_grad_checkpoint=False,
|
|
||||||
vit_precision="fp32",
|
|
||||||
freeze_vit=True,
|
|
||||||
freeze_qformer=True,
|
|
||||||
num_query_token=32,
|
|
||||||
llama_model="",
|
|
||||||
prompt_path="",
|
|
||||||
prompt_template="",
|
|
||||||
max_txt_len=32,
|
|
||||||
end_sym='\n',
|
|
||||||
low_resource=False, # use 8 bit and put vit in cpu
|
|
||||||
device_8bit=0
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.img_size = img_size
|
|
||||||
self.low_resource = low_resource
|
|
||||||
self.preprocessor = Blip2ImageEvalProcessor(img_size)
|
|
||||||
|
|
||||||
print('Loading VIT')
|
|
||||||
self.visual_encoder, self.ln_vision = self.init_vision_encoder(
|
|
||||||
vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision
|
|
||||||
)
|
|
||||||
print('Loading VIT Done')
|
|
||||||
print('Loading Q-Former')
|
|
||||||
self.Qformer, self.query_tokens = self.init_Qformer(
|
|
||||||
num_query_token, self.visual_encoder.num_features
|
|
||||||
)
|
|
||||||
self.Qformer.cls = None
|
|
||||||
self.Qformer.bert.embeddings.word_embeddings = None
|
|
||||||
self.Qformer.bert.embeddings.position_embeddings = None
|
|
||||||
for layer in self.Qformer.bert.encoder.layer:
|
|
||||||
layer.output = None
|
|
||||||
layer.intermediate = None
|
|
||||||
self.load_from_pretrained(url_or_filename=q_former_model)
|
|
||||||
print('Loading Q-Former Done')
|
|
||||||
self.llama_proj = nn.Linear(
|
|
||||||
self.Qformer.config.hidden_size, 5120 # self.llama_model.config.hidden_size
|
|
||||||
)
|
|
||||||
self.max_txt_len = max_txt_len
|
|
||||||
self.end_sym = end_sym
|
|
||||||
self.model = MyModel(["main", *args])
|
|
||||||
# system prompt
|
|
||||||
self.model.eval_string("Give the following image: <Img>ImageContent</Img>. "
|
|
||||||
"You will be able to see the image once I provide it to you. Please answer my questions."
|
|
||||||
"###")
|
|
||||||
|
|
||||||
def encode_img(self, image):
|
|
||||||
image = self.preprocessor(image)
|
|
||||||
image = image.unsqueeze(0)
|
|
||||||
device = image.device
|
|
||||||
if self.low_resource:
|
|
||||||
self.vit_to_cpu()
|
|
||||||
image = image.to("cpu")
|
|
||||||
|
|
||||||
with self.maybe_autocast():
|
|
||||||
image_embeds = self.ln_vision(self.visual_encoder(image)).to(device)
|
|
||||||
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
|
|
||||||
|
|
||||||
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
|
||||||
query_output = self.Qformer.bert(
|
|
||||||
query_embeds=query_tokens,
|
|
||||||
encoder_hidden_states=image_embeds,
|
|
||||||
encoder_attention_mask=image_atts,
|
|
||||||
return_dict=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
inputs_llama = self.llama_proj(query_output.last_hidden_state)
|
|
||||||
# atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device)
|
|
||||||
return inputs_llama
|
|
||||||
|
|
||||||
def load_projection(self, path):
|
|
||||||
state = torch.load(path)["model"]
|
|
||||||
self.llama_proj.load_state_dict({
|
|
||||||
"weight": state["llama_proj.weight"],
|
|
||||||
"bias": state["llama_proj.bias"]})
|
|
||||||
|
|
||||||
def chat(self, question):
|
|
||||||
self.model.eval_string("Human: ")
|
|
||||||
self.model.eval_string(question)
|
|
||||||
self.model.eval_string("\n### Assistant:")
|
|
||||||
return self.model.generate_with_print(end="###")
|
|
||||||
|
|
||||||
def chat_with_image(self, image, question):
|
|
||||||
with torch.no_grad():
|
|
||||||
embd_image = self.encode_img(image)
|
|
||||||
embd_image = embd_image.cpu().numpy()[0]
|
|
||||||
self.model.eval_string("Human: <Img>")
|
|
||||||
self.model.eval_float(embd_image.T)
|
|
||||||
self.model.eval_string("</Img> ")
|
|
||||||
self.model.eval_string(question)
|
|
||||||
self.model.eval_string("\n### Assistant:")
|
|
||||||
return self.model.generate_with_print(end="###")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__=="__main__":
|
|
||||||
a = MiniGPT4(["--model", "./models/ggml-vicuna-13b-v0-q4_1.bin", "-c", "2048"])
|
|
||||||
a.load_projection(os.path.join(
|
|
||||||
os.path.dirname(__file__) ,
|
|
||||||
"pretrained_minigpt4.pth"))
|
|
||||||
respose = a.chat_with_image(
|
|
||||||
Image.open("./media/llama1-logo.png").convert('RGB'),
|
|
||||||
"what is the text in the picture?")
|
|
||||||
a.chat("what is the color of it?")
|
|
@ -1,99 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
sys.path.insert(0, os.path.dirname(__file__))
|
|
||||||
from embd_input import MyModel
|
|
||||||
import numpy as np
|
|
||||||
from torch import nn
|
|
||||||
import torch
|
|
||||||
|
|
||||||
# use PandaGPT path
|
|
||||||
panda_gpt_path = os.path.join(os.path.dirname(__file__), "PandaGPT")
|
|
||||||
imagebind_ckpt_path = "./models/panda_gpt/"
|
|
||||||
|
|
||||||
sys.path.insert(0, os.path.join(panda_gpt_path,"code","model"))
|
|
||||||
from ImageBind.models import imagebind_model
|
|
||||||
from ImageBind import data
|
|
||||||
|
|
||||||
ModalityType = imagebind_model.ModalityType
|
|
||||||
max_tgt_len = 400
|
|
||||||
|
|
||||||
class PandaGPT:
|
|
||||||
def __init__(self, args):
|
|
||||||
self.visual_encoder,_ = imagebind_model.imagebind_huge(pretrained=True, store_path=imagebind_ckpt_path)
|
|
||||||
self.visual_encoder.eval()
|
|
||||||
self.llama_proj = nn.Linear(1024, 5120) # self.visual_hidden_size, 5120)
|
|
||||||
self.max_tgt_len = max_tgt_len
|
|
||||||
self.model = MyModel(["main", *args])
|
|
||||||
self.generated_text = ""
|
|
||||||
self.device = "cpu"
|
|
||||||
|
|
||||||
def load_projection(self, path):
|
|
||||||
state = torch.load(path, map_location="cpu")
|
|
||||||
self.llama_proj.load_state_dict({
|
|
||||||
"weight": state["llama_proj.weight"],
|
|
||||||
"bias": state["llama_proj.bias"]})
|
|
||||||
|
|
||||||
def eval_inputs(self, inputs):
|
|
||||||
self.model.eval_string("<Img>")
|
|
||||||
embds = self.extract_multimoal_feature(inputs)
|
|
||||||
for i in embds:
|
|
||||||
self.model.eval_float(i.T)
|
|
||||||
self.model.eval_string("</Img> ")
|
|
||||||
|
|
||||||
def chat(self, question):
|
|
||||||
return self.chat_with_image(None, question)
|
|
||||||
|
|
||||||
def chat_with_image(self, inputs, question):
|
|
||||||
if self.generated_text == "":
|
|
||||||
self.model.eval_string("###")
|
|
||||||
self.model.eval_string(" Human: ")
|
|
||||||
if inputs:
|
|
||||||
self.eval_inputs(inputs)
|
|
||||||
self.model.eval_string(question)
|
|
||||||
self.model.eval_string("\n### Assistant:")
|
|
||||||
ret = self.model.generate_with_print(end="###")
|
|
||||||
self.generated_text += ret
|
|
||||||
return ret
|
|
||||||
|
|
||||||
def extract_multimoal_feature(self, inputs):
|
|
||||||
features = []
|
|
||||||
for key in ["image", "audio", "video", "thermal"]:
|
|
||||||
if key + "_paths" in inputs:
|
|
||||||
embeds = self.encode_data(key, inputs[key+"_paths"])
|
|
||||||
features.append(embeds)
|
|
||||||
return features
|
|
||||||
|
|
||||||
def encode_data(self, data_type, data_paths):
|
|
||||||
|
|
||||||
type_map = {
|
|
||||||
"image": ModalityType.VISION,
|
|
||||||
"audio": ModalityType.AUDIO,
|
|
||||||
"video": ModalityType.VISION,
|
|
||||||
"thermal": ModalityType.THERMAL,
|
|
||||||
}
|
|
||||||
load_map = {
|
|
||||||
"image": data.load_and_transform_vision_data,
|
|
||||||
"audio": data.load_and_transform_audio_data,
|
|
||||||
"video": data.load_and_transform_video_data,
|
|
||||||
"thermal": data.load_and_transform_thermal_data
|
|
||||||
}
|
|
||||||
|
|
||||||
load_function = load_map[data_type]
|
|
||||||
key = type_map[data_type]
|
|
||||||
|
|
||||||
inputs = {key: load_function(data_paths, self.device)}
|
|
||||||
with torch.no_grad():
|
|
||||||
embeddings = self.visual_encoder(inputs)
|
|
||||||
embeds = embeddings[key]
|
|
||||||
embeds = self.llama_proj(embeds).cpu().numpy()
|
|
||||||
return embeds
|
|
||||||
|
|
||||||
|
|
||||||
if __name__=="__main__":
|
|
||||||
a = PandaGPT(["--model", "./models/ggml-vicuna-13b-v0-q4_1.bin", "-c", "2048", "--lora", "./models/panda_gpt/ggml-adapter-model.bin","--temp", "0"])
|
|
||||||
a.load_projection("./models/panda_gpt/adapter_model.bin")
|
|
||||||
a.chat_with_image(
|
|
||||||
{"image_paths": ["./media/llama1-logo.png"]},
|
|
||||||
"what is the text in the picture? 'llama' or 'lambda'?")
|
|
||||||
a.chat("what is the color of it?")
|
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -39,8 +39,8 @@ static gpt_params * g_params;
|
|||||||
static std::vector<llama_token> * g_input_tokens;
|
static std::vector<llama_token> * g_input_tokens;
|
||||||
static std::ostringstream * g_output_ss;
|
static std::ostringstream * g_output_ss;
|
||||||
static std::vector<llama_token> * g_output_tokens;
|
static std::vector<llama_token> * g_output_tokens;
|
||||||
static bool is_interacting = false;
|
|
||||||
|
|
||||||
|
static bool is_interacting = false;
|
||||||
|
|
||||||
static void write_logfile(
|
static void write_logfile(
|
||||||
const llama_context * ctx, const gpt_params & params, const llama_model * model,
|
const llama_context * ctx, const gpt_params & params, const llama_model * model,
|
||||||
@ -104,7 +104,7 @@ static void sigint_handler(int signo) {
|
|||||||
|
|
||||||
int main(int argc, char ** argv) {
|
int main(int argc, char ** argv) {
|
||||||
gpt_params params;
|
gpt_params params;
|
||||||
llama_sampling_params & sparams = params.sampling_params;
|
llama_sampling_params & sparams = params.sparams;
|
||||||
g_params = ¶ms;
|
g_params = ¶ms;
|
||||||
|
|
||||||
if (!gpt_params_parse(argc, argv, params)) {
|
if (!gpt_params_parse(argc, argv, params)) {
|
||||||
@ -358,36 +358,10 @@ int main(int argc, char ** argv) {
|
|||||||
LOG_TEE("Input suffix: '%s'\n", params.input_suffix.c_str());
|
LOG_TEE("Input suffix: '%s'\n", params.input_suffix.c_str());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
LOG_TEE("sampling: repeat_last_n = %d, repeat_penalty = %f, presence_penalty = %f, frequency_penalty = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_lr = %f, mirostat_ent = %f\n",
|
LOG_TEE("sampling: \n%s\n", llama_sampling_print(sparams).c_str());
|
||||||
sparams.repeat_last_n, sparams.repeat_penalty, sparams.presence_penalty, sparams.frequency_penalty, sparams.top_k, sparams.tfs_z, sparams.top_p, sparams.typical_p, sparams.temp, sparams.mirostat, sparams.mirostat_eta, sparams.mirostat_tau);
|
|
||||||
LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
|
LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
|
||||||
LOG_TEE("\n\n");
|
LOG_TEE("\n\n");
|
||||||
|
|
||||||
struct llama_grammar * grammar = NULL;
|
|
||||||
grammar_parser::parse_state parsed_grammar;
|
|
||||||
|
|
||||||
if (!params.grammar.empty()) {
|
|
||||||
parsed_grammar = grammar_parser::parse(params.grammar.c_str());
|
|
||||||
// will be empty (default) if there are parse errors
|
|
||||||
if (parsed_grammar.rules.empty()) {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
LOG_TEE("%s: grammar:\n", __func__);
|
|
||||||
grammar_parser::print_grammar(stderr, parsed_grammar);
|
|
||||||
LOG_TEE("\n");
|
|
||||||
|
|
||||||
{
|
|
||||||
auto it = sparams.logit_bias.find(llama_token_eos(ctx));
|
|
||||||
if (it != sparams.logit_bias.end() && it->second == -INFINITY) {
|
|
||||||
LOG_TEE("%s: warning: EOS token is disabled, which will cause most grammars to fail\n", __func__);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
|
|
||||||
grammar = llama_grammar_init(
|
|
||||||
grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
|
|
||||||
}
|
|
||||||
|
|
||||||
LOG_TEE("\n##### Infill mode #####\n\n");
|
LOG_TEE("\n##### Infill mode #####\n\n");
|
||||||
if (params.infill) {
|
if (params.infill) {
|
||||||
printf("\n************\n");
|
printf("\n************\n");
|
||||||
@ -430,7 +404,7 @@ int main(int argc, char ** argv) {
|
|||||||
std::vector<llama_token> embd;
|
std::vector<llama_token> embd;
|
||||||
std::vector<llama_token> embd_guidance;
|
std::vector<llama_token> embd_guidance;
|
||||||
|
|
||||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params);
|
struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams);
|
||||||
|
|
||||||
while (n_remain != 0 || params.interactive) {
|
while (n_remain != 0 || params.interactive) {
|
||||||
// predict
|
// predict
|
||||||
@ -549,7 +523,7 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance);
|
const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance);
|
||||||
|
|
||||||
llama_sampling_accept(ctx_sampling, ctx, id);
|
llama_sampling_accept(ctx_sampling, ctx, id, true);
|
||||||
|
|
||||||
LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str());
|
LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str());
|
||||||
|
|
||||||
@ -567,8 +541,11 @@ int main(int argc, char ** argv) {
|
|||||||
LOG("embd_inp.size(): %d, n_consumed: %d\n", (int) embd_inp.size(), n_consumed);
|
LOG("embd_inp.size(): %d, n_consumed: %d\n", (int) embd_inp.size(), n_consumed);
|
||||||
while ((int) embd_inp.size() > n_consumed) {
|
while ((int) embd_inp.size() > n_consumed) {
|
||||||
embd.push_back(embd_inp[n_consumed]);
|
embd.push_back(embd_inp[n_consumed]);
|
||||||
ctx_sampling->prev.erase(ctx_sampling->prev.begin());
|
|
||||||
ctx_sampling->prev.push_back(embd_inp[n_consumed]);
|
// push the prompt in the sampling context in order to apply repetition penalties later
|
||||||
|
// for the prompt, we don't apply grammar rules
|
||||||
|
llama_sampling_accept(ctx_sampling, ctx, embd_inp[n_consumed], false);
|
||||||
|
|
||||||
++n_consumed;
|
++n_consumed;
|
||||||
if ((int) embd.size() >= params.n_batch) {
|
if ((int) embd.size() >= params.n_batch) {
|
||||||
break;
|
break;
|
||||||
@ -600,7 +577,7 @@ int main(int argc, char ** argv) {
|
|||||||
if ((int) embd_inp.size() <= n_consumed) {
|
if ((int) embd_inp.size() <= n_consumed) {
|
||||||
|
|
||||||
// deal with eot token in infill mode
|
// deal with eot token in infill mode
|
||||||
if ((ctx_sampling->prev.back() == llama_token_eot(ctx) || is_interacting) && params.interactive){
|
if ((llama_sampling_last(ctx_sampling) == llama_token_eot(ctx) || is_interacting) && params.interactive){
|
||||||
if(is_interacting && !params.interactive_first) {
|
if(is_interacting && !params.interactive_first) {
|
||||||
// print an eot token
|
// print an eot token
|
||||||
printf("%s", llama_token_to_piece(ctx, llama_token_eot(ctx)).c_str());
|
printf("%s", llama_token_to_piece(ctx, llama_token_eot(ctx)).c_str());
|
||||||
@ -640,7 +617,7 @@ int main(int argc, char ** argv) {
|
|||||||
process_escapes(params.input_suffix);
|
process_escapes(params.input_suffix);
|
||||||
}
|
}
|
||||||
suff_rm_leading_spc = params.escape;
|
suff_rm_leading_spc = params.escape;
|
||||||
if (suff_rm_leading_spc && params.input_suffix.find_first_of(" ") == 0 && params.input_suffix.size() > 1) {
|
if (suff_rm_leading_spc && params.input_suffix.find_first_of(' ') == 0 && params.input_suffix.size() > 1) {
|
||||||
params.input_suffix.erase(0, 1);
|
params.input_suffix.erase(0, 1);
|
||||||
suff_rm_leading_spc = false;
|
suff_rm_leading_spc = false;
|
||||||
}
|
}
|
||||||
@ -667,7 +644,7 @@ int main(int argc, char ** argv) {
|
|||||||
is_interacting = false;
|
is_interacting = false;
|
||||||
}
|
}
|
||||||
// deal with end of text token in interactive mode
|
// deal with end of text token in interactive mode
|
||||||
else if (ctx_sampling->prev.back() == llama_token_eos(ctx)) {
|
else if (llama_sampling_last(ctx_sampling) == llama_token_eos(ctx)) {
|
||||||
LOG("found EOS token\n");
|
LOG("found EOS token\n");
|
||||||
|
|
||||||
if (params.interactive) {
|
if (params.interactive) {
|
||||||
@ -740,15 +717,7 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
if (n_past > 0) {
|
if (n_past > 0) {
|
||||||
if (is_interacting) {
|
if (is_interacting) {
|
||||||
// reset grammar state if we're restarting generation
|
llama_sampling_reset(ctx_sampling);
|
||||||
if (grammar != NULL) {
|
|
||||||
llama_grammar_free(grammar);
|
|
||||||
|
|
||||||
std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
|
|
||||||
grammar = llama_grammar_init(
|
|
||||||
grammar_rules.data(), grammar_rules.size(),
|
|
||||||
parsed_grammar.symbol_ids.at("root"));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
is_interacting = false;
|
is_interacting = false;
|
||||||
}
|
}
|
||||||
@ -778,9 +747,7 @@ int main(int argc, char ** argv) {
|
|||||||
llama_free(ctx);
|
llama_free(ctx);
|
||||||
llama_free_model(model);
|
llama_free_model(model);
|
||||||
|
|
||||||
if (grammar != NULL) {
|
llama_sampling_free(ctx_sampling);
|
||||||
llama_grammar_free(grammar);
|
|
||||||
}
|
|
||||||
llama_backend_free();
|
llama_backend_free();
|
||||||
|
|
||||||
#ifndef LOG_DISABLE_LOGS
|
#ifndef LOG_DISABLE_LOGS
|
||||||
|
@ -58,20 +58,22 @@ inline bool eval_string(struct llama_context * ctx_llama, const char* str, int n
|
|||||||
|
|
||||||
// TODO: use common/sampling.h
|
// TODO: use common/sampling.h
|
||||||
inline llama_token sample_id(llama_context * ctx_llama, gpt_params & params) {
|
inline llama_token sample_id(llama_context * ctx_llama, gpt_params & params) {
|
||||||
|
auto & sparams = params.sparams;
|
||||||
|
|
||||||
// out of user input, sample next token
|
// out of user input, sample next token
|
||||||
const float temp = params.sampling_params.temp;
|
const float temp = sparams.temp;
|
||||||
const int32_t top_k = params.sampling_params.top_k <= 0 ? llama_n_vocab(llama_get_model(ctx_llama)) : params.sampling_params.top_k;
|
const int32_t top_k = sparams.top_k <= 0 ? llama_n_vocab(llama_get_model(ctx_llama)) : sparams.top_k;
|
||||||
const float top_p = params.sampling_params.top_p;
|
const float top_p = sparams.top_p;
|
||||||
const float tfs_z = params.sampling_params.tfs_z;
|
const float tfs_z = sparams.tfs_z;
|
||||||
const float typical_p = params.sampling_params.typical_p;
|
const float typical_p = sparams.typical_p;
|
||||||
// const int32_t repeat_last_n = params.sampling_params.repeat_last_n < 0 ? n_ctx : params.sampling_params.repeat_last_n;
|
// const int32_t repeat_last_n = sparams.repeat_last_n < 0 ? n_ctx : sparams.repeat_last_n;
|
||||||
// const float repeat_penalty = params.sampling_params.repeat_penalty;
|
// const float repeat_penalty = sparams.repeat_penalty;
|
||||||
// const float alpha_presence = params.sampling_params.presence_penalty;
|
// const float alpha_presence = sparams.presence_penalty;
|
||||||
// const float alpha_frequency = params.sampling_params.frequency_penalty;
|
// const float alpha_frequency = sparams.frequency_penalty;
|
||||||
const int mirostat = params.sampling_params.mirostat;
|
const int mirostat = sparams.mirostat;
|
||||||
const float mirostat_tau = params.sampling_params.mirostat_tau;
|
const float mirostat_tau = sparams.mirostat_tau;
|
||||||
const float mirostat_eta = params.sampling_params.mirostat_eta;
|
const float mirostat_eta = sparams.mirostat_eta;
|
||||||
// const bool penalize_nl = params.sampling_params.penalize_nl;
|
// const bool penalize_nl = sparams.penalize_nl;
|
||||||
|
|
||||||
llama_token id = 0;
|
llama_token id = 0;
|
||||||
{
|
{
|
||||||
@ -79,7 +81,7 @@ inline llama_token sample_id(llama_context * ctx_llama, gpt_params & params) {
|
|||||||
auto n_vocab = llama_n_vocab(llama_get_model(ctx_llama));
|
auto n_vocab = llama_n_vocab(llama_get_model(ctx_llama));
|
||||||
|
|
||||||
// Apply params.logit_bias map
|
// Apply params.logit_bias map
|
||||||
for (auto it = params.sampling_params.logit_bias.begin(); it != params.sampling_params.logit_bias.end(); it++) {
|
for (auto it = sparams.logit_bias.begin(); it != sparams.logit_bias.end(); it++) {
|
||||||
logits[it->first] += it->second;
|
logits[it->first] += it->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -108,7 +108,7 @@ int main(int argc, char ** argv) {
|
|||||||
if (!gpt_params_parse(argc, argv, params)) {
|
if (!gpt_params_parse(argc, argv, params)) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
llama_sampling_params & sparams = params.sampling_params;
|
llama_sampling_params & sparams = params.sparams;
|
||||||
|
|
||||||
#ifndef LOG_DISABLE_LOGS
|
#ifndef LOG_DISABLE_LOGS
|
||||||
log_set_target(log_filename_generator("main", "log"));
|
log_set_target(log_filename_generator("main", "log"));
|
||||||
@ -415,8 +415,7 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
LOG_TEE("sampling: repeat_last_n = %d, repeat_penalty = %f, presence_penalty = %f, frequency_penalty = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_lr = %f, mirostat_ent = %f\n",
|
LOG_TEE("sampling: \n%s\n", llama_sampling_print(sparams).c_str());
|
||||||
sparams.repeat_last_n, sparams.repeat_penalty, sparams.presence_penalty, sparams.frequency_penalty, sparams.top_k, sparams.tfs_z, sparams.top_p, sparams.typical_p, sparams.temp, sparams.mirostat, sparams.mirostat_eta, sparams.mirostat_tau);
|
|
||||||
LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
|
LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
|
||||||
LOG_TEE("\n\n");
|
LOG_TEE("\n\n");
|
||||||
|
|
||||||
@ -459,7 +458,7 @@ int main(int argc, char ** argv) {
|
|||||||
std::vector<llama_token> embd;
|
std::vector<llama_token> embd;
|
||||||
std::vector<llama_token> embd_guidance;
|
std::vector<llama_token> embd_guidance;
|
||||||
|
|
||||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params);
|
struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams);
|
||||||
|
|
||||||
while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
|
while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
|
||||||
// predict
|
// predict
|
||||||
@ -612,7 +611,7 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance);
|
const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance);
|
||||||
|
|
||||||
llama_sampling_accept(ctx_sampling, ctx, id);
|
llama_sampling_accept(ctx_sampling, ctx, id, true);
|
||||||
|
|
||||||
LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str());
|
LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str());
|
||||||
|
|
||||||
@ -631,12 +630,9 @@ int main(int argc, char ** argv) {
|
|||||||
while ((int) embd_inp.size() > n_consumed) {
|
while ((int) embd_inp.size() > n_consumed) {
|
||||||
embd.push_back(embd_inp[n_consumed]);
|
embd.push_back(embd_inp[n_consumed]);
|
||||||
|
|
||||||
// GG: I'm not sure it's a good idea to push the prompt tokens into the sampling context
|
// push the prompt in the sampling context in order to apply repetition penalties later
|
||||||
// Most likely will remove this in the future to avoid exposing "prev"
|
// for the prompt, we don't apply grammar rules
|
||||||
// Same thing is done in "server". If we stop pushing the prompt tokens, then the repetition
|
llama_sampling_accept(ctx_sampling, ctx, embd_inp[n_consumed], false);
|
||||||
// penalty will be applied only based on the tokens generated by the model.
|
|
||||||
ctx_sampling->prev.erase(ctx_sampling->prev.begin());
|
|
||||||
ctx_sampling->prev.push_back(embd_inp[n_consumed]);
|
|
||||||
|
|
||||||
++n_consumed;
|
++n_consumed;
|
||||||
if ((int) embd.size() >= params.n_batch) {
|
if ((int) embd.size() >= params.n_batch) {
|
||||||
@ -667,12 +663,10 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
// if not currently processing queued inputs;
|
// if not currently processing queued inputs;
|
||||||
if ((int) embd_inp.size() <= n_consumed) {
|
if ((int) embd_inp.size() <= n_consumed) {
|
||||||
// check for reverse prompt
|
// check for reverse prompt in the last n_prev tokens
|
||||||
if (!params.antiprompt.empty()) {
|
if (!params.antiprompt.empty()) {
|
||||||
std::string last_output;
|
const int n_prev = 32;
|
||||||
for (auto id : ctx_sampling->prev) {
|
const std::string last_output = llama_sampling_prev_str(ctx_sampling, ctx, n_prev);
|
||||||
last_output += llama_token_to_piece(ctx, id);
|
|
||||||
}
|
|
||||||
|
|
||||||
is_antiprompt = false;
|
is_antiprompt = false;
|
||||||
// Check if each of the reverse prompts appears at the end of the output.
|
// Check if each of the reverse prompts appears at the end of the output.
|
||||||
@ -699,7 +693,7 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// deal with end of text token in interactive mode
|
// deal with end of text token in interactive mode
|
||||||
if (ctx_sampling->prev.back() == llama_token_eos(ctx)) {
|
if (llama_sampling_last(ctx_sampling) == llama_token_eos(ctx)) {
|
||||||
LOG("found EOS token\n");
|
LOG("found EOS token\n");
|
||||||
|
|
||||||
if (params.interactive) {
|
if (params.interactive) {
|
||||||
|
@ -157,7 +157,7 @@ int main(int argc, char ** argv) {
|
|||||||
for (size_t i = 0; i < clients.size(); ++i) {
|
for (size_t i = 0; i < clients.size(); ++i) {
|
||||||
auto & client = clients[i];
|
auto & client = clients[i];
|
||||||
client.id = i;
|
client.id = i;
|
||||||
client.ctx_sampling = llama_sampling_init(params);
|
client.ctx_sampling = llama_sampling_init(params.sparams);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<llama_token> tokens_system;
|
std::vector<llama_token> tokens_system;
|
||||||
@ -330,7 +330,7 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
const llama_token id = llama_sampling_sample(client.ctx_sampling, ctx, NULL, client.i_batch - i);
|
const llama_token id = llama_sampling_sample(client.ctx_sampling, ctx, NULL, client.i_batch - i);
|
||||||
|
|
||||||
llama_sampling_accept(client.ctx_sampling, ctx, id);
|
llama_sampling_accept(client.ctx_sampling, ctx, id, true);
|
||||||
|
|
||||||
if (client.n_decoded == 1) {
|
if (client.n_decoded == 1) {
|
||||||
// start measuring generation time after the first token to make sure all concurrent clients
|
// start measuring generation time after the first token to make sure all concurrent clients
|
||||||
|
@ -195,10 +195,12 @@ struct llama_server_context
|
|||||||
json prompt;
|
json prompt;
|
||||||
std::vector<llama_token> embd;
|
std::vector<llama_token> embd;
|
||||||
|
|
||||||
|
gpt_params params;
|
||||||
|
|
||||||
llama_model *model = nullptr;
|
llama_model *model = nullptr;
|
||||||
llama_context *ctx = nullptr;
|
llama_context *ctx = nullptr;
|
||||||
gpt_params params;
|
|
||||||
llama_sampling_context *ctx_sampling = nullptr;
|
llama_sampling_context *ctx_sampling = nullptr;
|
||||||
|
|
||||||
int n_ctx;
|
int n_ctx;
|
||||||
|
|
||||||
bool truncated = false;
|
bool truncated = false;
|
||||||
@ -232,7 +234,7 @@ struct llama_server_context
|
|||||||
void rewind()
|
void rewind()
|
||||||
{
|
{
|
||||||
params.antiprompt.clear();
|
params.antiprompt.clear();
|
||||||
params.grammar.clear();
|
params.sparams.grammar.clear();
|
||||||
num_prompt_tokens = 0;
|
num_prompt_tokens = 0;
|
||||||
num_tokens_predicted = 0;
|
num_tokens_predicted = 0;
|
||||||
generated_text = "";
|
generated_text = "";
|
||||||
@ -246,11 +248,14 @@ struct llama_server_context
|
|||||||
multibyte_pending = 0;
|
multibyte_pending = 0;
|
||||||
n_remain = 0;
|
n_remain = 0;
|
||||||
n_past = 0;
|
n_past = 0;
|
||||||
|
params.sparams.n_prev = n_ctx;
|
||||||
|
}
|
||||||
|
|
||||||
|
void initSampling() {
|
||||||
if (ctx_sampling != nullptr) {
|
if (ctx_sampling != nullptr) {
|
||||||
llama_sampling_free(ctx_sampling);
|
llama_sampling_free(ctx_sampling);
|
||||||
}
|
}
|
||||||
ctx_sampling = llama_sampling_init(params);
|
ctx_sampling = llama_sampling_init(params.sparams);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool loadModel(const gpt_params ¶ms_)
|
bool loadModel(const gpt_params ¶ms_)
|
||||||
@ -311,16 +316,32 @@ struct llama_server_context
|
|||||||
return prompt_tokens;
|
return prompt_tokens;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool loadGrammar()
|
void truncatePrompt(std::vector<llama_token> &prompt_tokens) {
|
||||||
{
|
const int n_left = n_ctx - params.n_keep;
|
||||||
ctx_sampling = llama_sampling_init(params);
|
const int n_block_size = n_left / 2;
|
||||||
return true;
|
const int erased_blocks = (prompt_tokens.size() - params.n_keep - n_block_size) / n_block_size;
|
||||||
|
|
||||||
|
// Keep n_keep tokens at start of prompt (at most n_ctx - 4)
|
||||||
|
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep);
|
||||||
|
|
||||||
|
new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_block_size, prompt_tokens.end());
|
||||||
|
|
||||||
|
LOG_VERBOSE("input truncated", {
|
||||||
|
{"n_ctx", n_ctx},
|
||||||
|
{"n_keep", params.n_keep},
|
||||||
|
{"n_left", n_left},
|
||||||
|
{"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
|
||||||
|
{"num_prompt_tokens", new_tokens.size()}
|
||||||
|
});
|
||||||
|
|
||||||
|
truncated = true;
|
||||||
|
prompt_tokens = new_tokens;
|
||||||
}
|
}
|
||||||
|
|
||||||
void loadInfill()
|
void loadInfill()
|
||||||
{
|
{
|
||||||
bool suff_rm_leading_spc = true;
|
bool suff_rm_leading_spc = true;
|
||||||
if (params.input_suffix.find_first_of(" ") == 0 && params.input_suffix.size() > 1) {
|
if (params.input_suffix.find_first_of(' ') == 0 && params.input_suffix.size() > 1) {
|
||||||
params.input_suffix.erase(0, 1);
|
params.input_suffix.erase(0, 1);
|
||||||
suff_rm_leading_spc = false;
|
suff_rm_leading_spc = false;
|
||||||
}
|
}
|
||||||
@ -336,6 +357,7 @@ struct llama_server_context
|
|||||||
prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(ctx));
|
prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(ctx));
|
||||||
prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end());
|
prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end());
|
||||||
prefix_tokens.push_back(llama_token_middle(ctx));
|
prefix_tokens.push_back(llama_token_middle(ctx));
|
||||||
|
|
||||||
auto prompt_tokens = prefix_tokens;
|
auto prompt_tokens = prefix_tokens;
|
||||||
|
|
||||||
num_prompt_tokens = prompt_tokens.size();
|
num_prompt_tokens = prompt_tokens.size();
|
||||||
@ -347,31 +369,18 @@ struct llama_server_context
|
|||||||
params.n_keep = std::min(params.n_ctx - 4, params.n_keep);
|
params.n_keep = std::min(params.n_ctx - 4, params.n_keep);
|
||||||
|
|
||||||
// if input prompt is too big, truncate like normal
|
// if input prompt is too big, truncate like normal
|
||||||
if (num_prompt_tokens >= (size_t)params.n_ctx)
|
if (num_prompt_tokens >= (size_t) n_ctx)
|
||||||
{
|
{
|
||||||
printf("Input prompt is too big, truncating. Can only take %d tokens but got %zu\n", params.n_ctx, num_prompt_tokens);
|
truncatePrompt(prompt_tokens);
|
||||||
// todo we probably want to cut from both sides
|
num_prompt_tokens = prompt_tokens.size();
|
||||||
const int n_left = (params.n_ctx - params.n_keep) / 2;
|
|
||||||
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep);
|
|
||||||
const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left;
|
|
||||||
new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end());
|
|
||||||
std::copy(prompt_tokens.end() - params.n_ctx, prompt_tokens.end(), ctx_sampling->prev.begin());
|
|
||||||
|
|
||||||
LOG_VERBOSE("input truncated", {
|
GGML_ASSERT(num_prompt_tokens < (size_t)n_ctx);
|
||||||
{"n_ctx", params.n_ctx},
|
|
||||||
{"n_keep", params.n_keep},
|
|
||||||
{"n_left", n_left},
|
|
||||||
{"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
|
|
||||||
});
|
|
||||||
|
|
||||||
truncated = true;
|
|
||||||
prompt_tokens = new_tokens;
|
|
||||||
}
|
}
|
||||||
else
|
|
||||||
|
// push the prompt into the sampling context (do not apply grammar)
|
||||||
|
for (auto & token : prompt_tokens)
|
||||||
{
|
{
|
||||||
const size_t ps = num_prompt_tokens;
|
llama_sampling_accept(ctx_sampling, ctx, token, false);
|
||||||
std::fill(ctx_sampling->prev.begin(), ctx_sampling->prev.end() - ps, 0);
|
|
||||||
std::copy(prompt_tokens.begin(), prompt_tokens.end(), ctx_sampling->prev.end() - ps);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// compare the evaluated prompt with the new prompt
|
// compare the evaluated prompt with the new prompt
|
||||||
@ -411,27 +420,16 @@ struct llama_server_context
|
|||||||
// if input prompt is too big, truncate like normal
|
// if input prompt is too big, truncate like normal
|
||||||
if (num_prompt_tokens >= (size_t) n_ctx)
|
if (num_prompt_tokens >= (size_t) n_ctx)
|
||||||
{
|
{
|
||||||
const int n_left = (n_ctx - params.n_keep) / 2;
|
truncatePrompt(prompt_tokens);
|
||||||
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep);
|
num_prompt_tokens = prompt_tokens.size();
|
||||||
const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left;
|
|
||||||
new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end());
|
|
||||||
std::copy(prompt_tokens.end() - n_ctx, prompt_tokens.end(), ctx_sampling->prev.begin());
|
|
||||||
|
|
||||||
LOG_VERBOSE("input truncated", {
|
GGML_ASSERT(num_prompt_tokens < (size_t)n_ctx);
|
||||||
{"n_ctx", n_ctx},
|
|
||||||
{"n_keep", params.n_keep},
|
|
||||||
{"n_left", n_left},
|
|
||||||
{"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
|
|
||||||
});
|
|
||||||
|
|
||||||
truncated = true;
|
|
||||||
prompt_tokens = new_tokens;
|
|
||||||
}
|
}
|
||||||
else
|
|
||||||
|
// push the prompt into the sampling context (do not apply grammar)
|
||||||
|
for (auto & token : prompt_tokens)
|
||||||
{
|
{
|
||||||
const size_t ps = num_prompt_tokens;
|
llama_sampling_accept(ctx_sampling, ctx, token, false);
|
||||||
std::fill(ctx_sampling->prev.begin(), ctx_sampling->prev.end() - ps, 0);
|
|
||||||
std::copy(prompt_tokens.begin(), prompt_tokens.end(), ctx_sampling->prev.end() - ps);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// compare the evaluated prompt with the new prompt
|
// compare the evaluated prompt with the new prompt
|
||||||
@ -530,8 +528,8 @@ struct llama_server_context
|
|||||||
|
|
||||||
llama_token_data_array cur_p = { ctx_sampling->cur.data(), ctx_sampling->cur.size(), false };
|
llama_token_data_array cur_p = { ctx_sampling->cur.data(), ctx_sampling->cur.size(), false };
|
||||||
|
|
||||||
const int32_t n_probs = params.sampling_params.n_probs;
|
const int32_t n_probs = params.sparams.n_probs;
|
||||||
if (params.sampling_params.temp <= 0 && n_probs > 0)
|
if (params.sparams.temp <= 0 && n_probs > 0)
|
||||||
{
|
{
|
||||||
// For llama_sample_token_greedy we need to sort candidates
|
// For llama_sample_token_greedy we need to sort candidates
|
||||||
llama_sample_softmax(ctx, &cur_p);
|
llama_sample_softmax(ctx, &cur_p);
|
||||||
@ -542,7 +540,7 @@ struct llama_server_context
|
|||||||
result.probs.push_back({cur_p.data[i].id, cur_p.data[i].p});
|
result.probs.push_back({cur_p.data[i].id, cur_p.data[i].p});
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_sampling_accept(ctx_sampling, ctx, result.tok);
|
llama_sampling_accept(ctx_sampling, ctx, result.tok, true);
|
||||||
|
|
||||||
if (tg) {
|
if (tg) {
|
||||||
num_tokens_predicted++;
|
num_tokens_predicted++;
|
||||||
@ -606,7 +604,7 @@ struct llama_server_context
|
|||||||
const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_piece(ctx, token_with_probs.tok);
|
const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_piece(ctx, token_with_probs.tok);
|
||||||
generated_text += token_text;
|
generated_text += token_text;
|
||||||
|
|
||||||
if (params.sampling_params.n_probs > 0)
|
if (params.sparams.n_probs > 0)
|
||||||
{
|
{
|
||||||
generated_token_probs.push_back(token_with_probs);
|
generated_token_probs.push_back(token_with_probs);
|
||||||
}
|
}
|
||||||
@ -1004,7 +1002,7 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
|
|||||||
|
|
||||||
static json format_generation_settings(llama_server_context &llama)
|
static json format_generation_settings(llama_server_context &llama)
|
||||||
{
|
{
|
||||||
const auto & sparams = llama.params.sampling_params;
|
const auto & sparams = llama.params.sparams;
|
||||||
const auto eos_bias = sparams.logit_bias.find(llama_token_eos(llama.ctx));
|
const auto eos_bias = sparams.logit_bias.find(llama_token_eos(llama.ctx));
|
||||||
const bool ignore_eos = eos_bias != sparams.logit_bias.end() &&
|
const bool ignore_eos = eos_bias != sparams.logit_bias.end() &&
|
||||||
eos_bias->second < 0.0f && std::isinf(eos_bias->second);
|
eos_bias->second < 0.0f && std::isinf(eos_bias->second);
|
||||||
@ -1018,10 +1016,10 @@ static json format_generation_settings(llama_server_context &llama)
|
|||||||
{"top_p", sparams.top_p},
|
{"top_p", sparams.top_p},
|
||||||
{"tfs_z", sparams.tfs_z},
|
{"tfs_z", sparams.tfs_z},
|
||||||
{"typical_p", sparams.typical_p},
|
{"typical_p", sparams.typical_p},
|
||||||
{"repeat_last_n", sparams.repeat_last_n},
|
{"repeat_last_n", sparams.penalty_last_n},
|
||||||
{"repeat_penalty", sparams.repeat_penalty},
|
{"repeat_penalty", sparams.penalty_repeat},
|
||||||
{"presence_penalty", sparams.presence_penalty},
|
{"frequency_penalty", sparams.penalty_freq},
|
||||||
{"frequency_penalty", sparams.frequency_penalty},
|
{"presence_penalty", sparams.penalty_present},
|
||||||
{"mirostat", sparams.mirostat},
|
{"mirostat", sparams.mirostat},
|
||||||
{"mirostat_tau", sparams.mirostat_tau},
|
{"mirostat_tau", sparams.mirostat_tau},
|
||||||
{"mirostat_eta", sparams.mirostat_eta},
|
{"mirostat_eta", sparams.mirostat_eta},
|
||||||
@ -1033,7 +1031,7 @@ static json format_generation_settings(llama_server_context &llama)
|
|||||||
{"stream", llama.stream},
|
{"stream", llama.stream},
|
||||||
{"logit_bias", sparams.logit_bias},
|
{"logit_bias", sparams.logit_bias},
|
||||||
{"n_probs", sparams.n_probs},
|
{"n_probs", sparams.n_probs},
|
||||||
{"grammar", llama.params.grammar},
|
{"grammar", llama.params.sparams.grammar},
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1081,7 +1079,7 @@ static json format_final_response(llama_server_context &llama, const std::string
|
|||||||
{"timings", format_timings(llama)},
|
{"timings", format_timings(llama)},
|
||||||
};
|
};
|
||||||
|
|
||||||
if (llama.params.sampling_params.n_probs > 0)
|
if (llama.params.sparams.n_probs > 0)
|
||||||
{
|
{
|
||||||
res["completion_probabilities"] = probs_vector_to_json(llama.ctx, probs);
|
res["completion_probabilities"] = probs_vector_to_json(llama.ctx, probs);
|
||||||
}
|
}
|
||||||
@ -1097,7 +1095,7 @@ static json format_partial_response(
|
|||||||
{"stop", false},
|
{"stop", false},
|
||||||
};
|
};
|
||||||
|
|
||||||
if (llama.params.sampling_params.n_probs > 0)
|
if (llama.params.sparams.n_probs > 0)
|
||||||
{
|
{
|
||||||
res["completion_probabilities"] = probs_vector_to_json(llama.ctx, probs);
|
res["completion_probabilities"] = probs_vector_to_json(llama.ctx, probs);
|
||||||
}
|
}
|
||||||
@ -1129,27 +1127,29 @@ static T json_value(const json &body, const std::string &key, const T &default_v
|
|||||||
static void parse_options_completion(const json &body, llama_server_context &llama)
|
static void parse_options_completion(const json &body, llama_server_context &llama)
|
||||||
{
|
{
|
||||||
gpt_params default_params;
|
gpt_params default_params;
|
||||||
const auto & default_sparams = default_params.sampling_params;
|
const auto & default_sparams = default_params.sparams;
|
||||||
auto & sparams = llama.params.sampling_params;
|
|
||||||
|
auto & params = llama.params;
|
||||||
|
auto & sparams = llama.params.sparams;
|
||||||
|
|
||||||
llama.stream = json_value(body, "stream", false);
|
llama.stream = json_value(body, "stream", false);
|
||||||
llama.params.n_predict = json_value(body, "n_predict", default_params.n_predict);
|
params.n_predict = json_value(body, "n_predict", default_params.n_predict);
|
||||||
sparams.top_k = json_value(body, "top_k", default_sparams.top_k);
|
sparams.top_k = json_value(body, "top_k", default_sparams.top_k);
|
||||||
sparams.top_p = json_value(body, "top_p", default_sparams.top_p);
|
sparams.top_p = json_value(body, "top_p", default_sparams.top_p);
|
||||||
sparams.tfs_z = json_value(body, "tfs_z", default_sparams.tfs_z);
|
sparams.tfs_z = json_value(body, "tfs_z", default_sparams.tfs_z);
|
||||||
sparams.typical_p = json_value(body, "typical_p", default_sparams.typical_p);
|
sparams.typical_p = json_value(body, "typical_p", default_sparams.typical_p);
|
||||||
sparams.repeat_last_n = json_value(body, "repeat_last_n", default_sparams.repeat_last_n);
|
|
||||||
sparams.temp = json_value(body, "temperature", default_sparams.temp);
|
sparams.temp = json_value(body, "temperature", default_sparams.temp);
|
||||||
sparams.repeat_penalty = json_value(body, "repeat_penalty", default_sparams.repeat_penalty);
|
sparams.penalty_last_n = json_value(body, "repeat_last_n", default_sparams.penalty_last_n);
|
||||||
sparams.presence_penalty = json_value(body, "presence_penalty", default_sparams.presence_penalty);
|
sparams.penalty_repeat = json_value(body, "repeat_penalty", default_sparams.penalty_repeat);
|
||||||
sparams.frequency_penalty = json_value(body, "frequency_penalty", default_sparams.frequency_penalty);
|
sparams.penalty_freq = json_value(body, "frequency_penalty", default_sparams.penalty_freq);
|
||||||
|
sparams.penalty_present = json_value(body, "presence_penalty", default_sparams.penalty_present);
|
||||||
sparams.mirostat = json_value(body, "mirostat", default_sparams.mirostat);
|
sparams.mirostat = json_value(body, "mirostat", default_sparams.mirostat);
|
||||||
sparams.mirostat_tau = json_value(body, "mirostat_tau", default_sparams.mirostat_tau);
|
sparams.mirostat_tau = json_value(body, "mirostat_tau", default_sparams.mirostat_tau);
|
||||||
sparams.mirostat_eta = json_value(body, "mirostat_eta", default_sparams.mirostat_eta);
|
sparams.mirostat_eta = json_value(body, "mirostat_eta", default_sparams.mirostat_eta);
|
||||||
sparams.penalize_nl = json_value(body, "penalize_nl", default_sparams.penalize_nl);
|
sparams.penalize_nl = json_value(body, "penalize_nl", default_sparams.penalize_nl);
|
||||||
llama.params.n_keep = json_value(body, "n_keep", default_params.n_keep);
|
params.n_keep = json_value(body, "n_keep", default_params.n_keep);
|
||||||
llama.params.seed = json_value(body, "seed", default_params.seed);
|
params.seed = json_value(body, "seed", default_params.seed);
|
||||||
llama.params.grammar = json_value(body, "grammar", default_params.grammar);
|
sparams.grammar = json_value(body, "grammar", default_sparams.grammar);
|
||||||
sparams.n_probs = json_value(body, "n_probs", default_sparams.n_probs);
|
sparams.n_probs = json_value(body, "n_probs", default_sparams.n_probs);
|
||||||
|
|
||||||
if (body.count("prompt") != 0)
|
if (body.count("prompt") != 0)
|
||||||
@ -1204,8 +1204,6 @@ static void parse_options_completion(const json &body, llama_server_context &lla
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
llama.ctx_sampling = llama_sampling_init(llama.params);
|
|
||||||
|
|
||||||
LOG_VERBOSE("completion parameters parsed", format_generation_settings(llama));
|
LOG_VERBOSE("completion parameters parsed", format_generation_settings(llama));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1374,15 +1372,9 @@ int main(int argc, char **argv)
|
|||||||
llama.rewind();
|
llama.rewind();
|
||||||
|
|
||||||
llama_reset_timings(llama.ctx);
|
llama_reset_timings(llama.ctx);
|
||||||
|
|
||||||
parse_options_completion(json::parse(req.body), llama);
|
parse_options_completion(json::parse(req.body), llama);
|
||||||
|
|
||||||
if (!llama.loadGrammar())
|
llama.initSampling();
|
||||||
{
|
|
||||||
res.status = 400;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
llama.loadPrompt();
|
llama.loadPrompt();
|
||||||
llama.beginCompletion();
|
llama.beginCompletion();
|
||||||
|
|
||||||
@ -1414,7 +1406,7 @@ int main(int argc, char **argv)
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto probs = llama.generated_token_probs;
|
auto probs = llama.generated_token_probs;
|
||||||
if (llama.params.sampling_params.n_probs > 0 && llama.stopped_word) {
|
if (llama.params.sparams.n_probs > 0 && llama.stopped_word) {
|
||||||
const std::vector<llama_token> stop_word_toks = llama_tokenize(llama.ctx, llama.stopping_word, false);
|
const std::vector<llama_token> stop_word_toks = llama_tokenize(llama.ctx, llama.stopping_word, false);
|
||||||
probs = std::vector<completion_token_output>(llama.generated_token_probs.begin(), llama.generated_token_probs.end() - stop_word_toks.size());
|
probs = std::vector<completion_token_output>(llama.generated_token_probs.begin(), llama.generated_token_probs.end() - stop_word_toks.size());
|
||||||
}
|
}
|
||||||
@ -1466,7 +1458,7 @@ int main(int argc, char **argv)
|
|||||||
|
|
||||||
std::vector<completion_token_output> probs_output = {};
|
std::vector<completion_token_output> probs_output = {};
|
||||||
|
|
||||||
if (llama.params.sampling_params.n_probs > 0) {
|
if (llama.params.sparams.n_probs > 0) {
|
||||||
const std::vector<llama_token> to_send_toks = llama_tokenize(llama.ctx, to_send, false);
|
const std::vector<llama_token> to_send_toks = llama_tokenize(llama.ctx, to_send, false);
|
||||||
size_t probs_pos = std::min(sent_token_probs_index, llama.generated_token_probs.size());
|
size_t probs_pos = std::min(sent_token_probs_index, llama.generated_token_probs.size());
|
||||||
size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama.generated_token_probs.size());
|
size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama.generated_token_probs.size());
|
||||||
@ -1537,14 +1529,9 @@ int main(int argc, char **argv)
|
|||||||
llama.rewind();
|
llama.rewind();
|
||||||
|
|
||||||
llama_reset_timings(llama.ctx);
|
llama_reset_timings(llama.ctx);
|
||||||
|
|
||||||
parse_options_infill(json::parse(req.body), llama);
|
parse_options_infill(json::parse(req.body), llama);
|
||||||
|
|
||||||
if (!llama.loadGrammar())
|
llama.initSampling();
|
||||||
{
|
|
||||||
res.status = 400;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
llama.loadInfill();
|
llama.loadInfill();
|
||||||
llama.beginCompletion();
|
llama.beginCompletion();
|
||||||
const auto chunked_content_provider = [&](size_t, DataSink & sink) {
|
const auto chunked_content_provider = [&](size_t, DataSink & sink) {
|
||||||
@ -1587,7 +1574,7 @@ int main(int argc, char **argv)
|
|||||||
|
|
||||||
std::vector<completion_token_output> probs_output = {};
|
std::vector<completion_token_output> probs_output = {};
|
||||||
|
|
||||||
if (llama.params.sampling_params.n_probs > 0) {
|
if (llama.params.sparams.n_probs > 0) {
|
||||||
const std::vector<llama_token> to_send_toks = llama_tokenize(llama.ctx, to_send, false);
|
const std::vector<llama_token> to_send_toks = llama_tokenize(llama.ctx, to_send, false);
|
||||||
size_t probs_pos = std::min(sent_token_probs_index, llama.generated_token_probs.size());
|
size_t probs_pos = std::min(sent_token_probs_index, llama.generated_token_probs.size());
|
||||||
size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama.generated_token_probs.size());
|
size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama.generated_token_probs.size());
|
||||||
@ -1694,7 +1681,9 @@ int main(int argc, char **argv)
|
|||||||
const json body = json::parse(req.body);
|
const json body = json::parse(req.body);
|
||||||
|
|
||||||
llama.rewind();
|
llama.rewind();
|
||||||
|
|
||||||
llama_reset_timings(llama.ctx);
|
llama_reset_timings(llama.ctx);
|
||||||
|
|
||||||
if (body.count("content") != 0)
|
if (body.count("content") != 0)
|
||||||
{
|
{
|
||||||
llama.prompt = body["content"];
|
llama.prompt = body["content"];
|
||||||
@ -1704,6 +1693,8 @@ int main(int argc, char **argv)
|
|||||||
llama.prompt = "";
|
llama.prompt = "";
|
||||||
}
|
}
|
||||||
llama.params.n_predict = 0;
|
llama.params.n_predict = 0;
|
||||||
|
|
||||||
|
llama.initSampling();
|
||||||
llama.loadPrompt();
|
llama.loadPrompt();
|
||||||
llama.beginCompletion();
|
llama.beginCompletion();
|
||||||
llama.doCompletion();
|
llama.doCompletion();
|
||||||
|
@ -112,16 +112,16 @@ int main(int argc, char ** argv) {
|
|||||||
bool has_eos = false;
|
bool has_eos = false;
|
||||||
|
|
||||||
// target model sampling context
|
// target model sampling context
|
||||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params);
|
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams);
|
||||||
|
|
||||||
// draft sequence data
|
// draft sequence data
|
||||||
std::vector<seq_draft> drafts(n_seq_dft);
|
std::vector<seq_draft> drafts(n_seq_dft);
|
||||||
|
|
||||||
params.grammar.clear(); // the draft samplers will copy the target sampler's grammar
|
params.sparams.grammar.clear(); // the draft samplers will copy the target sampler's grammar
|
||||||
params.sampling_params.temp = std::max(0.01f, params.sampling_params.temp);
|
params.sparams.temp = std::max(0.01f, params.sparams.temp);
|
||||||
|
|
||||||
for (int s = 0; s < n_seq_dft; ++s) {
|
for (int s = 0; s < n_seq_dft; ++s) {
|
||||||
drafts[s].ctx_sampling = llama_sampling_init(params);
|
drafts[s].ctx_sampling = llama_sampling_init(params.sparams);
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_batch batch_dft = llama_batch_init(params.n_ctx, 0, 1);
|
llama_batch batch_dft = llama_batch_init(params.n_ctx, 0, 1);
|
||||||
@ -154,7 +154,7 @@ int main(int argc, char ** argv) {
|
|||||||
// sample from the target model
|
// sample from the target model
|
||||||
llama_token id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]);
|
llama_token id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]);
|
||||||
|
|
||||||
llama_sampling_accept(ctx_sampling, ctx_tgt, id);
|
llama_sampling_accept(ctx_sampling, ctx_tgt, id, true);
|
||||||
|
|
||||||
//LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, ctx_sampling->prev).c_str());
|
//LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, ctx_sampling->prev).c_str());
|
||||||
|
|
||||||
@ -328,7 +328,7 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
const int s = sa[is];
|
const int s = sa[is];
|
||||||
|
|
||||||
llama_sampling_accept(drafts[s].ctx_sampling, ctx_dft, id);
|
llama_sampling_accept(drafts[s].ctx_sampling, ctx_dft, id, true);
|
||||||
|
|
||||||
drafts[s].tokens.push_back(id);
|
drafts[s].tokens.push_back(id);
|
||||||
|
|
||||||
|
68
llama.cpp
68
llama.cpp
@ -1018,8 +1018,8 @@ enum e_model {
|
|||||||
};
|
};
|
||||||
|
|
||||||
static const size_t kB = 1024;
|
static const size_t kB = 1024;
|
||||||
static const size_t MB = kB*kB;
|
static const size_t MB = 1024*kB;
|
||||||
static const size_t GB = kB*kB*kB;
|
static const size_t GB = 1024*MB;
|
||||||
|
|
||||||
struct llama_hparams {
|
struct llama_hparams {
|
||||||
bool vocab_only;
|
bool vocab_only;
|
||||||
@ -1359,10 +1359,7 @@ static bool llama_kv_cache_init(
|
|||||||
cache.cells.clear();
|
cache.cells.clear();
|
||||||
cache.cells.resize(n_ctx);
|
cache.cells.resize(n_ctx);
|
||||||
|
|
||||||
// TODO: this should be:
|
cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*ggml_tensor_overhead());
|
||||||
// cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*ggml_tensor_overhead());
|
|
||||||
// change it and test that it works
|
|
||||||
cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*MB);
|
|
||||||
memset(cache.buf.data, 0, cache.buf.size);
|
memset(cache.buf.data, 0, cache.buf.size);
|
||||||
|
|
||||||
struct ggml_init_params params;
|
struct ggml_init_params params;
|
||||||
@ -7417,37 +7414,15 @@ void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array
|
|||||||
llama_sample_temp(ctx, candidates_p, temp);
|
llama_sample_temp(ctx, candidates_p, temp);
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float penalty) {
|
void llama_sample_repetition_penalties(
|
||||||
if (last_tokens_size == 0 || penalty == 1.0f) {
|
struct llama_context * ctx,
|
||||||
return;
|
llama_token_data_array * candidates,
|
||||||
}
|
const llama_token * last_tokens,
|
||||||
|
size_t penalty_last_n,
|
||||||
const int64_t t_start_sample_us = ggml_time_us();
|
float penalty_repeat,
|
||||||
|
float penalty_freq,
|
||||||
for (size_t i = 0; i < candidates->size; ++i) {
|
float penalty_present) {
|
||||||
const auto * token_iter = std::find(last_tokens, last_tokens + last_tokens_size, candidates->data[i].id);
|
if (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)) {
|
||||||
if (token_iter == last_tokens + last_tokens_size) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
|
|
||||||
// This is common fix for this problem, which is to multiply by the penalty instead of dividing.
|
|
||||||
if (candidates->data[i].logit <= 0) {
|
|
||||||
candidates->data[i].logit *= penalty;
|
|
||||||
} else {
|
|
||||||
candidates->data[i].logit /= penalty;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
candidates->sorted = false;
|
|
||||||
|
|
||||||
if (ctx) {
|
|
||||||
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens_p, size_t last_tokens_size, float alpha_frequency, float alpha_presence) {
|
|
||||||
if (last_tokens_size == 0 || (alpha_frequency == 0.0f && alpha_presence == 0.0f)) {
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -7455,19 +7430,28 @@ void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, l
|
|||||||
|
|
||||||
// Create a frequency map to count occurrences of each token in last_tokens
|
// Create a frequency map to count occurrences of each token in last_tokens
|
||||||
std::unordered_map<llama_token, int> token_count;
|
std::unordered_map<llama_token, int> token_count;
|
||||||
for (size_t i = 0; i < last_tokens_size; ++i) {
|
for (size_t i = 0; i < penalty_last_n; ++i) {
|
||||||
token_count[last_tokens_p[i]]++;
|
token_count[last_tokens[i]]++;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply frequency and presence penalties to the candidates
|
// Apply frequency and presence penalties to the candidates
|
||||||
for (size_t i = 0; i < candidates->size; ++i) {
|
for (size_t i = 0; i < candidates->size; ++i) {
|
||||||
auto token_iter = token_count.find(candidates->data[i].id);
|
const auto token_iter = token_count.find(candidates->data[i].id);
|
||||||
if (token_iter == token_count.end()) {
|
if (token_iter == token_count.end()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
int count = token_iter->second;
|
const int count = token_iter->second;
|
||||||
candidates->data[i].logit -= float(count) * alpha_frequency + float(count > 0) * alpha_presence;
|
|
||||||
|
// The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
|
||||||
|
// This is common fix for this problem, which is to multiply by the penalty instead of dividing.
|
||||||
|
if (candidates->data[i].logit <= 0) {
|
||||||
|
candidates->data[i].logit *= penalty_repeat;
|
||||||
|
} else {
|
||||||
|
candidates->data[i].logit /= penalty_repeat;
|
||||||
|
}
|
||||||
|
|
||||||
|
candidates->data[i].logit -= float(count) * penalty_freq + float(count > 0) * penalty_present;
|
||||||
}
|
}
|
||||||
|
|
||||||
candidates->sorted = false;
|
candidates->sorted = false;
|
||||||
|
16
llama.h
16
llama.h
@ -560,21 +560,15 @@ extern "C" {
|
|||||||
LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed);
|
LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed);
|
||||||
|
|
||||||
/// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
|
/// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
|
||||||
LLAMA_API void llama_sample_repetition_penalty(
|
|
||||||
struct llama_context * ctx,
|
|
||||||
llama_token_data_array * candidates,
|
|
||||||
const llama_token * last_tokens,
|
|
||||||
size_t last_tokens_size,
|
|
||||||
float penalty);
|
|
||||||
|
|
||||||
/// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
|
/// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
|
||||||
LLAMA_API void llama_sample_frequency_and_presence_penalties(
|
LLAMA_API void llama_sample_repetition_penalties(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
llama_token_data_array * candidates,
|
llama_token_data_array * candidates,
|
||||||
const llama_token * last_tokens,
|
const llama_token * last_tokens,
|
||||||
size_t last_tokens_size,
|
size_t penalty_last_n,
|
||||||
float alpha_frequency,
|
float penalty_repeat,
|
||||||
float alpha_presence);
|
float penalty_freq,
|
||||||
|
float penalty_present);
|
||||||
|
|
||||||
/// @details Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806
|
/// @details Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806
|
||||||
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, the logits must be directly extracted from the original generation context without being sorted.
|
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, the logits must be directly extracted from the original generation context without being sorted.
|
||||||
|
@ -8,11 +8,9 @@
|
|||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <iostream>
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
|
||||||
|
|
||||||
static void dump(const llama_token_data_array * candidates) {
|
static void dump(const llama_token_data_array * candidates) {
|
||||||
for (size_t i = 0; i < candidates->size; i++) {
|
for (size_t i = 0; i < candidates->size; i++) {
|
||||||
printf("%d: %f (%f)\n", candidates->data[i].id, candidates->data[i].p, candidates->data[i].logit);
|
printf("%d: %f (%f)\n", candidates->data[i].id, candidates->data[i].p, candidates->data[i].logit);
|
||||||
@ -21,7 +19,6 @@ static void dump(const llama_token_data_array * candidates) {
|
|||||||
|
|
||||||
#define DUMP(__candidates) do { printf("%s:%d (%s)\n", __FILE__, __LINE__, __func__); dump((__candidates)); printf("-\n"); } while(0)
|
#define DUMP(__candidates) do { printf("%s:%d (%s)\n", __FILE__, __LINE__, __func__); dump((__candidates)); printf("-\n"); } while(0)
|
||||||
|
|
||||||
|
|
||||||
static void test_top_k(const std::vector<float> & probs, const std::vector<float> & expected_probs, int k) {
|
static void test_top_k(const std::vector<float> & probs, const std::vector<float> & expected_probs, int k) {
|
||||||
size_t n_vocab = probs.size();
|
size_t n_vocab = probs.size();
|
||||||
std::vector<llama_token_data> candidates;
|
std::vector<llama_token_data> candidates;
|
||||||
@ -37,13 +34,12 @@ static void test_top_k(const std::vector<float> & probs, const std::vector<float
|
|||||||
llama_sample_top_k(nullptr, &candidates_p, k, 1);
|
llama_sample_top_k(nullptr, &candidates_p, k, 1);
|
||||||
DUMP(&candidates_p);
|
DUMP(&candidates_p);
|
||||||
|
|
||||||
assert(candidates_p.size == expected_probs.size());
|
GGML_ASSERT(candidates_p.size == expected_probs.size());
|
||||||
for (size_t i = 0; i < candidates_p.size; i++) {
|
for (size_t i = 0; i < candidates_p.size; i++) {
|
||||||
assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-5);
|
GGML_ASSERT(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-5);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
static void test_top_p(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
|
static void test_top_p(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
|
||||||
size_t n_vocab = probs.size();
|
size_t n_vocab = probs.size();
|
||||||
std::vector<llama_token_data> candidates;
|
std::vector<llama_token_data> candidates;
|
||||||
@ -59,13 +55,12 @@ static void test_top_p(const std::vector<float> & probs, const std::vector<float
|
|||||||
llama_sample_top_p(nullptr, &candidates_p, p, 1);
|
llama_sample_top_p(nullptr, &candidates_p, p, 1);
|
||||||
DUMP(&candidates_p);
|
DUMP(&candidates_p);
|
||||||
|
|
||||||
assert(candidates_p.size == expected_probs.size());
|
GGML_ASSERT(candidates_p.size == expected_probs.size());
|
||||||
for (size_t i = 0; i < candidates_p.size; i++) {
|
for (size_t i = 0; i < candidates_p.size; i++) {
|
||||||
assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
|
GGML_ASSERT(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
static void test_tfs(const std::vector<float> & probs, const std::vector<float> & expected_probs, float z) {
|
static void test_tfs(const std::vector<float> & probs, const std::vector<float> & expected_probs, float z) {
|
||||||
size_t n_vocab = probs.size();
|
size_t n_vocab = probs.size();
|
||||||
std::vector<llama_token_data> candidates;
|
std::vector<llama_token_data> candidates;
|
||||||
@ -80,13 +75,12 @@ static void test_tfs(const std::vector<float> & probs, const std::vector<float>
|
|||||||
llama_sample_tail_free(nullptr, &candidates_p, z, 1);
|
llama_sample_tail_free(nullptr, &candidates_p, z, 1);
|
||||||
DUMP(&candidates_p);
|
DUMP(&candidates_p);
|
||||||
|
|
||||||
assert(candidates_p.size == expected_probs.size());
|
GGML_ASSERT(candidates_p.size == expected_probs.size());
|
||||||
for (size_t i = 0; i < candidates_p.size; i++) {
|
for (size_t i = 0; i < candidates_p.size; i++) {
|
||||||
assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
|
GGML_ASSERT(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
static void test_typical(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
|
static void test_typical(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
|
||||||
size_t n_vocab = probs.size();
|
size_t n_vocab = probs.size();
|
||||||
std::vector<llama_token_data> candidates;
|
std::vector<llama_token_data> candidates;
|
||||||
@ -101,18 +95,17 @@ static void test_typical(const std::vector<float> & probs, const std::vector<flo
|
|||||||
llama_sample_typical(nullptr, &candidates_p, p, 1);
|
llama_sample_typical(nullptr, &candidates_p, p, 1);
|
||||||
DUMP(&candidates_p);
|
DUMP(&candidates_p);
|
||||||
|
|
||||||
assert(candidates_p.size == expected_probs.size());
|
GGML_ASSERT(candidates_p.size == expected_probs.size());
|
||||||
for (size_t i = 0; i < candidates_p.size; i++) {
|
for (size_t i = 0; i < candidates_p.size; i++) {
|
||||||
assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
|
GGML_ASSERT(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void test_repetition_penalties(
|
||||||
static void test_repetition_penalty(
|
|
||||||
const std::vector<float> & probs, const std::vector<llama_token> & last_tokens,
|
const std::vector<float> & probs, const std::vector<llama_token> & last_tokens,
|
||||||
const std::vector<float> & expected_probs, float penalty
|
const std::vector<float> & expected_probs, float repeat_penalty, float alpha_frequency, float alpha_presence
|
||||||
) {
|
) {
|
||||||
assert(probs.size() == expected_probs.size());
|
GGML_ASSERT(probs.size() == expected_probs.size());
|
||||||
|
|
||||||
size_t n_vocab = probs.size();
|
size_t n_vocab = probs.size();
|
||||||
std::vector<llama_token_data> candidates;
|
std::vector<llama_token_data> candidates;
|
||||||
@ -125,41 +118,13 @@ static void test_repetition_penalty(
|
|||||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||||
llama_sample_softmax(nullptr, &candidates_p);
|
llama_sample_softmax(nullptr, &candidates_p);
|
||||||
DUMP(&candidates_p);
|
DUMP(&candidates_p);
|
||||||
llama_sample_repetition_penalty(nullptr, &candidates_p, (const llama_token *) last_tokens.data(), last_tokens.size(), penalty);
|
llama_sample_repetition_penalties(nullptr, &candidates_p, (const llama_token *) last_tokens.data(), last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence);
|
||||||
llama_sample_softmax(nullptr, &candidates_p);
|
llama_sample_softmax(nullptr, &candidates_p);
|
||||||
DUMP(&candidates_p);
|
DUMP(&candidates_p);
|
||||||
|
|
||||||
assert(candidates_p.size == expected_probs.size());
|
GGML_ASSERT(candidates_p.size == expected_probs.size());
|
||||||
for (size_t i = 0; i < candidates_p.size; i++) {
|
for (size_t i = 0; i < candidates_p.size; i++) {
|
||||||
assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-6);
|
GGML_ASSERT(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
static void test_frequency_presence_penalty(
|
|
||||||
const std::vector<float> & probs, const std::vector<llama_token> & last_tokens,
|
|
||||||
const std::vector<float> & expected_probs, float alpha_frequency, float alpha_presence
|
|
||||||
) {
|
|
||||||
assert(probs.size() == expected_probs.size());
|
|
||||||
|
|
||||||
size_t n_vocab = probs.size();
|
|
||||||
std::vector<llama_token_data> candidates;
|
|
||||||
candidates.reserve(n_vocab);
|
|
||||||
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
|
|
||||||
float logit = log(probs[token_id]);
|
|
||||||
candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
|
|
||||||
}
|
|
||||||
|
|
||||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
|
||||||
llama_sample_softmax(nullptr, &candidates_p);
|
|
||||||
// DUMP(&candidates_p);
|
|
||||||
llama_sample_frequency_and_presence_penalties(nullptr, &candidates_p, (const llama_token *) last_tokens.data(), last_tokens.size(), alpha_frequency, alpha_presence);
|
|
||||||
llama_sample_softmax(nullptr, &candidates_p);
|
|
||||||
// DUMP(&candidates_p);
|
|
||||||
|
|
||||||
assert(candidates_p.size == expected_probs.size());
|
|
||||||
for (size_t i = 0; i < candidates_p.size; i++) {
|
|
||||||
assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -181,13 +146,13 @@ int main(void) {
|
|||||||
test_typical({0.97f, 0.01f, 0.01f, 0.01f}, {0.97f}, 0.5f);
|
test_typical({0.97f, 0.01f, 0.01f, 0.01f}, {0.97f}, 0.5f);
|
||||||
test_typical({0.4f, 0.2f, 0.2f, 0.2f}, {0.2f, 0.2f, 0.2f}, 0.5f);
|
test_typical({0.4f, 0.2f, 0.2f, 0.2f}, {0.2f, 0.2f, 0.2f}, 0.5f);
|
||||||
|
|
||||||
test_repetition_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.25f, 0.25f, 0.25f, 0.25f, 0}, 50.0f);
|
test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.25f, 0.25f, 0.25f, 0.25f, 0}, 50.0f, 0.0f, 0.0f);
|
||||||
test_repetition_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.5f, 0.5f, 0, 0, 0}, 50.0f);
|
test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f);
|
||||||
test_repetition_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.5f, 0.5f, 0, 0, 0}, 50.0f);
|
test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f);
|
||||||
|
|
||||||
test_frequency_presence_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.249997f, 0.249997f, 0.249997f, 0.249997f, 0.000011f}, 5.0f, 5.0f);
|
test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.249997f, 0.249997f, 0.249997f, 0.249997f, 0.000011f}, 1.0f, 5.0f, 5.0f);
|
||||||
test_frequency_presence_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.499966f, 0.499966f, 0.000023f, 0.000023f, 0.000023f}, 5.0f, 5.0f);
|
test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.499966f, 0.499966f, 0.000023f, 0.000023f, 0.000023f}, 1.0f, 5.0f, 5.0f);
|
||||||
test_frequency_presence_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.499977f, 0.499977f, 0.000023f, 0.000023f, 0.000000f}, 5.0f, 5.0f);
|
test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.499977f, 0.499977f, 0.000023f, 0.000023f, 0.000000f}, 1.0f, 5.0f, 5.0f);
|
||||||
|
|
||||||
printf("OK\n");
|
printf("OK\n");
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user