2023-10-11 19:35:46 +00:00
|
|
|
#pragma once
|
|
|
|
|
|
|
|
#include "llama.h"
|
|
|
|
|
2024-09-09 21:36:09 +00:00
|
|
|
#include "common.h"
|
|
|
|
|
2023-10-11 19:35:46 +00:00
|
|
|
#include <string>
|
2024-04-24 09:08:36 +00:00
|
|
|
#include <vector>
|
2023-10-11 19:35:46 +00:00
|
|
|
|
2024-10-10 20:57:42 +00:00
|
|
|
// common_sampler extends llama_sampler with additional functionality:
|
2024-09-07 12:16:19 +00:00
|
|
|
//
|
|
|
|
// - grammar support
|
|
|
|
// - custom sampler logic based on the parameters
|
|
|
|
// - history of the last accepted tokens
|
|
|
|
// - performance metrics
|
|
|
|
//
|
|
|
|
// This goal is to have a common implementation of the sampling logic shared across the examples.
|
|
|
|
// For example, depending on the temperature, the sampling chain can be very simple (greedy) or more
|
|
|
|
// complex (top-k, top-p, etc).
|
|
|
|
//
|
|
|
|
// Another example is related to the grammar. In general, the grammar constraints applied on the full
|
|
|
|
// vocabulary can be very taxing. To improve performance, the grammar can be applied only to the sampled
|
|
|
|
// token in order to verify if it fits the grammar. And only if the token doesn't fit the grammar, the
|
|
|
|
// grammar constraints are applied to the full vocabulary and the token is resampled.
|
|
|
|
//
|
2024-10-10 20:57:42 +00:00
|
|
|
// The common_sampler also maintains a container with the last accepted tokens. In the future, this can
|
2024-09-07 12:16:19 +00:00
|
|
|
// be moved into the core llama library.
|
|
|
|
//
|
2024-10-10 20:57:42 +00:00
|
|
|
// For convenience, the common_sampler also maintains a container with the current candidate tokens.
|
2024-09-07 12:16:19 +00:00
|
|
|
// This can be used to access the probabilities of the rest of the non-sampled tokens.
|
|
|
|
//
|
|
|
|
// TODO: measure grammar performance
|
|
|
|
//
|
2024-04-24 09:08:36 +00:00
|
|
|
|
2024-10-10 20:57:42 +00:00
|
|
|
struct common_sampler;
|
2023-10-11 19:35:46 +00:00
|
|
|
|
2024-09-07 12:16:19 +00:00
|
|
|
// llama_sampler API overloads
|
2023-10-11 19:35:46 +00:00
|
|
|
|
2024-11-24 12:55:16 +00:00
|
|
|
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params);
|
2023-10-18 13:21:57 +00:00
|
|
|
|
2024-10-10 20:57:42 +00:00
|
|
|
void common_sampler_free(struct common_sampler * gsmpl);
|
2023-10-18 13:21:57 +00:00
|
|
|
|
2024-09-07 12:16:19 +00:00
|
|
|
// if accept_grammar is true, the token is accepted both by the sampling chain and the grammar
|
2024-10-10 20:57:42 +00:00
|
|
|
void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar);
|
|
|
|
void common_sampler_reset (struct common_sampler * gsmpl);
|
|
|
|
struct common_sampler * common_sampler_clone (struct common_sampler * gsmpl);
|
2023-10-18 13:21:57 +00:00
|
|
|
|
2024-09-07 12:16:19 +00:00
|
|
|
// arguments can be nullptr to skip printing
|
2024-10-10 20:57:42 +00:00
|
|
|
void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl);
|
2024-04-24 09:08:36 +00:00
|
|
|
|
2024-09-07 12:16:19 +00:00
|
|
|
// extended sampling implementation:
|
|
|
|
//
|
|
|
|
// - set logits
|
|
|
|
// - apply the configured sampler chain
|
|
|
|
// - check if the token fits the grammar (if any)
|
|
|
|
// - if not: resample by first applying the grammar constraints and then sampling again (slower path)
|
|
|
|
//
|
|
|
|
// if grammar_first is true, the grammar is applied before the samplers (slower)
|
|
|
|
// useful in cases where all the resulting candidates (not just the sampled one) must fit the grammar
|
|
|
|
//
|
2024-10-10 20:57:42 +00:00
|
|
|
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false);
|
2023-10-11 19:35:46 +00:00
|
|
|
|
2024-11-17 16:55:27 +00:00
|
|
|
// generalized version of common_sampler_sample
|
|
|
|
//
|
2024-11-24 10:50:17 +00:00
|
|
|
// will cross-reference the sampled tokens with a batch of draft tokens and accept those that match
|
|
|
|
// if the sampler disagrees at some point, we stop and return the accepted tokens up to now
|
2024-11-17 16:55:27 +00:00
|
|
|
//
|
2024-11-24 10:50:17 +00:00
|
|
|
// common_sampler_sample_n(gsmpl, ctx, { idx }, {});
|
|
|
|
//
|
|
|
|
// is equivalent to
|
|
|
|
//
|
|
|
|
// common_sampler_sample(gsmpl, ctx, idx);
|
|
|
|
// common_sampler_accept(gsmpl, token, true);
|
2024-11-17 16:55:27 +00:00
|
|
|
//
|
|
|
|
// requires: idxs.size() == draft.size() + 1
|
|
|
|
//
|
|
|
|
// returns at least 1 token, up to idxs.size()
|
|
|
|
//
|
2024-11-24 10:50:17 +00:00
|
|
|
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first = false);
|
2024-11-15 06:20:28 +00:00
|
|
|
|
2024-11-22 09:31:28 +00:00
|
|
|
// assume idxs == [ 0, 1, 2, ..., draft.size() ]
|
2024-11-24 10:50:17 +00:00
|
|
|
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false);
|
2024-11-21 19:27:14 +00:00
|
|
|
|
2024-10-10 20:57:42 +00:00
|
|
|
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl);
|
2024-09-10 16:04:25 +00:00
|
|
|
|
2024-09-07 12:16:19 +00:00
|
|
|
// helpers
|
2023-10-20 18:07:23 +00:00
|
|
|
|
2024-09-07 12:16:19 +00:00
|
|
|
// access the internal list of current candidate tokens
|
2024-10-10 20:57:42 +00:00
|
|
|
llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl);
|
2023-10-20 18:07:23 +00:00
|
|
|
|
2024-09-07 12:16:19 +00:00
|
|
|
// get the last accepted token
|
2024-10-10 20:57:42 +00:00
|
|
|
llama_token common_sampler_last(const struct common_sampler * gsmpl);
|
2023-12-05 10:05:51 +00:00
|
|
|
|
2024-09-07 12:16:19 +00:00
|
|
|
// print the sampler chain into a string
|
2024-10-10 20:57:42 +00:00
|
|
|
std::string common_sampler_print(const struct common_sampler * gsmpl);
|
2023-10-20 18:07:23 +00:00
|
|
|
|
2024-09-07 12:16:19 +00:00
|
|
|
// get a string representation of the last accepted tokens
|
2024-10-10 20:57:42 +00:00
|
|
|
std::string common_sampler_prev_str(common_sampler * gsmpl, llama_context * ctx, int n);
|
2024-05-22 17:04:20 +00:00
|
|
|
|
2024-10-10 20:57:42 +00:00
|
|
|
char common_sampler_type_to_chr(enum common_sampler_type cnstr);
|
|
|
|
std::string common_sampler_type_to_str(enum common_sampler_type cnstr);
|
2024-05-22 17:04:20 +00:00
|
|
|
|
2024-10-10 20:57:42 +00:00
|
|
|
std::vector<enum common_sampler_type> common_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names);
|
|
|
|
std::vector<enum common_sampler_type> common_sampler_types_from_chars(const std::string & chars);
|