separate source layer for steering vector.

This commit is contained in:
Henri Vasserman 2023-05-16 18:43:40 +03:00
parent 8388aaa604
commit c90059fba6
No known key found for this signature in database
GPG Key ID: 2995FC0F58B1A986
3 changed files with 11 additions and 3 deletions

View File

@ -362,6 +362,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
break; break;
} }
params.steering_mul = std::stof(argv[i]); params.steering_mul = std::stof(argv[i]);
} else if (arg == "--steering-source") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.steering_source = std::stoi(argv[i]);
} else if (arg == "--steering-layer") { } else if (arg == "--steering-layer") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
@ -456,8 +462,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stderr, " number of layers to store in VRAM\n"); fprintf(stderr, " number of layers to store in VRAM\n");
fprintf(stderr, " --steering-add add positive steering prompt\n"); fprintf(stderr, " --steering-add add positive steering prompt\n");
fprintf(stderr, " --steering-sub add negativ steering prompt\n"); fprintf(stderr, " --steering-sub add negativ steering prompt\n");
fprintf(stderr, " --steering-mul set steering strength (negative is reverse, default %.1f)\n", params.steering_mul); fprintf(stderr, " --steering-mul steering strength (negative is reverse, default %.1f)\n", params.steering_mul);
fprintf(stderr, " --steering-layer set layer for steering (default %d)\n", params.steering_layer); fprintf(stderr, " --steering-source layer for steering source (default %d)\n", params.steering_source);
fprintf(stderr, " --steering-layer layer for steering insertion (default %d)\n", params.steering_layer);
fprintf(stderr, " --mtest compute maximum memory usage\n"); fprintf(stderr, " --mtest compute maximum memory usage\n");
fprintf(stderr, " --verbose-prompt print prompt before generation\n"); fprintf(stderr, " --verbose-prompt print prompt before generation\n");
fprintf(stderr, " --lora FNAME apply LoRA adapter (implies --no-mmap)\n"); fprintf(stderr, " --lora FNAME apply LoRA adapter (implies --no-mmap)\n");

View File

@ -77,6 +77,7 @@ struct gpt_params {
std::string steering_sub; std::string steering_sub;
float steering_mul = 1.0f; float steering_mul = 1.0f;
int steering_layer = 15; int steering_layer = 15;
int steering_source = 2;
}; };
bool gpt_params_parse(int argc, char ** argv, gpt_params & params); bool gpt_params_parse(int argc, char ** argv, gpt_params & params);

View File

@ -191,7 +191,7 @@ int main(int argc, char ** argv) {
// } // }
//} //}
//const int N = embd_inp.size(); //const int N = embd_inp.size();
llama_set_steering_write(ctx, params.steering_layer, +1.0f); llama_set_steering_write(ctx, params.steering_source, +1.0f);
llama_eval(ctx, add_tokens.data(), std::min((int)add_tokens.size(), n_ctx), 0, params.n_threads); llama_eval(ctx, add_tokens.data(), std::min((int)add_tokens.size(), n_ctx), 0, params.n_threads);
llama_set_steering_write(ctx, params.steering_layer, -1.0f); llama_set_steering_write(ctx, params.steering_layer, -1.0f);