Merge branch 'ggerganov:master' into npu01

This commit is contained in:
Xinpeng Dou 2024-09-11 14:36:05 +08:00 committed by GitHub
commit fb24e846a9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 120 additions and 107 deletions

View File

@ -173,7 +173,6 @@ static bool gpt_params_parse_ex(int argc, char ** argv, gpt_params_context & ctx
std::string arg; std::string arg;
const std::string arg_prefix = "--"; const std::string arg_prefix = "--";
gpt_params & params = ctx_arg.params; gpt_params & params = ctx_arg.params;
gpt_sampler_params & sparams = params.sparams;
std::unordered_map<std::string, llama_arg *> arg_to_options; std::unordered_map<std::string, llama_arg *> arg_to_options;
for (auto & opt : ctx_arg.options) { for (auto & opt : ctx_arg.options) {
@ -283,10 +282,6 @@ static bool gpt_params_parse_ex(int argc, char ** argv, gpt_params_context & ctx
params.kv_overrides.back().key[0] = 0; params.kv_overrides.back().key[0] = 0;
} }
if (sparams.seed == LLAMA_DEFAULT_SEED) {
sparams.seed = time(NULL);
}
return true; return true;
} }
@ -823,7 +818,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
[](gpt_params & params) { [](gpt_params & params) {
params.special = true; params.special = true;
} }
).set_examples({LLAMA_EXAMPLE_MAIN})); ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}));
add_opt(llama_arg( add_opt(llama_arg(
{"-cnv", "--conversation"}, {"-cnv", "--conversation"},
format( format(
@ -909,7 +904,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
).set_sparam()); ).set_sparam());
add_opt(llama_arg( add_opt(llama_arg(
{"-s", "--seed"}, "SEED", {"-s", "--seed"}, "SEED",
format("RNG seed (default: %d, use random seed for < 0)", params.sparams.seed), format("RNG seed (default: %u, use random seed for %u)", params.sparams.seed, LLAMA_DEFAULT_SEED),
[](gpt_params & params, const std::string & value) { [](gpt_params & params, const std::string & value) {
params.sparams.seed = std::stoul(value); params.sparams.seed = std::stoul(value);
} }
@ -1422,20 +1417,18 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
params.split_mode = LLAMA_SPLIT_MODE_NONE; params.split_mode = LLAMA_SPLIT_MODE_NONE;
} else if (arg_next == "layer") { } else if (arg_next == "layer") {
params.split_mode = LLAMA_SPLIT_MODE_LAYER; params.split_mode = LLAMA_SPLIT_MODE_LAYER;
} } else if (arg_next == "row") {
else if (arg_next == "row") {
#ifdef GGML_USE_SYCL #ifdef GGML_USE_SYCL
fprintf(stderr, "warning: The split mode value:[row] is not supported by llama.cpp with SYCL. It's developing.\nExit!\n"); fprintf(stderr, "warning: The split mode value:[row] is not supported by llama.cpp with SYCL. It's developing.\nExit!\n");
exit(1); exit(1);
#endif // GGML_USE_SYCL #endif // GGML_USE_SYCL
params.split_mode = LLAMA_SPLIT_MODE_ROW; params.split_mode = LLAMA_SPLIT_MODE_ROW;
} } else {
else {
throw std::invalid_argument("invalid value"); throw std::invalid_argument("invalid value");
} }
#ifndef GGML_USE_CUDA_SYCL_VULKAN if (!llama_supports_gpu_offload()) {
fprintf(stderr, "warning: llama.cpp was compiled without CUDA/SYCL/Vulkan. Setting the split mode has no effect.\n"); fprintf(stderr, "warning: llama.cpp was compiled without support for GPU offload. Setting the split mode has no effect.\n");
#endif // GGML_USE_CUDA_SYCL_VULKAN }
} }
)); ));
add_opt(llama_arg( add_opt(llama_arg(
@ -1455,14 +1448,14 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
} }
for (size_t i = 0; i < llama_max_devices(); ++i) { for (size_t i = 0; i < llama_max_devices(); ++i) {
if (i < split_arg.size()) { if (i < split_arg.size()) {
params.tensor_split[i] = std::stof(split_arg[i]); params.tensor_split[i] = std::stof(split_arg[i]);
} else { } else {
params.tensor_split[i] = 0.0f; params.tensor_split[i] = 0.0f;
} }
} }
#ifndef GGML_USE_CUDA_SYCL_VULKAN if (!llama_supports_gpu_offload()) {
fprintf(stderr, "warning: llama.cpp was compiled without CUDA/SYCL/Vulkan. Setting a tensor split has no effect.\n"); fprintf(stderr, "warning: llama.cpp was compiled without support for GPU offload. Setting a tensor split has no effect.\n");
#endif // GGML_USE_CUDA_SYCL_VULKAN }
} }
)); ));
add_opt(llama_arg( add_opt(llama_arg(
@ -1470,9 +1463,9 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
format("the GPU to use for the model (with split-mode = none), or for intermediate results and KV (with split-mode = row) (default: %d)", params.main_gpu), format("the GPU to use for the model (with split-mode = none), or for intermediate results and KV (with split-mode = row) (default: %d)", params.main_gpu),
[](gpt_params & params, int value) { [](gpt_params & params, int value) {
params.main_gpu = value; params.main_gpu = value;
#ifndef GGML_USE_CUDA_SYCL_VULKAN if (!llama_supports_gpu_offload()) {
fprintf(stderr, "warning: llama.cpp was compiled without CUDA/SYCL/Vulkan. Setting the main GPU has no effect.\n"); fprintf(stderr, "warning: llama.cpp was compiled without support for GPU offload. Setting the main GPU has no effect.\n");
#endif // GGML_USE_CUDA_SYCL_VULKAN }
} }
)); ));
add_opt(llama_arg( add_opt(llama_arg(

View File

@ -56,14 +56,6 @@
#pragma warning(disable: 4244 4267) // possible loss of data #pragma warning(disable: 4244 4267) // possible loss of data
#endif #endif
#if (defined(GGML_USE_CUDA) || defined(GGML_USE_SYCL))
#define GGML_USE_CUDA_SYCL
#endif
#if (defined(GGML_USE_CUDA) || defined(GGML_USE_SYCL)) || defined(GGML_USE_VULKAN)
#define GGML_USE_CUDA_SYCL_VULKAN
#endif
#if defined(LLAMA_USE_CURL) #if defined(LLAMA_USE_CURL)
#ifdef __linux__ #ifdef __linux__
#include <linux/limits.h> #include <linux/limits.h>

View File

@ -310,6 +310,10 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context
return cur_p.data[cur_p.selected].id; return cur_p.data[cur_p.selected].id;
} }
uint32_t gpt_sampler_get_seed(const struct gpt_sampler * gsmpl) {
return llama_sampler_get_seed(gsmpl->chain);
}
// helpers // helpers
llama_token_data_array * gpt_sampler_get_candidates(struct gpt_sampler * gsmpl) { llama_token_data_array * gpt_sampler_get_candidates(struct gpt_sampler * gsmpl) {

View File

@ -60,6 +60,8 @@ void gpt_perf_print(const struct llama_context * ctx, const struct gpt_sampler *
// //
llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false); llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false);
uint32_t gpt_sampler_get_seed(const struct gpt_sampler * gsmpl);
// helpers // helpers
// access the internal list of current candidate tokens // access the internal list of current candidate tokens

View File

@ -90,8 +90,6 @@ int main(int argc, char ** argv) {
print_build_info(); print_build_info();
LOG_TEE("%s: seed = %u\n", __func__, params.sparams.seed);
llama_backend_init(); llama_backend_init();
llama_numa_init(params.numa); llama_numa_init(params.numa);

View File

@ -159,8 +159,6 @@ int main(int argc, char ** argv) {
print_build_info(); print_build_info();
LOG_TEE("%s: seed = %u\n", __func__, params.sparams.seed);
LOG("%s: llama backend init\n", __func__); LOG("%s: llama backend init\n", __func__);
llama_backend_init(); llama_backend_init();
llama_numa_init(params.numa); llama_numa_init(params.numa);
@ -301,6 +299,9 @@ int main(int argc, char ** argv) {
LOG_TEE("Input suffix: '%s'\n", params.input_suffix.c_str()); LOG_TEE("Input suffix: '%s'\n", params.input_suffix.c_str());
} }
} }
smpl = gpt_sampler_init(model, sparams);
LOG_TEE("sampling seed: %u\n", gpt_sampler_get_seed(smpl));
LOG_TEE("sampling: \n%s\n", sparams.print().c_str()); LOG_TEE("sampling: \n%s\n", sparams.print().c_str());
LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
LOG_TEE("\n\n"); LOG_TEE("\n\n");
@ -340,8 +341,6 @@ int main(int argc, char ** argv) {
std::vector<llama_token> embd; std::vector<llama_token> embd;
smpl = gpt_sampler_init(model, sparams);
while (n_remain != 0 || params.interactive) { while (n_remain != 0 || params.interactive) {
// predict // predict
if (!embd.empty()) { if (!embd.empty()) {

View File

@ -191,8 +191,6 @@ int main(int argc, char ** argv) {
print_build_info(); print_build_info();
LOG_TEE("%s: seed = %u\n", __func__, params.sparams.seed);
LOG("%s: llama backend init\n", __func__); LOG("%s: llama backend init\n", __func__);
llama_backend_init(); llama_backend_init();
llama_numa_init(params.numa); llama_numa_init(params.numa);
@ -470,8 +468,10 @@ int main(int argc, char ** argv) {
exit(1); exit(1);
} }
LOG_TEE("sampling seed: %u\n", gpt_sampler_get_seed(smpl));
LOG_TEE("sampling params: \n%s\n", sparams.print().c_str()); LOG_TEE("sampling params: \n%s\n", sparams.print().c_str());
LOG_TEE(" sampler constr: \n%s\n", gpt_sampler_print(smpl).c_str()); LOG_TEE("sampler constr: \n%s\n", gpt_sampler_print(smpl).c_str());
LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
// group-attention state // group-attention state

View File

@ -2007,8 +2007,6 @@ int main(int argc, char ** argv) {
print_build_info(); print_build_info();
LOG_TEE("%s: seed = %u\n", __func__, params.sparams.seed);
llama_backend_init(); llama_backend_init();
llama_numa_init(params.numa); llama_numa_init(params.numa);

View File

@ -1266,6 +1266,7 @@ struct server_context {
{"n_predict", slot.n_predict}, // Server configured n_predict {"n_predict", slot.n_predict}, // Server configured n_predict
{"model", params.model_alias}, {"model", params.model_alias},
{"seed", slot.sparams.seed}, {"seed", slot.sparams.seed},
{"seed_cur", slot.smpl ? gpt_sampler_get_seed(slot.smpl) : 0},
{"temperature", slot.sparams.temp}, {"temperature", slot.sparams.temp},
{"dynatemp_range", slot.sparams.dynatemp_range}, {"dynatemp_range", slot.sparams.dynatemp_range},
{"dynatemp_exponent", slot.sparams.dynatemp_exponent}, {"dynatemp_exponent", slot.sparams.dynatemp_exponent},

View File

@ -5,11 +5,11 @@
"nixpkgs-lib": "nixpkgs-lib" "nixpkgs-lib": "nixpkgs-lib"
}, },
"locked": { "locked": {
"lastModified": 1725024810, "lastModified": 1725234343,
"narHash": "sha256-ODYRm8zHfLTH3soTFWE452ydPYz2iTvr9T8ftDMUQ3E=", "narHash": "sha256-+ebgonl3NbiKD2UD0x4BszCZQ6sTfL4xioaM49o5B3Y=",
"owner": "hercules-ci", "owner": "hercules-ci",
"repo": "flake-parts", "repo": "flake-parts",
"rev": "af510d4a62d071ea13925ce41c95e3dec816c01d", "rev": "567b938d64d4b4112ee253b9274472dc3a346eb6",
"type": "github" "type": "github"
}, },
"original": { "original": {
@ -20,11 +20,11 @@
}, },
"nixpkgs": { "nixpkgs": {
"locked": { "locked": {
"lastModified": 1724819573, "lastModified": 1725634671,
"narHash": "sha256-GnR7/ibgIH1vhoy8cYdmXE6iyZqKqFxQSVkFgosBh6w=", "narHash": "sha256-v3rIhsJBOMLR8e/RNWxr828tB+WywYIoajrZKFM+0Gg=",
"owner": "NixOS", "owner": "NixOS",
"repo": "nixpkgs", "repo": "nixpkgs",
"rev": "71e91c409d1e654808b2621f28a327acfdad8dc2", "rev": "574d1eac1c200690e27b8eb4e24887f8df7ac27c",
"type": "github" "type": "github"
}, },
"original": { "original": {
@ -36,14 +36,14 @@
}, },
"nixpkgs-lib": { "nixpkgs-lib": {
"locked": { "locked": {
"lastModified": 1722555339, "lastModified": 1725233747,
"narHash": "sha256-uFf2QeW7eAHlYXuDktm9c25OxOyCoUOQmh5SZ9amE5Q=", "narHash": "sha256-Ss8QWLXdr2JCBPcYChJhz4xJm+h/xjl4G0c0XlP6a74=",
"type": "tarball", "type": "tarball",
"url": "https://github.com/NixOS/nixpkgs/archive/a5d394176e64ab29c852d03346c1fc9b0b7d33eb.tar.gz" "url": "https://github.com/NixOS/nixpkgs/archive/356624c12086a18f2ea2825fed34523d60ccc4e3.tar.gz"
}, },
"original": { "original": {
"type": "tarball", "type": "tarball",
"url": "https://github.com/NixOS/nixpkgs/archive/a5d394176e64ab29c852d03346c1fc9b0b7d33eb.tar.gz" "url": "https://github.com/NixOS/nixpkgs/archive/356624c12086a18f2ea2825fed34523d60ccc4e3.tar.gz"
} }
}, },
"root": { "root": {

View File

@ -130,42 +130,3 @@
#define cudaKernelNodeParams musaKernelNodeParams #define cudaKernelNodeParams musaKernelNodeParams
#define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed #define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
#define cudaStreamEndCapture musaStreamEndCapture #define cudaStreamEndCapture musaStreamEndCapture
// XXX: Clang builtins mapping
#define __vsub4 __vsub4_musa
#define __vcmpeq4 __vcmpeq4_musa
#define __vcmpne4 __vcmpne4_musa
#ifndef __has_builtin
#define __has_builtin(x) 0
#endif
typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));
static __device__ __forceinline__ int __vsub4_musa(const int a, const int b) {
return __vsubss4(a, b);
}
static __device__ __forceinline__ unsigned int __vcmpeq4_musa(unsigned int a, unsigned int b) {
const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
unsigned int c;
uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
#pragma unroll
for (int i = 0; i < 4; ++i) {
vc[i] = va[i] == vb[i] ? 0xff : 0x00;
}
return c;
}
static __device__ __forceinline__ unsigned int __vcmpne4_musa(unsigned int a, unsigned int b) {
const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
unsigned int c;
uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
#pragma unroll
for (int i = 0; i < 4; ++i) {
vc[i] = va[i] == vb[i] ? 0x00 : 0xff;
}
return c;
}

View File

@ -5137,13 +5137,17 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons
case GGML_OP_SCALE: case GGML_OP_SCALE:
case GGML_OP_SQR: case GGML_OP_SQR:
case GGML_OP_CLAMP: case GGML_OP_CLAMP:
return true;
case GGML_OP_CONT: case GGML_OP_CONT:
return op->src[0]->type != GGML_TYPE_BF16;
case GGML_OP_DIAG_MASK_INF: case GGML_OP_DIAG_MASK_INF:
case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX:
return true; return true;
case GGML_OP_ROPE: case GGML_OP_ROPE:
return ggml_is_contiguous(op->src[0]); return ggml_is_contiguous(op->src[0]);
case GGML_OP_IM2COL: case GGML_OP_IM2COL:
// TODO: add support for the new F32 operations
return op->src[0]->type == GGML_TYPE_F16;
case GGML_OP_POOL_2D: case GGML_OP_POOL_2D:
case GGML_OP_SUM_ROWS: case GGML_OP_SUM_ROWS:
case GGML_OP_ARGSORT: case GGML_OP_ARGSORT:

View File

@ -1127,6 +1127,10 @@ extern "C" {
int32_t n_logit_bias, int32_t n_logit_bias,
const llama_logit_bias * logit_bias); const llama_logit_bias * logit_bias);
// Returns the seed used by the sampler if applicable, LLAMA_DEFAULT_SEED otherwise
LLAMA_API uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl);
/// @details Sample and accept a token from the idx-th output of the last evaluation /// @details Sample and accept a token from the idx-th output of the last evaluation
// //
// Shorthand for: // Shorthand for:

View File

@ -8,6 +8,7 @@
#include <cstring> #include <cstring>
#include <ctime> #include <ctime>
#include <cfloat> #include <cfloat>
#include <chrono>
#include <cmath> #include <cmath>
#include <numeric> #include <numeric>
#include <random> #include <random>
@ -162,6 +163,19 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k)
cur_p->size = k; cur_p->size = k;
} }
static uint32_t get_rng_seed(uint32_t seed) {
if (seed == LLAMA_DEFAULT_SEED) {
// use system clock if std::random_device is not a true RNG
static bool is_rd_prng = std::random_device().entropy() == 0;
if (is_rd_prng) {
return (uint32_t) std::chrono::system_clock::now().time_since_epoch().count();
}
std::random_device rd;
return rd();
}
return seed;
}
// llama_sampler API // llama_sampler API
const char * llama_sampler_name(const struct llama_sampler * smpl) { const char * llama_sampler_name(const struct llama_sampler * smpl) {
@ -387,6 +401,7 @@ struct llama_sampler * llama_sampler_init_greedy() {
struct llama_sampler_dist { struct llama_sampler_dist {
const uint32_t seed; const uint32_t seed;
uint32_t seed_cur;
std::mt19937 rng; std::mt19937 rng;
}; };
@ -416,7 +431,8 @@ static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sample
static void llama_sampler_dist_reset(struct llama_sampler * smpl) { static void llama_sampler_dist_reset(struct llama_sampler * smpl) {
auto * ctx = (llama_sampler_dist *) smpl->ctx; auto * ctx = (llama_sampler_dist *) smpl->ctx;
ctx->rng = std::mt19937(ctx->seed); ctx->seed_cur = get_rng_seed(ctx->seed);
ctx->rng.seed(ctx->seed_cur);
} }
static void llama_sampler_dist_free(struct llama_sampler * smpl) { static void llama_sampler_dist_free(struct llama_sampler * smpl) {
@ -433,11 +449,13 @@ static struct llama_sampler_i llama_sampler_dist_i = {
}; };
struct llama_sampler * llama_sampler_init_dist(uint32_t seed) { struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
auto seed_cur = get_rng_seed(seed);
return new llama_sampler { return new llama_sampler {
/* .iface = */ &llama_sampler_dist_i, /* .iface = */ &llama_sampler_dist_i,
/* .ctx = */ new llama_sampler_dist { /* .ctx = */ new llama_sampler_dist {
/* .seed = */ seed, /* .seed = */ seed,
/* .rng = */ std::mt19937(seed), /* .seed_cur = */ seed_cur,
/* .rng = */ std::mt19937(seed_cur),
}, },
}; };
} }
@ -1032,6 +1050,7 @@ struct llama_sampler_mirostat {
const int32_t n_vocab; const int32_t n_vocab;
const uint32_t seed; const uint32_t seed;
uint32_t seed_cur;
const float tau; const float tau;
const float eta; const float eta;
@ -1100,7 +1119,8 @@ static struct llama_sampler * llama_sampler_mirostat_clone(const struct llama_sa
static void llama_sampler_mirostat_reset(struct llama_sampler * smpl) { static void llama_sampler_mirostat_reset(struct llama_sampler * smpl) {
auto * ctx = (llama_sampler_mirostat *) smpl->ctx; auto * ctx = (llama_sampler_mirostat *) smpl->ctx;
ctx->mu = 2.0f*ctx->tau; ctx->mu = 2.0f*ctx->tau;
ctx->rng = std::mt19937(ctx->seed); ctx->seed_cur = get_rng_seed(ctx->seed);
ctx->rng.seed(ctx->seed_cur);
} }
static void llama_sampler_mirostat_free(struct llama_sampler * smpl) { static void llama_sampler_mirostat_free(struct llama_sampler * smpl) {
@ -1117,16 +1137,18 @@ static struct llama_sampler_i llama_sampler_mirostat_i = {
}; };
struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) { struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) {
auto seed_cur = get_rng_seed(seed);
return new llama_sampler { return new llama_sampler {
/* .iface = */ &llama_sampler_mirostat_i, /* .iface = */ &llama_sampler_mirostat_i,
/* .ctx = */ new llama_sampler_mirostat { /* .ctx = */ new llama_sampler_mirostat {
/* .n_vocab = */ n_vocab, /* .n_vocab = */ n_vocab,
/* .seed = */ seed, /* .seed = */ seed,
/* .tau = */ tau, /* .seed_cur = */ seed_cur,
/* .eta = */ eta, /* .tau = */ tau,
/* .m = */ m, /* .eta = */ eta,
/* .mu = */ 2.0f*tau, /* .m = */ m,
/* .rng = */ std::mt19937(seed), /* .mu = */ 2.0f*tau,
/* .rng = */ std::mt19937(seed_cur),
}, },
}; };
} }
@ -1135,6 +1157,7 @@ struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t see
struct llama_sampler_mirostat_v2 { struct llama_sampler_mirostat_v2 {
const uint32_t seed; const uint32_t seed;
uint32_t seed_cur;
const float tau; const float tau;
const float eta; const float eta;
@ -1179,7 +1202,8 @@ static void llama_sampler_mirostat_v2_apply(struct llama_sampler * smpl, llama_t
static void llama_sampler_mirostat_v2_reset(struct llama_sampler * smpl) { static void llama_sampler_mirostat_v2_reset(struct llama_sampler * smpl) {
auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx; auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx;
ctx->mu = 2.0f*ctx->tau; ctx->mu = 2.0f*ctx->tau;
ctx->rng = std::mt19937(ctx->seed); ctx->seed_cur = get_rng_seed(ctx->seed);
ctx->rng.seed(ctx->seed_cur);
} }
static struct llama_sampler * llama_sampler_mirostat_v2_clone(const struct llama_sampler * smpl) { static struct llama_sampler * llama_sampler_mirostat_v2_clone(const struct llama_sampler * smpl) {
@ -1212,14 +1236,16 @@ static struct llama_sampler_i llama_sampler_mirostat_v2_i = {
}; };
struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) { struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) {
auto seed_cur = get_rng_seed(seed);
return new llama_sampler { return new llama_sampler {
/* .iface = */ &llama_sampler_mirostat_v2_i, /* .iface = */ &llama_sampler_mirostat_v2_i,
/* .ctx = */ new llama_sampler_mirostat_v2 { /* .ctx = */ new llama_sampler_mirostat_v2 {
/* .seed = */ seed, /* .seed = */ seed,
/* .tau = */ tau, /* .seed_cur = */ seed_cur,
/* .eta = */ eta, /* .tau = */ tau,
/* .mu = */ 2.0f*tau, /* .eta = */ eta,
/* .rng = */ std::mt19937(seed), /* .mu = */ 2.0f*tau,
/* .rng = */ std::mt19937(seed_cur),
}, },
}; };
} }
@ -1505,6 +1531,8 @@ struct llama_sampler * llama_sampler_init_penalties(
ignore_eos = false; ignore_eos = false;
} }
penalty_last_n = std::max(penalty_last_n, 0);
return new llama_sampler { return new llama_sampler {
/* .iface = */ &llama_sampler_penalties_i, /* .iface = */ &llama_sampler_penalties_i,
/* .ctx = */ new llama_sampler_penalties { /* .ctx = */ new llama_sampler_penalties {
@ -1568,6 +1596,7 @@ static void llama_sampler_logit_bias_apply(struct llama_sampler * smpl, llama_to
} }
} }
} }
static struct llama_sampler * llama_sampler_logit_bias_clone(const struct llama_sampler * smpl) { static struct llama_sampler * llama_sampler_logit_bias_clone(const struct llama_sampler * smpl) {
const auto * ctx = (const llama_sampler_logit_bias *) smpl->ctx; const auto * ctx = (const llama_sampler_logit_bias *) smpl->ctx;
return llama_sampler_init_logit_bias(ctx->n_vocab, ctx->logit_bias.size(), ctx->logit_bias.data()); return llama_sampler_init_logit_bias(ctx->n_vocab, ctx->logit_bias.size(), ctx->logit_bias.data());
@ -1599,3 +1628,31 @@ struct llama_sampler * llama_sampler_init_logit_bias(
}, },
}; };
} }
// utils
uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) {
if (smpl->iface == &llama_sampler_dist_i) {
return ((const llama_sampler_dist *) smpl->ctx)->seed_cur;
}
if (smpl->iface == &llama_sampler_mirostat_i) {
return ((const llama_sampler_mirostat *) smpl->ctx)->seed_cur;
}
if (smpl->iface == &llama_sampler_mirostat_v2_i) {
return ((const llama_sampler_mirostat_v2 *) smpl->ctx)->seed_cur;
}
if (smpl->iface == &llama_sampler_chain_i) {
const auto * ctx = (const llama_sampler_chain *) smpl->ctx;
for (auto it = ctx->samplers.rbegin(); it != ctx->samplers.rend(); ++it) {
const uint32_t seed = llama_sampler_get_seed(*it);
if (seed != LLAMA_DEFAULT_SEED) {
return seed;
}
}
}
return LLAMA_DEFAULT_SEED;
}