This commit is contained in:
Henri Vasserman 2023-05-19 01:47:26 +03:00
parent 7f59af52a9
commit 7df9ab9687
No known key found for this signature in database
GPG Key ID: 2995FC0F58B1A986

View File

@ -176,17 +176,23 @@ int main(int argc, char ** argv) {
if (!params.steering_add.empty() || !params.steering_sub.empty()) if (!params.steering_add.empty() || !params.steering_sub.empty())
{ {
fprintf(stderr, "%s: steering: ('%s' - '%s') * %f\n",
__func__, params.steering_add.c_str(), params.steering_sub.c_str(), params.steering_mul);
params.steering_add.insert(0, 1, ' ');
params.steering_sub.insert(0, 1, ' ');
auto add_tokens = ::llama_tokenize(ctx, params.steering_add, true); auto add_tokens = ::llama_tokenize(ctx, params.steering_add, true);
auto sub_tokens = ::llama_tokenize(ctx, params.steering_sub, true); auto sub_tokens = ::llama_tokenize(ctx, params.steering_sub, true);
if (add_tokens.size() != sub_tokens.size()) { if (add_tokens.size() != sub_tokens.size()) {
while (add_tokens.size() < sub_tokens.size()) { while (add_tokens.size() < sub_tokens.size()) {
add_tokens.push_back(llama_token_nl()); add_tokens.push_back(llama_token_nl());
} }
while (sub_tokens.size() < add_tokens.size()) { while (sub_tokens.size() < add_tokens.size()) {
sub_tokens.push_back(llama_token_nl()); sub_tokens.push_back(llama_token_nl());
} }
} }
llama_set_steering_write(ctx, params.steering_source, +1.0f); llama_set_steering_write(ctx, params.steering_source, +1.0f);
@ -196,7 +202,6 @@ int main(int argc, char ** argv) {
llama_eval(ctx, sub_tokens.data(), std::min((int)sub_tokens.size(), n_ctx), 0, params.n_threads); llama_eval(ctx, sub_tokens.data(), std::min((int)sub_tokens.size(), n_ctx), 0, params.n_threads);
llama_set_steering_read(ctx, params.steering_layer, params.steering_mul); llama_set_steering_read(ctx, params.steering_layer, params.steering_mul);
std::cout << "Steering: `" << params.steering_add << "` - `" << params.steering_sub << "` * " << params.steering_mul << "\n";
} }
// debug message about similarity of saved session, if applicable // debug message about similarity of saved session, if applicable