From 4410aa09fb8b5175542dfeeceee6e1f96be36378 Mon Sep 17 00:00:00 2001 From: "Gilad S." <7817232+giladgd@users.noreply.github.com> Date: Thu, 12 Sep 2024 03:57:46 +0300 Subject: [PATCH] fix: return removed sampler --- include/llama.h | 2 +- src/llama-sampling.cpp | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/include/llama.h b/include/llama.h index fdebdf26a..744ef9d90 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1057,7 +1057,7 @@ extern "C" { LLAMA_API int llama_sampler_chain_n (const struct llama_sampler * chain); // after removing a sampler, the chain will no longer own it, and it will not be freed when the chain is freed - LLAMA_API void llama_sampler_chain_remove( struct llama_sampler * chain, int32_t i); + LLAMA_API struct llama_sampler * llama_sampler_chain_remove( struct llama_sampler * chain, int32_t i); // available samplers: diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 58f76ef26..d02828c5a 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -356,15 +356,17 @@ struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chai return p->samplers[i]; } -void llama_sampler_chain_remove(struct llama_sampler * chain, int32_t i) { +struct llama_sampler * llama_sampler_chain_remove(struct llama_sampler * chain, int32_t i) { auto * p = (llama_sampler_chain *) chain->ctx; if (i < 0 || i >= (int32_t) p->samplers.size()) { - return; + return nullptr; } auto * result = p->samplers[i]; p->samplers.erase(p->samplers.begin() + i); + + return result; } int llama_sampler_chain_n(const struct llama_sampler * chain) {