Fix memory bug in grammar parser (#7194)

The llama.cpp grammar parser had a bug where forgetting to add a closing
quotation mark to strings would cause parsing to crash. Anyone running a
server on a public endpoint is advised to upgrade. To reproduce this bug

    ./llamafile -m foo.gguf -p bar --grammar 'root::="'

Credit for discovering and reporting this issue goes to Eclypsium
Security Researcher Richard Johnson <Richard.johnson@eclypsium.com>.
This commit is contained in:
Justine Tunney 2024-05-10 07:01:08 -04:00 committed by GitHub
parent f89fe2732c
commit 4e3880978f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 21 additions and 5 deletions

View File

@ -1371,15 +1371,13 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) { if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) {
std::replace(arg.begin(), arg.end(), '_', '-'); std::replace(arg.begin(), arg.end(), '_', '-');
} }
if (!gpt_params_find_arg(argc, argv, arg, params, i, invalid_param)) { if (!gpt_params_find_arg(argc, argv, arg, params, i, invalid_param)) {
throw std::invalid_argument("error: unknown argument: " + arg); throw std::invalid_argument("error: unknown argument: " + arg);
} }
}
if (invalid_param) { if (invalid_param) {
throw std::invalid_argument("error: invalid parameter for argument: " + arg); throw std::invalid_argument("error: invalid parameter for argument: " + arg);
} }
}
if (params.prompt_cache_all && if (params.prompt_cache_all &&
(params.interactive || params.interactive_first || (params.interactive || params.interactive_first ||

View File

@ -142,6 +142,9 @@ namespace grammar_parser {
pos++; pos++;
last_sym_start = out_elements.size(); last_sym_start = out_elements.size();
while (*pos != '"') { while (*pos != '"') {
if (!*pos) {
throw std::runtime_error("unexpected end of input");
}
auto char_pair = parse_char(pos); auto char_pair = parse_char(pos);
pos = char_pair.second; pos = char_pair.second;
out_elements.push_back({LLAMA_GRETYPE_CHAR, char_pair.first}); out_elements.push_back({LLAMA_GRETYPE_CHAR, char_pair.first});
@ -156,6 +159,9 @@ namespace grammar_parser {
} }
last_sym_start = out_elements.size(); last_sym_start = out_elements.size();
while (*pos != ']') { while (*pos != ']') {
if (!*pos) {
throw std::runtime_error("unexpected end of input");
}
auto char_pair = parse_char(pos); auto char_pair = parse_char(pos);
pos = char_pair.second; pos = char_pair.second;
enum llama_gretype type = last_sym_start < out_elements.size() enum llama_gretype type = last_sym_start < out_elements.size()
@ -164,6 +170,9 @@ namespace grammar_parser {
out_elements.push_back({type, char_pair.first}); out_elements.push_back({type, char_pair.first});
if (pos[0] == '-' && pos[1] != ']') { if (pos[0] == '-' && pos[1] != ']') {
if (!pos[1]) {
throw std::runtime_error("unexpected end of input");
}
auto endchar_pair = parse_char(pos + 1); auto endchar_pair = parse_char(pos + 1);
pos = endchar_pair.second; pos = endchar_pair.second;
out_elements.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first}); out_elements.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first});

View File

@ -189,6 +189,11 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_
LOG_TEE("\n"); LOG_TEE("\n");
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params->sparams); struct llama_sampling_context * ctx_sampling = llama_sampling_init(params->sparams);
if (!ctx_sampling) {
fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__);
exit(1);
}
std::string response = ""; std::string response = "";
for (int i = 0; i < max_tgt_len; i++) { for (int i = 0; i < max_tgt_len; i++) {
const char * tmp = sample(ctx_sampling, ctx_llava->ctx_llama, &n_past); const char * tmp = sample(ctx_sampling, ctx_llava->ctx_llama, &n_past);

View File

@ -523,6 +523,10 @@ int main(int argc, char ** argv) {
} }
struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams); struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams);
if (!ctx_sampling) {
fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__);
exit(1);
}
while ((n_remain != 0 && !is_antiprompt) || params.interactive) { while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
// predict // predict