mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-11 19:21:46 +00:00
examples : make n_ctx warning work again (#3066)
This was broken by commit e36ecdcc
("build : on Mac OS enable Metal by
default (#2901)").
This commit is contained in:
parent
94f10b91ed
commit
e64f5b5578
@ -17,11 +17,6 @@ int main(int argc, char ** argv) {
|
||||
|
||||
params.embedding = true;
|
||||
|
||||
if (params.n_ctx > 2048) {
|
||||
fprintf(stderr, "%s: warning: model might not support context sizes greater than 2048 tokens (%d specified);"
|
||||
"expect poor results\n", __func__, params.n_ctx);
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s: build = %d (%s)\n", __func__, BUILD_NUMBER, BUILD_COMMIT);
|
||||
|
||||
if (params.seed == LLAMA_DEFAULT_SEED) {
|
||||
@ -47,6 +42,12 @@ int main(int argc, char ** argv) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
const int n_ctx_train = llama_n_ctx_train(ctx);
|
||||
if (params.n_ctx > n_ctx_train) {
|
||||
fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n",
|
||||
__func__, n_ctx_train, params.n_ctx);
|
||||
}
|
||||
|
||||
// print system information
|
||||
{
|
||||
fprintf(stderr, "\n");
|
||||
|
@ -182,8 +182,10 @@ int main(int argc, char ** argv) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (params.n_ctx > llama_n_ctx(ctx)) {
|
||||
LOG_TEE("%s: warning: base model only supports context sizes no greater than %d tokens (%d specified)\n", __func__, llama_n_ctx(ctx), params.n_ctx);
|
||||
const int n_ctx_train = llama_n_ctx_train(ctx);
|
||||
if (params.n_ctx > n_ctx_train) {
|
||||
LOG_TEE("%s: warning: model was trained on only %d context tokens (%d specified)\n",
|
||||
__func__, n_ctx_train, params.n_ctx);
|
||||
} else if (params.n_ctx < 8) {
|
||||
LOG_TEE("%s: warning: minimum context size is 8, using minimum size.\n", __func__);
|
||||
params.n_ctx = 8;
|
||||
|
@ -693,9 +693,10 @@ int main(int argc, char ** argv) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (params.n_ctx > llama_n_ctx(ctx)) {
|
||||
fprintf(stderr, "%s: warning: model might not support context sizes greater than %d tokens (%d specified);"
|
||||
"expect poor results\n", __func__, llama_n_ctx(ctx), params.n_ctx);
|
||||
const int n_ctx_train = llama_n_ctx_train(ctx);
|
||||
if (params.n_ctx > n_ctx_train) {
|
||||
fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n",
|
||||
__func__, n_ctx_train, params.n_ctx);
|
||||
}
|
||||
|
||||
// print system information
|
||||
|
14
llama.cpp
14
llama.cpp
@ -5633,15 +5633,19 @@ void llama_free(struct llama_context * ctx) {
|
||||
}
|
||||
|
||||
int llama_n_vocab(const struct llama_context * ctx) {
|
||||
return ctx->model.vocab.id_to_token.size();
|
||||
return llama_model_n_vocab(&ctx->model);
|
||||
}
|
||||
|
||||
int llama_n_ctx(const struct llama_context * ctx) {
|
||||
return ctx->model.hparams.n_ctx;
|
||||
return llama_model_n_ctx(&ctx->model);
|
||||
}
|
||||
|
||||
int llama_n_ctx_train(const struct llama_context * ctx) {
|
||||
return llama_model_n_ctx_train(&ctx->model);
|
||||
}
|
||||
|
||||
int llama_n_embd(const struct llama_context * ctx) {
|
||||
return ctx->model.hparams.n_embd;
|
||||
return llama_model_n_embd(&ctx->model);
|
||||
}
|
||||
|
||||
enum llama_vocab_type llama_vocab_type(const struct llama_context * ctx) {
|
||||
@ -5656,6 +5660,10 @@ int llama_model_n_ctx(const struct llama_model * model) {
|
||||
return model->hparams.n_ctx;
|
||||
}
|
||||
|
||||
int llama_model_n_ctx_train(const struct llama_model * model) {
|
||||
return model->hparams.n_ctx_train;
|
||||
}
|
||||
|
||||
int llama_model_n_embd(const struct llama_model * model) {
|
||||
return model->hparams.n_embd;
|
||||
}
|
||||
|
2
llama.h
2
llama.h
@ -247,12 +247,14 @@ extern "C" {
|
||||
|
||||
LLAMA_API int llama_n_vocab (const struct llama_context * ctx);
|
||||
LLAMA_API int llama_n_ctx (const struct llama_context * ctx);
|
||||
LLAMA_API int llama_n_ctx_train(const struct llama_context * ctx);
|
||||
LLAMA_API int llama_n_embd (const struct llama_context * ctx);
|
||||
|
||||
LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_context * ctx);
|
||||
|
||||
LLAMA_API int llama_model_n_vocab (const struct llama_model * model);
|
||||
LLAMA_API int llama_model_n_ctx (const struct llama_model * model);
|
||||
LLAMA_API int llama_model_n_ctx_train(const struct llama_model * model);
|
||||
LLAMA_API int llama_model_n_embd (const struct llama_model * model);
|
||||
|
||||
// Get a string describing the model type
|
||||
|
Loading…
Reference in New Issue
Block a user