Opt class for positional argument handling (#10508)

Added support for positional arguments `model` and `prompt`. Added
functionality to download via strings like:

  llama-run llama3
  llama-run ollama://granite-code
  llama-run ollama://granite-code:8b
  llama-run hf://QuantFactory/SmolLM-135M-GGUF/SmolLM-135M.Q2_K.gguf
  llama-run huggingface://bartowski/SmolLM-1.7B-Instruct-v0.2-GGUF/SmolLM-1.7B-Instruct-v0.2-IQ3_M.gguf
  llama-run https://example.com/some-file1.gguf
  llama-run some-file2.gguf
  llama-run file://some-file3.gguf

Signed-off-by: Eric Curtin <ecurtin@redhat.com>
This commit is contained in:
Eric Curtin 2024-12-13 18:34:25 +00:00 committed by GitHub
parent 11e07fd63b
commit c27ac678dd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 542 additions and 163 deletions

View File

@ -433,6 +433,20 @@ To learn more about model quantization, [read this documentation](examples/quant
</details> </details>
## [`llama-run`](examples/run)
#### A comprehensive example for running `llama.cpp` models. Useful for inferencing. Used with RamaLama [^3].
- <details>
<summary>Run a model with a specific prompt (by default it's pulled from Ollama registry)</summary>
```bash
llama-run granite-code
```
</details>
[^3]: [https://github.com/containers/ramalama](RamaLama)
## [`llama-simple`](examples/simple) ## [`llama-simple`](examples/simple)

View File

@ -81,7 +81,7 @@ set(LLAMA_COMMON_EXTRA_LIBS build_info)
# Use curl to download model url # Use curl to download model url
if (LLAMA_CURL) if (LLAMA_CURL)
find_package(CURL REQUIRED) find_package(CURL REQUIRED)
add_definitions(-DLLAMA_USE_CURL) target_compile_definitions(${TARGET} PUBLIC LLAMA_USE_CURL)
include_directories(${CURL_INCLUDE_DIRS}) include_directories(${CURL_INCLUDE_DIRS})
find_library(CURL_LIBRARY curl REQUIRED) find_library(CURL_LIBRARY curl REQUIRED)
set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} ${CURL_LIBRARY}) set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} ${CURL_LIBRARY})

View File

@ -1076,12 +1076,6 @@ struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_p
#define CURL_MAX_RETRY 3 #define CURL_MAX_RETRY 3
#define CURL_RETRY_DELAY_SECONDS 2 #define CURL_RETRY_DELAY_SECONDS 2
static bool starts_with(const std::string & str, const std::string & prefix) {
// While we wait for C++20's std::string::starts_with...
return str.rfind(prefix, 0) == 0;
}
static bool curl_perform_with_retry(const std::string& url, CURL* curl, int max_attempts, int retry_delay_seconds) { static bool curl_perform_with_retry(const std::string& url, CURL* curl, int max_attempts, int retry_delay_seconds) {
int remaining_attempts = max_attempts; int remaining_attempts = max_attempts;

View File

@ -37,9 +37,9 @@ using llama_tokens = std::vector<llama_token>;
// build info // build info
extern int LLAMA_BUILD_NUMBER; extern int LLAMA_BUILD_NUMBER;
extern char const * LLAMA_COMMIT; extern const char * LLAMA_COMMIT;
extern char const * LLAMA_COMPILER; extern const char * LLAMA_COMPILER;
extern char const * LLAMA_BUILD_TARGET; extern const char * LLAMA_BUILD_TARGET;
struct common_control_vector_load_info; struct common_control_vector_load_info;
@ -437,6 +437,11 @@ std::vector<std::string> string_split<std::string>(const std::string & input, ch
return parts; return parts;
} }
static bool string_starts_with(const std::string & str,
const std::string & prefix) { // While we wait for C++20's std::string::starts_with...
return str.rfind(prefix, 0) == 0;
}
bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides); bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);
void string_process_escapes(std::string & input); void string_process_escapes(std::string & input);

View File

@ -1,5 +1,5 @@
set(TARGET llama-run) set(TARGET llama-run)
add_executable(${TARGET} run.cpp) add_executable(${TARGET} run.cpp)
install(TARGETS ${TARGET} RUNTIME) install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE llama ${CMAKE_THREAD_LIBS_INIT}) target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_17) target_compile_features(${TARGET} PRIVATE cxx_std_17)

View File

@ -3,5 +3,45 @@
The purpose of this example is to demonstrate a minimal usage of llama.cpp for running models. The purpose of this example is to demonstrate a minimal usage of llama.cpp for running models.
```bash ```bash
./llama-run Meta-Llama-3.1-8B-Instruct.gguf llama-run granite-code
...
```bash
llama-run -h
Description:
Runs a llm
Usage:
llama-run [options] model [prompt]
Options:
-c, --context-size <value>
Context size (default: 2048)
-n, --ngl <value>
Number of GPU layers (default: 0)
-h, --help
Show help message
Commands:
model
Model is a string with an optional prefix of
huggingface:// (hf://), ollama://, https:// or file://.
If no protocol is specified and a file exists in the specified
path, file:// is assumed, otherwise if a file does not exist in
the specified path, ollama:// is assumed. Models that are being
pulled are downloaded with .partial extension while being
downloaded and then renamed as the file without the .partial
extension when complete.
Examples:
llama-run llama3
llama-run ollama://granite-code
llama-run ollama://smollm:135m
llama-run hf://QuantFactory/SmolLM-135M-GGUF/SmolLM-135M.Q2_K.gguf
llama-run huggingface://bartowski/SmolLM-1.7B-Instruct-v0.2-GGUF/SmolLM-1.7B-Instruct-v0.2-IQ3_M.gguf
llama-run https://example.com/some-file1.gguf
llama-run some-file2.gguf
llama-run file://some-file3.gguf
llama-run --ngl 99 some-file4.gguf
llama-run --ngl 99 some-file5.gguf Hello World
... ...

View File

@ -4,110 +4,330 @@
# include <unistd.h> # include <unistd.h>
#endif #endif
#include <climits> #if defined(LLAMA_USE_CURL)
# include <curl/curl.h>
#endif
#include <cstdarg>
#include <cstdio> #include <cstdio>
#include <cstring> #include <cstring>
#include <filesystem>
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include <string> #include <string>
#include <unordered_map>
#include <vector> #include <vector>
#include "common.h"
#include "json.hpp"
#include "llama-cpp.h" #include "llama-cpp.h"
typedef std::unique_ptr<char[]> char_array_ptr; #define printe(...) \
do { \
fprintf(stderr, __VA_ARGS__); \
} while (0)
struct Argument { class Opt {
std::string flag;
std::string help_text;
};
struct Options {
std::string model_path, prompt_non_interactive;
int ngl = 99;
int n_ctx = 2048;
};
class ArgumentParser {
public: public:
ArgumentParser(const char * program_name) : program_name(program_name) {} int init(int argc, const char ** argv) {
construct_help_str_();
void add_argument(const std::string & flag, std::string & var, const std::string & help_text = "") { // Parse arguments
string_args[flag] = &var; if (parse(argc, argv)) {
arguments.push_back({flag, help_text}); printe("Error: Failed to parse arguments.\n");
help();
return 1;
} }
void add_argument(const std::string & flag, int & var, const std::string & help_text = "") { // If help is requested, show help and exit
int_args[flag] = &var; if (help_) {
arguments.push_back({flag, help_text}); help();
return 2;
}
return 0; // Success
}
std::string model_;
std::string user_;
int context_size_ = 2048, ngl_ = -1;
private:
std::string help_str_;
bool help_ = false;
void construct_help_str_() {
help_str_ =
"Description:\n"
" Runs a llm\n"
"\n"
"Usage:\n"
" llama-run [options] model [prompt]\n"
"\n"
"Options:\n"
" -c, --context-size <value>\n"
" Context size (default: " +
std::to_string(context_size_);
help_str_ +=
")\n"
" -n, --ngl <value>\n"
" Number of GPU layers (default: " +
std::to_string(ngl_);
help_str_ +=
")\n"
" -h, --help\n"
" Show help message\n"
"\n"
"Commands:\n"
" model\n"
" Model is a string with an optional prefix of \n"
" huggingface:// (hf://), ollama://, https:// or file://.\n"
" If no protocol is specified and a file exists in the specified\n"
" path, file:// is assumed, otherwise if a file does not exist in\n"
" the specified path, ollama:// is assumed. Models that are being\n"
" pulled are downloaded with .partial extension while being\n"
" downloaded and then renamed as the file without the .partial\n"
" extension when complete.\n"
"\n"
"Examples:\n"
" llama-run llama3\n"
" llama-run ollama://granite-code\n"
" llama-run ollama://smollm:135m\n"
" llama-run hf://QuantFactory/SmolLM-135M-GGUF/SmolLM-135M.Q2_K.gguf\n"
" llama-run huggingface://bartowski/SmolLM-1.7B-Instruct-v0.2-GGUF/SmolLM-1.7B-Instruct-v0.2-IQ3_M.gguf\n"
" llama-run https://example.com/some-file1.gguf\n"
" llama-run some-file2.gguf\n"
" llama-run file://some-file3.gguf\n"
" llama-run --ngl 99 some-file4.gguf\n"
" llama-run --ngl 99 some-file5.gguf Hello World\n";
} }
int parse(int argc, const char ** argv) { int parse(int argc, const char ** argv) {
int positional_args_i = 0;
for (int i = 1; i < argc; ++i) { for (int i = 1; i < argc; ++i) {
std::string arg = argv[i]; if (strcmp(argv[i], "-c") == 0 || strcmp(argv[i], "--context-size") == 0) {
if (string_args.count(arg)) { if (i + 1 >= argc) {
if (i + 1 < argc) {
*string_args[arg] = argv[++i];
} else {
fprintf(stderr, "error: missing value for %s\n", arg.c_str());
print_usage();
return 1; return 1;
} }
} else if (int_args.count(arg)) {
if (i + 1 < argc) { context_size_ = std::atoi(argv[++i]);
if (parse_int_arg(argv[++i], *int_args[arg]) != 0) { } else if (strcmp(argv[i], "-n") == 0 || strcmp(argv[i], "--ngl") == 0) {
fprintf(stderr, "error: invalid value for %s: %s\n", arg.c_str(), argv[i]); if (i + 1 >= argc) {
print_usage();
return 1; return 1;
} }
ngl_ = std::atoi(argv[++i]);
} else if (strcmp(argv[i], "-h") == 0 || strcmp(argv[i], "--help") == 0) {
help_ = true;
return 0;
} else if (!positional_args_i) {
++positional_args_i;
model_ = argv[i];
} else if (positional_args_i == 1) {
++positional_args_i;
user_ = argv[i];
} else { } else {
fprintf(stderr, "error: missing value for %s\n", arg.c_str()); user_ += " " + std::string(argv[i]);
print_usage();
return 1;
}
} else {
fprintf(stderr, "error: unrecognized argument %s\n", arg.c_str());
print_usage();
return 1;
} }
} }
if (string_args["-m"]->empty()) { return model_.empty(); // model_ is the only required value
fprintf(stderr, "error: -m is required\n"); }
print_usage();
void help() const { printf("%s", help_str_.c_str()); }
};
struct progress_data {
size_t file_size = 0;
std::chrono::steady_clock::time_point start_time = std::chrono::steady_clock::now();
bool printed = false;
};
struct FileDeleter {
void operator()(FILE * file) const {
if (file) {
fclose(file);
}
}
};
typedef std::unique_ptr<FILE, FileDeleter> FILE_ptr;
#ifdef LLAMA_USE_CURL
class CurlWrapper {
public:
int init(const std::string & url, const std::vector<std::string> & headers, const std::string & output_file,
const bool progress, std::string * response_str = nullptr) {
std::string output_file_partial;
curl = curl_easy_init();
if (!curl) {
return 1; return 1;
} }
progress_data data;
FILE_ptr out;
if (!output_file.empty()) {
output_file_partial = output_file + ".partial";
out.reset(fopen(output_file_partial.c_str(), "ab"));
}
set_write_options(response_str, out);
data.file_size = set_resume_point(output_file_partial);
set_progress_options(progress, data);
set_headers(headers);
perform(url);
if (!output_file.empty()) {
std::filesystem::rename(output_file_partial, output_file);
}
return 0; return 0;
} }
~CurlWrapper() {
if (chunk) {
curl_slist_free_all(chunk);
}
if (curl) {
curl_easy_cleanup(curl);
}
}
private: private:
const char * program_name; CURL * curl = nullptr;
std::unordered_map<std::string, std::string *> string_args; struct curl_slist * chunk = nullptr;
std::unordered_map<std::string, int *> int_args;
std::vector<Argument> arguments;
int parse_int_arg(const char * arg, int & value) { void set_write_options(std::string * response_str, const FILE_ptr & out) {
char * end; if (response_str) {
const long val = std::strtol(arg, &end, 10); curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, capture_data);
if (*end == '\0' && val >= INT_MIN && val <= INT_MAX) { curl_easy_setopt(curl, CURLOPT_WRITEDATA, response_str);
value = static_cast<int>(val); } else {
curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, write_data);
curl_easy_setopt(curl, CURLOPT_WRITEDATA, out.get());
}
}
size_t set_resume_point(const std::string & output_file) {
size_t file_size = 0;
if (std::filesystem::exists(output_file)) {
file_size = std::filesystem::file_size(output_file);
curl_easy_setopt(curl, CURLOPT_RESUME_FROM_LARGE, static_cast<curl_off_t>(file_size));
}
return file_size;
}
void set_progress_options(bool progress, progress_data & data) {
if (progress) {
curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 0L);
curl_easy_setopt(curl, CURLOPT_XFERINFODATA, &data);
curl_easy_setopt(curl, CURLOPT_XFERINFOFUNCTION, progress_callback);
}
}
void set_headers(const std::vector<std::string> & headers) {
if (!headers.empty()) {
if (chunk) {
curl_slist_free_all(chunk);
chunk = 0;
}
for (const auto & header : headers) {
chunk = curl_slist_append(chunk, header.c_str());
}
curl_easy_setopt(curl, CURLOPT_HTTPHEADER, chunk);
}
}
void perform(const std::string & url) {
CURLcode res;
curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L);
curl_easy_setopt(curl, CURLOPT_DEFAULT_PROTOCOL, "https");
curl_easy_setopt(curl, CURLOPT_FAILONERROR, 1L);
res = curl_easy_perform(curl);
if (res != CURLE_OK) {
printe("curl_easy_perform() failed: %s\n", curl_easy_strerror(res));
}
}
static std::string human_readable_time(double seconds) {
int hrs = static_cast<int>(seconds) / 3600;
int mins = (static_cast<int>(seconds) % 3600) / 60;
int secs = static_cast<int>(seconds) % 60;
std::ostringstream out;
if (hrs > 0) {
out << hrs << "h " << std::setw(2) << std::setfill('0') << mins << "m " << std::setw(2) << std::setfill('0')
<< secs << "s";
} else if (mins > 0) {
out << mins << "m " << std::setw(2) << std::setfill('0') << secs << "s";
} else {
out << secs << "s";
}
return out.str();
}
static std::string human_readable_size(curl_off_t size) {
static const char * suffix[] = { "B", "KB", "MB", "GB", "TB" };
char length = sizeof(suffix) / sizeof(suffix[0]);
int i = 0;
double dbl_size = size;
if (size > 1024) {
for (i = 0; (size / 1024) > 0 && i < length - 1; i++, size /= 1024) {
dbl_size = size / 1024.0;
}
}
std::ostringstream out;
out << std::fixed << std::setprecision(2) << dbl_size << " " << suffix[i];
return out.str();
}
static int progress_callback(void * ptr, curl_off_t total_to_download, curl_off_t now_downloaded, curl_off_t,
curl_off_t) {
progress_data * data = static_cast<progress_data *>(ptr);
if (total_to_download <= 0) {
return 0; return 0;
} }
return 1;
total_to_download += data->file_size;
const curl_off_t now_downloaded_plus_file_size = now_downloaded + data->file_size;
const curl_off_t percentage = (now_downloaded_plus_file_size * 100) / total_to_download;
const curl_off_t pos = (percentage / 5);
std::string progress_bar;
for (int i = 0; i < 20; ++i) {
progress_bar.append((i < pos) ? "" : " ");
} }
void print_usage() const { // Calculate download speed and estimated time to completion
printf("\nUsage:\n"); const auto now = std::chrono::steady_clock::now();
printf(" %s [OPTIONS]\n\n", program_name); const std::chrono::duration<double> elapsed_seconds = now - data->start_time;
printf("Options:\n"); const double speed = now_downloaded / elapsed_seconds.count();
for (const auto & arg : arguments) { const double estimated_time = (total_to_download - now_downloaded) / speed;
printf(" %-10s %s\n", arg.flag.c_str(), arg.help_text.c_str()); printe("\r%ld%% |%s| %s/%s %.2f MB/s %s ", percentage, progress_bar.c_str(),
human_readable_size(now_downloaded).c_str(), human_readable_size(total_to_download).c_str(),
speed / (1024 * 1024), human_readable_time(estimated_time).c_str());
fflush(stderr);
data->printed = true;
return 0;
} }
printf("\n"); // Function to write data to a file
static size_t write_data(void * ptr, size_t size, size_t nmemb, void * stream) {
FILE * out = static_cast<FILE *>(stream);
return fwrite(ptr, size, nmemb, out);
}
// Function to capture data into a string
static size_t capture_data(void * ptr, size_t size, size_t nmemb, void * stream) {
std::string * str = static_cast<std::string *>(stream);
str->append(static_cast<char *>(ptr), size * nmemb);
return size * nmemb;
} }
}; };
#endif
class LlamaData { class LlamaData {
public: public:
@ -115,14 +335,16 @@ class LlamaData {
llama_sampler_ptr sampler; llama_sampler_ptr sampler;
llama_context_ptr context; llama_context_ptr context;
std::vector<llama_chat_message> messages; std::vector<llama_chat_message> messages;
std::vector<std::string> msg_strs;
std::vector<char> fmtted;
int init(const Options & opt) { int init(Opt & opt) {
model = initialize_model(opt.model_path, opt.ngl); model = initialize_model(opt);
if (!model) { if (!model) {
return 1; return 1;
} }
context = initialize_context(model, opt.n_ctx); context = initialize_context(model, opt.context_size_);
if (!context) { if (!context) {
return 1; return 1;
} }
@ -132,14 +354,122 @@ class LlamaData {
} }
private: private:
// Initializes the model and returns a unique pointer to it #ifdef LLAMA_USE_CURL
llama_model_ptr initialize_model(const std::string & model_path, const int ngl) { int download(const std::string & url, const std::vector<std::string> & headers, const std::string & output_file,
llama_model_params model_params = llama_model_default_params(); const bool progress, std::string * response_str = nullptr) {
model_params.n_gpu_layers = ngl; CurlWrapper curl;
if (curl.init(url, headers, output_file, progress, response_str)) {
return 1;
}
llama_model_ptr model(llama_load_model_from_file(model_path.c_str(), model_params)); return 0;
}
#else
int download(const std::string &, const std::vector<std::string> &, const std::string &, const bool,
std::string * = nullptr) {
printe("%s: llama.cpp built without libcurl, downloading from an url not supported.\n", __func__);
return 1;
}
#endif
int huggingface_dl(const std::string & model, const std::vector<std::string> headers, const std::string & bn) {
// Find the second occurrence of '/' after protocol string
size_t pos = model.find('/');
pos = model.find('/', pos + 1);
if (pos == std::string::npos) {
return 1;
}
const std::string hfr = model.substr(0, pos);
const std::string hff = model.substr(pos + 1);
const std::string url = "https://huggingface.co/" + hfr + "/resolve/main/" + hff;
return download(url, headers, bn, true);
}
int ollama_dl(std::string & model, const std::vector<std::string> headers, const std::string & bn) {
if (model.find('/') == std::string::npos) {
model = "library/" + model;
}
std::string model_tag = "latest";
size_t colon_pos = model.find(':');
if (colon_pos != std::string::npos) {
model_tag = model.substr(colon_pos + 1);
model = model.substr(0, colon_pos);
}
std::string manifest_url = "https://registry.ollama.ai/v2/" + model + "/manifests/" + model_tag;
std::string manifest_str;
const int ret = download(manifest_url, headers, "", false, &manifest_str);
if (ret) {
return ret;
}
nlohmann::json manifest = nlohmann::json::parse(manifest_str);
std::string layer;
for (const auto & l : manifest["layers"]) {
if (l["mediaType"] == "application/vnd.ollama.image.model") {
layer = l["digest"];
break;
}
}
std::string blob_url = "https://registry.ollama.ai/v2/" + model + "/blobs/" + layer;
return download(blob_url, headers, bn, true);
}
std::string basename(const std::string & path) {
const size_t pos = path.find_last_of("/\\");
if (pos == std::string::npos) {
return path;
}
return path.substr(pos + 1);
}
int remove_proto(std::string & model_) {
const std::string::size_type pos = model_.find("://");
if (pos == std::string::npos) {
return 1;
}
model_ = model_.substr(pos + 3); // Skip past "://"
return 0;
}
int resolve_model(std::string & model_) {
const std::string bn = basename(model_);
const std::vector<std::string> headers = { "--header",
"Accept: application/vnd.docker.distribution.manifest.v2+json" };
int ret = 0;
if (string_starts_with(model_, "file://") || std::filesystem::exists(bn)) {
remove_proto(model_);
} else if (string_starts_with(model_, "hf://") || string_starts_with(model_, "huggingface://")) {
remove_proto(model_);
ret = huggingface_dl(model_, headers, bn);
} else if (string_starts_with(model_, "ollama://")) {
remove_proto(model_);
ret = ollama_dl(model_, headers, bn);
} else if (string_starts_with(model_, "https://")) {
download(model_, headers, bn, true);
} else {
ret = ollama_dl(model_, headers, bn);
}
model_ = bn;
return ret;
}
// Initializes the model and returns a unique pointer to it
llama_model_ptr initialize_model(Opt & opt) {
ggml_backend_load_all();
llama_model_params model_params = llama_model_default_params();
model_params.n_gpu_layers = opt.ngl_ >= 0 ? opt.ngl_ : model_params.n_gpu_layers;
resolve_model(opt.model_);
llama_model_ptr model(llama_load_model_from_file(opt.model_.c_str(), model_params));
if (!model) { if (!model) {
fprintf(stderr, "%s: error: unable to load model\n", __func__); printe("%s: error: unable to load model from file: %s\n", __func__, opt.model_.c_str());
} }
return model; return model;
@ -150,10 +480,9 @@ class LlamaData {
llama_context_params ctx_params = llama_context_default_params(); llama_context_params ctx_params = llama_context_default_params();
ctx_params.n_ctx = n_ctx; ctx_params.n_ctx = n_ctx;
ctx_params.n_batch = n_ctx; ctx_params.n_batch = n_ctx;
llama_context_ptr context(llama_new_context_with_model(model.get(), ctx_params)); llama_context_ptr context(llama_new_context_with_model(model.get(), ctx_params));
if (!context) { if (!context) {
fprintf(stderr, "%s: error: failed to create the llama_context\n", __func__); printe("%s: error: failed to create the llama_context\n", __func__);
} }
return context; return context;
@ -170,23 +499,22 @@ class LlamaData {
} }
}; };
// Add a message to `messages` and store its content in `owned_content` // Add a message to `messages` and store its content in `msg_strs`
static void add_message(const char * role, const std::string & text, LlamaData & llama_data, static void add_message(const char * role, const std::string & text, LlamaData & llama_data) {
std::vector<char_array_ptr> & owned_content) { llama_data.msg_strs.push_back(std::move(text));
char_array_ptr content(new char[text.size() + 1]); llama_data.messages.push_back({ role, llama_data.msg_strs.back().c_str() });
std::strcpy(content.get(), text.c_str());
llama_data.messages.push_back({role, content.get()});
owned_content.push_back(std::move(content));
} }
// Function to apply the chat template and resize `formatted` if needed // Function to apply the chat template and resize `formatted` if needed
static int apply_chat_template(const LlamaData & llama_data, std::vector<char> & formatted, const bool append) { static int apply_chat_template(LlamaData & llama_data, const bool append) {
int result = llama_chat_apply_template(llama_data.model.get(), nullptr, llama_data.messages.data(), int result = llama_chat_apply_template(
llama_data.messages.size(), append, formatted.data(), formatted.size()); llama_data.model.get(), nullptr, llama_data.messages.data(), llama_data.messages.size(), append,
if (result > static_cast<int>(formatted.size())) { append ? llama_data.fmtted.data() : nullptr, append ? llama_data.fmtted.size() : 0);
formatted.resize(result); if (append && result > static_cast<int>(llama_data.fmtted.size())) {
llama_data.fmtted.resize(result);
result = llama_chat_apply_template(llama_data.model.get(), nullptr, llama_data.messages.data(), result = llama_chat_apply_template(llama_data.model.get(), nullptr, llama_data.messages.data(),
llama_data.messages.size(), append, formatted.data(), formatted.size()); llama_data.messages.size(), append, llama_data.fmtted.data(),
llama_data.fmtted.size());
} }
return result; return result;
@ -199,7 +527,8 @@ static int tokenize_prompt(const llama_model_ptr & model, const std::string & pr
prompt_tokens.resize(n_prompt_tokens); prompt_tokens.resize(n_prompt_tokens);
if (llama_tokenize(model.get(), prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true, if (llama_tokenize(model.get(), prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true,
true) < 0) { true) < 0) {
GGML_ABORT("failed to tokenize the prompt\n"); printe("failed to tokenize the prompt\n");
return -1;
} }
return n_prompt_tokens; return n_prompt_tokens;
@ -211,7 +540,7 @@ static int check_context_size(const llama_context_ptr & ctx, const llama_batch &
const int n_ctx_used = llama_get_kv_cache_used_cells(ctx.get()); const int n_ctx_used = llama_get_kv_cache_used_cells(ctx.get());
if (n_ctx_used + batch.n_tokens > n_ctx) { if (n_ctx_used + batch.n_tokens > n_ctx) {
printf("\033[0m\n"); printf("\033[0m\n");
fprintf(stderr, "context size exceeded\n"); printe("context size exceeded\n");
return 1; return 1;
} }
@ -223,7 +552,8 @@ static int convert_token_to_string(const llama_model_ptr & model, const llama_to
char buf[256]; char buf[256];
int n = llama_token_to_piece(model.get(), token_id, buf, sizeof(buf), 0, true); int n = llama_token_to_piece(model.get(), token_id, buf, sizeof(buf), 0, true);
if (n < 0) { if (n < 0) {
GGML_ABORT("failed to convert token to piece\n"); printe("failed to convert token to piece\n");
return 1;
} }
piece = std::string(buf, n); piece = std::string(buf, n);
@ -238,19 +568,19 @@ static void print_word_and_concatenate_to_response(const std::string & piece, st
// helper function to evaluate a prompt and generate a response // helper function to evaluate a prompt and generate a response
static int generate(LlamaData & llama_data, const std::string & prompt, std::string & response) { static int generate(LlamaData & llama_data, const std::string & prompt, std::string & response) {
std::vector<llama_token> prompt_tokens; std::vector<llama_token> tokens;
const int n_prompt_tokens = tokenize_prompt(llama_data.model, prompt, prompt_tokens); if (tokenize_prompt(llama_data.model, prompt, tokens) < 0) {
if (n_prompt_tokens < 0) {
return 1; return 1;
} }
// prepare a batch for the prompt // prepare a batch for the prompt
llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size()); llama_batch batch = llama_batch_get_one(tokens.data(), tokens.size());
llama_token new_token_id; llama_token new_token_id;
while (true) { while (true) {
check_context_size(llama_data.context, batch); check_context_size(llama_data.context, batch);
if (llama_decode(llama_data.context.get(), batch)) { if (llama_decode(llama_data.context.get(), batch)) {
GGML_ABORT("failed to decode\n"); printe("failed to decode\n");
return 1;
} }
// sample the next token, check is it an end of generation? // sample the next token, check is it an end of generation?
@ -273,22 +603,9 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str
return 0; return 0;
} }
static int parse_arguments(const int argc, const char ** argv, Options & opt) {
ArgumentParser parser(argv[0]);
parser.add_argument("-m", opt.model_path, "model");
parser.add_argument("-p", opt.prompt_non_interactive, "prompt");
parser.add_argument("-c", opt.n_ctx, "context_size");
parser.add_argument("-ngl", opt.ngl, "n_gpu_layers");
if (parser.parse(argc, argv)) {
return 1;
}
return 0;
}
static int read_user_input(std::string & user) { static int read_user_input(std::string & user) {
std::getline(std::cin, user); std::getline(std::cin, user);
return user.empty(); // Indicate an error or empty input return user.empty(); // Should have data in happy path
} }
// Function to generate a response based on the prompt // Function to generate a response based on the prompt
@ -296,7 +613,7 @@ static int generate_response(LlamaData & llama_data, const std::string & prompt,
// Set response color // Set response color
printf("\033[33m"); printf("\033[33m");
if (generate(llama_data, prompt, response)) { if (generate(llama_data, prompt, response)) {
fprintf(stderr, "failed to generate response\n"); printe("failed to generate response\n");
return 1; return 1;
} }
@ -306,11 +623,10 @@ static int generate_response(LlamaData & llama_data, const std::string & prompt,
} }
// Helper function to apply the chat template and handle errors // Helper function to apply the chat template and handle errors
static int apply_chat_template_with_error_handling(const LlamaData & llama_data, std::vector<char> & formatted, static int apply_chat_template_with_error_handling(LlamaData & llama_data, const bool append, int & output_length) {
const bool is_user_input, int & output_length) { const int new_len = apply_chat_template(llama_data, append);
const int new_len = apply_chat_template(llama_data, formatted, is_user_input);
if (new_len < 0) { if (new_len < 0) {
fprintf(stderr, "failed to apply the chat template\n"); printe("failed to apply the chat template\n");
return -1; return -1;
} }
@ -319,49 +635,56 @@ static int apply_chat_template_with_error_handling(const LlamaData & llama_data,
} }
// Helper function to handle user input // Helper function to handle user input
static bool handle_user_input(std::string & user_input, const std::string & prompt_non_interactive) { static int handle_user_input(std::string & user_input, const std::string & user_) {
if (!prompt_non_interactive.empty()) { if (!user_.empty()) {
user_input = prompt_non_interactive; user_input = user_;
return true; // No need for interactive input return 0; // No need for interactive input
} }
printf("\033[32m> \033[0m"); printf(
return !read_user_input(user_input); // Returns false if input ends the loop "\r "
"\r\033[32m> \033[0m");
return read_user_input(user_input); // Returns true if input ends the loop
} }
// Function to tokenize the prompt // Function to tokenize the prompt
static int chat_loop(LlamaData & llama_data, std::string & prompt_non_interactive) { static int chat_loop(LlamaData & llama_data, const std::string & user_) {
std::vector<char_array_ptr> owned_content;
std::vector<char> fmtted(llama_n_ctx(llama_data.context.get()));
int prev_len = 0; int prev_len = 0;
llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get()));
while (true) { while (true) {
// Get user input // Get user input
std::string user_input; std::string user_input;
if (!handle_user_input(user_input, prompt_non_interactive)) { while (handle_user_input(user_input, user_)) {
break;
} }
add_message("user", prompt_non_interactive.empty() ? user_input : prompt_non_interactive, llama_data, add_message("user", user_.empty() ? user_input : user_, llama_data);
owned_content);
int new_len; int new_len;
if (apply_chat_template_with_error_handling(llama_data, fmtted, true, new_len) < 0) { if (apply_chat_template_with_error_handling(llama_data, true, new_len) < 0) {
return 1; return 1;
} }
std::string prompt(fmtted.begin() + prev_len, fmtted.begin() + new_len); std::string prompt(llama_data.fmtted.begin() + prev_len, llama_data.fmtted.begin() + new_len);
std::string response; std::string response;
if (generate_response(llama_data, prompt, response)) { if (generate_response(llama_data, prompt, response)) {
return 1; return 1;
} }
if (!user_.empty()) {
break;
} }
add_message("assistant", response, llama_data);
if (apply_chat_template_with_error_handling(llama_data, false, prev_len) < 0) {
return 1;
}
}
return 0; return 0;
} }
static void log_callback(const enum ggml_log_level level, const char * text, void *) { static void log_callback(const enum ggml_log_level level, const char * text, void *) {
if (level == GGML_LOG_LEVEL_ERROR) { if (level == GGML_LOG_LEVEL_ERROR) {
fprintf(stderr, "%s", text); printe("%s", text);
} }
} }
@ -382,17 +705,20 @@ static std::string read_pipe_data() {
} }
int main(int argc, const char ** argv) { int main(int argc, const char ** argv) {
Options opt; Opt opt;
if (parse_arguments(argc, argv, opt)) { const int ret = opt.init(argc, argv);
if (ret == 2) {
return 0;
} else if (ret) {
return 1; return 1;
} }
if (!is_stdin_a_terminal()) { if (!is_stdin_a_terminal()) {
if (!opt.prompt_non_interactive.empty()) { if (!opt.user_.empty()) {
opt.prompt_non_interactive += "\n\n"; opt.user_ += "\n\n";
} }
opt.prompt_non_interactive += read_pipe_data(); opt.user_ += read_pipe_data();
} }
llama_log_set(log_callback, nullptr); llama_log_set(log_callback, nullptr);
@ -401,7 +727,7 @@ int main(int argc, const char ** argv) {
return 1; return 1;
} }
if (chat_loop(llama_data, opt.prompt_non_interactive)) { if (chat_loop(llama_data, opt.user_)) {
return 1; return 1;
} }