diff --git a/examples/batched-bench/README.md b/examples/batched-bench/README.md index 28dbbdca9..34b343f66 100644 --- a/examples/batched-bench/README.md +++ b/examples/batched-bench/README.md @@ -10,7 +10,7 @@ There are 2 modes of operation: - `prompt is shared` - there is a common prompt of size `PP` used by all batches (i.e. `N_KV = PP + B*TG`) ```bash -./batched-bench MODEL_PATH [N_KV_MAX] [IS_PP_SHARED] [NGL] +./batched-bench MODEL_PATH [N_KV_MAX] [IS_PP_SHARED] [NGL] [MMQ] # LLaMA 7B, F16, N_KV_MAX = 16384 (8GB), prompt not shared ./batched-bench ./models/llama-7b/ggml-model-f16.gguf 16384 0 99 @@ -19,7 +19,7 @@ There are 2 modes of operation: ./batched-bench ./models/llama-7b/ggml-model-q8_0.gguf 16384 1 99 # custom set of batches -./batched-bench ./models/llama-7b/ggml-model-q8_0.gguf 2048 0 999 128,256,512 128,256 1,2,4,8,16,32 +./batched-bench ./models/llama-7b/ggml-model-q8_0.gguf 2048 0 999 0 128,256,512 128,256 1,2,4,8,16,32 ``` ## Sample results diff --git a/examples/batched-bench/batched-bench.cpp b/examples/batched-bench/batched-bench.cpp index 675651316..3e1e0716d 100644 --- a/examples/batched-bench/batched-bench.cpp +++ b/examples/batched-bench/batched-bench.cpp @@ -32,15 +32,16 @@ int main(int argc, char ** argv) { gpt_params params; if (argc == 1 || argv[1][0] == '-') { - printf("usage: %s MODEL_PATH [N_KV_MAX] [IS_PP_SHARED] [NGL] \n" , argv[0]); + printf("usage: %s MODEL_PATH [N_KV_MAX] [IS_PP_SHARED] [NGL] [MMQ] \n" , argv[0]); printf(" , and PL are comma-separated lists of numbers without spaces\n\n"); - printf(" example: %s ggml-model-f16.gguf 2048 0 999 128,256,512 128,256 1,2,4,8,16,32\n\n", argv[0]); + printf(" example: %s ggml-model-f16.gguf 2048 0 999 0 128,256,512 128,256 1,2,4,8,16,32\n\n", argv[0]); return 1 ; } int n_kv_max = 2048; int is_pp_shared = 0; int n_gpu_layers = 0; + int mmq = 0; std::vector n_pp = { 128, 256, 512, 1024, 2048, 3584, 7680, }; std::vector n_tg = { 128, 256, }; @@ -64,15 +65,19 @@ int main(int argc, char ** argv) { } if (argc >= 6) { - n_pp = parse_list(argv[5]); + mmq = std::atoi(argv[5]); } if (argc >= 7) { - n_tg = parse_list(argv[6]); + n_pp = parse_list(argv[6]); } if (argc >= 8) { - n_pl = parse_list(argv[7]); + n_tg = parse_list(argv[7]); + } + + if (argc >= 9) { + n_pl = parse_list(argv[8]); } // init LLM @@ -94,9 +99,11 @@ int main(int argc, char ** argv) { llama_context_params ctx_params = llama_context_default_params(); - ctx_params.seed = 1234; - ctx_params.n_ctx = n_kv_max; - ctx_params.n_batch = 512; + ctx_params.seed = 1234; + ctx_params.n_ctx = n_kv_max; + ctx_params.n_batch = 512; + ctx_params.mul_mat_q = mmq; + ctx_params.n_threads = params.n_threads; ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;