Allow setting the rng seed after initialization. (#1184)

The llama_set_state_data function restores the rng state to what it
was at the time llama_copy_state_data was called. But users may want
to restore the state and proceed with a different seed.
This commit is contained in:
Ásgeir Bjarni Ingvarsson 2023-04-26 20:08:43 +00:00 committed by GitHub
parent ea3ad7eb60
commit 87a6f846d3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 10 additions and 0 deletions

View File

@ -2082,6 +2082,13 @@ int llama_get_kv_cache_token_count(struct llama_context * ctx) {
#define LLAMA_MAX_RNG_STATE 64*1024 #define LLAMA_MAX_RNG_STATE 64*1024
void llama_set_rng_seed(struct llama_context * ctx, int seed) {
if (seed <= 0) {
seed = time(NULL);
}
ctx->rng.seed(seed);
}
// Returns the size of the state // Returns the size of the state
size_t llama_get_state_size(struct llama_context * ctx) { size_t llama_get_state_size(struct llama_context * ctx) {
// we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state. // we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state.

View File

@ -116,6 +116,9 @@ extern "C" {
// Returns the number of tokens in the KV cache // Returns the number of tokens in the KV cache
LLAMA_API int llama_get_kv_cache_token_count(struct llama_context * ctx); LLAMA_API int llama_get_kv_cache_token_count(struct llama_context * ctx);
// Sets the current rng seed.
LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, int seed);
// Returns the size in bytes of the state (rng, logits, embedding and kv_cache) // Returns the size in bytes of the state (rng, logits, embedding and kv_cache)
LLAMA_API size_t llama_get_state_size(struct llama_context * ctx); LLAMA_API size_t llama_get_state_size(struct llama_context * ctx);