batched : add NGL arg

This commit is contained in:
Georgi Gerganov 2023-10-23 20:36:12 +03:00
parent 8fb1be642e
commit 6a30bf3e51
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -11,7 +11,7 @@ int main(int argc, char ** argv) {
gpt_params params; gpt_params params;
if (argc == 1 || argv[1][0] == '-') { if (argc == 1 || argv[1][0] == '-') {
printf("usage: %s MODEL_PATH [PROMPT] [PARALLEL] [LEN]\n" , argv[0]); printf("usage: %s MODEL_PATH [PROMPT] [PARALLEL] [LEN] [NGL]\n" , argv[0]);
return 1 ; return 1 ;
} }
@ -21,6 +21,9 @@ int main(int argc, char ** argv) {
// total length of the sequences including the prompt // total length of the sequences including the prompt
int n_len = 32; int n_len = 32;
// number of layers to offload to the GPU
int n_gpu_layers = 0;
if (argc >= 2) { if (argc >= 2) {
params.model = argv[1]; params.model = argv[1];
} }
@ -37,6 +40,10 @@ int main(int argc, char ** argv) {
n_len = std::atoi(argv[4]); n_len = std::atoi(argv[4]);
} }
if (argc >= 6) {
n_gpu_layers = std::atoi(argv[5]);
}
if (params.prompt.empty()) { if (params.prompt.empty()) {
params.prompt = "Hello my name is"; params.prompt = "Hello my name is";
} }
@ -49,7 +56,7 @@ int main(int argc, char ** argv) {
llama_model_params model_params = llama_model_default_params(); llama_model_params model_params = llama_model_default_params();
// model_params.n_gpu_layers = 99; // offload all layers to the GPU model_params.n_gpu_layers = n_gpu_layers;
llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params); llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params);