mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-27 03:44:35 +00:00
ggml : move rope type enum to ggml.h
This commit moves the `llama_rope_type` enum from `llama.h` to `ggml.h` and changes its name to `ggml_rope_type`. The motivation for this change is to address the TODO in `llama.h` and use the enum in ggml. Note: This commit does not change the `mode` parameter to be of type `enum ggml_rope_type`. The name `mode` and its usage suggest that it might be more generic and possibly used as a bit field for multiple flags. Further investigation/discussion may be needed to determine if `mode` should be restricted to RoPE types.
This commit is contained in:
parent
3071c0a5f2
commit
14b549c708
@ -437,6 +437,14 @@ extern "C" {
|
||||
GGML_FTYPE_MOSTLY_Q4_0_8_8 = 27, // except 1d tensors
|
||||
};
|
||||
|
||||
// Rotary Positional Embedding (RoPE) types
|
||||
enum ggml_rope_type {
|
||||
GGML_ROPE_TYPE_NONE = -1,
|
||||
GGML_ROPE_TYPE_NORM = 0,
|
||||
GGML_ROPE_TYPE_NEOX = 2,
|
||||
GGML_ROPE_TYPE_GLM = 4,
|
||||
};
|
||||
|
||||
// available tensor operations:
|
||||
enum ggml_op {
|
||||
GGML_OP_NONE = 0,
|
||||
|
@ -6545,7 +6545,7 @@ struct ggml_tensor * ggml_rope_back(
|
||||
GGML_ASSERT(a->ne[2] == b->ne[0]);
|
||||
GGML_ASSERT(c == NULL && "freq factors not implemented yet");
|
||||
|
||||
GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet");
|
||||
GGML_ASSERT((mode & GGML_ROPE_TYPE_GLM) == 0 && "ggml_rope_back() for ChatGLM not implemented yet");
|
||||
|
||||
bool is_node = false;
|
||||
|
||||
@ -14093,7 +14093,7 @@ static void ggml_compute_forward_rope_f32(
|
||||
float corr_dims[2];
|
||||
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
||||
|
||||
const bool is_neox = mode & 2;
|
||||
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
||||
|
||||
const float * freq_factors = NULL;
|
||||
if (src2 != NULL) {
|
||||
@ -14218,7 +14218,7 @@ static void ggml_compute_forward_rope_f16(
|
||||
float corr_dims[2];
|
||||
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
||||
|
||||
const bool is_neox = mode & 2;
|
||||
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
||||
|
||||
const float * freq_factors = NULL;
|
||||
if (src2 != NULL) {
|
||||
|
@ -95,15 +95,6 @@ extern "C" {
|
||||
LLAMA_VOCAB_PRE_TYPE_CODESHELL = 22,
|
||||
};
|
||||
|
||||
// note: these values should be synchronized with ggml_rope
|
||||
// TODO: maybe move this enum to ggml.h (ggml_rope_type)
|
||||
enum llama_rope_type {
|
||||
LLAMA_ROPE_TYPE_NONE = -1,
|
||||
LLAMA_ROPE_TYPE_NORM = 0,
|
||||
LLAMA_ROPE_TYPE_NEOX = 2,
|
||||
LLAMA_ROPE_TYPE_GLM = 4,
|
||||
};
|
||||
|
||||
enum llama_token_type { //TODO: remove, required until per token attributes are available from GGUF file
|
||||
LLAMA_TOKEN_TYPE_UNDEFINED = 0,
|
||||
LLAMA_TOKEN_TYPE_NORMAL = 1,
|
||||
@ -462,7 +453,7 @@ extern "C" {
|
||||
LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx);
|
||||
|
||||
LLAMA_API enum llama_vocab_type llama_vocab_type (const struct llama_model * model);
|
||||
LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model);
|
||||
LLAMA_API enum ggml_rope_type ggml_rope_type (const struct llama_model * model);
|
||||
|
||||
LLAMA_API int32_t llama_n_vocab (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model);
|
||||
|
@ -2201,7 +2201,7 @@ struct llama_hparams {
|
||||
llama_token dec_start_token_id = -1;
|
||||
|
||||
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE;
|
||||
enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
|
||||
enum ggml_rope_type rope_type = GGML_ROPE_TYPE_NONE;
|
||||
enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE;
|
||||
|
||||
bool operator!=(const llama_hparams & other) const {
|
||||
@ -5219,7 +5219,7 @@ static void llm_load_hparams(
|
||||
hparams.use_alibi = true;
|
||||
}
|
||||
|
||||
hparams.rope_type = llama_rope_type(&model);
|
||||
hparams.rope_type = ggml_rope_type(&model);
|
||||
}
|
||||
|
||||
static void llm_load_vocab(
|
||||
@ -8331,7 +8331,7 @@ struct llm_build_context {
|
||||
const bool flash_attn;
|
||||
|
||||
const enum llama_pooling_type pooling_type;
|
||||
const enum llama_rope_type rope_type;
|
||||
const enum ggml_rope_type rope_type;
|
||||
|
||||
const llm_build_cb & cb;
|
||||
|
||||
@ -15105,7 +15105,7 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
|
||||
bool need_reserve = false;
|
||||
|
||||
// apply K-shift if needed
|
||||
if (lctx.model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE && lctx.kv_self.has_shift) {
|
||||
if (lctx.model.hparams.rope_type != GGML_ROPE_TYPE_NONE && lctx.kv_self.has_shift) {
|
||||
if (lctx.model.arch == LLM_ARCH_DEEPSEEK2) { // not supported due to MLA
|
||||
GGML_ABORT("Deepseek2 does not support K-shift");
|
||||
}
|
||||
@ -16881,7 +16881,7 @@ enum llama_vocab_type llama_vocab_type(const struct llama_model * model) {
|
||||
return model->vocab.type;
|
||||
}
|
||||
|
||||
enum llama_rope_type llama_rope_type(const struct llama_model * model) {
|
||||
enum ggml_rope_type ggml_rope_type(const struct llama_model * model) {
|
||||
switch (model->arch) {
|
||||
// these models do not use RoPE
|
||||
case LLM_ARCH_GPT2:
|
||||
@ -16893,7 +16893,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
|
||||
case LLM_ARCH_JINA_BERT_V2:
|
||||
case LLM_ARCH_T5:
|
||||
case LLM_ARCH_JAIS:
|
||||
return LLAMA_ROPE_TYPE_NONE;
|
||||
return GGML_ROPE_TYPE_NONE;
|
||||
|
||||
// use what we call a normal RoPE, operating on pairs of consecutive head values
|
||||
case LLM_ARCH_LLAMA:
|
||||
@ -16909,7 +16909,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
|
||||
case LLM_ARCH_ARCTIC:
|
||||
case LLM_ARCH_DEEPSEEK2:
|
||||
case LLM_ARCH_CHATGLM:
|
||||
return LLAMA_ROPE_TYPE_NORM;
|
||||
return GGML_ROPE_TYPE_NORM;
|
||||
|
||||
// the pairs of head values are offset by n_rot/2
|
||||
case LLM_ARCH_FALCON:
|
||||
@ -16930,14 +16930,14 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
|
||||
case LLM_ARCH_OPENELM:
|
||||
case LLM_ARCH_GPTNEOX:
|
||||
case LLM_ARCH_CODESHELL:
|
||||
return LLAMA_ROPE_TYPE_NEOX;
|
||||
return GGML_ROPE_TYPE_NEOX;
|
||||
|
||||
// all model arches should be listed explicitly here
|
||||
case LLM_ARCH_UNKNOWN:
|
||||
GGML_ABORT("unknown architecture");
|
||||
}
|
||||
|
||||
return LLAMA_ROPE_TYPE_NONE;
|
||||
return GGML_ROPE_TYPE_NONE;
|
||||
}
|
||||
|
||||
enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx) {
|
||||
|
Loading…
Reference in New Issue
Block a user