mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 02:44:36 +00:00
llama : move vocab, grammar and sampling into separate files (#8508)
* llama : move sampling code into llama-sampling ggml-ci * llama : move grammar code into llama-grammar ggml-ci * cont ggml-ci * cont : pre-fetch rules * cont ggml-ci * llama : deprecate llama_sample_grammar * llama : move tokenizers into llama-vocab ggml-ci * make : update llama.cpp deps [no ci] * llama : redirect external API to internal APIs ggml-ci * llama : suffix the internal APIs with "_impl" ggml-ci * llama : clean-up
This commit is contained in:
parent
751fcfc6c3
commit
938943cdbf
32
Makefile
32
Makefile
@ -876,6 +876,9 @@ OBJ_GGML += \
|
|||||||
|
|
||||||
OBJ_LLAMA = \
|
OBJ_LLAMA = \
|
||||||
src/llama.o \
|
src/llama.o \
|
||||||
|
src/llama-vocab.o \
|
||||||
|
src/llama-grammar.o \
|
||||||
|
src/llama-sampling.o \
|
||||||
src/unicode.o \
|
src/unicode.o \
|
||||||
src/unicode-data.o
|
src/unicode-data.o
|
||||||
|
|
||||||
@ -1055,6 +1058,10 @@ src/unicode-data.o: \
|
|||||||
|
|
||||||
src/llama.o: \
|
src/llama.o: \
|
||||||
src/llama.cpp \
|
src/llama.cpp \
|
||||||
|
src/llama-impl.h \
|
||||||
|
src/llama-vocab.h \
|
||||||
|
src/llama-grammar.h \
|
||||||
|
src/llama-sampling.h \
|
||||||
src/unicode.h \
|
src/unicode.h \
|
||||||
include/llama.h \
|
include/llama.h \
|
||||||
ggml/include/ggml-cuda.h \
|
ggml/include/ggml-cuda.h \
|
||||||
@ -1064,6 +1071,29 @@ src/llama.o: \
|
|||||||
ggml/include/ggml-backend.h
|
ggml/include/ggml-backend.h
|
||||||
$(CXX) $(CXXFLAGS) -c $< -o $@
|
$(CXX) $(CXXFLAGS) -c $< -o $@
|
||||||
|
|
||||||
|
src/llama-vocab.o: \
|
||||||
|
src/llama-vocab.cpp \
|
||||||
|
src/llama-vocab.h \
|
||||||
|
src/llama-impl.h \
|
||||||
|
include/llama.h
|
||||||
|
$(CXX) $(CXXFLAGS) -c $< -o $@
|
||||||
|
|
||||||
|
src/llama-grammar.o: \
|
||||||
|
src/llama-grammar.cpp \
|
||||||
|
src/llama-grammar.h \
|
||||||
|
src/llama-impl.h \
|
||||||
|
src/llama-vocab.h \
|
||||||
|
src/llama-sampling.h \
|
||||||
|
include/llama.h
|
||||||
|
$(CXX) $(CXXFLAGS) -c $< -o $@
|
||||||
|
|
||||||
|
src/llama-sampling.o: \
|
||||||
|
src/llama-sampling.cpp \
|
||||||
|
src/llama-sampling.h \
|
||||||
|
src/llama-impl.h \
|
||||||
|
include/llama.h
|
||||||
|
$(CXX) $(CXXFLAGS) -c $< -o $@
|
||||||
|
|
||||||
$(LIB_LLAMA): \
|
$(LIB_LLAMA): \
|
||||||
$(OBJ_LLAMA) \
|
$(OBJ_LLAMA) \
|
||||||
$(LIB_GGML)
|
$(LIB_GGML)
|
||||||
@ -1439,7 +1469,7 @@ run-benchmark-matmult: llama-benchmark-matmult
|
|||||||
.PHONY: run-benchmark-matmult swift
|
.PHONY: run-benchmark-matmult swift
|
||||||
|
|
||||||
tests/test-llama-grammar: tests/test-llama-grammar.cpp \
|
tests/test-llama-grammar: tests/test-llama-grammar.cpp \
|
||||||
$(OBJ_GGML) $(OBJ_COMMON) src/unicode.o src/unicode-data.o
|
$(OBJ_ALL)
|
||||||
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
|
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
|
||||||
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
|
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
|
||||||
|
|
||||||
|
@ -4,6 +4,9 @@ import PackageDescription
|
|||||||
|
|
||||||
var sources = [
|
var sources = [
|
||||||
"src/llama.cpp",
|
"src/llama.cpp",
|
||||||
|
"src/llama-vocab.cpp",
|
||||||
|
"src/llama-grammar.cpp",
|
||||||
|
"src/llama-sampling.cpp",
|
||||||
"src/unicode.cpp",
|
"src/unicode.cpp",
|
||||||
"src/unicode-data.cpp",
|
"src/unicode-data.cpp",
|
||||||
"ggml/src/ggml.c",
|
"ggml/src/ggml.c",
|
||||||
|
@ -330,7 +330,7 @@ static llama_token llama_sampling_sample_impl(
|
|||||||
llama_token_data_array single_token_data_array = { &single_token_data, 1, false };
|
llama_token_data_array single_token_data_array = { &single_token_data, 1, false };
|
||||||
|
|
||||||
// Apply grammar constraints to the single token
|
// Apply grammar constraints to the single token
|
||||||
llama_sample_grammar(ctx_main, &single_token_data_array, ctx_sampling->grammar);
|
llama_grammar_sample(ctx_sampling->grammar, ctx_main, &single_token_data_array);
|
||||||
|
|
||||||
// Check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY
|
// Check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY
|
||||||
bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
|
bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
|
||||||
@ -421,7 +421,7 @@ static llama_token_data_array llama_sampling_prepare_impl(
|
|||||||
|
|
||||||
// apply grammar checks before sampling logic
|
// apply grammar checks before sampling logic
|
||||||
if (apply_grammar && ctx_sampling->grammar != NULL) {
|
if (apply_grammar && ctx_sampling->grammar != NULL) {
|
||||||
llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar);
|
llama_grammar_sample(ctx_sampling->grammar, ctx_main, &cur_p);
|
||||||
}
|
}
|
||||||
|
|
||||||
return cur_p;
|
return cur_p;
|
||||||
@ -455,6 +455,6 @@ void llama_sampling_accept(
|
|||||||
ctx_sampling->prev.push_back(id);
|
ctx_sampling->prev.push_back(id);
|
||||||
|
|
||||||
if (ctx_sampling->grammar != NULL && apply_grammar) {
|
if (ctx_sampling->grammar != NULL && apply_grammar) {
|
||||||
llama_grammar_accept_token(ctx_main, ctx_sampling->grammar, id);
|
llama_grammar_accept_token(ctx_sampling->grammar, ctx_main, id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -16,20 +16,25 @@ static bool llama_sample_grammar_string(struct llama_grammar * grammar, const st
|
|||||||
auto decoded = decode_utf8(input_str, {});
|
auto decoded = decode_utf8(input_str, {});
|
||||||
const auto & code_points = decoded.first;
|
const auto & code_points = decoded.first;
|
||||||
|
|
||||||
|
const llama_grammar_rules & rules = llama_grammar_get_rules (grammar);
|
||||||
|
llama_grammar_stacks & cur_stacks = llama_grammar_get_stacks(grammar);
|
||||||
|
|
||||||
size_t pos = 0;
|
size_t pos = 0;
|
||||||
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
|
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
|
||||||
auto prev_stacks = grammar->stacks;
|
const llama_grammar_stacks prev_stacks = llama_grammar_get_stacks(grammar); // copy
|
||||||
llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
|
|
||||||
if (grammar->stacks.empty()) {
|
llama_grammar_accept(rules, prev_stacks, *it, cur_stacks);
|
||||||
|
|
||||||
|
if (cur_stacks.empty()) {
|
||||||
error_pos = pos;
|
error_pos = pos;
|
||||||
error_msg = "Unexpected character '" + unicode_cpt_to_utf8(*it) + "'";
|
error_msg = "Unexpected character '" + unicode_cpt_to_utf8(*it) + "'";
|
||||||
grammar->stacks = prev_stacks;
|
cur_stacks = prev_stacks;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
++pos;
|
++pos;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (const auto & stack : grammar->stacks) {
|
for (const auto & stack : cur_stacks) {
|
||||||
if (stack.empty()) {
|
if (stack.empty()) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -965,6 +965,10 @@ extern "C" {
|
|||||||
bool remove_special,
|
bool remove_special,
|
||||||
bool unparse_special);
|
bool unparse_special);
|
||||||
|
|
||||||
|
//
|
||||||
|
// Chat templates
|
||||||
|
//
|
||||||
|
|
||||||
/// Apply chat template. Inspired by hf apply_chat_template() on python.
|
/// Apply chat template. Inspired by hf apply_chat_template() on python.
|
||||||
/// Both "model" and "custom_template" are optional, but at least one is required. "custom_template" has higher precedence than "model"
|
/// Both "model" and "custom_template" are optional, but at least one is required. "custom_template" has higher precedence than "model"
|
||||||
/// NOTE: This function does not use a jinja parser. It only support a pre-defined list of template. See more: https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template
|
/// NOTE: This function does not use a jinja parser. It only support a pre-defined list of template. See more: https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template
|
||||||
@ -1003,6 +1007,23 @@ extern "C" {
|
|||||||
|
|
||||||
LLAMA_API struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar);
|
LLAMA_API struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar);
|
||||||
|
|
||||||
|
/// @details Apply constraints from grammar
|
||||||
|
LLAMA_API void llama_grammar_sample(
|
||||||
|
const struct llama_grammar * grammar,
|
||||||
|
const struct llama_context * ctx,
|
||||||
|
llama_token_data_array * candidates);
|
||||||
|
LLAMA_API DEPRECATED(void llama_sample_grammar(
|
||||||
|
struct llama_context * ctx,
|
||||||
|
llama_token_data_array * candidates,
|
||||||
|
const struct llama_grammar * grammar),
|
||||||
|
"use llama_grammar_sample instead");
|
||||||
|
|
||||||
|
/// @details Accepts the sampled token into the grammar
|
||||||
|
LLAMA_API void llama_grammar_accept_token(
|
||||||
|
struct llama_grammar * grammar,
|
||||||
|
struct llama_context * ctx,
|
||||||
|
llama_token token);
|
||||||
|
|
||||||
//
|
//
|
||||||
// Sampling functions
|
// Sampling functions
|
||||||
//
|
//
|
||||||
@ -1084,12 +1105,6 @@ extern "C" {
|
|||||||
llama_token_data_array * candidates,
|
llama_token_data_array * candidates,
|
||||||
float temp);
|
float temp);
|
||||||
|
|
||||||
/// @details Apply constraints from grammar
|
|
||||||
LLAMA_API void llama_sample_grammar(
|
|
||||||
struct llama_context * ctx,
|
|
||||||
llama_token_data_array * candidates,
|
|
||||||
const struct llama_grammar * grammar);
|
|
||||||
|
|
||||||
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
|
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
|
||||||
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
|
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
|
||||||
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
|
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
|
||||||
@ -1127,12 +1142,6 @@ extern "C" {
|
|||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
llama_token_data_array * candidates);
|
llama_token_data_array * candidates);
|
||||||
|
|
||||||
/// @details Accepts the sampled token into the grammar
|
|
||||||
LLAMA_API void llama_grammar_accept_token(
|
|
||||||
struct llama_context * ctx,
|
|
||||||
struct llama_grammar * grammar,
|
|
||||||
llama_token token);
|
|
||||||
|
|
||||||
//
|
//
|
||||||
// Model split
|
// Model split
|
||||||
//
|
//
|
||||||
@ -1175,34 +1184,41 @@ extern "C" {
|
|||||||
|
|
||||||
struct ggml_tensor;
|
struct ggml_tensor;
|
||||||
|
|
||||||
|
const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(
|
||||||
|
struct llama_context * ctx
|
||||||
|
);
|
||||||
|
|
||||||
struct llama_partial_utf8 {
|
struct llama_partial_utf8 {
|
||||||
uint32_t value; // bit value so far (unshifted)
|
uint32_t value; // bit value so far (unshifted)
|
||||||
int n_remain; // num bytes remaining; -1 indicates invalid sequence
|
int n_remain; // num bytes remaining; -1 indicates invalid sequence
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llama_grammar {
|
|
||||||
const std::vector<std::vector<llama_grammar_element>> rules;
|
|
||||||
std::vector<std::vector<const llama_grammar_element *>> stacks;
|
|
||||||
|
|
||||||
// buffer for partially generated UTF-8 sequence from accepted tokens
|
|
||||||
llama_partial_utf8 partial_utf8;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct llama_grammar_candidate {
|
struct llama_grammar_candidate {
|
||||||
size_t index;
|
size_t index;
|
||||||
const uint32_t * code_points;
|
const uint32_t * code_points;
|
||||||
llama_partial_utf8 partial_utf8;
|
llama_partial_utf8 partial_utf8;
|
||||||
};
|
};
|
||||||
|
|
||||||
const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(
|
using llama_grammar_rule = std::vector< llama_grammar_element>;
|
||||||
struct llama_context * ctx
|
using llama_grammar_stack = std::vector<const llama_grammar_element *>;
|
||||||
);
|
|
||||||
|
using llama_grammar_rules = std::vector<llama_grammar_rule>;
|
||||||
|
using llama_grammar_stacks = std::vector<llama_grammar_stack>;
|
||||||
|
using llama_grammar_candidates = std::vector<llama_grammar_candidate>;
|
||||||
|
|
||||||
|
const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar);
|
||||||
|
llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar);
|
||||||
|
|
||||||
void llama_grammar_accept(
|
void llama_grammar_accept(
|
||||||
const std::vector<std::vector<llama_grammar_element>> & rules,
|
const llama_grammar_rules & rules,
|
||||||
const std::vector<std::vector<const llama_grammar_element *>> & stacks,
|
const llama_grammar_stacks & stacks,
|
||||||
const uint32_t chr,
|
const uint32_t chr,
|
||||||
std::vector<std::vector<const llama_grammar_element *>> & new_stacks);
|
llama_grammar_stacks & new_stacks);
|
||||||
|
|
||||||
|
std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
|
||||||
|
const llama_grammar_rules & rules,
|
||||||
|
const llama_grammar_stack & stack,
|
||||||
|
const llama_grammar_candidates & candidates);
|
||||||
|
|
||||||
std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
|
std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
|
||||||
const std::string & src,
|
const std::string & src,
|
||||||
|
@ -14,6 +14,9 @@ endif()
|
|||||||
add_library(llama
|
add_library(llama
|
||||||
../include/llama.h
|
../include/llama.h
|
||||||
llama.cpp
|
llama.cpp
|
||||||
|
llama-vocab.cpp
|
||||||
|
llama-grammar.cpp
|
||||||
|
llama-sampling.cpp
|
||||||
unicode.h
|
unicode.h
|
||||||
unicode.cpp
|
unicode.cpp
|
||||||
unicode-data.cpp
|
unicode-data.cpp
|
||||||
|
539
src/llama-grammar.cpp
Normal file
539
src/llama-grammar.cpp
Normal file
@ -0,0 +1,539 @@
|
|||||||
|
#include "llama-grammar.h"
|
||||||
|
|
||||||
|
#include "llama-vocab.h"
|
||||||
|
#include "llama-sampling.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
|
||||||
|
// Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as
|
||||||
|
// pointer. If an invalid sequence is encountered, returns `llama_partial_utf8.n_remain == -1`.
|
||||||
|
std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
|
||||||
|
const std::string & src,
|
||||||
|
llama_partial_utf8 partial_start) {
|
||||||
|
static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
|
||||||
|
const char * pos = src.c_str();
|
||||||
|
std::vector<uint32_t> code_points;
|
||||||
|
|
||||||
|
// common english strings have the same number of codepoints and bytes. `+ 1` for the terminating 0.
|
||||||
|
code_points.reserve(src.size() + 1);
|
||||||
|
uint32_t value = partial_start.value;
|
||||||
|
int n_remain = partial_start.n_remain;
|
||||||
|
|
||||||
|
// continue previous decode, if applicable
|
||||||
|
while (*pos != 0 && n_remain > 0) {
|
||||||
|
uint8_t next_byte = static_cast<uint8_t>(*pos);
|
||||||
|
if ((next_byte >> 6) != 2) {
|
||||||
|
// invalid sequence, abort
|
||||||
|
code_points.push_back(0);
|
||||||
|
return std::make_pair(std::move(code_points), llama_partial_utf8{ 0, -1 });
|
||||||
|
}
|
||||||
|
value = (value << 6) + (next_byte & 0x3F);
|
||||||
|
++pos;
|
||||||
|
--n_remain;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (partial_start.n_remain > 0 && n_remain == 0) {
|
||||||
|
code_points.push_back(value);
|
||||||
|
}
|
||||||
|
|
||||||
|
// decode any subsequent utf-8 sequences, which may end in an incomplete one
|
||||||
|
while (*pos != 0) {
|
||||||
|
uint8_t first_byte = static_cast<uint8_t>(*pos);
|
||||||
|
uint8_t highbits = first_byte >> 4;
|
||||||
|
n_remain = lookup[highbits] - 1;
|
||||||
|
|
||||||
|
if (n_remain < 0) {
|
||||||
|
// invalid sequence, abort
|
||||||
|
code_points.clear();
|
||||||
|
code_points.push_back(0);
|
||||||
|
return std::make_pair(std::move(code_points), llama_partial_utf8{ 0, n_remain });
|
||||||
|
}
|
||||||
|
|
||||||
|
uint8_t mask = (1 << (7 - n_remain)) - 1;
|
||||||
|
value = first_byte & mask;
|
||||||
|
|
||||||
|
++pos;
|
||||||
|
while (*pos != 0 && n_remain > 0) {
|
||||||
|
value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
|
||||||
|
++pos;
|
||||||
|
--n_remain;
|
||||||
|
}
|
||||||
|
if (n_remain == 0) {
|
||||||
|
code_points.push_back(value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
code_points.push_back(0);
|
||||||
|
|
||||||
|
return std::make_pair(std::move(code_points), llama_partial_utf8{ value, n_remain });
|
||||||
|
}
|
||||||
|
|
||||||
|
const llama_grammar_rules & llama_grammar_get_rules(const struct llama_grammar * grammar) {
|
||||||
|
return grammar->rules;
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar) {
|
||||||
|
return grammar->stacks;
|
||||||
|
}
|
||||||
|
|
||||||
|
// returns true iff pos points to the end of one of the definitions of a rule
|
||||||
|
static bool llama_grammar_is_end_of_sequence(const llama_grammar_element * pos) {
|
||||||
|
switch (pos->type) {
|
||||||
|
case LLAMA_GRETYPE_END: return true; // NOLINT
|
||||||
|
case LLAMA_GRETYPE_ALT: return true; // NOLINT
|
||||||
|
default: return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// returns true iff chr satisfies the char range at pos (regular or inverse range)
|
||||||
|
// asserts that pos is pointing to a char range element
|
||||||
|
static std::pair<bool, const llama_grammar_element *> llama_grammar_match_char(
|
||||||
|
const llama_grammar_element * pos,
|
||||||
|
const uint32_t chr) {
|
||||||
|
|
||||||
|
bool found = false;
|
||||||
|
bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR || pos->type == LLAMA_GRETYPE_CHAR_ANY;
|
||||||
|
|
||||||
|
GGML_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT); // NOLINT
|
||||||
|
|
||||||
|
do {
|
||||||
|
if (pos[1].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) {
|
||||||
|
// inclusive range, e.g. [a-z]
|
||||||
|
found = found || (pos->value <= chr && chr <= pos[1].value);
|
||||||
|
pos += 2;
|
||||||
|
} else if (pos->type == LLAMA_GRETYPE_CHAR_ANY) {
|
||||||
|
// Any character matches "."
|
||||||
|
found = true;
|
||||||
|
pos += 1;
|
||||||
|
} else {
|
||||||
|
// exact char match, e.g. [a] or "a"
|
||||||
|
found = found || pos->value == chr;
|
||||||
|
pos += 1;
|
||||||
|
}
|
||||||
|
} while (pos->type == LLAMA_GRETYPE_CHAR_ALT);
|
||||||
|
|
||||||
|
return std::make_pair(found == is_positive_char, pos);
|
||||||
|
}
|
||||||
|
|
||||||
|
// returns true iff some continuation of the given partial UTF-8 sequence could satisfy the char
|
||||||
|
// range at pos (regular or inverse range)
|
||||||
|
// asserts that pos is pointing to a char range element
|
||||||
|
static bool llama_grammar_match_partial_char(
|
||||||
|
const llama_grammar_element * pos,
|
||||||
|
const llama_partial_utf8 partial_utf8) {
|
||||||
|
bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR || pos->type == LLAMA_GRETYPE_CHAR_ANY;
|
||||||
|
GGML_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT);
|
||||||
|
|
||||||
|
uint32_t partial_value = partial_utf8.value;
|
||||||
|
int n_remain = partial_utf8.n_remain;
|
||||||
|
|
||||||
|
// invalid sequence or 7-bit char split across 2 bytes (overlong)
|
||||||
|
if (n_remain < 0 || (n_remain == 1 && partial_value < 2)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// range of possible code points this partial UTF-8 sequence could complete to
|
||||||
|
uint32_t low = partial_value << (n_remain * 6);
|
||||||
|
uint32_t high = low | ((1 << (n_remain * 6)) - 1);
|
||||||
|
|
||||||
|
if (low == 0) {
|
||||||
|
if (n_remain == 2) {
|
||||||
|
low = 1 << 11;
|
||||||
|
} else if (n_remain == 3) {
|
||||||
|
low = 1 << 16;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
do {
|
||||||
|
if (pos[1].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) {
|
||||||
|
// inclusive range, e.g. [a-z]
|
||||||
|
if (pos->value <= high && low <= pos[1].value) {
|
||||||
|
return is_positive_char;
|
||||||
|
}
|
||||||
|
pos += 2;
|
||||||
|
} else if (pos->type == LLAMA_GRETYPE_CHAR_ANY) {
|
||||||
|
// Any character matches "."
|
||||||
|
return true;
|
||||||
|
} else {
|
||||||
|
// exact char match, e.g. [a] or "a"
|
||||||
|
if (low <= pos->value && pos->value <= high) {
|
||||||
|
return is_positive_char;
|
||||||
|
}
|
||||||
|
pos += 1;
|
||||||
|
}
|
||||||
|
} while (pos->type == LLAMA_GRETYPE_CHAR_ALT);
|
||||||
|
|
||||||
|
return !is_positive_char;
|
||||||
|
}
|
||||||
|
|
||||||
|
// transforms a grammar pushdown stack into N possible stacks, all ending
|
||||||
|
// at a character range (terminal element)
|
||||||
|
static void llama_grammar_advance_stack(
|
||||||
|
const llama_grammar_rules & rules,
|
||||||
|
const llama_grammar_stack & stack,
|
||||||
|
llama_grammar_stacks & new_stacks) {
|
||||||
|
if (stack.empty()) {
|
||||||
|
if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) {
|
||||||
|
new_stacks.emplace_back(stack);
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const llama_grammar_element * pos = stack.back();
|
||||||
|
|
||||||
|
switch (pos->type) {
|
||||||
|
case LLAMA_GRETYPE_RULE_REF: {
|
||||||
|
const size_t rule_id = static_cast<size_t>(pos->value);
|
||||||
|
const llama_grammar_element * subpos = rules[rule_id].data();
|
||||||
|
do {
|
||||||
|
// init new stack without the top (pos)
|
||||||
|
llama_grammar_stack new_stack(stack.begin(), stack.end() - 1);
|
||||||
|
if (!llama_grammar_is_end_of_sequence(pos + 1)) {
|
||||||
|
// if this rule ref is followed by another element, add that to stack
|
||||||
|
new_stack.push_back(pos + 1);
|
||||||
|
}
|
||||||
|
if (!llama_grammar_is_end_of_sequence(subpos)) {
|
||||||
|
// if alternate is nonempty, add to stack
|
||||||
|
new_stack.push_back(subpos);
|
||||||
|
}
|
||||||
|
llama_grammar_advance_stack(rules, new_stack, new_stacks);
|
||||||
|
while (!llama_grammar_is_end_of_sequence(subpos)) {
|
||||||
|
// scan to end of alternate def
|
||||||
|
subpos++;
|
||||||
|
}
|
||||||
|
if (subpos->type == LLAMA_GRETYPE_ALT) {
|
||||||
|
// there's another alternate def of this rule to process
|
||||||
|
subpos++;
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
} while (true);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case LLAMA_GRETYPE_CHAR:
|
||||||
|
case LLAMA_GRETYPE_CHAR_NOT:
|
||||||
|
case LLAMA_GRETYPE_CHAR_ANY:
|
||||||
|
if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) {
|
||||||
|
// only add the stack if it's not a duplicate of one we already have
|
||||||
|
new_stacks.emplace_back(stack);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
// end of alternate (LLAMA_GRETYPE_END, LLAMA_GRETYPE_ALT) or middle of char range
|
||||||
|
// (LLAMA_GRETYPE_CHAR_ALT, LLAMA_GRETYPE_CHAR_RNG_UPPER); stack should never be left on
|
||||||
|
// those
|
||||||
|
GGML_ASSERT(false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// takes a set of possible pushdown stacks on a grammar, which are required to
|
||||||
|
// be positioned at a character range (see `llama_grammar_advance_stack`), and
|
||||||
|
// produces the N possible stacks if the given char is accepted at those
|
||||||
|
// positions
|
||||||
|
void llama_grammar_accept(
|
||||||
|
const llama_grammar_rules & rules,
|
||||||
|
const llama_grammar_stacks & stacks,
|
||||||
|
const uint32_t chr,
|
||||||
|
llama_grammar_stacks & new_stacks) {
|
||||||
|
new_stacks.clear();
|
||||||
|
|
||||||
|
for (const auto & stack : stacks) {
|
||||||
|
if (stack.empty()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto match = llama_grammar_match_char(stack.back(), chr);
|
||||||
|
if (match.first) {
|
||||||
|
const llama_grammar_element * pos = match.second;
|
||||||
|
|
||||||
|
// update top of stack to next element, if any
|
||||||
|
llama_grammar_stack new_stack(stack.begin(), stack.end() - 1);
|
||||||
|
if (!llama_grammar_is_end_of_sequence(pos)) {
|
||||||
|
new_stack.push_back(pos);
|
||||||
|
}
|
||||||
|
llama_grammar_advance_stack(rules, new_stack, new_stacks);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static llama_grammar_candidates llama_grammar_reject_candidates(
|
||||||
|
const llama_grammar_rules & rules,
|
||||||
|
const llama_grammar_stacks & stacks,
|
||||||
|
const llama_grammar_candidates & candidates) {
|
||||||
|
GGML_ASSERT(!stacks.empty()); // REVIEW
|
||||||
|
|
||||||
|
if (candidates.empty()) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
auto rejects = llama_grammar_reject_candidates_for_stack(rules, stacks.front(), candidates);
|
||||||
|
|
||||||
|
for (size_t i = 1, size = stacks.size(); i < size; ++i) {
|
||||||
|
rejects = llama_grammar_reject_candidates_for_stack(rules, stacks[i], rejects);
|
||||||
|
}
|
||||||
|
return rejects;
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
|
||||||
|
const llama_grammar_rules & rules,
|
||||||
|
const llama_grammar_stack & stack,
|
||||||
|
const llama_grammar_candidates & candidates) {
|
||||||
|
|
||||||
|
llama_grammar_candidates rejects;
|
||||||
|
rejects.reserve(candidates.size());
|
||||||
|
|
||||||
|
if (stack.empty()) {
|
||||||
|
for (const auto & tok : candidates) {
|
||||||
|
if (*tok.code_points != 0 || tok.partial_utf8.n_remain != 0) {
|
||||||
|
rejects.push_back(tok);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return rejects;
|
||||||
|
}
|
||||||
|
|
||||||
|
const llama_grammar_element * stack_pos = stack.back();
|
||||||
|
|
||||||
|
llama_grammar_candidates next_candidates;
|
||||||
|
next_candidates.reserve(candidates.size());
|
||||||
|
|
||||||
|
for (const auto & tok : candidates) {
|
||||||
|
if (*tok.code_points == 0) {
|
||||||
|
// reached end of full codepoints in token, reject iff it ended in a partial sequence
|
||||||
|
// that cannot satisfy this position in grammar
|
||||||
|
if (tok.partial_utf8.n_remain != 0 &&
|
||||||
|
!llama_grammar_match_partial_char(stack_pos, tok.partial_utf8)) {
|
||||||
|
rejects.push_back(tok);
|
||||||
|
}
|
||||||
|
} else if (llama_grammar_match_char(stack_pos, *tok.code_points).first) {
|
||||||
|
next_candidates.push_back({ tok.index, tok.code_points + 1, tok.partial_utf8 });
|
||||||
|
} else {
|
||||||
|
rejects.push_back(tok);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto * stack_pos_after = llama_grammar_match_char(stack_pos, 0).second;
|
||||||
|
|
||||||
|
// update top of stack to next element, if any
|
||||||
|
llama_grammar_stack stack_after(stack.begin(), stack.end() - 1);
|
||||||
|
if (!llama_grammar_is_end_of_sequence(stack_pos_after)) {
|
||||||
|
stack_after.push_back(stack_pos_after);
|
||||||
|
}
|
||||||
|
llama_grammar_stacks next_stacks;
|
||||||
|
llama_grammar_advance_stack(rules, stack_after, next_stacks);
|
||||||
|
|
||||||
|
auto next_rejects = llama_grammar_reject_candidates(rules, next_stacks, next_candidates);
|
||||||
|
for (const auto & tok : next_rejects) {
|
||||||
|
rejects.push_back({ tok.index, tok.code_points - 1, tok.partial_utf8 });
|
||||||
|
}
|
||||||
|
|
||||||
|
return rejects;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool llama_grammar_detect_left_recursion(
|
||||||
|
const llama_grammar_rules & rules,
|
||||||
|
size_t rule_index,
|
||||||
|
std::vector<bool> * rules_visited,
|
||||||
|
std::vector<bool> * rules_in_progress,
|
||||||
|
std::vector<bool> * rules_may_be_empty) {
|
||||||
|
if ((*rules_in_progress)[rule_index]) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
(*rules_in_progress)[rule_index] = true;
|
||||||
|
|
||||||
|
const llama_grammar_rule & rule = rules[rule_index];
|
||||||
|
|
||||||
|
// First check if the rule might produce the empty string. This could be done combined with the second
|
||||||
|
// step but it's more readable as two steps.
|
||||||
|
bool at_rule_start = true;
|
||||||
|
for (size_t i = 0; i < rule.size(); i++) {
|
||||||
|
if (llama_grammar_is_end_of_sequence(&rule[i])) {
|
||||||
|
if (at_rule_start) {
|
||||||
|
(*rules_may_be_empty)[rule_index] = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
at_rule_start = true;
|
||||||
|
} else {
|
||||||
|
at_rule_start = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Second, recurse into leftmost nonterminals (or next-leftmost as long as the previous nonterminal may
|
||||||
|
// be empty)
|
||||||
|
bool recurse_into_nonterminal = true;
|
||||||
|
for (size_t i = 0; i < rule.size(); i++) {
|
||||||
|
if (rule[i].type == LLAMA_GRETYPE_RULE_REF && recurse_into_nonterminal) {
|
||||||
|
if (llama_grammar_detect_left_recursion(rules, (size_t)rule[i].value, rules_visited, rules_in_progress, rules_may_be_empty)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (!((*rules_may_be_empty)[(size_t)rule[i].value])) {
|
||||||
|
recurse_into_nonterminal = false;
|
||||||
|
}
|
||||||
|
} else if (llama_grammar_is_end_of_sequence(&rule[i])) {
|
||||||
|
recurse_into_nonterminal = true;
|
||||||
|
} else {
|
||||||
|
recurse_into_nonterminal = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
(*rules_in_progress)[rule_index] = false;
|
||||||
|
(*rules_visited)[rule_index] = true;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// grammar - external
|
||||||
|
//
|
||||||
|
|
||||||
|
struct llama_grammar * llama_grammar_init_impl(
|
||||||
|
const llama_grammar_element ** rules,
|
||||||
|
size_t n_rules,
|
||||||
|
size_t start_rule_index) {
|
||||||
|
const llama_grammar_element * pos;
|
||||||
|
|
||||||
|
// copy rule definitions into vectors
|
||||||
|
llama_grammar_rules vec_rules(n_rules);
|
||||||
|
for (size_t i = 0; i < n_rules; i++) {
|
||||||
|
for (pos = rules[i]; pos->type != LLAMA_GRETYPE_END; pos++) {
|
||||||
|
vec_rules[i].push_back(*pos);
|
||||||
|
}
|
||||||
|
vec_rules[i].push_back({LLAMA_GRETYPE_END, 0});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for left recursion
|
||||||
|
std::vector<bool> rules_visited(n_rules);
|
||||||
|
std::vector<bool> rules_in_progress(n_rules);
|
||||||
|
std::vector<bool> rules_may_be_empty(n_rules);
|
||||||
|
for (size_t i = 0; i < n_rules; i++) {
|
||||||
|
if (rules_visited[i]) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (llama_grammar_detect_left_recursion(vec_rules, i, &rules_visited, &rules_in_progress, &rules_may_be_empty)) {
|
||||||
|
LLAMA_LOG_ERROR("unsupported grammar, left recursion detected for nonterminal at index %zu", i);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// loop over alternates of start rule to build initial stacks
|
||||||
|
llama_grammar_stacks stacks;
|
||||||
|
pos = vec_rules[start_rule_index].data();
|
||||||
|
do {
|
||||||
|
llama_grammar_stack stack;
|
||||||
|
if (!llama_grammar_is_end_of_sequence(pos)) {
|
||||||
|
// if alternate is nonempty, add to stack
|
||||||
|
stack.push_back(pos);
|
||||||
|
}
|
||||||
|
llama_grammar_advance_stack(vec_rules, stack, stacks);
|
||||||
|
while (!llama_grammar_is_end_of_sequence(pos)) {
|
||||||
|
// scan to end of alternate def
|
||||||
|
pos++;
|
||||||
|
}
|
||||||
|
if (pos->type == LLAMA_GRETYPE_ALT) {
|
||||||
|
// there's another alternate def of this rule to process
|
||||||
|
pos++;
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
} while (true);
|
||||||
|
|
||||||
|
// Important: vec_rules has to be moved here, not copied, because stacks contains
|
||||||
|
// pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
|
||||||
|
// then the pointers would be invalidated when the local vec_rules goes out of scope.
|
||||||
|
return new llama_grammar{ std::move(vec_rules), std::move(stacks), {} };
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_grammar_free_impl(struct llama_grammar * grammar) {
|
||||||
|
delete grammar;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * grammar) {
|
||||||
|
llama_grammar * result = new llama_grammar{ grammar->rules, grammar->stacks, grammar->partial_utf8 };
|
||||||
|
|
||||||
|
// redirect elements in stacks to point to new rules
|
||||||
|
for (size_t is = 0; is < result->stacks.size(); is++) {
|
||||||
|
for (size_t ie = 0; ie < result->stacks[is].size(); ie++) {
|
||||||
|
for (size_t ir0 = 0; ir0 < grammar->rules.size(); ir0++) {
|
||||||
|
for (size_t ir1 = 0; ir1 < grammar->rules[ir0].size(); ir1++) {
|
||||||
|
if (grammar->stacks[is][ie] == &grammar->rules[ir0][ir1]) {
|
||||||
|
result->stacks[is][ie] = &result->rules[ir0][ir1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_grammar_sample_impl(const struct llama_grammar * grammar, const struct llama_vocab * vocab, const struct llama_sampling * smpl, llama_token_data_array * candidates) {
|
||||||
|
GGML_ASSERT(grammar);
|
||||||
|
GGML_ASSERT(vocab);
|
||||||
|
|
||||||
|
int64_t t_start_sample_us = ggml_time_us();
|
||||||
|
|
||||||
|
bool allow_eog = false;
|
||||||
|
for (const auto & stack : grammar->stacks) {
|
||||||
|
if (stack.empty()) {
|
||||||
|
allow_eog = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded;
|
||||||
|
candidates_decoded.reserve(candidates->size);
|
||||||
|
|
||||||
|
llama_grammar_candidates candidates_grammar;
|
||||||
|
candidates_grammar.reserve(candidates->size);
|
||||||
|
|
||||||
|
for (size_t i = 0; i < candidates->size; ++i) {
|
||||||
|
const llama_token id = candidates->data[i].id;
|
||||||
|
const std::string & piece = vocab->cache_token_to_piece.at(id);
|
||||||
|
|
||||||
|
if (llama_token_is_eog_impl(*vocab, id)) {
|
||||||
|
if (!allow_eog) {
|
||||||
|
candidates->data[i].logit = -INFINITY;
|
||||||
|
}
|
||||||
|
} else if (piece.empty() || piece[0] == 0) {
|
||||||
|
candidates->data[i].logit = -INFINITY;
|
||||||
|
} else {
|
||||||
|
candidates_decoded.push_back(decode_utf8(piece, grammar->partial_utf8));
|
||||||
|
candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto rejects = llama_grammar_reject_candidates(grammar->rules, grammar->stacks, candidates_grammar);
|
||||||
|
for (const auto & reject : rejects) {
|
||||||
|
candidates->data[reject.index].logit = -INFINITY;
|
||||||
|
}
|
||||||
|
|
||||||
|
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_grammar_accept_token_impl(struct llama_grammar * grammar, const struct llama_vocab * vocab, const struct llama_sampling * smpl, llama_token token) {
|
||||||
|
const int64_t t_start_sample_us = ggml_time_us();
|
||||||
|
|
||||||
|
if (llama_token_is_eog_impl(*vocab, token)) {
|
||||||
|
for (const auto & stack : grammar->stacks) {
|
||||||
|
if (stack.empty()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
GGML_ASSERT(false);
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::string & piece = vocab->cache_token_to_piece.at(token);
|
||||||
|
|
||||||
|
// Note terminating 0 in decoded string
|
||||||
|
const auto decoded = decode_utf8(piece, grammar->partial_utf8);
|
||||||
|
const auto & code_points = decoded.first;
|
||||||
|
|
||||||
|
llama_grammar_stacks tmp_new_stacks;
|
||||||
|
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
|
||||||
|
llama_grammar_accept(grammar->rules, grammar->stacks, *it, tmp_new_stacks);
|
||||||
|
grammar->stacks = tmp_new_stacks;
|
||||||
|
}
|
||||||
|
|
||||||
|
grammar->partial_utf8 = decoded.second;
|
||||||
|
GGML_ASSERT(!grammar->stacks.empty());
|
||||||
|
|
||||||
|
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||||
|
}
|
41
src/llama-grammar.h
Normal file
41
src/llama-grammar.h
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "llama-impl.h"
|
||||||
|
|
||||||
|
struct llama_vocab;
|
||||||
|
struct llama_sampling;
|
||||||
|
|
||||||
|
struct llama_grammar {
|
||||||
|
const llama_grammar_rules rules;
|
||||||
|
llama_grammar_stacks stacks;
|
||||||
|
|
||||||
|
// buffer for partially generated UTF-8 sequence from accepted tokens
|
||||||
|
llama_partial_utf8 partial_utf8;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct llama_grammar * llama_get_grammar(struct llama_context * ctx);
|
||||||
|
|
||||||
|
//
|
||||||
|
// internal API
|
||||||
|
//
|
||||||
|
|
||||||
|
struct llama_grammar * llama_grammar_init_impl(
|
||||||
|
const llama_grammar_element ** rules,
|
||||||
|
size_t n_rules,
|
||||||
|
size_t start_rule_index);
|
||||||
|
|
||||||
|
void llama_grammar_free_impl(struct llama_grammar * grammar);
|
||||||
|
|
||||||
|
struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * grammar);
|
||||||
|
|
||||||
|
void llama_grammar_sample_impl(
|
||||||
|
const struct llama_grammar * grammar,
|
||||||
|
const struct llama_vocab * vocab,
|
||||||
|
const struct llama_sampling * smpl,
|
||||||
|
llama_token_data_array * candidates);
|
||||||
|
|
||||||
|
void llama_grammar_accept_token_impl(
|
||||||
|
struct llama_grammar * grammar,
|
||||||
|
const struct llama_vocab * vocab,
|
||||||
|
const struct llama_sampling * smpl,
|
||||||
|
llama_token token);
|
26
src/llama-impl.h
Normal file
26
src/llama-impl.h
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#define LLAMA_API_INTERNAL
|
||||||
|
#include "llama.h"
|
||||||
|
|
||||||
|
#ifdef __GNUC__
|
||||||
|
#ifdef __MINGW32__
|
||||||
|
#define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
|
||||||
|
#else
|
||||||
|
#define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
|
||||||
|
#endif
|
||||||
|
#else
|
||||||
|
#define LLAMA_ATTRIBUTE_FORMAT(...)
|
||||||
|
#endif
|
||||||
|
|
||||||
|
//
|
||||||
|
// logging
|
||||||
|
//
|
||||||
|
|
||||||
|
LLAMA_ATTRIBUTE_FORMAT(2, 3)
|
||||||
|
void llama_log_internal (ggml_log_level level, const char * format, ...);
|
||||||
|
void llama_log_callback_default(ggml_log_level level, const char * text, void * user_data);
|
||||||
|
|
||||||
|
#define LLAMA_LOG_INFO(...) llama_log_internal(GGML_LOG_LEVEL_INFO , __VA_ARGS__)
|
||||||
|
#define LLAMA_LOG_WARN(...) llama_log_internal(GGML_LOG_LEVEL_WARN , __VA_ARGS__)
|
||||||
|
#define LLAMA_LOG_ERROR(...) llama_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
|
635
src/llama-sampling.cpp
Normal file
635
src/llama-sampling.cpp
Normal file
@ -0,0 +1,635 @@
|
|||||||
|
#include "llama-sampling.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cstring>
|
||||||
|
#include <ctime>
|
||||||
|
#include <cfloat>
|
||||||
|
#include <numeric>
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
|
static void llama_log_softmax(float * array, size_t size) {
|
||||||
|
float max_l = *std::max_element(array, array + size);
|
||||||
|
float sum = 0.f;
|
||||||
|
for (size_t i = 0; i < size; ++i) {
|
||||||
|
float p = expf(array[i] - max_l);
|
||||||
|
sum += p;
|
||||||
|
array[i] = p;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t i = 0; i < size; ++i) {
|
||||||
|
array[i] = logf(array[i] / sum);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_set_rng_seed_impl(struct llama_sampling * smpl, uint32_t seed) {
|
||||||
|
if (seed == LLAMA_DEFAULT_SEED) {
|
||||||
|
seed = time(NULL);
|
||||||
|
}
|
||||||
|
|
||||||
|
smpl->rng.seed(seed);
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_sample_softmax_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) {
|
||||||
|
GGML_ASSERT(candidates->size > 0);
|
||||||
|
|
||||||
|
const int64_t t_start_sample_us = ggml_time_us();
|
||||||
|
|
||||||
|
// Sort the logits in descending order
|
||||||
|
if (!candidates->sorted) {
|
||||||
|
std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
|
||||||
|
return a.logit > b.logit;
|
||||||
|
});
|
||||||
|
candidates->sorted = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
float max_l = candidates->data[0].logit;
|
||||||
|
float cum_sum = 0.0f;
|
||||||
|
for (size_t i = 0; i < candidates->size; ++i) {
|
||||||
|
float p = expf(candidates->data[i].logit - max_l);
|
||||||
|
candidates->data[i].p = p;
|
||||||
|
cum_sum += p;
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < candidates->size; ++i) {
|
||||||
|
candidates->data[i].p /= cum_sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (smpl) {
|
||||||
|
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep) {
|
||||||
|
// TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast
|
||||||
|
// if (k >= (int32_t)candidates->size) {
|
||||||
|
// return;
|
||||||
|
// }
|
||||||
|
|
||||||
|
const int64_t t_start_sample_us = ggml_time_us();
|
||||||
|
|
||||||
|
if (k <= 0) {
|
||||||
|
k = candidates->size;
|
||||||
|
}
|
||||||
|
|
||||||
|
k = std::max(k, (int) min_keep);
|
||||||
|
k = std::min(k, (int) candidates->size);
|
||||||
|
|
||||||
|
// Sort scores in descending order
|
||||||
|
if (!candidates->sorted) {
|
||||||
|
auto comp = [](const llama_token_data & a, const llama_token_data & b) {
|
||||||
|
return a.logit > b.logit;
|
||||||
|
};
|
||||||
|
if (k <= 128) {
|
||||||
|
std::partial_sort(candidates->data, candidates->data + k, candidates->data + candidates->size, comp);
|
||||||
|
} else {
|
||||||
|
constexpr int nbuckets = 128;
|
||||||
|
constexpr float bucket_low = -10.0f;
|
||||||
|
constexpr float bucket_high = 10.0f;
|
||||||
|
constexpr float bucket_scale = nbuckets/(bucket_high - bucket_low);
|
||||||
|
constexpr float bucker_inter = -bucket_low * bucket_scale;
|
||||||
|
|
||||||
|
std::vector<int> bucket_idx(candidates->size);
|
||||||
|
std::vector<int> histo(nbuckets, 0);
|
||||||
|
|
||||||
|
for (int i = 0; i < (int)candidates->size; ++i) {
|
||||||
|
const float val = candidates->data[i].logit;
|
||||||
|
int ib = int(bucket_scale * val + bucker_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low);
|
||||||
|
ib = std::max(0, std::min(nbuckets-1, ib));
|
||||||
|
bucket_idx[i] = ib;
|
||||||
|
++histo[ib];
|
||||||
|
}
|
||||||
|
int nhave = 0;
|
||||||
|
int ib = nbuckets - 1;
|
||||||
|
for ( ; ib >= 0; --ib) {
|
||||||
|
nhave += histo[ib];
|
||||||
|
if (nhave >= k) break;
|
||||||
|
}
|
||||||
|
std::vector<llama_token_data> tmp_tokens(nhave);
|
||||||
|
auto ptr = tmp_tokens.data();
|
||||||
|
std::vector<llama_token_data*> bucket_ptrs;
|
||||||
|
bucket_ptrs.reserve(nbuckets - ib);
|
||||||
|
for (int j = nbuckets - 1; j >= ib; --j) {
|
||||||
|
bucket_ptrs.push_back(ptr);
|
||||||
|
ptr += histo[j];
|
||||||
|
}
|
||||||
|
for (int i = 0; i < (int)candidates->size; ++i) {
|
||||||
|
int j = bucket_idx[i];
|
||||||
|
if (j >= ib) {
|
||||||
|
*bucket_ptrs[nbuckets-1-j]++ = candidates->data[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ptr = tmp_tokens.data();
|
||||||
|
int ndone = 0;
|
||||||
|
for (int j = nbuckets-1; j > ib; --j) {
|
||||||
|
std::sort(ptr, ptr + histo[j], comp);
|
||||||
|
ptr += histo[j];
|
||||||
|
ndone += histo[j];
|
||||||
|
}
|
||||||
|
std::partial_sort(ptr, ptr + k - ndone, ptr + histo[ib], comp);
|
||||||
|
|
||||||
|
std::memcpy(candidates->data, tmp_tokens.data(), k*sizeof(llama_token_data));
|
||||||
|
|
||||||
|
}
|
||||||
|
candidates->sorted = true;
|
||||||
|
}
|
||||||
|
candidates->size = k;
|
||||||
|
|
||||||
|
if (smpl) {
|
||||||
|
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_sample_top_p_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
|
||||||
|
if (p >= 1.0f) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_sample_softmax_impl(smpl, candidates);
|
||||||
|
|
||||||
|
const int64_t t_start_sample_us = ggml_time_us();
|
||||||
|
|
||||||
|
// Compute the cumulative probabilities
|
||||||
|
float cum_sum = 0.0f;
|
||||||
|
size_t last_idx = candidates->size;
|
||||||
|
|
||||||
|
for (size_t i = 0; i < candidates->size; ++i) {
|
||||||
|
cum_sum += candidates->data[i].p;
|
||||||
|
|
||||||
|
// Check if the running sum is at least p or if we have kept at least min_keep tokens
|
||||||
|
// we set the last index to i+1 to indicate that the current iterate should be included in the set
|
||||||
|
if (cum_sum >= p && i + 1 >= min_keep) {
|
||||||
|
last_idx = i + 1;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Resize the output vector to keep only the top-p tokens
|
||||||
|
candidates->size = last_idx;
|
||||||
|
|
||||||
|
if (smpl) {
|
||||||
|
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_sample_min_p_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
|
||||||
|
if (p <= 0.0f || !candidates->size) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int64_t t_start_sample_us = ggml_time_us();
|
||||||
|
|
||||||
|
bool min_p_applied = false;
|
||||||
|
|
||||||
|
// if the candidates aren't sorted, try the unsorted implementation first
|
||||||
|
if (!candidates->sorted) {
|
||||||
|
std::vector<llama_token_data> filtered_tokens;
|
||||||
|
|
||||||
|
float max_logit = -FLT_MAX;
|
||||||
|
for (size_t i = 0; i < candidates->size; ++i) {
|
||||||
|
max_logit = std::max(max_logit, candidates->data[i].logit);
|
||||||
|
}
|
||||||
|
const float min_logit = max_logit + logf(p); // min logit for p_i >= p * p_max
|
||||||
|
|
||||||
|
for (size_t i = 0; i < candidates->size; ++i) {
|
||||||
|
if (candidates->data[i].logit >= min_logit) {
|
||||||
|
filtered_tokens.push_back(candidates->data[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// if we have enough values the operation was a success
|
||||||
|
if (filtered_tokens.size() >= min_keep) {
|
||||||
|
memcpy(candidates->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data));
|
||||||
|
candidates->size = filtered_tokens.size();
|
||||||
|
min_p_applied = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// if the candidates are sorted or the unsorted implementation failed, use this implementation
|
||||||
|
if (!min_p_applied) {
|
||||||
|
// Sort the logits in descending order
|
||||||
|
if (!candidates->sorted) {
|
||||||
|
std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
|
||||||
|
return a.logit > b.logit;
|
||||||
|
});
|
||||||
|
candidates->sorted = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
const float min_logit = candidates->data[0].logit + logf(p); // min logit for p_i >= p * p_max
|
||||||
|
size_t i = 1; // first token always matches
|
||||||
|
|
||||||
|
for (; i < candidates->size; ++i) {
|
||||||
|
if (candidates->data[i].logit < min_logit && i >= min_keep) {
|
||||||
|
break; // prob too small
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Resize the output vector to keep only the matching tokens
|
||||||
|
candidates->size = i;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (smpl) {
|
||||||
|
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep) {
|
||||||
|
if (z >= 1.0f || candidates->size <= 2) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
|
||||||
|
const int64_t t_start_sample_us = ggml_time_us();
|
||||||
|
|
||||||
|
// Compute the first and second derivatives
|
||||||
|
std::vector<float> first_derivatives(candidates->size - 1);
|
||||||
|
std::vector<float> second_derivatives(candidates->size - 2);
|
||||||
|
|
||||||
|
for (size_t i = 0; i < first_derivatives.size(); ++i) {
|
||||||
|
first_derivatives[i] = candidates->data[i].p - candidates->data[i + 1].p;
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < second_derivatives.size(); ++i) {
|
||||||
|
second_derivatives[i] = first_derivatives[i] - first_derivatives[i + 1];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate absolute value of second derivatives
|
||||||
|
for (size_t i = 0; i < second_derivatives.size(); ++i) {
|
||||||
|
second_derivatives[i] = std::abs(second_derivatives[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Normalize the second derivatives
|
||||||
|
{
|
||||||
|
const float second_derivatives_sum = std::accumulate(second_derivatives.begin(), second_derivatives.end(), 0.0f);
|
||||||
|
|
||||||
|
if (second_derivatives_sum > 1e-6f) {
|
||||||
|
for (float & value : second_derivatives) {
|
||||||
|
value /= second_derivatives_sum;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (float & value : second_derivatives) {
|
||||||
|
value = 1.0f / second_derivatives.size();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
float cum_sum = 0.0f;
|
||||||
|
size_t last_idx = candidates->size;
|
||||||
|
for (size_t i = 0; i < second_derivatives.size(); ++i) {
|
||||||
|
cum_sum += second_derivatives[i];
|
||||||
|
|
||||||
|
// Check if the running sum is greater than z or if we have kept at least min_keep tokens
|
||||||
|
if (cum_sum > z && i >= min_keep) {
|
||||||
|
last_idx = i;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Resize the output vector to keep only the tokens above the tail location
|
||||||
|
candidates->size = last_idx;
|
||||||
|
|
||||||
|
if (smpl) {
|
||||||
|
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_sample_typical_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
|
||||||
|
// Reference implementation:
|
||||||
|
// https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr
|
||||||
|
if (p >= 1.0f) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute the softmax of logits and calculate entropy
|
||||||
|
llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
|
||||||
|
|
||||||
|
const int64_t t_start_sample_us = ggml_time_us();
|
||||||
|
|
||||||
|
float entropy = 0.0f;
|
||||||
|
for (size_t i = 0; i < candidates->size; ++i) {
|
||||||
|
entropy += -candidates->data[i].p * logf(candidates->data[i].p);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute the absolute difference between negative log probability and entropy for each candidate
|
||||||
|
std::vector<float> shifted_scores;
|
||||||
|
for (size_t i = 0; i < candidates->size; ++i) {
|
||||||
|
float shifted_score = fabsf(-logf(candidates->data[i].p) - entropy);
|
||||||
|
shifted_scores.push_back(shifted_score);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort tokens based on the shifted_scores and their corresponding indices
|
||||||
|
std::vector<size_t> indices(candidates->size);
|
||||||
|
std::iota(indices.begin(), indices.end(), 0);
|
||||||
|
|
||||||
|
std::sort(indices.begin(), indices.end(), [&](size_t a, size_t b) {
|
||||||
|
return shifted_scores[a] < shifted_scores[b];
|
||||||
|
});
|
||||||
|
|
||||||
|
// Compute the cumulative probabilities
|
||||||
|
float cum_sum = 0.0f;
|
||||||
|
size_t last_idx = indices.size();
|
||||||
|
|
||||||
|
for (size_t i = 0; i < indices.size(); ++i) {
|
||||||
|
size_t idx = indices[i];
|
||||||
|
cum_sum += candidates->data[idx].p;
|
||||||
|
|
||||||
|
// Check if the running sum is greater than typical or if we have kept at least min_keep tokens
|
||||||
|
if (cum_sum > p && i >= min_keep - 1) {
|
||||||
|
last_idx = i + 1;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Resize the output vector to keep only the locally typical tokens
|
||||||
|
std::vector<llama_token_data> new_candidates;
|
||||||
|
for (size_t i = 0; i < last_idx; ++i) {
|
||||||
|
size_t idx = indices[i];
|
||||||
|
new_candidates.push_back(candidates->data[idx]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Replace the data in candidates with the new_candidates data
|
||||||
|
std::copy(new_candidates.begin(), new_candidates.end(), candidates->data);
|
||||||
|
candidates->size = new_candidates.size();
|
||||||
|
candidates->sorted = false;
|
||||||
|
|
||||||
|
if (smpl) {
|
||||||
|
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_sample_entropy_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val) {
|
||||||
|
const int64_t t_start_sample_us = ggml_time_us();
|
||||||
|
|
||||||
|
// no need to do anything if there is only one (or zero) candidates
|
||||||
|
if(candidates->size <= 1) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate maximum possible entropy
|
||||||
|
float max_entropy = -logf(1.0f / candidates->size);
|
||||||
|
|
||||||
|
llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
|
||||||
|
|
||||||
|
// Calculate entropy of the softmax probabilities
|
||||||
|
float entropy = 0.0f;
|
||||||
|
for (size_t i = 0; i < candidates->size; ++i) {
|
||||||
|
float prob = candidates->data[i].p;
|
||||||
|
if (prob > 0.0f) { // Ensure no log(0)
|
||||||
|
entropy -= prob * logf(prob);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Normalize the entropy (max_entropy cannot be 0 here because we checked candidates->size != 1 above)
|
||||||
|
float normalized_entropy = entropy / max_entropy;
|
||||||
|
|
||||||
|
// Map the normalized entropy to the desired temperature range using the power function
|
||||||
|
float dyn_temp = min_temp + (max_temp - min_temp) * powf(normalized_entropy, exponent_val);
|
||||||
|
|
||||||
|
#ifdef DEBUG
|
||||||
|
LLAMA_LOG_INFO("Your text maxtemp value is: %f\n", max_temp);
|
||||||
|
LLAMA_LOG_INFO("Entropy: %f\n", entropy);
|
||||||
|
LLAMA_LOG_INFO("Max Possible Entropy: %f\n", max_entropy);
|
||||||
|
LLAMA_LOG_INFO("Normalized Entropy: %f\n", normalized_entropy);
|
||||||
|
LLAMA_LOG_INFO("Exponent: %f\n", exponent_val);
|
||||||
|
LLAMA_LOG_INFO("Dynamic Temperature (dyn_temp): %f\n", dyn_temp);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// Apply the dynamically calculated temperature scaling
|
||||||
|
for (size_t i = 0; i < candidates->size; ++i) {
|
||||||
|
candidates->data[i].logit /= dyn_temp;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Re-compute softmax probabilities after scaling logits with dynamic temperature
|
||||||
|
double max_l_double = candidates->data[0].logit;
|
||||||
|
double cum_sum_double = 0.0;
|
||||||
|
for (size_t i = 0; i < candidates->size; ++i) {
|
||||||
|
double p = exp(candidates->data[i].logit - max_l_double);
|
||||||
|
candidates->data[i].p = p; // Store the scaled probability
|
||||||
|
cum_sum_double += p;
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < candidates->size; ++i) {
|
||||||
|
candidates->data[i].p /= cum_sum_double; // Re-normalize the probabilities
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifdef DEBUG
|
||||||
|
// Print the updated top 25 probabilities after temperature scaling
|
||||||
|
LLAMA_LOG_INFO("\nUpdated Top 25 Probabilities After Dynamic Temperature Scaling (in percentages):\n");
|
||||||
|
for (size_t i = 0; i < 25 && i < candidates->size; ++i) {
|
||||||
|
LLAMA_LOG_INFO("Token %zu: %f%%\n", i + 1, candidates->data[i].p * 100.0f);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
if (smpl) {
|
||||||
|
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_sample_temp_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float temp) {
|
||||||
|
const int64_t t_start_sample_us = ggml_time_us();
|
||||||
|
|
||||||
|
for (size_t i = 0; i < candidates->size; ++i) {
|
||||||
|
candidates->data[i].logit /= temp;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (smpl) {
|
||||||
|
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_sample_repetition_penalties_impl(
|
||||||
|
struct llama_sampling * smpl,
|
||||||
|
llama_token_data_array * candidates,
|
||||||
|
const llama_token * last_tokens,
|
||||||
|
size_t penalty_last_n,
|
||||||
|
float penalty_repeat,
|
||||||
|
float penalty_freq,
|
||||||
|
float penalty_present) {
|
||||||
|
if (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int64_t t_start_sample_us = ggml_time_us();
|
||||||
|
|
||||||
|
// Create a frequency map to count occurrences of each token in last_tokens
|
||||||
|
std::unordered_map<llama_token, int> token_count;
|
||||||
|
for (size_t i = 0; i < penalty_last_n; ++i) {
|
||||||
|
token_count[last_tokens[i]]++;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply frequency and presence penalties to the candidates
|
||||||
|
for (size_t i = 0; i < candidates->size; ++i) {
|
||||||
|
const auto token_iter = token_count.find(candidates->data[i].id);
|
||||||
|
if (token_iter == token_count.end()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int count = token_iter->second;
|
||||||
|
|
||||||
|
// The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
|
||||||
|
// This is common fix for this problem, which is to multiply by the penalty instead of dividing.
|
||||||
|
if (candidates->data[i].logit <= 0) {
|
||||||
|
candidates->data[i].logit *= penalty_repeat;
|
||||||
|
} else {
|
||||||
|
candidates->data[i].logit /= penalty_repeat;
|
||||||
|
}
|
||||||
|
|
||||||
|
candidates->data[i].logit -= float(count) * penalty_freq + float(count > 0) * penalty_present;
|
||||||
|
}
|
||||||
|
|
||||||
|
candidates->sorted = false;
|
||||||
|
|
||||||
|
if (smpl) {
|
||||||
|
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_sample_apply_guidance_impl(
|
||||||
|
struct llama_sampling * smpl,
|
||||||
|
float * logits,
|
||||||
|
float * logits_guidance,
|
||||||
|
float scale) {
|
||||||
|
GGML_ASSERT(smpl);
|
||||||
|
|
||||||
|
const auto t_start_sample_us = ggml_time_us();
|
||||||
|
const auto n_vocab = smpl->n_vocab;
|
||||||
|
|
||||||
|
llama_log_softmax(logits, n_vocab);
|
||||||
|
llama_log_softmax(logits_guidance, n_vocab);
|
||||||
|
|
||||||
|
for (int i = 0; i < n_vocab; ++i) {
|
||||||
|
auto & l = logits[i];
|
||||||
|
const auto & g = logits_guidance[i];
|
||||||
|
|
||||||
|
l = scale * (l - g) + g;
|
||||||
|
}
|
||||||
|
|
||||||
|
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_token llama_sample_token_mirostat_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) {
|
||||||
|
GGML_ASSERT(smpl);
|
||||||
|
|
||||||
|
const int32_t n_vocab = float(smpl->n_vocab);
|
||||||
|
|
||||||
|
int64_t t_start_sample_us = ggml_time_us();
|
||||||
|
|
||||||
|
llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
|
||||||
|
|
||||||
|
// Estimate s_hat using the most probable m tokens
|
||||||
|
float s_hat = 0.0;
|
||||||
|
float sum_ti_bi = 0.0;
|
||||||
|
float sum_ti_sq = 0.0;
|
||||||
|
for (size_t i = 0; i < size_t(m - 1) && i < candidates->size - 1; ++i) {
|
||||||
|
float t_i = logf(float(i + 2) / float(i + 1));
|
||||||
|
float b_i = logf(candidates->data[i].p / candidates->data[i + 1].p);
|
||||||
|
sum_ti_bi += t_i * b_i;
|
||||||
|
sum_ti_sq += t_i * t_i;
|
||||||
|
}
|
||||||
|
s_hat = sum_ti_bi / sum_ti_sq;
|
||||||
|
|
||||||
|
// Compute k from the estimated s_hat and target surprise value
|
||||||
|
float epsilon_hat = s_hat - 1;
|
||||||
|
float k = powf((epsilon_hat * powf(2, *mu)) / (1 - powf(n_vocab, -epsilon_hat)), 1 / s_hat);
|
||||||
|
|
||||||
|
// Sample the next word X using top-k sampling
|
||||||
|
llama_sample_top_k_impl((struct llama_sampling *) nullptr, candidates, int(k), 1);
|
||||||
|
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||||
|
llama_token X = llama_sample_token_impl(smpl, candidates);
|
||||||
|
t_start_sample_us = ggml_time_us();
|
||||||
|
|
||||||
|
// Compute error as the difference between observed surprise and target surprise value
|
||||||
|
size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
|
||||||
|
return candidate.id == X;
|
||||||
|
}));
|
||||||
|
float observed_surprise = -log2f(candidates->data[X_idx].p);
|
||||||
|
float e = observed_surprise - tau;
|
||||||
|
|
||||||
|
// Update mu using the learning rate and error
|
||||||
|
*mu = *mu - eta * e;
|
||||||
|
|
||||||
|
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||||
|
return X;
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, float * mu) {
|
||||||
|
int64_t t_start_sample_us;
|
||||||
|
t_start_sample_us = ggml_time_us();
|
||||||
|
|
||||||
|
llama_sample_softmax_impl(smpl, candidates);
|
||||||
|
|
||||||
|
// Truncate the words with surprise values greater than mu
|
||||||
|
candidates->size = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
|
||||||
|
return -log2f(candidate.p) > *mu;
|
||||||
|
}));
|
||||||
|
|
||||||
|
if (candidates->size == 0) {
|
||||||
|
candidates->size = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (smpl) {
|
||||||
|
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Normalize the probabilities of the remaining words
|
||||||
|
llama_sample_softmax_impl(smpl, candidates);
|
||||||
|
|
||||||
|
// Sample the next word X from the remaining words
|
||||||
|
llama_token X = llama_sample_token_impl(smpl, candidates);
|
||||||
|
t_start_sample_us = ggml_time_us();
|
||||||
|
|
||||||
|
// Compute error as the difference between observed surprise and target surprise value
|
||||||
|
size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
|
||||||
|
return candidate.id == X;
|
||||||
|
}));
|
||||||
|
float observed_surprise = -log2f(candidates->data[X_idx].p);
|
||||||
|
float e = observed_surprise - tau;
|
||||||
|
|
||||||
|
// Update mu using the learning rate and error
|
||||||
|
*mu = *mu - eta * e;
|
||||||
|
|
||||||
|
if (smpl) {
|
||||||
|
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||||
|
}
|
||||||
|
return X;
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_token llama_sample_token_greedy_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) {
|
||||||
|
const int64_t t_start_sample_us = ggml_time_us();
|
||||||
|
|
||||||
|
// Find max element
|
||||||
|
auto * max_iter = std::max_element(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
|
||||||
|
return a.logit < b.logit;
|
||||||
|
});
|
||||||
|
|
||||||
|
llama_token result = max_iter->id;
|
||||||
|
if (smpl) {
|
||||||
|
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||||
|
smpl->n_sample++;
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_token llama_sample_token_with_rng_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng) {
|
||||||
|
GGML_ASSERT(smpl);
|
||||||
|
|
||||||
|
const int64_t t_start_sample_us = ggml_time_us();
|
||||||
|
llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
|
||||||
|
|
||||||
|
std::vector<float> probs;
|
||||||
|
probs.reserve(candidates->size);
|
||||||
|
for (size_t i = 0; i < candidates->size; ++i) {
|
||||||
|
probs.push_back(candidates->data[i].p);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::discrete_distribution<> dist(probs.begin(), probs.end());
|
||||||
|
int idx = dist(rng);
|
||||||
|
|
||||||
|
llama_token result = candidates->data[idx].id;
|
||||||
|
|
||||||
|
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||||
|
smpl->n_sample++;
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_token llama_sample_token_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) {
|
||||||
|
return llama_sample_token_with_rng_impl(smpl, candidates, smpl->rng);
|
||||||
|
}
|
56
src/llama-sampling.h
Normal file
56
src/llama-sampling.h
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "llama-impl.h"
|
||||||
|
|
||||||
|
struct llama_sampling {
|
||||||
|
llama_sampling(int32_t n_vocab) : n_vocab(n_vocab) {}
|
||||||
|
|
||||||
|
std::mt19937 rng;
|
||||||
|
|
||||||
|
int32_t n_vocab = 0;
|
||||||
|
|
||||||
|
mutable int64_t t_sample_us = 0;
|
||||||
|
mutable int32_t n_sample = 0;
|
||||||
|
|
||||||
|
void reset_timings() const {
|
||||||
|
t_sample_us = 0;
|
||||||
|
n_sample = 0;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
//
|
||||||
|
// internal API
|
||||||
|
//
|
||||||
|
|
||||||
|
void llama_set_rng_seed_impl(struct llama_sampling * smpl, uint32_t seed);
|
||||||
|
|
||||||
|
void llama_sample_softmax_impl (struct llama_sampling * smpl, llama_token_data_array * candidates);
|
||||||
|
void llama_sample_top_k_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep);
|
||||||
|
void llama_sample_top_p_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep);
|
||||||
|
void llama_sample_min_p_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep);
|
||||||
|
void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep);
|
||||||
|
void llama_sample_typical_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep);
|
||||||
|
void llama_sample_entropy_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val);
|
||||||
|
void llama_sample_temp_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float temp);
|
||||||
|
|
||||||
|
void llama_sample_repetition_penalties_impl(
|
||||||
|
struct llama_sampling * smpl,
|
||||||
|
llama_token_data_array * candidates,
|
||||||
|
const llama_token * last_tokens,
|
||||||
|
size_t penalty_last_n,
|
||||||
|
float penalty_repeat,
|
||||||
|
float penalty_freq,
|
||||||
|
float penalty_present);
|
||||||
|
|
||||||
|
void llama_sample_apply_guidance_impl(
|
||||||
|
struct llama_sampling * smpl,
|
||||||
|
float * logits,
|
||||||
|
float * logits_guidance,
|
||||||
|
float scale);
|
||||||
|
|
||||||
|
llama_token llama_sample_token_mirostat_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu);
|
||||||
|
llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, float * mu);
|
||||||
|
llama_token llama_sample_token_greedy_impl (struct llama_sampling * smpl, llama_token_data_array * candidates);
|
||||||
|
llama_token llama_sample_token_with_rng_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng);
|
||||||
|
llama_token llama_sample_token_impl (struct llama_sampling * smpl, llama_token_data_array * candidates);
|
||||||
|
|
1721
src/llama-vocab.cpp
Normal file
1721
src/llama-vocab.cpp
Normal file
File diff suppressed because it is too large
Load Diff
130
src/llama-vocab.h
Normal file
130
src/llama-vocab.h
Normal file
@ -0,0 +1,130 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "llama-impl.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <map>
|
||||||
|
|
||||||
|
struct llama_vocab {
|
||||||
|
using id = llama_token;
|
||||||
|
using token = std::string;
|
||||||
|
using tattr = llama_token_attr;
|
||||||
|
|
||||||
|
struct token_data {
|
||||||
|
token text;
|
||||||
|
float score;
|
||||||
|
tattr attr;
|
||||||
|
};
|
||||||
|
|
||||||
|
enum llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM;
|
||||||
|
enum llama_vocab_pre_type type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
|
||||||
|
|
||||||
|
int max_token_len = 0; // used for optimizing longest token search
|
||||||
|
|
||||||
|
std::unordered_map<token, id> token_to_id;
|
||||||
|
std::vector<token_data> id_to_token;
|
||||||
|
|
||||||
|
std::vector<id> cache_special_tokens;
|
||||||
|
std::vector<token> cache_token_to_piece; // llama_token_to_piece(special = true);
|
||||||
|
|
||||||
|
std::map<std::pair<std::string, std::string>, int> bpe_ranks;
|
||||||
|
|
||||||
|
// default LLaMA special tokens
|
||||||
|
id special_bos_id = 1;
|
||||||
|
id special_eos_id = 2;
|
||||||
|
id special_unk_id = 0;
|
||||||
|
id special_sep_id = -1;
|
||||||
|
id special_pad_id = -1;
|
||||||
|
id special_cls_id = -1;
|
||||||
|
id special_mask_id = -1;
|
||||||
|
|
||||||
|
id linefeed_id = 13;
|
||||||
|
id special_prefix_id = -1;
|
||||||
|
id special_suffix_id = -1;
|
||||||
|
id special_middle_id = -1;
|
||||||
|
id special_eot_id = -1; // TODO: move above after "eos_id", and here add "file separator" token
|
||||||
|
|
||||||
|
// tokenizer flags
|
||||||
|
bool tokenizer_add_space_prefix = false;
|
||||||
|
bool tokenizer_add_bos = false;
|
||||||
|
bool tokenizer_add_eos = false;
|
||||||
|
bool tokenizer_ignore_merges = false;
|
||||||
|
bool tokenizer_clean_spaces = false; // clean_up_tokenization_spaces
|
||||||
|
bool tokenizer_remove_extra_whitespaces = false;
|
||||||
|
bool tokenizer_escape_whitespaces = true;
|
||||||
|
bool tokenizer_treat_whitespace_as_suffix = false;
|
||||||
|
|
||||||
|
std::vector<char> precompiled_charsmap;
|
||||||
|
|
||||||
|
int find_bpe_rank(const std::string & token_left, const std::string & token_right) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
const struct llama_vocab * llama_get_vocab(const struct llama_context * ctx);
|
||||||
|
|
||||||
|
//
|
||||||
|
// internal API
|
||||||
|
//
|
||||||
|
|
||||||
|
// TODO: rename to llama_tokenize_impl
|
||||||
|
// TODO: This should probably be in llama.h
|
||||||
|
std::vector<llama_vocab::id> llama_tokenize_internal(
|
||||||
|
const llama_vocab & vocab,
|
||||||
|
std::string raw_text,
|
||||||
|
bool add_special,
|
||||||
|
bool parse_special = false);
|
||||||
|
|
||||||
|
llama_token llama_byte_to_token_impl(const llama_vocab & vocab, uint8_t ch);
|
||||||
|
|
||||||
|
const char * llama_token_get_text_impl(const struct llama_vocab & vocab, llama_token token);
|
||||||
|
|
||||||
|
float llama_token_get_score_impl(const struct llama_vocab & vocab, llama_token token);
|
||||||
|
|
||||||
|
llama_token_attr llama_token_get_attr_impl(const struct llama_vocab & vocab, llama_token token);
|
||||||
|
|
||||||
|
bool llama_token_is_eog_impl(const struct llama_vocab & vocab, llama_token token);
|
||||||
|
|
||||||
|
bool llama_token_is_control_impl(const struct llama_vocab & vocab, llama_token token);
|
||||||
|
|
||||||
|
llama_token llama_token_bos_impl(const struct llama_vocab & vocab);
|
||||||
|
llama_token llama_token_eos_impl(const struct llama_vocab & vocab);
|
||||||
|
llama_token llama_token_cls_impl(const struct llama_vocab & vocab);
|
||||||
|
llama_token llama_token_sep_impl(const struct llama_vocab & vocab);
|
||||||
|
llama_token llama_token_nl_impl (const struct llama_vocab & vocab);
|
||||||
|
llama_token llama_token_pad_impl(const struct llama_vocab & vocab);
|
||||||
|
|
||||||
|
int32_t llama_add_bos_token_impl(const struct llama_vocab & vocab);
|
||||||
|
int32_t llama_add_eos_token_impl(const struct llama_vocab & vocab);
|
||||||
|
|
||||||
|
llama_token llama_token_prefix_impl(const struct llama_vocab & vocab);
|
||||||
|
llama_token llama_token_middle_impl(const struct llama_vocab & vocab);
|
||||||
|
llama_token llama_token_suffix_impl(const struct llama_vocab & vocab);
|
||||||
|
llama_token llama_token_eot_impl (const struct llama_vocab & vocab);
|
||||||
|
|
||||||
|
int32_t llama_tokenize_impl(
|
||||||
|
const struct llama_vocab & vocab,
|
||||||
|
const char * text,
|
||||||
|
int32_t text_len,
|
||||||
|
llama_token * tokens,
|
||||||
|
int32_t n_tokens_max,
|
||||||
|
bool add_special,
|
||||||
|
bool parse_special);
|
||||||
|
|
||||||
|
// does not write null-terminator to buf
|
||||||
|
int32_t llama_token_to_piece_impl(
|
||||||
|
const struct llama_vocab & vocab,
|
||||||
|
llama_token token,
|
||||||
|
char * buf,
|
||||||
|
int32_t length,
|
||||||
|
int32_t lstrip,
|
||||||
|
bool special);
|
||||||
|
|
||||||
|
int32_t llama_detokenize_impl(
|
||||||
|
const struct llama_vocab & vocab,
|
||||||
|
const llama_token * tokens,
|
||||||
|
int32_t n_tokens,
|
||||||
|
char * text,
|
||||||
|
int32_t text_len_max,
|
||||||
|
bool remove_special,
|
||||||
|
bool unparse_special);
|
3136
src/llama.cpp
3136
src/llama.cpp
File diff suppressed because it is too large
Load Diff
@ -19,6 +19,12 @@
|
|||||||
#include <locale>
|
#include <locale>
|
||||||
#include <codecvt>
|
#include <codecvt>
|
||||||
|
|
||||||
|
size_t unicode_len_utf8(char src) {
|
||||||
|
const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
|
||||||
|
uint8_t highbits = static_cast<uint8_t>(src) >> 4;
|
||||||
|
return lookup[highbits];
|
||||||
|
}
|
||||||
|
|
||||||
static std::string unicode_cpts_to_utf8(const std::vector<uint32_t> & cps) {
|
static std::string unicode_cpts_to_utf8(const std::vector<uint32_t> & cps) {
|
||||||
std::string result;
|
std::string result;
|
||||||
for (size_t i = 0; i < cps.size(); ++i) {
|
for (size_t i = 0; i < cps.size(); ++i) {
|
||||||
|
@ -4,6 +4,8 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
// TODO: prefix all symbols with "llama_"
|
||||||
|
|
||||||
struct codepoint_flags {
|
struct codepoint_flags {
|
||||||
enum {
|
enum {
|
||||||
UNDEFINED = 0x0001,
|
UNDEFINED = 0x0001,
|
||||||
@ -46,6 +48,7 @@ struct codepoint_flags {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
size_t unicode_len_utf8(char src);
|
||||||
|
|
||||||
std::string unicode_cpt_to_utf8(uint32_t cp);
|
std::string unicode_cpt_to_utf8(uint32_t cp);
|
||||||
uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset);
|
uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset);
|
||||||
|
@ -44,21 +44,26 @@ static bool test_build_grammar_fails(const std::string & grammar_str) {
|
|||||||
return grammar_fails;
|
return grammar_fails;
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool match_string(const std::string & input, llama_grammar* grammar) {
|
static bool match_string(const std::string & input, llama_grammar * grammar) {
|
||||||
auto decoded = decode_utf8(input, {});
|
auto decoded = decode_utf8(input, {});
|
||||||
|
|
||||||
const auto & code_points = decoded.first;
|
const auto & code_points = decoded.first;
|
||||||
|
|
||||||
|
const llama_grammar_rules & rules = llama_grammar_get_rules (grammar);
|
||||||
|
llama_grammar_stacks & cur_stacks = llama_grammar_get_stacks(grammar);
|
||||||
|
|
||||||
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
|
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
|
||||||
auto prev_stacks = grammar->stacks;
|
const llama_grammar_stacks prev_stacks = llama_grammar_get_stacks(grammar); // copy
|
||||||
llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
|
|
||||||
if (grammar->stacks.empty()) {
|
llama_grammar_accept(rules, prev_stacks, *it, cur_stacks);
|
||||||
|
|
||||||
|
if (cur_stacks.empty()) {
|
||||||
// no stacks means that the grammar failed to match at this point
|
// no stacks means that the grammar failed to match at this point
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (const auto & stack : grammar->stacks) {
|
for (const auto & stack : cur_stacks) {
|
||||||
if (stack.empty()) {
|
if (stack.empty()) {
|
||||||
// An empty stack means that the grammar has been completed
|
// An empty stack means that the grammar has been completed
|
||||||
return true;
|
return true;
|
||||||
@ -75,7 +80,9 @@ static void test(const std::string & test_desc, const std::string & grammar_str,
|
|||||||
auto grammar = build_grammar(grammar_str);
|
auto grammar = build_grammar(grammar_str);
|
||||||
|
|
||||||
// Save the original grammar stacks so that we can reset after every new string we want to test
|
// Save the original grammar stacks so that we can reset after every new string we want to test
|
||||||
auto original_stacks = grammar->stacks;
|
const llama_grammar_stacks original_stacks = llama_grammar_get_stacks(grammar);
|
||||||
|
|
||||||
|
llama_grammar_stacks & cur_stacks = llama_grammar_get_stacks(grammar);
|
||||||
|
|
||||||
fprintf(stderr, " 🔵 Valid strings:\n");
|
fprintf(stderr, " 🔵 Valid strings:\n");
|
||||||
|
|
||||||
@ -112,7 +119,7 @@ static void test(const std::string & test_desc, const std::string & grammar_str,
|
|||||||
assert(matched);
|
assert(matched);
|
||||||
|
|
||||||
// Reset the grammar stacks
|
// Reset the grammar stacks
|
||||||
grammar->stacks = original_stacks;
|
cur_stacks = original_stacks;
|
||||||
}
|
}
|
||||||
|
|
||||||
fprintf(stderr, " 🟠 Invalid strings:\n");
|
fprintf(stderr, " 🟠 Invalid strings:\n");
|
||||||
@ -132,7 +139,7 @@ static void test(const std::string & test_desc, const std::string & grammar_str,
|
|||||||
assert(!matched);
|
assert(!matched);
|
||||||
|
|
||||||
// Reset the grammar stacks
|
// Reset the grammar stacks
|
||||||
grammar->stacks = original_stacks;
|
cur_stacks = original_stacks;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Clean up allocated memory
|
// Clean up allocated memory
|
||||||
|
@ -2,10 +2,12 @@
|
|||||||
#undef NDEBUG
|
#undef NDEBUG
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#include "llama.cpp" // TODO: not great
|
#define LLAMA_API_INTERNAL
|
||||||
|
#include "llama.h"
|
||||||
#include "grammar-parser.h"
|
#include "grammar-parser.h"
|
||||||
|
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
|
#include <stdexcept>
|
||||||
|
|
||||||
int main()
|
int main()
|
||||||
{
|
{
|
||||||
@ -112,10 +114,10 @@ int main()
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_grammar *grammar = NULL;
|
llama_grammar * grammar = NULL;
|
||||||
std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
|
std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
|
||||||
grammar = llama_grammar_init(
|
|
||||||
grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
|
grammar = llama_grammar_init(grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
|
||||||
if (grammar == nullptr)
|
if (grammar == nullptr)
|
||||||
{
|
{
|
||||||
throw std::runtime_error("Failed to initialize llama_grammar");
|
throw std::runtime_error("Failed to initialize llama_grammar");
|
||||||
@ -172,7 +174,7 @@ int main()
|
|||||||
}};
|
}};
|
||||||
|
|
||||||
auto index = 0;
|
auto index = 0;
|
||||||
for (auto stack : grammar->stacks)
|
for (auto stack : llama_grammar_get_stacks(grammar))
|
||||||
{
|
{
|
||||||
// compare stack to expected_stack
|
// compare stack to expected_stack
|
||||||
for (uint32_t i = 0; i < stack.size(); i++)
|
for (uint32_t i = 0; i < stack.size(); i++)
|
||||||
@ -374,13 +376,13 @@ int main()
|
|||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
std::vector<llama_grammar_candidate> rejects = llama_grammar_reject_candidates_for_stack(grammar->rules, grammar->stacks[0], next_candidates);
|
std::vector<llama_grammar_candidate> rejects = llama_grammar_reject_candidates_for_stack(llama_grammar_get_rules(grammar), llama_grammar_get_stacks(grammar)[0], next_candidates);
|
||||||
|
|
||||||
std::vector<std::vector<llama_grammar_candidate>> all_rejects;
|
std::vector<std::vector<llama_grammar_candidate>> all_rejects;
|
||||||
|
|
||||||
for (std::size_t count = 0; count < grammar->stacks.size(); ++count)
|
for (std::size_t count = 0; count < llama_grammar_get_stacks(grammar).size(); ++count)
|
||||||
{
|
{
|
||||||
rejects = llama_grammar_reject_candidates_for_stack(grammar->rules, grammar->stacks[count], next_candidates);
|
rejects = llama_grammar_reject_candidates_for_stack(llama_grammar_get_rules(grammar), llama_grammar_get_stacks(grammar)[count], next_candidates);
|
||||||
all_rejects.push_back(rejects);
|
all_rejects.push_back(rejects);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -401,6 +403,6 @@ int main()
|
|||||||
delete[] candidate.code_points;
|
delete[] candidate.code_points;
|
||||||
candidate.code_points = nullptr;
|
candidate.code_points = nullptr;
|
||||||
}
|
}
|
||||||
delete grammar;
|
llama_grammar_free(grammar);
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user