From 82da9efc02ac973ba1ccd55dee03189bb89b3a6f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 21 Oct 2024 09:00:57 +0300 Subject: [PATCH] ggml : add asserts for type conversion in fattn kernels ggml-ci --- common/common.cpp | 4 ++-- ggml/src/ggml.c | 6 +++++- src/llama.cpp | 2 +- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 2bc0b8800..a8eebb68b 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1035,7 +1035,7 @@ static ggml_type kv_cache_type_from_str(const std::string & s) { return GGML_TYPE_Q5_1; } - throw std::runtime_error("Invalid cache type: " + s); + throw std::runtime_error("Unsupported cache type: " + s); } struct llama_context_params common_context_params_to_llama(const common_params & params) { @@ -1047,7 +1047,7 @@ struct llama_context_params common_context_params_to_llama(const common_params & cparams.n_ubatch = params.n_ubatch; cparams.n_threads = params.cpuparams.n_threads; cparams.n_threads_batch = params.cpuparams_batch.n_threads == -1 ? - params.cpuparams.n_threads : params.cpuparams_batch.n_threads; + params.cpuparams.n_threads : params.cpuparams_batch.n_threads; cparams.logits_all = params.logits_all; cparams.embeddings = params.embedding; cparams.rope_scaling_type = params.rope_scaling_type; diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 7e24313ed..b16c462fa 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -324,8 +324,9 @@ struct ggml_logger_state { static struct ggml_logger_state g_logger_state = {ggml_log_callback_default, NULL}; static void ggml_log_internal_v(enum ggml_log_level level, const char * format, va_list args) { - if (format == NULL) + if (format == NULL) { return; + } va_list args_copy; va_copy(args_copy, args); char buffer[128]; @@ -15723,6 +15724,9 @@ static void ggml_compute_forward_flash_attn_ext_f16( ggml_vec_dot_t const kq_vec_dot = type_traits[k->type].vec_dot; ggml_to_float_t const v_to_float = type_traits[v->type].to_float; + GGML_ASSERT(q_to_vec_dot && "fattn: unsupported K-type"); + GGML_ASSERT(v_to_float && "fattn: unsupported V-type"); + // loop over n_batch and n_head for (int ir = ir0; ir < ir1; ++ir) { // q indices diff --git a/src/llama.cpp b/src/llama.cpp index 1813dd29b..98ec123c1 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -19243,7 +19243,7 @@ struct llama_context * llama_new_context_with_model( params.flash_attn = false; } - if (params.type_v != GGML_TYPE_F16 && !params.flash_attn) { + if (ggml_is_quantized(params.type_v) && !params.flash_attn) { LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__); return nullptr; }