diff --git a/examples/common.cpp b/examples/common.cpp index aaf6e27a9..da09853bd 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -362,6 +362,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { break; } params.steering_mul = std::stof(argv[i]); + } else if (arg == "--steering-source") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.steering_source = std::stoi(argv[i]); } else if (arg == "--steering-layer") { if (++i >= argc) { invalid_param = true; @@ -456,8 +462,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stderr, " number of layers to store in VRAM\n"); fprintf(stderr, " --steering-add add positive steering prompt\n"); fprintf(stderr, " --steering-sub add negativ steering prompt\n"); - fprintf(stderr, " --steering-mul set steering strength (negative is reverse, default %.1f)\n", params.steering_mul); - fprintf(stderr, " --steering-layer set layer for steering (default %d)\n", params.steering_layer); + fprintf(stderr, " --steering-mul steering strength (negative is reverse, default %.1f)\n", params.steering_mul); + fprintf(stderr, " --steering-source layer for steering source (default %d)\n", params.steering_source); + fprintf(stderr, " --steering-layer layer for steering insertion (default %d)\n", params.steering_layer); fprintf(stderr, " --mtest compute maximum memory usage\n"); fprintf(stderr, " --verbose-prompt print prompt before generation\n"); fprintf(stderr, " --lora FNAME apply LoRA adapter (implies --no-mmap)\n"); diff --git a/examples/common.h b/examples/common.h index e56ad648e..04883dcf3 100644 --- a/examples/common.h +++ b/examples/common.h @@ -77,6 +77,7 @@ struct gpt_params { std::string steering_sub; float steering_mul = 1.0f; int steering_layer = 15; + int steering_source = 2; }; bool gpt_params_parse(int argc, char ** argv, gpt_params & params); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index ffa779e05..18280bde1 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -191,7 +191,7 @@ int main(int argc, char ** argv) { // } //} //const int N = embd_inp.size(); - llama_set_steering_write(ctx, params.steering_layer, +1.0f); + llama_set_steering_write(ctx, params.steering_source, +1.0f); llama_eval(ctx, add_tokens.data(), std::min((int)add_tokens.size(), n_ctx), 0, params.n_threads); llama_set_steering_write(ctx, params.steering_layer, -1.0f);