mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-24 10:24:35 +00:00
Add embedding mode with arg flag. Currently working (#282)
* working but ugly * add arg flag, not working on embedding mode * typo * Working! Thanks to @nullhook * make params argument instead of hardcoded boolean. remove useless time check * start doing the instructions but not finished. This probably doesnt compile * Embeddings extraction support --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
b6b268d441
commit
8d4a855c24
56
llama.cpp
56
llama.cpp
@ -102,6 +102,9 @@ struct llama_context {
|
|||||||
// decode output (2-dimensional array: [n_tokens][n_vocab])
|
// decode output (2-dimensional array: [n_tokens][n_vocab])
|
||||||
std::vector<float> logits;
|
std::vector<float> logits;
|
||||||
bool logits_all = false;
|
bool logits_all = false;
|
||||||
|
|
||||||
|
// input embedding (1-dimensional array: [n_embd])
|
||||||
|
std::vector<float> embedding;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llama_context_params llama_context_default_params() {
|
struct llama_context_params llama_context_default_params() {
|
||||||
@ -112,6 +115,7 @@ struct llama_context_params llama_context_default_params() {
|
|||||||
/*.f16_kv =*/ false,
|
/*.f16_kv =*/ false,
|
||||||
/*.logits_all =*/ false,
|
/*.logits_all =*/ false,
|
||||||
/*.vocab_only =*/ false,
|
/*.vocab_only =*/ false,
|
||||||
|
/*.embedding =*/ false,
|
||||||
};
|
};
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
@ -592,8 +596,6 @@ static bool llama_model_load(
|
|||||||
fin.close();
|
fin.close();
|
||||||
}
|
}
|
||||||
|
|
||||||
lctx.logits.reserve(lctx.model.hparams.n_ctx);
|
|
||||||
|
|
||||||
lctx.t_load_us = ggml_time_us() - t_start_us;
|
lctx.t_load_us = ggml_time_us() - t_start_us;
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
@ -791,6 +793,9 @@ static bool llama_eval_internal(
|
|||||||
inpL = cur;
|
inpL = cur;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// used at the end to optionally extract the embeddings
|
||||||
|
struct ggml_tensor * embeddings = NULL;
|
||||||
|
|
||||||
// norm
|
// norm
|
||||||
{
|
{
|
||||||
inpL = ggml_rms_norm(ctx0, inpL);
|
inpL = ggml_rms_norm(ctx0, inpL);
|
||||||
@ -799,6 +804,8 @@ static bool llama_eval_internal(
|
|||||||
inpL = ggml_mul(ctx0,
|
inpL = ggml_mul(ctx0,
|
||||||
ggml_repeat(ctx0, model.norm, inpL),
|
ggml_repeat(ctx0, model.norm, inpL),
|
||||||
inpL);
|
inpL);
|
||||||
|
|
||||||
|
embeddings = inpL;
|
||||||
}
|
}
|
||||||
|
|
||||||
// lm_head
|
// lm_head
|
||||||
@ -821,15 +828,26 @@ static bool llama_eval_internal(
|
|||||||
//embd_w.resize(n_vocab*N);
|
//embd_w.resize(n_vocab*N);
|
||||||
//memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
|
//memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
|
||||||
|
|
||||||
auto & logits_out = lctx.logits;
|
// extract logits
|
||||||
|
{
|
||||||
|
auto & logits_out = lctx.logits;
|
||||||
|
|
||||||
if (lctx.logits_all) {
|
if (lctx.logits_all) {
|
||||||
logits_out.resize(n_vocab * N);
|
logits_out.resize(n_vocab * N);
|
||||||
memcpy(logits_out.data(), (float *) ggml_get_data(inpL), sizeof(float)*n_vocab*N);
|
memcpy(logits_out.data(), (float *) ggml_get_data(inpL), sizeof(float)*n_vocab*N);
|
||||||
} else {
|
} else {
|
||||||
// return result for just the last token
|
// return result for just the last token
|
||||||
logits_out.resize(n_vocab);
|
logits_out.resize(n_vocab);
|
||||||
memcpy(logits_out.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
|
memcpy(logits_out.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// extract embeddings
|
||||||
|
if (lctx.embedding.size()) {
|
||||||
|
auto & embedding_out = lctx.embedding;
|
||||||
|
|
||||||
|
embedding_out.resize(n_embd);
|
||||||
|
memcpy(embedding_out.data(), (float *) ggml_get_data(embeddings) + (n_embd*(N - 1)), sizeof(float)*n_embd);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (mem_per_token == 0) {
|
if (mem_per_token == 0) {
|
||||||
@ -1416,6 +1434,20 @@ struct llama_context * llama_init_from_file(
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// reserve memory for context buffers
|
||||||
|
{
|
||||||
|
const auto & hparams = ctx->model.hparams;
|
||||||
|
if (params.logits_all) {
|
||||||
|
ctx->logits.reserve(hparams.n_ctx*hparams.n_vocab);
|
||||||
|
} else {
|
||||||
|
ctx->logits.reserve(hparams.n_ctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (params.embedding){
|
||||||
|
ctx->embedding.reserve(hparams.n_embd);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return ctx;
|
return ctx;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1484,6 +1516,10 @@ float * llama_get_logits(struct llama_context * ctx) {
|
|||||||
return ctx->logits.data();
|
return ctx->logits.data();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
float * llama_get_embeddings(struct llama_context * ctx) {
|
||||||
|
return ctx->embedding.data();
|
||||||
|
}
|
||||||
|
|
||||||
const char * llama_token_to_str(struct llama_context * ctx, llama_token token) {
|
const char * llama_token_to_str(struct llama_context * ctx, llama_token token) {
|
||||||
if (token >= llama_n_vocab(ctx)) {
|
if (token >= llama_n_vocab(ctx)) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
5
llama.h
5
llama.h
@ -53,6 +53,7 @@ extern "C" {
|
|||||||
bool f16_kv; // use fp16 for KV cache
|
bool f16_kv; // use fp16 for KV cache
|
||||||
bool logits_all; // the llama_eval() call computes all logits, not just the last one
|
bool logits_all; // the llama_eval() call computes all logits, not just the last one
|
||||||
bool vocab_only; // only load the vocabulary, no weights
|
bool vocab_only; // only load the vocabulary, no weights
|
||||||
|
bool embedding; // embedding mode only
|
||||||
};
|
};
|
||||||
|
|
||||||
LLAMA_API struct llama_context_params llama_context_default_params();
|
LLAMA_API struct llama_context_params llama_context_default_params();
|
||||||
@ -108,6 +109,10 @@ extern "C" {
|
|||||||
// Cols: n_vocab
|
// Cols: n_vocab
|
||||||
LLAMA_API float * llama_get_logits(struct llama_context * ctx);
|
LLAMA_API float * llama_get_logits(struct llama_context * ctx);
|
||||||
|
|
||||||
|
// Get the embeddings for the input
|
||||||
|
// shape: [n_embd] (1-dimensional)
|
||||||
|
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
|
||||||
|
|
||||||
// Token Id -> String. Uses the vocabulary in the provided context
|
// Token Id -> String. Uses the vocabulary in the provided context
|
||||||
LLAMA_API const char * llama_token_to_str(struct llama_context * ctx, llama_token token);
|
LLAMA_API const char * llama_token_to_str(struct llama_context * ctx, llama_token token);
|
||||||
|
|
||||||
|
23
main.cpp
23
main.cpp
@ -199,6 +199,7 @@ int main(int argc, char ** argv) {
|
|||||||
lparams.seed = params.seed;
|
lparams.seed = params.seed;
|
||||||
lparams.f16_kv = params.memory_f16;
|
lparams.f16_kv = params.memory_f16;
|
||||||
lparams.logits_all = params.perplexity;
|
lparams.logits_all = params.perplexity;
|
||||||
|
lparams.embedding = params.embedding;
|
||||||
|
|
||||||
ctx = llama_init_from_file(params.model.c_str(), lparams);
|
ctx = llama_init_from_file(params.model.c_str(), lparams);
|
||||||
|
|
||||||
@ -292,6 +293,7 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
std::vector<llama_token> embd;
|
std::vector<llama_token> embd;
|
||||||
|
|
||||||
|
|
||||||
int last_n_size = params.repeat_last_n;
|
int last_n_size = params.repeat_last_n;
|
||||||
std::vector<llama_token> last_n_tokens(last_n_size);
|
std::vector<llama_token> last_n_tokens(last_n_size);
|
||||||
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
|
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
|
||||||
@ -324,6 +326,27 @@ int main(int argc, char ** argv) {
|
|||||||
// the first thing we will do is to output the prompt, so set color accordingly
|
// the first thing we will do is to output the prompt, so set color accordingly
|
||||||
set_console_state(CONSOLE_STATE_PROMPT);
|
set_console_state(CONSOLE_STATE_PROMPT);
|
||||||
|
|
||||||
|
if (params.embedding){
|
||||||
|
embd = embd_inp;
|
||||||
|
|
||||||
|
if (embd.size() > 0) {
|
||||||
|
if (llama_eval(ctx, embd.data(), embd.size(), n_past, params.n_threads)) {
|
||||||
|
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto embeddings = llama_get_embeddings(ctx);
|
||||||
|
|
||||||
|
// TODO: print / use the embeddings
|
||||||
|
|
||||||
|
if (params.use_color) {
|
||||||
|
printf(ANSI_COLOR_RESET);
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
while (remaining_tokens > 0 || params.interactive) {
|
while (remaining_tokens > 0 || params.interactive) {
|
||||||
// predict
|
// predict
|
||||||
if (embd.size() > 0) {
|
if (embd.size() > 0) {
|
||||||
|
@ -117,6 +117,10 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
|||||||
params.model = argv[i];
|
params.model = argv[i];
|
||||||
} else if (arg == "-i" || arg == "--interactive") {
|
} else if (arg == "-i" || arg == "--interactive") {
|
||||||
params.interactive = true;
|
params.interactive = true;
|
||||||
|
} else if (arg == "--embedding") {
|
||||||
|
params.embedding = true;
|
||||||
|
} else if (arg == "--interactive-start") {
|
||||||
|
params.interactive = true;
|
||||||
} else if (arg == "--interactive-first") {
|
} else if (arg == "--interactive-first") {
|
||||||
params.interactive_start = true;
|
params.interactive_start = true;
|
||||||
} else if (arg == "-ins" || arg == "--instruct") {
|
} else if (arg == "-ins" || arg == "--instruct") {
|
||||||
|
4
utils.h
4
utils.h
@ -32,13 +32,17 @@ struct gpt_params {
|
|||||||
std::string model = "models/lamma-7B/ggml-model.bin"; // model path
|
std::string model = "models/lamma-7B/ggml-model.bin"; // model path
|
||||||
std::string prompt = "";
|
std::string prompt = "";
|
||||||
|
|
||||||
|
|
||||||
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
|
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
|
||||||
|
|
||||||
bool memory_f16 = false; // use f16 instead of f32 for memory kv
|
bool memory_f16 = false; // use f16 instead of f32 for memory kv
|
||||||
bool random_prompt = false; // do not randomize prompt if none provided
|
bool random_prompt = false; // do not randomize prompt if none provided
|
||||||
bool use_color = false; // use color to distinguish generations and inputs
|
bool use_color = false; // use color to distinguish generations and inputs
|
||||||
bool interactive = false; // interactive mode
|
bool interactive = false; // interactive mode
|
||||||
|
|
||||||
|
bool embedding = false; // get only sentence embedding
|
||||||
bool interactive_start = false; // wait for user input immediately
|
bool interactive_start = false; // wait for user input immediately
|
||||||
|
|
||||||
bool instruct = false; // instruction mode (used for Alpaca models)
|
bool instruct = false; // instruction mode (used for Alpaca models)
|
||||||
bool ignore_eos = false; // do not stop generating after eos
|
bool ignore_eos = false; // do not stop generating after eos
|
||||||
bool perplexity = false; // compute perplexity over the prompt
|
bool perplexity = false; // compute perplexity over the prompt
|
||||||
|
Loading…
Reference in New Issue
Block a user