From 7f59af52a90b8011005a8a4aefa109f612cb490d Mon Sep 17 00:00:00 2001 From: Laura Date: Thu, 18 May 2023 23:47:10 +0200 Subject: [PATCH] Steer with inpSA instead of with inpL Signed-off-by: Henri Vasserman --- examples/main/main.cpp | 25 ++++++++++++------------- llama.cpp | 5 +++-- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 18280bde1..974e1277b 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -176,28 +176,27 @@ int main(int argc, char ** argv) { if (!params.steering_add.empty() || !params.steering_sub.empty()) { - params.steering_add.insert(0, 1, ' '); - params.steering_sub.insert(0, 1, ' '); - auto add_tokens = ::llama_tokenize(ctx, params.steering_add, true); auto sub_tokens = ::llama_tokenize(ctx, params.steering_sub, true); - //if (add_tokens.size() != sub_tokens.size()) { - // while (add_tokens.size() < sub_tokens.size()) { - // add_tokens.push_back(llama_token_nl()); - // } - // while (sub_tokens.size() < add_tokens.size()) { - // sub_tokens.push_back(llama_token_nl()); - // } - //} - //const int N = embd_inp.size(); + + if (add_tokens.size() != sub_tokens.size()) { + while (add_tokens.size() < sub_tokens.size()) { + add_tokens.push_back(llama_token_nl()); + } + while (sub_tokens.size() < add_tokens.size()) { + sub_tokens.push_back(llama_token_nl()); + } + } + llama_set_steering_write(ctx, params.steering_source, +1.0f); llama_eval(ctx, add_tokens.data(), std::min((int)add_tokens.size(), n_ctx), 0, params.n_threads); - llama_set_steering_write(ctx, params.steering_layer, -1.0f); + llama_set_steering_write(ctx, params.steering_source, -1.0f); llama_eval(ctx, sub_tokens.data(), std::min((int)sub_tokens.size(), n_ctx), 0, params.n_threads); llama_set_steering_read(ctx, params.steering_layer, params.steering_mul); + std::cout << "Steering: `" << params.steering_add << "` - `" << params.steering_sub << "` * " << params.steering_mul << "\n"; } // debug message about similarity of saved session, if applicable diff --git a/llama.cpp b/llama.cpp index 5e85e55d5..a02ef4cb8 100644 --- a/llama.cpp +++ b/llama.cpp @@ -32,6 +32,7 @@ #include #include #include +#include #define LLAMA_USE_SCRATCH #define LLAMA_MAX_SCRATCH_BUFFERS 16 @@ -1187,8 +1188,8 @@ static bool llama_eval_internal( ggml_add(ctx0, ggml_scale(ctx0, inpL, scal), steer), steer)); break; } - - inpL = ggml_add(ctx0, ggml_scale(ctx0, steer, scal), inpL); + // std::cout << "\nAdding steering vector to inpL " << il << "\n"; + inpSA = ggml_add(ctx0, ggml_scale(ctx0, steer, scal), inpSA); } // norm