main : Add ChatML functionality to main example (#4046)

Co-authored-by: Sebastian Cramond <sebby37@users.noreply.github.com>
This commit is contained in:
Seb C 2023-11-21 00:26:59 +10:30 committed by GitHub
parent f23c0359a3
commit 881800d1f0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 42 additions and 5 deletions

View File

@ -491,6 +491,8 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
params.interactive_first = true; params.interactive_first = true;
} else if (arg == "-ins" || arg == "--instruct") { } else if (arg == "-ins" || arg == "--instruct") {
params.instruct = true; params.instruct = true;
} else if (arg == "-cml" || arg == "--chatml") {
params.chatml = true;
} else if (arg == "--infill") { } else if (arg == "--infill") {
params.infill = true; params.infill = true;
} else if (arg == "--multiline-input") { } else if (arg == "--multiline-input") {
@ -730,6 +732,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" -i, --interactive run in interactive mode\n"); printf(" -i, --interactive run in interactive mode\n");
printf(" --interactive-first run in interactive mode and wait for input right away\n"); printf(" --interactive-first run in interactive mode and wait for input right away\n");
printf(" -ins, --instruct run in instruction mode (use with Alpaca models)\n"); printf(" -ins, --instruct run in instruction mode (use with Alpaca models)\n");
printf(" -cml, --chatml run in chatml mode (use with ChatML-compatible models)\n");
printf(" --multiline-input allows you to write or paste multiple lines without ending each in '\\'\n"); printf(" --multiline-input allows you to write or paste multiple lines without ending each in '\\'\n");
printf(" -r PROMPT, --reverse-prompt PROMPT\n"); printf(" -r PROMPT, --reverse-prompt PROMPT\n");
printf(" halt generation at PROMPT, return control in interactive mode\n"); printf(" halt generation at PROMPT, return control in interactive mode\n");

View File

@ -102,6 +102,7 @@ struct gpt_params {
bool random_prompt = false; // do not randomize prompt if none provided bool random_prompt = false; // do not randomize prompt if none provided
bool use_color = false; // use color to distinguish generations and inputs bool use_color = false; // use color to distinguish generations and inputs
bool interactive = false; // interactive mode bool interactive = false; // interactive mode
bool chatml = false; // chatml mode (used for models trained on chatml syntax)
bool prompt_cache_all = false; // save user input and generations to prompt cache bool prompt_cache_all = false; // save user input and generations to prompt cache
bool prompt_cache_ro = false; // open the prompt cache read-only and do not update it bool prompt_cache_ro = false; // open the prompt cache read-only and do not update it

View File

@ -146,6 +146,13 @@ int main(int argc, char ** argv) {
return 0; return 0;
} }
if (params.chatml) {
printf("\n************\n");
printf("%s: please use the 'main' tool for chatml mode\n", __func__);
printf("************\n\n");
return 0;
}
if (!params.antiprompt.empty()) { if (!params.antiprompt.empty()) {
printf("\n************\n"); printf("\n************\n");
printf("%s: please use the 'main' tool for antiprompt mode\n", __func__); printf("%s: please use the 'main' tool for antiprompt mode\n", __func__);

View File

@ -234,8 +234,11 @@ int main(int argc, char ** argv) {
std::vector<llama_token> embd_inp; std::vector<llama_token> embd_inp;
if (params.interactive_first || params.instruct || !params.prompt.empty() || session_tokens.empty()) { if (params.interactive_first || params.instruct || params.chatml || !params.prompt.empty() || session_tokens.empty()) {
LOG("tokenize the prompt\n"); LOG("tokenize the prompt\n");
if (params.chatml) {
params.prompt = "<|im_start|>system\n" + params.prompt + "<|im_end|>";
}
embd_inp = ::llama_tokenize(ctx, params.prompt, add_bos, true); embd_inp = ::llama_tokenize(ctx, params.prompt, add_bos, true);
} else { } else {
LOG("use session tokens\n"); LOG("use session tokens\n");
@ -313,7 +316,7 @@ int main(int argc, char ** argv) {
} }
// number of tokens to keep when resetting context // number of tokens to keep when resetting context
if (params.n_keep < 0 || params.n_keep > (int) embd_inp.size() || params.instruct) { if (params.n_keep < 0 || params.n_keep > (int) embd_inp.size() || params.instruct || params.chatml) {
params.n_keep = (int)embd_inp.size(); params.n_keep = (int)embd_inp.size();
} }
@ -324,11 +327,23 @@ int main(int argc, char ** argv) {
LOG("inp_pfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_pfx).c_str()); LOG("inp_pfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_pfx).c_str());
LOG("inp_sfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_sfx).c_str()); LOG("inp_sfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_sfx).c_str());
// chatml prefix & suffix
const auto cml_pfx = ::llama_tokenize(ctx, "\n<|im_start|>user\n", add_bos, true);
const auto cml_sfx = ::llama_tokenize(ctx, "<|im_end|>\n<|im_start|>assistant\n", false, true);
LOG("cml_pfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, cml_pfx).c_str());
LOG("cml_sfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, cml_sfx).c_str());
// in instruct mode, we inject a prefix and a suffix to each input by the user // in instruct mode, we inject a prefix and a suffix to each input by the user
if (params.instruct) { if (params.instruct) {
params.interactive_first = true; params.interactive_first = true;
params.antiprompt.push_back("### Instruction:\n\n"); params.antiprompt.push_back("### Instruction:\n\n");
} }
// similar for chatml mode
else if (params.chatml) {
params.interactive_first = true;
params.antiprompt.push_back("<|im_start|>user\n");
}
// enable interactive mode if interactive start is specified // enable interactive mode if interactive start is specified
if (params.interactive_first) { if (params.interactive_first) {
@ -705,7 +720,7 @@ int main(int argc, char ** argv) {
is_interacting = true; is_interacting = true;
printf("\n"); printf("\n");
} else if (params.instruct) { } else if (params.instruct || params.chatml) {
is_interacting = true; is_interacting = true;
} }
} }
@ -713,7 +728,7 @@ int main(int argc, char ** argv) {
if (n_past > 0 && is_interacting) { if (n_past > 0 && is_interacting) {
LOG("waiting for user input\n"); LOG("waiting for user input\n");
if (params.instruct) { if (params.instruct || params.chatml) {
printf("\n> "); printf("\n> ");
} }
@ -760,6 +775,12 @@ int main(int argc, char ** argv) {
n_consumed = embd_inp.size(); n_consumed = embd_inp.size();
embd_inp.insert(embd_inp.end(), inp_pfx.begin(), inp_pfx.end()); embd_inp.insert(embd_inp.end(), inp_pfx.begin(), inp_pfx.end());
} }
// chatml mode: insert user chat prefix
if (params.chatml && !is_antiprompt) {
LOG("inserting chatml prefix\n");
n_consumed = embd_inp.size();
embd_inp.insert(embd_inp.end(), cml_pfx.begin(), cml_pfx.end());
}
if (params.escape) { if (params.escape) {
process_escapes(buffer); process_escapes(buffer);
} }
@ -778,6 +799,11 @@ int main(int argc, char ** argv) {
LOG("inserting instruction suffix\n"); LOG("inserting instruction suffix\n");
embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end()); embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end());
} }
// chatml mode: insert assistant chat suffix
if (params.chatml) {
LOG("inserting chatml suffix\n");
embd_inp.insert(embd_inp.end(), cml_sfx.begin(), cml_sfx.end());
}
for (size_t i = original_size; i < embd_inp.size(); ++i) { for (size_t i = original_size; i < embd_inp.size(); ++i) {
const llama_token token = embd_inp[i]; const llama_token token = embd_inp[i];
@ -803,7 +829,7 @@ int main(int argc, char ** argv) {
} }
// end of text token // end of text token
if (!embd.empty() && embd.back() == llama_token_eos(model) && !(params.instruct || params.interactive)) { if (!embd.empty() && embd.back() == llama_token_eos(model) && !(params.instruct || params.interactive || params.chatml)) {
LOG_TEE(" [end of text]\n"); LOG_TEE(" [end of text]\n");
break; break;
} }