From 5a3369d8e851891ab365816c9aa6cdbf0f874d7b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 21 Sep 2023 19:51:32 +0200 Subject: [PATCH] llama : llama.h formatting + comments --- llama.cpp | 4 + llama.h | 235 ++++++++++++++++++++++++++++++++++++++++-------------- 2 files changed, 181 insertions(+), 58 deletions(-) diff --git a/llama.cpp b/llama.cpp index 1576c3b86..73a636cea 100644 --- a/llama.cpp +++ b/llama.cpp @@ -7477,6 +7477,10 @@ float * llama_get_logits(struct llama_context * ctx) { return ctx->logits.data(); } +float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) { + return ctx->logits.data() + i*ctx->model.hparams.n_vocab; +} + float * llama_get_embeddings(struct llama_context * ctx) { return ctx->embedding.data(); } diff --git a/llama.h b/llama.h index 54eab8f08..590af79bb 100644 --- a/llama.h +++ b/llama.h @@ -66,26 +66,6 @@ extern "C" { typedef int32_t llama_token; typedef int32_t llama_seq_id; - // data used for batch inference - typedef struct llama_batch { - int32_t n_tokens; - - llama_token * token; - float * embd; - llama_pos * pos; - llama_seq_id * seq_id; - int8_t * logits; // if 0, do not extract logits for that token - - // NOTE: helpers for smooth API transition - can be deprecated in the future - // for future-proof code, use the above fields instead and ignore everything below - // - // pos[i] = all_pos_0 + i*all_pos_1 - // - llama_pos all_pos_0; // used if pos == NULL - llama_pos all_pos_1; // used if pos == NULL - llama_seq_id all_seq_id; // used if seq_id == NULL - } llama_batch; - enum llama_log_level { LLAMA_LOG_LEVEL_ERROR = 2, LLAMA_LOG_LEVEL_WARN = 3, @@ -146,6 +126,35 @@ extern "C" { typedef void (*llama_progress_callback)(float progress, void *ctx); + // Input data for llama_decode + // A llama_batch object can contain input about one or many sequences + // The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens + // + // - token : the token ids of the input (used when embd is NULL) + // - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL) + // - pos : the positions of the respective token in the sequence + // - seq_id : the sequence to which the respective token belongs + // - logits : if zero, the logits for the respective token will not be output + // + typedef struct llama_batch { + int32_t n_tokens; + + llama_token * token; + float * embd; + llama_pos * pos; + llama_seq_id * seq_id; + int8_t * logits; + + // NOTE: helpers for smooth API transition - can be deprecated in the future + // for future-proof code, use the above fields instead and ignore everything below + // + // pos[i] = all_pos_0 + i*all_pos_1 + // + llama_pos all_pos_0; // used if pos == NULL + llama_pos all_pos_1; // used if pos == NULL + llama_seq_id all_seq_id; // used if seq_id == NULL + } llama_batch; + struct llama_context_params { uint32_t seed; // RNG seed, -1 for random int32_t n_ctx; // text context @@ -239,6 +248,7 @@ extern "C" { int32_t n_eval; }; + // Helpers for getting default parameters LLAMA_API struct llama_context_params llama_context_default_params(void); LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(void); @@ -283,8 +293,10 @@ extern "C" { // Get a string describing the model type LLAMA_API int llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size); + // Returns the total size of all the tensors in the model in bytes LLAMA_API uint64_t llama_model_size(const struct llama_model * model); + // Returns the total number of parameters in the model LLAMA_API uint64_t llama_model_n_params(const struct llama_model * model); @@ -305,7 +317,7 @@ extern "C" { const char * path_lora, const char * path_base_model, int n_threads), - "please use llama_model_apply_lora_from_file instead"); + "use llama_model_apply_lora_from_file instead"); LLAMA_API int llama_model_apply_lora_from_file( const struct llama_model * model, @@ -322,20 +334,40 @@ extern "C" { "avoid using this, it will be removed in the future, instead - count the tokens in user code"); // Remove all tokens data of cells in [c0, c1) - LLAMA_API void llama_kv_cache_tokens_rm(struct llama_context * ctx, int32_t c0, int32_t c1); + LLAMA_API void llama_kv_cache_tokens_rm( + struct llama_context * ctx, + int32_t c0, + int32_t c1); // Removes all tokens that belong to the specified sequence and have positions in [p0, p1) - LLAMA_API void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1); + LLAMA_API void llama_kv_cache_seq_rm( + struct llama_context * ctx, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1); // Copy all tokens that belong to the specified sequence to another sequence - LLAMA_API void llama_kv_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1); + // Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence + LLAMA_API void llama_kv_cache_seq_cp( + struct llama_context * ctx, + llama_seq_id seq_id_src, + llama_seq_id seq_id_dst, + llama_pos p0, + llama_pos p1); // Removes all tokens that do not belong to the specified sequence - LLAMA_API void llama_kv_cache_seq_keep(struct llama_context * ctx, llama_seq_id seq_id); + LLAMA_API void llama_kv_cache_seq_keep( + struct llama_context * ctx, + llama_seq_id seq_id); // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1) // If the KV cache is RoPEd, the KV data is updated accordingly - LLAMA_API void llama_kv_cache_seq_shift(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta); + LLAMA_API void llama_kv_cache_seq_shift( + struct llama_context * ctx, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + llama_pos delta); // // State / sessions @@ -348,21 +380,35 @@ extern "C" { // Copies the state to the specified destination address. // Destination needs to have allocated enough memory. // Returns the number of bytes copied - LLAMA_API size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst); + LLAMA_API size_t llama_copy_state_data( + struct llama_context * ctx, + uint8_t * dst); // Set the state reading from the specified address // Returns the number of bytes read - LLAMA_API size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src); + LLAMA_API size_t llama_set_state_data( + struct llama_context * ctx, + uint8_t * src); // Save/load session file - LLAMA_API bool llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out); - LLAMA_API bool llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count); + LLAMA_API bool llama_load_session_file( + struct llama_context * ctx, + const char * path_session, + llama_token * tokens_out, + size_t n_token_capacity, + size_t * n_token_count_out); + + LLAMA_API bool llama_save_session_file( + struct llama_context * ctx, + const char * path_session, + const llama_token * tokens, + size_t n_token_count); // // Decoding // - // Run the llama inference to obtain the logits and probabilities for the next token. + // Run the llama inference to obtain the logits and probabilities for the next token(s). // tokens + n_tokens is the provided batch of new tokens to process // n_past is the number of tokens to use from previous eval calls // Returns 0 on success @@ -373,7 +419,7 @@ extern "C" { int32_t n_tokens, int n_past, int n_threads), - "please use llama_decode() instead"); + "use llama_decode() instead"); // Same as llama_eval, but use float matrix input directly. // DEPRECATED: use llama_decode() instead @@ -383,7 +429,7 @@ extern "C" { int32_t n_tokens, int n_past, int n_threads), - "please use llama_decode() instead"); + "use llama_decode() instead"); // Return batch for single sequence of tokens starting at pos_0 // @@ -396,12 +442,14 @@ extern "C" { llama_seq_id seq_id); // Allocates a batch of tokens on the heap - // The batch needs to be freed with llama_batch_free() - // If embd > 0, llama_batch.embd will be allocated with size of n_tokens * embd * sizeof(float) + // The batch has to be freed with llama_batch_free() + // If embd != 0, llama_batch.embd will be allocated with size of n_tokens * embd * sizeof(float) // Otherwise, llama_batch.token will be allocated to store n_tokens llama_token // The rest of the llama_batch members are allocated with size n_tokens // All members are left uninitialized - LLAMA_API struct llama_batch llama_batch_init(int32_t n_tokens, int32_t embd); + LLAMA_API struct llama_batch llama_batch_init( + int32_t n_tokens, + int32_t embd); // Frees a batch of tokens allocated with llama_batch_init() LLAMA_API void llama_batch_free(struct llama_batch batch); @@ -417,11 +465,15 @@ extern "C" { // Token logits obtained from the last call to llama_eval() // The logits for the last token are stored in the last row - // Can be mutated in order to change the probabilities of the next token - // Rows: n_tokens + // Logits for which llama_batch.logits[i] == 0 are undefined + // Rows: n_tokens provided with llama_batch // Cols: n_vocab LLAMA_API float * llama_get_logits(struct llama_context * ctx); + // Logits for the ith token. Equivalent to: + // llama_get_logits(ctx) + i*n_vocab + LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i); + // Get the embeddings for the input // shape: [n_embd] (1-dimensional) LLAMA_API float * llama_get_embeddings(struct llama_context * ctx); @@ -502,10 +554,21 @@ extern "C" { LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed); /// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. - LLAMA_API void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float penalty); + LLAMA_API void llama_sample_repetition_penalty( + struct llama_context * ctx, + llama_token_data_array * candidates, + const llama_token * last_tokens, + size_t last_tokens_size, + float penalty); /// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. - LLAMA_API void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float alpha_frequency, float alpha_presence); + LLAMA_API void llama_sample_frequency_and_presence_penalties( + struct llama_context * ctx, + llama_token_data_array * candidates, + const llama_token * last_tokens, + size_t last_tokens_size, + float alpha_frequency, + float alpha_presence); /// @details Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806 /// @param candidates A vector of `llama_token_data` containing the candidate tokens, the logits must be directly extracted from the original generation context without being sorted. @@ -518,26 +581,54 @@ extern "C" { float scale); /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. - LLAMA_API void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates); + LLAMA_API void llama_sample_softmax( + struct llama_context * ctx, + llama_token_data_array * candidates); /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 - LLAMA_API void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int k, size_t min_keep); + LLAMA_API void llama_sample_top_k( + struct llama_context * ctx, + llama_token_data_array * candidates, + int k, + size_t min_keep); /// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 - LLAMA_API void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep); + LLAMA_API void llama_sample_top_p( + struct llama_context * ctx, + llama_token_data_array * candidates, + float p, + size_t min_keep); /// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. - LLAMA_API void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep); + LLAMA_API void llama_sample_tail_free( + struct llama_context * ctx, + llama_token_data_array * candidates, + float z, + size_t min_keep); /// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. - LLAMA_API void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep); - LLAMA_API void llama_sample_temp(struct llama_context * ctx, llama_token_data_array * candidates, float temp); + LLAMA_API void llama_sample_typical( + struct llama_context * ctx, + llama_token_data_array * candidates, + float p, + size_t min_keep); - LLAMA_API DEPRECATED(void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates, float temp), - "Use llama_sample_temp instead"); + LLAMA_API void llama_sample_temp( + struct llama_context * ctx, + llama_token_data_array * candidates, + float temp); + + LLAMA_API DEPRECATED(void llama_sample_temperature( + struct llama_context * ctx, + llama_token_data_array * candidates, + float temp), + "use llama_sample_temp instead"); /// @details Apply constraints from grammar - LLAMA_API void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar); + LLAMA_API void llama_sample_grammar( + struct llama_context * ctx, + llama_token_data_array * candidates, + const struct llama_grammar * grammar); /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. @@ -545,23 +636,41 @@ extern "C" { /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. /// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm. /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. - LLAMA_API llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float * mu); + LLAMA_API llama_token llama_sample_token_mirostat( + struct llama_context * ctx, + llama_token_data_array * candidates, + float tau, + float eta, + int m, + float * mu); /// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. - LLAMA_API llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu); + LLAMA_API llama_token llama_sample_token_mirostat_v2( + struct llama_context * ctx, + llama_token_data_array * candidates, + float tau, + float eta, + float * mu); /// @details Selects the token with the highest probability. - LLAMA_API llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates); + LLAMA_API llama_token llama_sample_token_greedy( + struct llama_context * ctx, + llama_token_data_array * candidates); /// @details Randomly selects a token from the candidates based on their probabilities. - LLAMA_API llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates); + LLAMA_API llama_token llama_sample_token( + struct llama_context * ctx, + llama_token_data_array * candidates); /// @details Accepts the sampled token into the grammar - LLAMA_API void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token); + LLAMA_API void llama_grammar_accept_token( + struct llama_context * ctx, + struct llama_grammar * grammar, + llama_token token); // // Beam search @@ -569,9 +678,10 @@ extern "C" { struct llama_beam_view { const llama_token * tokens; + size_t n_tokens; - float p; // Cumulative beam probability (renormalized relative to all beams) - bool eob; // Callback should set this to true when a beam is at end-of-beam. + float p; // Cumulative beam probability (renormalized relative to all beams) + bool eob; // Callback should set this to true when a beam is at end-of-beam. }; // Passed to beam_search_callback function. @@ -580,9 +690,10 @@ extern "C" { // These pointers are valid only during the synchronous callback, so should not be saved. struct llama_beams_state { struct llama_beam_view * beam_views; + size_t n_beams; // Number of elements in beam_views[]. size_t common_prefix_length; // Current max length of prefix tokens shared by all beams. - bool last_call; // True iff this is the last callback invocation. + bool last_call; // True iff this is the last callback invocation. }; // Type of pointer to the beam_search_callback function. @@ -598,10 +709,18 @@ extern "C" { /// @param n_past Number of tokens already evaluated. /// @param n_predict Maximum number of tokens to predict. EOS may occur earlier. /// @param n_threads Number of threads as passed to llama_eval(). - LLAMA_API void llama_beam_search(struct llama_context * ctx, llama_beam_search_callback_fn_t callback, void * callback_data, size_t n_beams, int n_past, int n_predict, int n_threads); + LLAMA_API void llama_beam_search( + struct llama_context * ctx, + llama_beam_search_callback_fn_t callback, + void * callback_data, + size_t n_beams, + int n_past, + int n_predict, + int n_threads); // Performance information LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx); + LLAMA_API void llama_print_timings(struct llama_context * ctx); LLAMA_API void llama_reset_timings(struct llama_context * ctx);