diff --git a/ci/run.sh b/ci/run.sh index a1cd9908f..bf21b6b31 100755 --- a/ci/run.sh +++ b/ci/run.sh @@ -336,7 +336,8 @@ function gg_run_open_llama_3b_v2 { (time ./bin/imatrix --model ${model_f16} -f ${wiki_test_60} -c 128 -b 128 --chunks 1 ) 2>&1 | tee -a $OUT/${ci}-imatrix.log - (time ./bin/save-load-state --model ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log + (time ./bin/save-load-state --model ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log + (time ./bin/save-load-state -fa --model ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log function check_ppl { qnt="$1" @@ -517,7 +518,10 @@ function gg_run_open_llama_7b_v2 { (time ./bin/imatrix --model ${model_f16} -f ${wiki_test} -t 1 -ngl 999 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-imatrix.log - (time ./bin/save-load-state --model ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log + (time ./bin/save-load-state -ngl 10 --model ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log + (time ./bin/save-load-state -fa -ngl 10 --model ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log + (time ./bin/save-load-state -ngl 99 --model ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log + (time ./bin/save-load-state -fa -ngl 99 --model ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log function check_ppl { qnt="$1" diff --git a/common/common.cpp b/common/common.cpp index 099d0356f..243b88abf 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -947,6 +947,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.cont_batching = true; return true; } + if (arg == "-fa" || arg == "--flash-attn") { + params.flash_attn = true; + return true; + } if (arg == "--color") { params.use_color = true; return true; @@ -1513,6 +1517,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" -ns N, --sequences N number of sequences to decode (default: %d)\n", params.n_sequences); printf(" -ps N, --p-split N speculative decoding split probability (default: %.1f)\n", (double)params.p_split); printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n"); + printf(" -fa, --flash-attn enable Flash Attention (default: %s)\n", params.flash_attn ? "enabled" : "disabled"); printf(" --mmproj MMPROJ_FILE path to a multimodal projector file for LLaVA. see examples/llava/README.md\n"); printf(" --image IMAGE_FILE path to an image file. use with multimodal models. Specify multiple times for batching\n"); if (llama_supports_mlock()) { @@ -1885,6 +1890,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param cparams.cb_eval = params.cb_eval; cparams.cb_eval_user_data = params.cb_eval_user_data; cparams.offload_kqv = !params.no_kv_offload; + cparams.flash_attn = params.flash_attn; cparams.type_k = kv_cache_type_from_str(params.cache_type_k); cparams.type_v = kv_cache_type_from_str(params.cache_type_v); @@ -2707,6 +2713,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l fprintf(stream, "seed: %u # default: -1 (random seed)\n", params.seed); fprintf(stream, "simple_io: %s # default: false\n", params.simple_io ? "true" : "false"); fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false"); + fprintf(stream, "flash_attn: %s # default: false\n", params.flash_attn ? "true" : "false"); fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp); const std::vector tensor_split_vector(params.tensor_split, params.tensor_split + llama_max_devices()); diff --git a/common/common.h b/common/common.h index 8afdf2bdf..0618eea74 100644 --- a/common/common.h +++ b/common/common.h @@ -150,6 +150,7 @@ struct gpt_params { bool multiline_input = false; // reverse the usage of `\` bool simple_io = false; // improves compatibility with subprocesses and limited consoles bool cont_batching = true; // insert new sequences for decoding on-the-fly + bool flash_attn = false; // flash attention bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix bool ignore_eos = false; // ignore generated EOS tokens diff --git a/examples/batched-bench/batched-bench.cpp b/examples/batched-bench/batched-bench.cpp index 1e34de620..2924d8116 100644 --- a/examples/batched-bench/batched-bench.cpp +++ b/examples/batched-bench/batched-bench.cpp @@ -32,7 +32,7 @@ int main(int argc, char ** argv) { gpt_params params; if (argc == 1 || argv[1][0] == '-') { - printf("usage: %s MODEL_PATH [N_KV_MAX] [N_BATCH] [N_UBATCH] [IS_PP_SHARED] [NGL] \n" , argv[0]); + printf("usage: %s MODEL_PATH [N_KV_MAX] [N_BATCH] [N_UBATCH] [FATTN] [IS_PP_SHARED] [NGL] \n" , argv[0]); printf(" , and PL are comma-separated lists of numbers without spaces\n\n"); printf(" example: %s ggml-model-f16.gguf 2048 2048 512 0 999 128,256,512 128,256 1,2,4,8,16,32\n\n", argv[0]); return 1 ; @@ -41,6 +41,7 @@ int main(int argc, char ** argv) { int n_kv_max = 2048; int n_batch = 2048; int n_ubatch = 512; + bool flash_attn = false; int is_pp_shared = 0; int n_gpu_layers = 0; @@ -66,23 +67,27 @@ int main(int argc, char ** argv) { } if (argc >= 6) { - is_pp_shared = std::atoi(argv[5]); + flash_attn = std::atoi(argv[5]); } if (argc >= 7) { - n_gpu_layers = std::atoi(argv[6]); + is_pp_shared = std::atoi(argv[6]); } if (argc >= 8) { - n_pp = parse_list(argv[7]); + n_gpu_layers = std::atoi(argv[7]); } if (argc >= 9) { - n_tg = parse_list(argv[8]); + n_pp = parse_list(argv[8]); } if (argc >= 10) { - n_pl = parse_list(argv[9]); + n_tg = parse_list(argv[9]); + } + + if (argc >= 11) { + n_pl = parse_list(argv[10]); } // init LLM @@ -108,10 +113,11 @@ int main(int argc, char ** argv) { llama_context_params ctx_params = llama_context_default_params(); - ctx_params.seed = 1234; - ctx_params.n_ctx = n_kv_max; - ctx_params.n_batch = n_batch; - ctx_params.n_ubatch = n_ubatch; + ctx_params.seed = 1234; + ctx_params.n_ctx = n_kv_max; + ctx_params.n_batch = n_batch; + ctx_params.n_ubatch = n_ubatch; + ctx_params.flash_attn = flash_attn; ctx_params.n_threads = params.n_threads; ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch; @@ -169,7 +175,7 @@ int main(int argc, char ** argv) { } LOG_TEE("\n"); - LOG_TEE("%s: n_kv_max = %d, n_batch = %d, n_ubatch = %d, is_pp_shared = %d, n_gpu_layers = %d, n_threads = %u, n_threads_batch = %u\n", __func__, n_kv_max, n_batch, n_ubatch, is_pp_shared, n_gpu_layers, ctx_params.n_threads, ctx_params.n_threads_batch); + LOG_TEE("%s: n_kv_max = %d, n_batch = %d, n_ubatch = %d, flash_attn = %d, is_pp_shared = %d, n_gpu_layers = %d, n_threads = %u, n_threads_batch = %u\n", __func__, n_kv_max, n_batch, n_ubatch, flash_attn, is_pp_shared, n_gpu_layers, ctx_params.n_threads, ctx_params.n_threads_batch); LOG_TEE("\n"); LOG_TEE("|%6s | %6s | %4s | %6s | %8s | %8s | %8s | %8s | %8s | %8s |\n", "PP", "TG", "B", "N_KV", "T_PP s", "S_PP t/s", "T_TG s", "S_TG t/s", "T s", "S t/s"); diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 8b532c8b6..95c3095dd 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -174,6 +174,7 @@ struct cmd_params { std::vector split_mode; std::vector main_gpu; std::vector no_kv_offload; + std::vector flash_attn; std::vector> tensor_split; std::vector use_mmap; std::vector embeddings; @@ -195,6 +196,7 @@ static const cmd_params cmd_params_defaults = { /* split_mode */ {LLAMA_SPLIT_MODE_LAYER}, /* main_gpu */ {0}, /* no_kv_offload */ {false}, + /* flash_attn */ {false}, /* tensor_split */ {std::vector(llama_max_devices(), 0.0f)}, /* use_mmap */ {true}, /* embeddings */ {false}, @@ -220,6 +222,7 @@ static void print_usage(int /* argc */, char ** argv) { printf(" -sm, --split-mode (default: %s)\n", join(transform_to_str(cmd_params_defaults.split_mode, split_mode_str), ",").c_str()); printf(" -mg, --main-gpu (default: %s)\n", join(cmd_params_defaults.main_gpu, ",").c_str()); printf(" -nkvo, --no-kv-offload <0|1> (default: %s)\n", join(cmd_params_defaults.no_kv_offload, ",").c_str()); + printf(" -fa, --flash-attn <0|1> (default: %s)\n", join(cmd_params_defaults.flash_attn, ",").c_str()); printf(" -mmp, --mmap <0|1> (default: %s)\n", join(cmd_params_defaults.use_mmap, ",").c_str()); printf(" -embd, --embeddings <0|1> (default: %s)\n", join(cmd_params_defaults.embeddings, ",").c_str()); printf(" -ts, --tensor-split (default: 0)\n"); @@ -393,6 +396,13 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { } auto p = split(argv[i], split_delim); params.no_kv_offload.insert(params.no_kv_offload.end(), p.begin(), p.end()); + } else if (arg == "-fa" || arg == "--flash-attn") { + if (++i >= argc) { + invalid_param = true; + break; + } + auto p = split(argv[i], split_delim); + params.flash_attn.insert(params.flash_attn.end(), p.begin(), p.end()); } else if (arg == "-mmp" || arg == "--mmap") { if (++i >= argc) { invalid_param = true; @@ -477,6 +487,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { if (params.split_mode.empty()) { params.split_mode = cmd_params_defaults.split_mode; } if (params.main_gpu.empty()) { params.main_gpu = cmd_params_defaults.main_gpu; } if (params.no_kv_offload.empty()){ params.no_kv_offload = cmd_params_defaults.no_kv_offload; } + if (params.flash_attn.empty()) { params.flash_attn = cmd_params_defaults.flash_attn; } if (params.tensor_split.empty()) { params.tensor_split = cmd_params_defaults.tensor_split; } if (params.use_mmap.empty()) { params.use_mmap = cmd_params_defaults.use_mmap; } if (params.embeddings.empty()) { params.embeddings = cmd_params_defaults.embeddings; } @@ -498,6 +509,7 @@ struct cmd_params_instance { llama_split_mode split_mode; int main_gpu; bool no_kv_offload; + bool flash_attn; std::vector tensor_split; bool use_mmap; bool embeddings; @@ -532,6 +544,7 @@ struct cmd_params_instance { cparams.type_k = type_k; cparams.type_v = type_v; cparams.offload_kqv = !no_kv_offload; + cparams.flash_attn = flash_attn; cparams.embeddings = embeddings; return cparams; @@ -554,6 +567,7 @@ static std::vector get_cmd_params_instances(const cmd_param for (const auto & tk : params.type_k) for (const auto & tv : params.type_v) for (const auto & nkvo : params.no_kv_offload) + for (const auto & fa : params.flash_attn) for (const auto & nt : params.n_threads) { for (const auto & n_prompt : params.n_prompt) { if (n_prompt == 0) { @@ -572,6 +586,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .split_mode = */ sm, /* .main_gpu = */ mg, /* .no_kv_offload= */ nkvo, + /* .flash_attn = */ fa, /* .tensor_split = */ ts, /* .use_mmap = */ mmp, /* .embeddings = */ embd, @@ -596,6 +611,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .split_mode = */ sm, /* .main_gpu = */ mg, /* .no_kv_offload= */ nkvo, + /* .flash_attn = */ fa, /* .tensor_split = */ ts, /* .use_mmap = */ mmp, /* .embeddings = */ embd, @@ -633,6 +649,7 @@ struct test { llama_split_mode split_mode; int main_gpu; bool no_kv_offload; + bool flash_attn; std::vector tensor_split; bool use_mmap; bool embeddings; @@ -657,6 +674,7 @@ struct test { split_mode = inst.split_mode; main_gpu = inst.main_gpu; no_kv_offload = inst.no_kv_offload; + flash_attn = inst.flash_attn; tensor_split = inst.tensor_split; use_mmap = inst.use_mmap; embeddings = inst.embeddings; @@ -731,7 +749,7 @@ struct test { "n_batch", "n_ubatch", "n_threads", "type_k", "type_v", "n_gpu_layers", "split_mode", - "main_gpu", "no_kv_offload", + "main_gpu", "no_kv_offload", "flash_attn", "tensor_split", "use_mmap", "embeddings", "n_prompt", "n_gen", "test_time", "avg_ns", "stddev_ns", @@ -753,7 +771,7 @@ struct test { } if (field == "cuda" || field == "opencl" || field == "vulkan" || field == "kompute" || field == "metal" || field == "gpu_blas" || field == "blas" || field == "sycl" ||field == "f16_kv" || field == "no_kv_offload" || - field == "use_mmap" || field == "embeddings") { + field == "flash_attn" || field == "use_mmap" || field == "embeddings") { return BOOL; } if (field == "avg_ts" || field == "stddev_ts") { @@ -787,7 +805,7 @@ struct test { std::to_string(n_batch), std::to_string(n_ubatch), std::to_string(n_threads), ggml_type_name(type_k), ggml_type_name(type_v), std::to_string(n_gpu_layers), split_mode_str(split_mode), - std::to_string(main_gpu), std::to_string(no_kv_offload), + std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn), tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings), std::to_string(n_prompt), std::to_string(n_gen), test_time, std::to_string(avg_ns()), std::to_string(stdev_ns()), @@ -955,6 +973,9 @@ struct markdown_printer : public printer { if (field == "no_kv_offload") { return "nkvo"; } + if (field == "flash_attn") { + return "fa"; + } if (field == "use_mmap") { return "mmap"; } @@ -1001,6 +1022,9 @@ struct markdown_printer : public printer { if (params.no_kv_offload.size() > 1 || params.no_kv_offload != cmd_params_defaults.no_kv_offload) { fields.emplace_back("no_kv_offload"); } + if (params.flash_attn.size() > 1 || params.flash_attn != cmd_params_defaults.flash_attn) { + fields.emplace_back("flash_attn"); + } if (params.tensor_split.size() > 1 || params.tensor_split != cmd_params_defaults.tensor_split) { fields.emplace_back("tensor_split"); } diff --git a/examples/server/bench/bench.py b/examples/server/bench/bench.py index 6ca637bdd..86c5de101 100644 --- a/examples/server/bench/bench.py +++ b/examples/server/bench/bench.py @@ -268,6 +268,7 @@ def start_server_background(args): server_args.extend(['--defrag-thold', "0.1"]) server_args.append('--cont-batching') server_args.append('--metrics') + server_args.append('--flash-attn') server_args.extend(['--log-format', "text"]) args = [str(arg) for arg in [server_path, *server_args]] print(f"bench: starting server with: {' '.join(args)}") diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 01453af2c..f60530cf3 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2377,6 +2377,7 @@ static void server_print_usage(const char * argv0, const gpt_params & params, co printf(" --embeddings enable embedding vector output (default: %s)\n", params.embedding ? "enabled" : "disabled"); printf(" -np N, --parallel N number of slots for process requests (default: %d)\n", params.n_parallel); printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: enabled)\n"); + printf(" -fa, --flash-attn enable Flash Attention (default: %s)\n", params.flash_attn ? "enabled" : "disabled"); printf(" -spf FNAME, --system-prompt-file FNAME\n"); printf(" set a file to load a system prompt (initial prompt of all slots), this is useful for chat applications.\n"); printf(" -ctk TYPE, --cache-type-k TYPE\n"); @@ -2742,6 +2743,8 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams, params.embedding = true; } else if (arg == "-cb" || arg == "--cont-batching") { params.cont_batching = true; + } else if (arg == "-fa" || arg == "--flash-attn") { + params.flash_attn = true; } else if (arg == "-np" || arg == "--parallel") { if (++i >= argc) { invalid_param = true; diff --git a/ggml-cuda.cu b/ggml-cuda.cu index d277104d1..c30554f0c 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -14,6 +14,7 @@ #include "ggml-cuda/cpy.cuh" #include "ggml-cuda/diagmask.cuh" #include "ggml-cuda/dmmv.cuh" +#include "ggml-cuda/fattn.cuh" #include "ggml-cuda/getrows.cuh" #include "ggml-cuda/im2col.cuh" #include "ggml-cuda/mmq.cuh" @@ -140,6 +141,7 @@ static ggml_cuda_device_info ggml_cuda_init() { info.devices[id].cc = 100*prop.major + 10*prop.minor; #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) info.devices[id].smpb = prop.sharedMemPerBlock; + info.devices[id].nsm = prop.multiProcessorCount; } for (int id = 0; id < info.device_count; ++id) { @@ -2290,6 +2292,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_ARGSORT: ggml_cuda_op_argsort(ctx, dst); break; + case GGML_OP_FLASH_ATTN_EXT: + ggml_cuda_flash_attn_ext(ctx, dst); + break; default: return false; } @@ -2564,6 +2569,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_OP_ARANGE: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_LEAKY_RELU: + case GGML_OP_FLASH_ATTN_EXT: return true; default: return false; diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh index 481065b2a..156eba6d1 100644 --- a/ggml-cuda/common.cuh +++ b/ggml-cuda/common.cuh @@ -142,6 +142,7 @@ #define CC_PASCAL 600 #define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products #define CC_VOLTA 700 +#define CC_AMPERE 800 #define CC_OFFSET_AMD 1000000 #define CC_RDNA1 (CC_OFFSET_AMD + 1010) #define CC_RDNA2 (CC_OFFSET_AMD + 1030) @@ -271,7 +272,6 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) { return a; } -#ifdef GGML_CUDA_F16 static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL #pragma unroll @@ -284,7 +284,6 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { NO_DEVICE_CODE; #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL } -#endif // GGML_CUDA_F16 static __device__ __forceinline__ float warp_reduce_max(float x) { #pragma unroll @@ -294,19 +293,26 @@ static __device__ __forceinline__ float warp_reduce_max(float x) { return x; } -//static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { -//#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX -//#pragma unroll -// for (int mask = 16; mask > 0; mask >>= 1) { -// x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32)); -// } -// return x; -//#else -// GGML_UNUSED(x); -// NO_DEVICE_CODE; -//#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX -//} +static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32)); + } + return x; +#else + GGML_UNUSED(x); + NO_DEVICE_CODE; +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX +} +#if CUDART_VERSION < 12000 +static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half2 b) { + const uint32_t mask_low = 0x0000FFFF * (float( __low2half(a)) > float( __low2half(b))); + const uint32_t mask_high = 0xFFFF0000 * (float(__high2half(a)) > float(__high2half(b))); + return mask_low | mask_high; +} +#endif // CUDART_VERSION < 12000 #if defined(GGML_USE_HIPBLAS) #define __CUDA_ARCH__ 1300 @@ -391,6 +397,11 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) { } #endif // defined(GGML_USE_HIPBLAS) +#define FP16_AVAILABLE defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) ? \ + defined(RDNA1) || defined(RDNA2) || defined(RDNA3) : __CUDA_ARCH__ >= CC_PASCAL + +#define FP16_MMA_AVAILABLE !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA + // TODO: move to ggml-common.h static const __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113}; @@ -404,6 +415,7 @@ struct ggml_cuda_device_info { struct cuda_device_info { int cc; // compute capability + int nsm; // number of streaming multiprocessors size_t smpb; // max. shared memory per block bool vmm; // virtual memory support size_t vmm_granularity; // granularity of virtual memory diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu new file mode 100644 index 000000000..df1e80068 --- /dev/null +++ b/ggml-cuda/fattn.cu @@ -0,0 +1,944 @@ +#include "common.cuh" +#include "fattn.cuh" + +#include + +#if FP16_MMA_AVAILABLE +#include +#endif + +#define FATTN_KQ_STRIDE 256 +#define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction. +#define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs. + +template // D == head size +__launch_bounds__(((D + WARP_SIZE - 1) / WARP_SIZE)*WARP_SIZE, 1) +static __global__ void flash_attn_vec_ext_f16( + const char * __restrict__ Q, + const char * __restrict__ K, + const char * __restrict__ V, + const char * __restrict__ mask, + float * __restrict__ dst, + float2 * __restrict__ dst_meta, + const float scale, + const int ne00, + const int ne01, + const int ne02, + const int ne03, + const int ne10, + const int ne11, + const int ne12, + const int ne13, + const int ne31, + const int nb31, + const int nb01, + const int nb02, + const int nb03, + const int nb11, + const int nb12, + const int nb13, + const int ne0, + const int ne1, + const int ne2, + const int ne3) { +#if FP16_AVAILABLE + //In this kernel Q, K, V are matrices while i, j, k are matrix indices. + + const int ic = blockIdx.x / parallel_blocks; // Index of the Q/QKV column to work on. + const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel. + + const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. + const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic); + const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio)); + const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape + const half * maskh = (const half *) mask + ne11*ic; + + const int stride_KV = nb11 / sizeof(half); + const int stride_KV2 = nb11 / sizeof(half2); + + constexpr int nwarps = (D + WARP_SIZE - 1) / WARP_SIZE; + const int tid = WARP_SIZE*threadIdx.y + threadIdx.x; + __builtin_assume(tid < nwarps*WARP_SIZE); + + __shared__ half KQ[nwarps*WARP_SIZE]; + KQ[tid] = -INFINITY; + half2 * KQ2 = (half2 *) KQ; + + half kqmax = -HALF_MAX_HALF; + half kqsum = 0.0f; + + __shared__ half kqmax_shared[WARP_SIZE]; + __shared__ half kqsum_shared[WARP_SIZE]; + if (threadIdx.y == 0) { + kqmax_shared[threadIdx.x] = -HALF_MAX_HALF; + kqsum_shared[threadIdx.x] = 0.0f; + } + __syncthreads(); + + // Convert Q to half2 and store in registers: + half2 Q_h2[(D/2 + WARP_SIZE - 1) / WARP_SIZE]; +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + if (i0 + WARP_SIZE > D/2 && i >= D/2) { + break; + } + + Q_h2[i0/WARP_SIZE] = make_half2(scale, scale) * make_half2(Q_f2[i].x, Q_f2[i].y); + } + + half2 VKQ = make_half2(0.0f, 0.0f); // Each thread calculates a single VKQ value. + + const int k_start = parallel_blocks == 1 ? 0 : ip*D; + for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) { + // Calculate KQ tile and keep track of new maximum KQ values: + half kqmax_new = kqmax; +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) { + const int i_KQ = i_KQ_0 + threadIdx.y; + + if ((i_KQ_0 + nwarps > D && i_KQ >= D) || (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + i_KQ >= ne11)) { + break; + } + + half2 sum2 = make_half2(0.0f, 0.0f); +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) { + const int k_KQ = k_KQ_0 + threadIdx.x; + if (k_KQ_0 + WARP_SIZE > D/2 && k_KQ >= D/2) { + break; + } + + const half2 K_ik = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ]; + sum2 += K_ik * Q_h2[k_KQ_0/WARP_SIZE]; + } + + sum2 = warp_reduce_sum(sum2); + half sum = __low2half(sum2) + __high2half(sum2); + sum += mask ? maskh[k_VKQ_0 + i_KQ] : __float2half(0.0f); + kqmax_new = __hmax(kqmax_new, sum); + if (threadIdx.x == 0) { + KQ[i_KQ] = sum; + } + } + + kqmax_new = warp_reduce_max(kqmax_new); + if (threadIdx.x == 0) { + kqmax_shared[threadIdx.y] = kqmax_new; + } + __syncthreads(); + kqmax_new = kqmax_shared[threadIdx.x]; + kqmax_new = warp_reduce_max(kqmax_new); + + const half KQ_max_scale = hexp(kqmax - kqmax_new); + kqmax = kqmax_new; + + const half val = hexp(KQ[tid] - kqmax); + kqsum = kqsum*KQ_max_scale + val; + KQ[tid] = val; + + VKQ *= __half2half2(KQ_max_scale); + + __syncthreads(); + + if (tid < D) { +#pragma unroll + for (int k0 = 0; k0 < D; k0 += 2) { + if (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + k0 >= ne11) { + break; + } + + half2 V_k; + reinterpret_cast(V_k.x) = V_h[(k_VKQ_0 + k0 + 0)*stride_KV + tid]; + reinterpret_cast(V_k.y) = V_h[(k_VKQ_0 + k0 + 1)*stride_KV + tid]; + VKQ += V_k*KQ2[k0/2]; + } + } + + __syncthreads(); + } + + if (tid >= D) { + kqsum = 0.0f; + } + + kqsum = warp_reduce_sum(kqsum); + if (threadIdx.x == 0) { + kqsum_shared[threadIdx.y] = kqsum; + } + __syncthreads(); + kqsum = kqsum_shared[threadIdx.x]; + kqsum = warp_reduce_sum(kqsum); + + if (tid >= D) { + return; + } + + half dst_val = (__low2half(VKQ) + __high2half(VKQ)); + if (parallel_blocks == 1) { + dst_val /= kqsum; + } + dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + tid] = dst_val; + + if (parallel_blocks == 1 || tid != 0) { + return; + } + dst_meta[ic*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax, kqsum); +#else + NO_DEVICE_CODE; +#endif // FP16_AVAILABLE +} + +// D == head size, VKQ_stride == num VKQ rows calculated in parallel: +template +__launch_bounds__(nwarps*WARP_SIZE, 1) +static __global__ void flash_attn_ext_f16( + const char * __restrict__ Q, + const char * __restrict__ K, + const char * __restrict__ V, + const char * __restrict__ mask, + float * __restrict__ dst, + float2 * __restrict__ dst_meta, + const float scale, + const int ne00, + const int ne01, + const int ne02, + const int ne03, + const int ne10, + const int ne11, + const int ne12, + const int ne13, + const int ne31, + const int nb31, + const int nb01, + const int nb02, + const int nb03, + const int nb11, + const int nb12, + const int nb13, + const int ne0, + const int ne1, + const int ne2, + const int ne3) { +#if FP16_MMA_AVAILABLE + //In this kernel Q, K, V are matrices while i, j, k are matrix indices. + + const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on. + const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel. + + static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE."); + static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16."); + constexpr int frag_m = ncols == 8 ? 32 : 16; + constexpr int frag_n = ncols == 8 ? 8 : 16; + static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0."); + typedef nvcuda::wmma::fragment frag_a_K; + typedef nvcuda::wmma::fragment frag_a_V; + typedef nvcuda::wmma::fragment frag_b; + typedef nvcuda::wmma::fragment frag_c_KQ; + typedef nvcuda::wmma::fragment frag_c_VKQ; + + constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel. + constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy. + static_assert(VKQ_ratio <= nwarps, "VKQ_ratio must be <= nwarps."); + + // Pad internal representation of KQ, KQV to reduce shared memory bank conflicts: + constexpr int D_padded = D + 8; + constexpr int kqs_padded = FATTN_KQ_STRIDE + 8; + constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half); + + const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. + const float * Q_f = (const float *) (Q + nb02* blockIdx.y + nb01*ic0); + const half * K_h = (const half *) (K + nb12*(blockIdx.y / gqa_ratio)); + const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape + const half * maskh = (const half *) mask + (nb31/sizeof(half))* ic0; + const half2 * mask2 = (const half2 *) mask + (nb31/sizeof(half))*(ic0/2); + + const int stride_Q = nb01 / sizeof(float); + const int stride_KV = nb11 / sizeof(half); + + frag_b Q_b[D/16][ncols/frag_n]; + + // A single buffer for temporarily holding tiles of KQ and VKQ parts: + constexpr int mem_KQ = ncols*kqs_padded*kqar; + constexpr int mem_VKQ_parts = VKQ_ratio*ncols*D_padded; + __shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts]; + float * KQ_f = (float *) KQ; + half2 * KQ2 = (half2 *) KQ; + + float KQ_rowsum_f[ncols/nwarps] = {0.0f}; + float KQ_max_f[ncols/nwarps]; + float KQ_max_scale_f[ncols/nwarps] = {0.0f}; + +#pragma unroll + for (int j = 0; j < ncols/nwarps; ++j) { + KQ_max_f[j] = -FLT_MAX/2.0f; + } + + half2 KQ_rowsum_h2[ncols/nwarps] = {{0.0f, 0.0f}}; + half2 KQ_max_h2[ncols/nwarps]; + half2 KQ_max_scale_h2[ncols/nwarps] = {{0.0f, 0.0f}}; + +#pragma unroll + for (int j = 0; j < ncols/nwarps; ++j) { + KQ_max_h2[j] = make_half2(-HALF_MAX_HALF, -HALF_MAX_HALF); + } + + __shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice. + half2 * VKQ2 = (half2 *) VKQ; +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + if (i0 + WARP_SIZE > D/2 && i >= D/2) { + break; + } + VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f); + } + } + + // Convert Q to half and apply scale, temporarily store in KQ: +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; +#pragma unroll + for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + if (i0 + WARP_SIZE > D && i >= D) { + break; + } + KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f; + } + } + + __syncthreads(); + + // Load Q into tensor core fragments/registers since it will be used frequently: +#pragma unroll + for (int i0 = 0; i0 < D; i0 += 16) { +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += frag_n) { + nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded); + } + } + + __syncthreads(); + + // Iterate over ne11 == previous tokens: + for (int k_VKQ_0 = ip*FATTN_KQ_STRIDE; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE) { + // Calculate tile of KQ: +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) { + frag_c_KQ KQ_c[ncols/frag_n]; +#pragma unroll + for (int j = 0; j < ncols/frag_n; ++j) { + nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f); + } +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) { + frag_a_K K_a; + nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV); +#pragma unroll + for (int j = 0; j < ncols/frag_n; ++j) { + nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]); + } + } +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += frag_n) { + nvcuda::wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major); + } + } + + __syncthreads(); + + // Calculate softmax for each KQ column using the current max. value. + // The divisor is stored in KQ_rowsum and will be applied at the end. +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; + + if (std::is_same::value) { + float KQ_f_tmp[FATTN_KQ_STRIDE / WARP_SIZE]; +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; + + KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k]; + } + + float KQ_max_new = KQ_max_f[j0/nwarps]; +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; + + KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f; + KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]); + } + KQ_max_new = warp_reduce_max(KQ_max_new); + + const float diff = KQ_max_f[j0/nwarps] - KQ_max_new; + KQ_max_scale_f[j0/nwarps] = expf(diff); + if (diff <= SOFTMAX_FTZ_THRESHOLD) { + KQ_max_scale_f[j0/nwarps] = 0.0f; + } + KQ_max_f[j0/nwarps] = KQ_max_new; + + float KQ_rowsum_add = 0.0f; +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; + + const float diff = KQ_f_tmp[k0/WARP_SIZE] - KQ_max_f[j0/nwarps]; + KQ_f_tmp[k0/WARP_SIZE] = expf(diff); + if (diff <= SOFTMAX_FTZ_THRESHOLD) { + KQ_f_tmp[k0/WARP_SIZE] = 0.0f; + } + KQ_rowsum_add += KQ_f_tmp[k0/WARP_SIZE]; + KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/WARP_SIZE]; + } + KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add); + + // Scale previous KQ_rowsum to account for a potential increase in KQ_max: + KQ_rowsum_f[j0/nwarps] = KQ_max_scale_f[j0/nwarps]*KQ_rowsum_f[j0/nwarps] + KQ_rowsum_add; + } else { + half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*WARP_SIZE)]; +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; + + KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k]; + } + + half2 KQ_max_new = KQ_max_h2[j0/nwarps]; +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; + + KQ2_tmp[k0/WARP_SIZE] += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); + KQ_max_new = __hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]); + } + KQ_max_new = __half2half2(warp_reduce_max(__hmax(__low2half(KQ_max_new), __high2half(KQ_max_new)))); + const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new; + KQ_max_scale_h2[j0/nwarps] = h2exp(diff); + const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); + *((uint32_t *) &KQ_max_scale_h2[j0/nwarps]) &= ftz_mask; + KQ_max_h2[j0/nwarps] = KQ_max_new; + + half2 KQ_rowsum_add = make_half2(0.0f, 0.0f); +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; + + const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max_h2[j0/nwarps]; + KQ2_tmp[k0/WARP_SIZE] = h2exp(diff); + const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); + *((uint32_t *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask; + KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE]; + KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/WARP_SIZE]; + } + KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add); + + // Scale previous KQ_rowsum to account for a potential increase in KQ_max: + KQ_rowsum_h2[j0/nwarps] = KQ_max_scale_h2[j0/nwarps]*KQ_rowsum_h2[j0/nwarps] + KQ_rowsum_add; + } + } + + __syncthreads(); + + frag_b KQ_b[FATTN_KQ_STRIDE/(VKQ_ratio*16)][ncols/frag_n]; +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += frag_n) { +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) { + const int k = k0 + (threadIdx.y % VKQ_ratio)*16; + nvcuda::wmma::load_matrix_sync( + KQ_b[k0/(VKQ_ratio*16)][j0/frag_n], + KQ + j0*(kqar*kqs_padded) + k, + kqar*kqs_padded); + } + } + + frag_c_VKQ VKQ_c[D/VKQ_stride][ncols/frag_n]; +#pragma unroll + for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) { +#pragma unroll + for (int j = 0; j < ncols/frag_n; ++j) { + nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f); + } + +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) { + const int k = k0 + (threadIdx.y % VKQ_ratio)*16; + + frag_a_V v_a; + nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV); +#pragma unroll + for (int j = 0; j < ncols/frag_n; ++j) { + nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]); + } + } + } + + __syncthreads(); + + const int offset_k = (threadIdx.y % VKQ_ratio) * (ncols*D_padded); +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) { +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += frag_n) { + nvcuda::wmma::store_matrix_sync( + KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio), + VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n], + D_padded, nvcuda::wmma::mem_col_major); + } + } + + __syncthreads(); + +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; + + half2 VKQ_scale; + if (std::is_same::value) { + VKQ_scale = make_half2(KQ_max_scale_f[j0/nwarps], KQ_max_scale_f[j0/nwarps]); + } else { + VKQ_scale = KQ_max_scale_h2[j0/nwarps]; + } + +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + if (i0 + WARP_SIZE > D/2 && i >= D/2) { + break; + } + + half2 VKQ_add = make_half2(0.0f, 0.0f); +#pragma unroll + for (int l = 0; l < VKQ_ratio; ++l) { + VKQ_add += KQ2[l*(ncols*D_padded/2) + j*(D_padded/2) + i]; + } + VKQ2[j*(D_padded/2) + i] = VKQ_scale*VKQ2[j*(D_padded/2) + i] + VKQ_add; + } + } + + __syncthreads(); + } + +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j_VKQ = j0 + threadIdx.y; + if (ic0 + j_VKQ >= ne01) { + return; + } + const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip; + + float KQ_rowsum_j; + if (std::is_same::value) { + KQ_rowsum_j = KQ_rowsum_f[j0/nwarps]; + } else { + KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]); + } + +#pragma unroll + for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + if (i0 + WARP_SIZE > D && i >= D) { + break; + } + float dst_val = VKQ[j_VKQ*D_padded + i]; + if (parallel_blocks == 1) { + dst_val /= KQ_rowsum_j; + } + dst[j_dst*gridDim.y*D + blockIdx.y*D + i] = dst_val; + } + + if (parallel_blocks == 1 || threadIdx.x != 0) { + continue; + } + + float2 dst_meta_val; + if (std::is_same::value) { + dst_meta_val.x = KQ_max_f[j0/nwarps]; + } else { + dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]); + } + dst_meta_val.y = KQ_rowsum_j; + dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = dst_meta_val; + } +#else + NO_DEVICE_CODE; +#endif // FP16_MMA_AVAILABLE +} + +template // D == head size +__launch_bounds__(D, 1) +static __global__ void flash_attn_combine_results( + const float * __restrict__ VKQ_parts, + const float2 * __restrict__ VKQ_meta, + float * __restrict__ dst) { +#if FP16_AVAILABLE + VKQ_parts += parallel_blocks*D * gridDim.y*blockIdx.x; + VKQ_meta += parallel_blocks * gridDim.y*blockIdx.x; + dst += D * gridDim.y*blockIdx.x; + + const int tid = threadIdx.x; + __builtin_assume(tid < D); + + __shared__ float2 meta[parallel_blocks]; + if (tid < 2*parallel_blocks) { + ((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.y*(2*parallel_blocks) + tid]; + } + + __syncthreads(); + + float kqmax = meta[0].x; +#pragma unroll + for (int l = 1; l < parallel_blocks; ++l) { + kqmax = max(kqmax, meta[l].x); + } + + float VKQ_numerator = 0.0f; + float VKQ_denominator = 0.0f; +#pragma unroll + for (int l = 0; l < parallel_blocks; ++l) { + const float diff = meta[l].x - kqmax; + const float KQ_max_scale = expf(diff); + const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD); + *((uint32_t *) &KQ_max_scale) &= ftz_mask; + + VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid]; + VKQ_denominator += KQ_max_scale * meta[l].y; + } + + dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator; +#else + NO_DEVICE_CODE; +#endif // FP16_AVAILABLE +} + +constexpr int get_max_power_of_2(int x) { + return x % 2 == 0 ? 2*get_max_power_of_2(x/2) : 1; +} + +static_assert(get_max_power_of_2(1) == 1, "Test failed."); +static_assert(get_max_power_of_2(2) == 2, "Test failed."); +static_assert(get_max_power_of_2(4) == 4, "Test failed."); +static_assert(get_max_power_of_2(6) == 2, "Test failed."); + +// Number of VKQ rows calculated in parallel: +constexpr int get_VKQ_stride(int D, int nwarps, int frag_m) { + return (get_max_power_of_2(D/frag_m) < nwarps ? get_max_power_of_2(D/frag_m) : nwarps)*frag_m; +} + +static_assert(get_VKQ_stride(128, 1, 32) == 32, "Test failed."); +static_assert(get_VKQ_stride(128, 2, 32) == 64, "Test failed."); +static_assert(get_VKQ_stride(128, 4, 32) == 128, "Test failed."); +static_assert(get_VKQ_stride( 64, 1, 32) == 32, "Test failed."); +static_assert(get_VKQ_stride( 64, 2, 32) == 64, "Test failed."); +static_assert(get_VKQ_stride( 64, 4, 32) == 64, "Test failed."); +static_assert(get_VKQ_stride( 80, 1, 16) == 16, "Test failed."); +static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed."); +static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed."); + +template void launch_fattn_vec_f16( + const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask, + ggml_cuda_pool & pool, cudaStream_t main_stream +) { + ggml_cuda_pool_alloc dst_tmp(pool); + ggml_cuda_pool_alloc dst_tmp_meta(pool); + + if (parallel_blocks > 1) { + dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); + dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV)); + } + + constexpr int nwarps = (D + WARP_SIZE - 1) / WARP_SIZE; + const dim3 block_dim(WARP_SIZE, nwarps, 1); + const dim3 blocks_num(parallel_blocks*Q->ne[1], Q->ne[2], Q->ne[3]); + const int shmem = 0; + + float scale; + memcpy(&scale, KQV->op_params, sizeof(float)); + + flash_attn_vec_ext_f16 + <<>> ( + (const char *) Q->data, + (const char *) K->data, + (const char *) V->data, + mask ? ((const char *) mask->data) : nullptr, + parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + CUDA_CHECK(cudaGetLastError()); + + if (parallel_blocks == 1) { + return; + } + + const dim3 block_dim_combine(D, 1, 1); + const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z); + const int shmem_combine = 0; + + flash_attn_combine_results + <<>> + (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); + CUDA_CHECK(cudaGetLastError()); +} + +template void launch_fattn_f16_impl( + const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask, + ggml_cuda_pool & pool, cudaStream_t main_stream +) { + ggml_cuda_pool_alloc dst_tmp(pool); + ggml_cuda_pool_alloc dst_tmp_meta(pool); + + if (parallel_blocks > 1) { + dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); + dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV)); + } + + constexpr int frag_m = (cols_per_block) == 8 && (D) % 32 == 0 ? 32 : 16; + const dim3 block_dim(WARP_SIZE, nwarps, 1); + const dim3 blocks_num(parallel_blocks*(Q->ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]); + const int shmem = 0; + + float scale; + memcpy(&scale, KQV->op_params, sizeof(float)); + + flash_attn_ext_f16 + <<>> ( + (const char *) Q->data, + (const char *) K->data, + (const char *) V->data, + mask ? ((const char *) mask->data) : nullptr, + (parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + CUDA_CHECK(cudaGetLastError()); + + if ((parallel_blocks) == 1) { + return; + } + + const dim3 block_dim_combine(D, 1, 1); + const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z); + const int shmem_combine = 0; + + flash_attn_combine_results + <<>> + (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); + CUDA_CHECK(cudaGetLastError()); +} + +template void launch_fattn_f16( + const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask, + const int nsm, ggml_cuda_pool & pool, cudaStream_t main_stream +) { + const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3]; + + if (4*blocks_num_pb1 < 2*nsm) { + launch_fattn_f16_impl(Q, K, V, KQV, mask, pool, main_stream); + return; + } + if (2*blocks_num_pb1 < 2*nsm) { + launch_fattn_f16_impl(Q, K, V, KQV, mask, pool, main_stream); + return; + } + launch_fattn_f16_impl(Q, K, V, KQV, mask, pool, main_stream); +} + +void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + + const ggml_tensor * mask = dst->src[3]; + + ggml_tensor * KQV = dst; + + GGML_ASSERT(Q->type == GGML_TYPE_F32); + GGML_ASSERT(K->type == GGML_TYPE_F16); + GGML_ASSERT(V->type == GGML_TYPE_F16); + GGML_ASSERT(KQV->type == GGML_TYPE_F32); + + GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16); + GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) && + "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big"); + + GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding."); + + ggml_cuda_set_device(ctx.device); + + const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; + + const int32_t precision = KQV->op_params[1]; + + if (precision != GGML_PREC_DEFAULT) { + if (Q->ne[1] <= 32 || Q->ne[0] > 128) { + constexpr int cols_per_block = 16; + constexpr int nwarps = 4; + switch (Q->ne[0]) { + case 64: + launch_fattn_f16< 64, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 80: + launch_fattn_f16< 80, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 96: + launch_fattn_f16< 96, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 112: + launch_fattn_f16<112, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 128: + launch_fattn_f16<128, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 256: + launch_fattn_f16<256, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + default: + GGML_ASSERT(false); + break; + } + } else { + constexpr int cols_per_block = 32; + constexpr int nwarps = 4; + switch (Q->ne[0]) { + case 64: + launch_fattn_f16< 64, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 80: + launch_fattn_f16< 80, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 96: + launch_fattn_f16< 96, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 112: + launch_fattn_f16<112, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 128: + launch_fattn_f16<128, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + // case 256: + // launch_fattn_f16<256, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + // break; + default: + GGML_ASSERT(false); + break; + } + } + return; + } + + if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) { + constexpr int parallel_blocks = 4; + switch (Q->ne[0]) { + case 64: + launch_fattn_vec_f16< 64, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); + break; + case 128: + launch_fattn_vec_f16<128, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); + break; + case 256: + launch_fattn_vec_f16<256, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream()); + break; + default: + GGML_ASSERT(false); + break; + } + return; + } + + if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) { + constexpr int cols_per_block = 8; + constexpr int nwarps = 4; + switch (Q->ne[0]) { + case 64: + launch_fattn_f16< 64, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 96: + launch_fattn_f16< 96, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 128: + launch_fattn_f16<128, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 256: + launch_fattn_f16<256, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + default: + GGML_ASSERT(false); + break; + } + return; + } + + if (Q->ne[1] <= 32) { + constexpr int cols_per_block = 16; + constexpr int nwarps = 4; + switch (Q->ne[0]) { + case 64: + launch_fattn_f16< 64, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 80: + launch_fattn_f16< 80, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 96: + launch_fattn_f16< 96, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 112: + launch_fattn_f16<112, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 128: + launch_fattn_f16<128, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 256: + launch_fattn_f16<256, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + default: + GGML_ASSERT(false); + break; + } + return; + } + + constexpr int cols_per_block = 32; + constexpr int nwarps = 4; + switch (Q->ne[0]) { + case 64: + launch_fattn_f16< 64, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 80: + launch_fattn_f16< 80, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 96: + launch_fattn_f16< 96, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 112: + launch_fattn_f16<112, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 128: + launch_fattn_f16<128, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + case 256: + launch_fattn_f16<256, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream()); + break; + default: + GGML_ASSERT(false); + break; + } + return; +} diff --git a/ggml-cuda/fattn.cuh b/ggml-cuda/fattn.cuh new file mode 100644 index 000000000..ad3ca7a8d --- /dev/null +++ b/ggml-cuda/fattn.cuh @@ -0,0 +1,3 @@ +#include "common.cuh" + +void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml-cuda/softmax.cu b/ggml-cuda/softmax.cu index fa8f987cf..6ed225999 100644 --- a/ggml-cuda/softmax.cu +++ b/ggml-cuda/softmax.cu @@ -1,7 +1,17 @@ #include "softmax.cuh" -template -static __global__ void soft_max_f32(const float * x, const float * mask, const float * pos, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) { +template +static __device__ __forceinline__ float t2f32(T val) { + return (float) val; +} + +template <> +__device__ float __forceinline__ t2f32(half val) { + return __half2float(val); +} + +template +static __global__ void soft_max_f32(const float * x, const T * mask, const T * pos, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) { const int ncols = ncols_template == 0 ? ncols_par : ncols_template; const int tid = threadIdx.x; @@ -43,7 +53,7 @@ static __global__ void soft_max_f32(const float * x, const float * mask, const f const int64_t ix = (int64_t)rowx*ncols + col; const int64_t iy = (int64_t)rowy*ncols + col; - const float val = x[ix]*scale + (mask ? mask[iy] : 0.0f) + (pos ? slope*pos[col] : 0.0f); + const float val = x[ix]*scale + (mask ? t2f32(mask[iy]) : 0.0f) + (pos ? slope*t2f32(pos[col]) : 0.0f); vals[col] = val; max_val = max(max_val, val); @@ -114,7 +124,8 @@ static __global__ void soft_max_f32(const float * x, const float * mask, const f } } -static void soft_max_f32_cuda(const float * x, const float * mask, const float * pos, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) { +template +static void soft_max_f32_cuda(const float * x, const T * mask, const T * pos, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) { int nth = WARP_SIZE; while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2; const dim3 block_dims(nth, 1, 1); @@ -167,15 +178,19 @@ static void soft_max_f32_cuda(const float * x, const float * mask, const float * void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; + const ggml_tensor * src2 = dst->src[2]; + const float * src0_d = (const float *)src0->data; - const float * src1_d = src1 ? (const float *)src1->data : nullptr; + const void * src1_d = src1 ? (const void *)src1->data : nullptr; + float * dst_d = (float *)dst->data; cudaStream_t stream = ctx.stream(); GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); - GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional + GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional + GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F16 || src2->type == GGML_TYPE_F32); // src2 contains positions and it is optional const int64_t ne00 = src0->ne[0]; const int64_t nrows_x = ggml_nrows(src0); @@ -188,14 +203,25 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); // positions tensor - float * src2_dd = nullptr; + void * src2_d = nullptr; - ggml_tensor * src2 = dst->src[2]; const bool use_src2 = src2 != nullptr; if (use_src2) { - src2_dd = (float *)src2->data; + src2_d = (void *)src2->data; } - soft_max_f32_cuda(src0_d, src1_d, src2_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream); + const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16); + + if (use_f16) { + const half * src1_dd = (const half *)src1_d; + const half * src2_dd = (const half *)src2_d; + + soft_max_f32_cuda(src0_d, src1_dd, src2_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream); + } else { + const float * src1_dd = (const float *)src1_d; + const float * src2_dd = (const float *)src2_d; + + soft_max_f32_cuda(src0_d, src1_dd, src2_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream); + } } diff --git a/ggml-kompute.cpp b/ggml-kompute.cpp index 407062e6f..9a469821d 100644 --- a/ggml-kompute.cpp +++ b/ggml-kompute.cpp @@ -1427,6 +1427,7 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml for (int i = node_start; i < node_end; ++i) { struct ggml_tensor * src0 = gf->nodes[i]->src[0]; struct ggml_tensor * src1 = gf->nodes[i]->src[1]; + struct ggml_tensor * src2 = gf->nodes[i]->src[2]; GGML_UNUSED(src2); struct ggml_tensor * dst = gf->nodes[i]; GGML_ASSERT(dst->data != nullptr); @@ -1559,6 +1560,12 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml { float scale; memcpy(&scale, dst->op_params, sizeof(float)); + +#pragma message("TODO: add ggml_vk_soft_max() F16/F32 src1 and src2 support") +#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021") + GGML_ASSERT(!src1 || src1t == GGML_TYPE_F32); + GGML_ASSERT(src2 == nullptr); + ggml_vk_soft_max(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale); } break; case GGML_OP_DIAG_MASK_INF: diff --git a/ggml-metal.m b/ggml-metal.m index 9cb421988..c6d580b84 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -46,8 +46,10 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, GGML_METAL_KERNEL_TYPE_SILU, GGML_METAL_KERNEL_TYPE_SILU_4, - GGML_METAL_KERNEL_TYPE_SOFT_MAX, - GGML_METAL_KERNEL_TYPE_SOFT_MAX_4, + GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, + GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, + GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, + GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, @@ -177,6 +179,14 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, GGML_METAL_KERNEL_TYPE_CPY_F32_F16, GGML_METAL_KERNEL_TYPE_CPY_F32_F32, GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, @@ -443,7 +453,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { } /* - GGML_METAL_LOG_INFO("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \ + GGML_METAL_LOG_INFO("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \ (int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \ (int) kernel->pipeline.threadExecutionWidth); \ */ @@ -459,172 +469,182 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { return NULL; \ } \ } else { \ - GGML_METAL_LOG_WARN("%s: skipping %-32s (not supported)\n", __func__, "kernel_"#name); \ + GGML_METAL_LOG_WARN("%s: skipping %-40s (not supported)\n", __func__, "kernel_"#name); \ } // simd_sum and simd_max requires MTLGPUFamilyApple7 - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX, soft_max, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_4, soft_max_4, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K, get_rows_q5_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S, get_rows_iq3_s, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, get_rows_iq2_s, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M, get_rows_iq1_m, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction); - //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction); - //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, ctx->support_simdgroup_reduction); - //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32, arange_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, soft_max_f32_4, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K, get_rows_q5_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S, get_rows_iq3_s, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, get_rows_iq2_s, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M, get_rows_iq1_m, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction); + //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction); + //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, ctx->support_simdgroup_reduction); + //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32, arange_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); } [metal_library release]; @@ -743,6 +763,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_ARGSORT: case GGML_OP_LEAKY_RELU: + case GGML_OP_FLASH_ATTN_EXT: return true; case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: @@ -1326,20 +1347,33 @@ static enum ggml_status ggml_metal_graph_compute( } break; case GGML_OP_SOFT_MAX: { + GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); + GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F16 || src2->type == GGML_TYPE_F32); + int nth = 32; // SIMD width id pipeline = nil; + const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16); + if (ne00%4 == 0) { while (nth < ne00/4 && nth < 256) { nth *= 2; } - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_4].pipeline; + if (use_f16) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4].pipeline; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4].pipeline; + } } else { while (nth < ne00 && nth < 1024) { nth *= 2; } - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX].pipeline; + if (use_f16) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16].pipeline; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32].pipeline; + } } float scale; @@ -2503,6 +2537,161 @@ static enum ggml_status ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; + case GGML_OP_FLASH_ATTN_EXT: + { + GGML_ASSERT(ne00 % 4 == 0); + GGML_ASSERT(src0->type == GGML_TYPE_F32); + + struct ggml_tensor * src3 = gf->nodes[i]->src[3]; + + GGML_ASSERT(ggml_are_same_shape(src1, src2)); + GGML_ASSERT(src3); + + size_t offs_src3 = 0; + + id id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil; + + GGML_ASSERT(!src3 || src3->type == GGML_TYPE_F16); + GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) && + "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big"); + + const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30); + const int64_t ne31 = src3 ? src3->ne[1] : 0; + const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32); + const int64_t ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33); + + const uint64_t nb30 = src3 ? src3->nb[0] : 0; GGML_UNUSED(nb30); + const uint64_t nb31 = src3 ? src3->nb[1] : 0; + const uint64_t nb32 = src3 ? src3->nb[2] : 0; GGML_UNUSED(nb32); + const uint64_t nb33 = src3 ? src3->nb[3] : 0; GGML_UNUSED(nb33); + + const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t); + + float scale; + memcpy(&scale, dst->op_params, sizeof(float)); + + id pipeline = nil; + + bool use_vec_kernel = false; + + if (ne01 >= 4 || (ne00%128 != 0)) { + switch (ne00) { + case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break; + case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break; + case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break; + case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break; + case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break; + case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break; + default: + { + GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); + GGML_METAL_LOG_ERROR("add template specialization for this size\n"); + GGML_ASSERT(false && "add template specialization for this size"); + } + } + } else { + use_vec_kernel = true; + + switch (ne00) { + case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break; + case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break; + default: + { + GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); + GGML_METAL_LOG_ERROR("add template specialization for this size\n"); + GGML_ASSERT(false && "add template specialization for this size"); + } + } + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; + [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:4]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:5]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:6]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:7]; + [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:8]; + [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:10]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:11]; + [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:12]; + [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:14]; + [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:15]; + [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:16]; + [encoder setBytes:&nb10 length:sizeof(uint64_t) atIndex:17]; + [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:18]; + [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:19]; + [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:20]; + [encoder setBytes:&ne31 length:sizeof( int64_t) atIndex:21]; + [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:22]; + [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:23]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:24]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:25]; + [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; + [encoder setBytes:&scale length:sizeof( float) atIndex:27]; + + if (!use_vec_kernel) { + // half8x8 kernel + const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! + const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! + + GGML_ASSERT(nqptg <= 32); + GGML_ASSERT(nqptg % 8 == 0); + GGML_ASSERT(ncpsg % 32 == 0); + + int64_t nsgmax = 2; + + while (true) { + const size_t smem = nqptg*(ne00 + 2*nsgmax*(ncpsg + nqptg))*(sizeof(float)/2); + if (smem > ctx->device.maxThreadgroupMemoryLength) { + break; + } + nsgmax *= 2; + } + nsgmax /= 2; + + // simdgroups per threadgroup (a.k.a. warps) + const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4; + + const size_t smem = nqptg*(ne00 + 2*nsg*(ncpsg + nqptg))*(sizeof(float)/2); + + //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength); + GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); + + [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0]; + + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; + } else { + // half1x4 kernel + const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !! + const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! + + GGML_ASSERT(nqptg <= 32); + GGML_ASSERT(nqptg % 1 == 0); + GGML_ASSERT(ncpsg % 32 == 0); + + // simdgroups per threadgroup (a.k.a. warps) + const int64_t nsgt = MAX(2, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)); + + int64_t nsg = 1; + while (nsg <= nsgt) { + nsg *= 2; + } + nsg /= 2; + + const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2); + + //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength); + GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); + [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0]; + + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; + } + } break; case GGML_OP_DUP: case GGML_OP_CPY: case GGML_OP_CONT: @@ -2706,10 +2895,13 @@ GGML_CALL static const char * ggml_backend_metal_buffer_type_get_name(ggml_backe UNUSED(buft); } -static void ggml_backend_metal_log_allocated_size(id device) { +static void ggml_backend_metal_log_allocated_size(id device, size_t size_aligned) { +#ifndef GGML_METAL_NDEBUG #if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15) if (@available(macOS 10.12, iOS 16.0, *)) { - GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)", + GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, (%8.2f / %8.2f)", + __func__, + size_aligned / 1024.0 / 1024.0, device.currentAllocatedSize / 1024.0 / 1024.0, device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0); @@ -2719,10 +2911,15 @@ static void ggml_backend_metal_log_allocated_size(id device) { GGML_METAL_LOG_INFO("\n"); } } else { - GGML_METAL_LOG_INFO(", (%8.2f)\n", device.currentAllocatedSize / 1024.0 / 1024.0); + GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, (%8.2f)\n", + __func__, + size_aligned / 1024.0 / 1024.0, + device.currentAllocatedSize / 1024.0 / 1024.0); } +#endif #endif UNUSED(device); + UNUSED(size_aligned); } GGML_CALL static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { @@ -2756,8 +2953,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buff return NULL; } - GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB", __func__, size_aligned / 1024.0 / 1024.0); - ggml_backend_metal_log_allocated_size(device); + //ggml_backend_metal_log_allocated_size(device, size_aligned); return ggml_backend_buffer_init(buft, ggml_backend_metal_buffer_i, ctx, size); } @@ -2844,7 +3040,7 @@ GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, return false; } - GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB", __func__, size_aligned / 1024.0 / 1024.0); + ggml_backend_metal_log_allocated_size(device, size_aligned); ++ctx->n_buffers; } else { @@ -2867,7 +3063,8 @@ GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, return false; } - GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, offs = %12ld", __func__, size_step_aligned / 1024.0 / 1024.0, i); + ggml_backend_metal_log_allocated_size(device, size_step_aligned); + if (i + size_step < size) { GGML_METAL_LOG_INFO("\n"); } @@ -2876,8 +3073,6 @@ GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, } } - ggml_backend_metal_log_allocated_size(device); - return ggml_backend_buffer_init(ggml_backend_metal_buffer_type(), ggml_backend_metal_buffer_i, ctx, size); } diff --git a/ggml-metal.metal b/ggml-metal.metal index 191880af1..3d4276ae0 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -352,11 +352,12 @@ kernel void kernel_sum_rows( dst_row[0] = row_sum; } +template kernel void kernel_soft_max( - device const float * src0, - device const float * src1, - device const float * src2, - device float * dst, + device const char * src0, + device const char * src1, + device const char * src2, + device char * dst, constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, @@ -375,10 +376,10 @@ kernel void kernel_soft_max( const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); - device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr; - device const float * ppos = src2 != src0 ? src2 : nullptr; - device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00 : nullptr; + device const T * ppos = src2 != src0 ? (device const T *) src2 : nullptr; + device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); float slope = 0.0f; @@ -456,11 +457,12 @@ kernel void kernel_soft_max( } } +template kernel void kernel_soft_max_4( - device const float * src0, - device const float * src1, - device const float * src2, - device float * dst, + device const char * src0, + device const char * src1, + device const char * src2, + device char * dst, constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, @@ -479,10 +481,10 @@ kernel void kernel_soft_max_4( const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); - device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); - device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr; - device const float4 * ppos = src2 != src0 ? (device const float4 *)(src2) : nullptr; - device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4; + device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00/4 : nullptr; + device const T * ppos = src2 != src0 ? (device const T *) src2 : nullptr; + device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4; float slope = 0.0f; @@ -499,7 +501,7 @@ kernel void kernel_soft_max_4( float4 lmax4 = -INFINITY; for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { - lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)); + lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f))); } const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3])); @@ -525,7 +527,7 @@ kernel void kernel_soft_max_4( // parallel sum float4 lsum4 = 0.0f; for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { - const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)) - max_val); + const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f))) - max_val); lsum4 += exp_psrc4; pdst4[i00] = exp_psrc4; } @@ -562,6 +564,14 @@ kernel void kernel_soft_max_4( } } +typedef decltype(kernel_soft_max) kernel_soft_max_t; +typedef decltype(kernel_soft_max_4) kernel_soft_max_4_t; + +template [[host_name("kernel_soft_max_f16")]] kernel kernel_soft_max_t kernel_soft_max; +template [[host_name("kernel_soft_max_f32")]] kernel kernel_soft_max_t kernel_soft_max; +template [[host_name("kernel_soft_max_f16_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4; +template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4; + kernel void kernel_diag_mask_inf( device const float * src0, device float * dst, @@ -2084,6 +2094,632 @@ kernel void kernel_leaky_relu_f32( dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope; } +typedef void (flash_attn_ext_f16_t)( + device const char * q, + device const char * k, + device const char * v, + device const char * mask, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant int64_t & ne31, + constant uint64_t & nb31, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant float & scale, + threadgroup half * shared, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]); + +// ref: https://arxiv.org/pdf/2307.08691.pdf +template // head size, queries per threadgroup, cache items per threadgroup +kernel void kernel_flash_attn_ext_f16( + device const char * q, + device const char * k, + device const char * v, + device const char * mask, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant int64_t & ne31, + constant uint64_t & nb31, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant float & scale, + threadgroup half * shared [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + const short nsg = ntg.y; // number of simdgroups + + const short iq3 = tgpig[2]; + const short iq2 = tgpig[1]; + const short iq1 = tgpig[0]*Q; + + const short D4 = D/4; + const short D8 = D/8; + const short Q8 = Q/8; + const short NW = N_SIMDWIDTH; + const short SH = (C + Q); // shared memory per simdgroup in (half) + + const short T = D + 2*nsg*SH; // shared memory size per query in (half) + const short TF = T/2; // shared memory size per query in (float) + const short T4 = T/4; // shared memory size per query in (half4) + + threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data + threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4 + threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix + + // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) + simdgroup_half8x8 lo[D8]; + + // load heads from Q to shared memory + for (short j = sgitg; j < Q; j += nsg) { + device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); + + for (short i = tiisg; i < D4; i += NW) { + if (iq1 + j < ne01) { + sq4[j*T4 + i] = (half4) q4[i]; + } else { + sq4[j*T4 + i] = 0.0h; + } + } + } + + // zero out lo + for (short i = 0; i < D8; ++i) { + lo[i] = make_filled_simdgroup_matrix(0.0h); + } + + // zero out shared memory SH + for (short j = 0; j < Q; ++j) { + for (short i = tiisg; i < SH; i += NW) { + ss[j*TF + i] = 0.0f; + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + { + float S[Q] = { [0 ... Q-1] = 0.0h }; + float M[Q] = { [0 ... Q-1] = -FLT_MAX/2 }; + + // assume K and V are same shape + const short ne22 = ne12; + const short ne23 = ne13; + + const uint nb21 = nb11; + const uint nb22 = nb12; + const uint nb23 = nb13; + + // broadcast + const short rk2 = ne02/ne12; + const short rk3 = ne03/ne13; + + const short rv2 = ne02/ne22; + const short rv3 = ne03/ne23; + + // k indices + const short ik2 = iq2/rk2; + const short ik3 = iq3/rk3; + + // v indices + const short iv2 = iq2/rv2; + const short iv3 = iq3/rv3; + + // load the queries from shared memory into local memory + simdgroup_half8x8 mq[D8]; + + for (short i = 0; i < D8; ++i) { + simdgroup_load(mq[i], sq + i*8, T); + } + + // pointer to the mask + device const half * mp = (device const half *) (mask + iq1*nb31); + + // prepare diagonal scale matrix + simdgroup_float8x8 mscale(scale); + + // loop over the KV cache + // each simdgroup handles blocks of Q rows and C columns + for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) { + const int ic = ic0 + C*sgitg; + if (ic >= ne11) { + break; + } + + // Q*K^T + { + for (short cc = 0; cc < C/8; ++cc) { + simdgroup_float8x8 mqk = make_filled_simdgroup_matrix(0.h); + + device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13)); + + for (short i = 0; i < D8; ++i) { + simdgroup_half8x8 mk; + simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose + + simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk); + } + + // mqk = mqk*scale + mask + simdgroup_half8x8 mm; + simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(half), 0, false); + simdgroup_multiply_accumulate(mqk, mqk, mscale, mm); + + simdgroup_store(mqk, ss + 8*cc, TF, 0, false); + } + } + + // used to detect blocks full of -INF + float smax = -INFINITY; + + // online softmax + { + float ms[Q]; + + for (short j = 0; j < Q; ++j) { + const short p = tiisg; + + const float m = M[j]; + const float s = ss[j*TF + p]; + + smax = simd_max(max(smax, s)); + M[j] = simd_max(max(M[j], s)); + + ms[j] = exp(m - M[j]); + const float vs = exp(s - M[j]); + + S[j] = S[j]*ms[j] + simd_sum(vs); + + // the P matrix from the paper (Q rows, C columns) + ss[j*TF + p] = vs; + } + + // create a QxQ diagonal matrix for rescaling the output + if (tiisg < Q) { + ss[tiisg*TF + C + tiisg] = ms[tiisg]; + } + } + + // skip -INF blocks + if (smax == -INFINITY) { + continue; + } + + // O = diag(ms)*O + { + simdgroup_float8x8 mm; + simdgroup_load(mm, ss + C, TF, 0, false); + + for (short i = 0; i < D8; ++i) { + simdgroup_multiply(lo[i], mm, lo[i]); + } + } + + // O = O + (Q*K^T)*V + { + for (short cc = 0; cc < C/8; ++cc) { + device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23)); + + for (short i = 0; i < D8; ++i) { + simdgroup_half8x8 mk; + simdgroup_load(mk, pv + i*8, nb21/sizeof(half), 0, false); + + simdgroup_float8x8 mv; + simdgroup_load(mv, ss + 8*cc, TF, 0, false); + + simdgroup_multiply_accumulate(lo[i], mv, mk, lo[i]); + } + } + } + } + + // these are needed for reducing the results from the simdgroups (reuse the ss buffer) + for (short j = 0; j < Q; ++j) { + if (tiisg == 0) { + ss[j*TF + 0] = S[j]; + ss[j*TF + 1] = M[j]; + } + } + } + + // reduce the warps sequentially + for (short sg = 1; sg < nsg; ++sg) { + float S = { 0.0h }; + float M = { -FLT_MAX/2 }; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // each simdgroup stores its output to shared memory, reusing sq + if (sgitg == sg) { + for (short i = 0; i < D8; ++i) { + simdgroup_store(lo[i], sq + i*8, T, 0, false); + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // the first simdgroup accumulates the results from the other simdgroups + if (sgitg == 0) { + for (short j = 0; j < Q; ++j) { + const float S0 = ss[j*TF + 0]; + const float S1 = ss[j*TF + sg*SH + 0]; + + const float M0 = ss[j*TF + 1]; + const float M1 = ss[j*TF + sg*SH + 1]; + + M = max(M0, M1); + + const float ms0 = exp(M0 - M); + const float ms1 = exp(M1 - M); + + S = S0*ms0 + S1*ms1; + + if (tiisg == 0) { + ss[j*TF + 0] = S; + ss[j*TF + 1] = M; + + ss[j*TF + C + j ] = ms0; + ss[j*TF + C + j + sg*SH] = ms1; + } + } + + // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 + { + simdgroup_half8x8 t; + simdgroup_float8x8 ms0; + simdgroup_float8x8 ms1; + + simdgroup_load(ms0, ss + C, TF, 0, false); + simdgroup_load(ms1, ss + C + sg*SH, TF, 0, false); + + for (short i = 0; i < D8; ++i) { + simdgroup_load (t, sq + i*8, T, 0, false); + simdgroup_multiply(t, ms1, t); + + simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t); + } + } + } + } + + // store result to shared memory (reuse sq) + if (sgitg == 0) { + for (short i = 0; i < D8; ++i) { + simdgroup_store(lo[i], sq + i*8, T, 0, false); + } + } + + device float4 * dst4 = (device float4 *) dst; + + // final rescale with 1/S and store to global memory + if (sgitg == 0) { + for (short j = 0; j < Q && iq1 + j < ne01; ++j) { + const float S = ss[j*TF + 0]; + + for (short i = tiisg; i < D4; i += NW) { + dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sq4[j*T4 + i]/S; + } + } + } +} + +template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64>; +template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80>; +template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96>; +template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112>; +template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128>; +template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>; + +template // head size, queries per threadgroup, cache items per threadgroup +kernel void kernel_flash_attn_ext_vec_f16( + device const char * q, + device const char * k, + device const char * v, + device const char * mask, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant int64_t & ne31, + constant uint64_t & nb31, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant float & scale, + threadgroup half * shared [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + const short nsg = ntg.y; // number of simdgroups + + const short iq3 = tgpig[2]; + const short iq2 = tgpig[1]; + const short iq1 = tgpig[0]; + + const short D4 = D/4; + const short NW = N_SIMDWIDTH; + const short SH = (C + Q); // shared memory per simdgroup in (half) + + const short T = D + 2*nsg*SH; // shared memory size per query in (half) + + //threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data + threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4 + threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix + threadgroup float4 * ss4 = (threadgroup float4 *) (shared + 2*sgitg*SH + 1*D); // same as above but in half4 + threadgroup half4 * sr4 = (threadgroup half4 *) (shared + sgitg*D + 1*T); // scratch buffer for the results + + // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) + half4 lo[D4/NW]; + + // load heads from Q to shared memory + device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)); + + for (short i = tiisg; i < D4; i += NW) { + if (iq1 < ne01) { + sq4[i] = (half4) q4[i]; + } else { + sq4[i] = 0.0h; + } + } + + // zero out lo + for (short i = tiisg; i < D4; i += NW) { + lo[i/NW] = 0.0h; + } + + // zero out shared memory SH + for (short i = tiisg; i < SH/4; i += NW) { + ss4[i] = 0.0h; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + { + float S = { 0.0h }; + float M = { -FLT_MAX/2 }; + + // assume K and V are same shape + const short ne22 = ne12; + const short ne23 = ne13; + + const uint nb21 = nb11; + const uint nb22 = nb12; + const uint nb23 = nb13; + + // broadcast + const short rk2 = ne02/ne12; + const short rk3 = ne03/ne13; + + const short rv2 = ne02/ne22; + const short rv3 = ne03/ne23; + + // k indices + const short ik2 = iq2 / rk2; + const short ik3 = iq3 / rk3; + + // v indices + const short iv2 = iq2 / rv2; + const short iv3 = iq3 / rv3; + + // load the queries from shared memory into local memory + half4 mq[D4]; + + for (short ii = 0; ii < D4; ii += NW) { + short i = ii + tiisg; + mq[i] = sq4[i]; + } + + // pointer to the mask + device const half4 * mp4 = (device const half4 *) (mask + iq1*nb31); + + // loop over the KV cache + // each simdgroup handles blocks of Q rows and C columns + for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) { + const int ic = ic0 + C*sgitg; + if (ic >= ne11) { + break; + } + + // Q*K^T + { +#pragma unroll + for (short cc = 0; cc < C/4; ++cc) { + float4 mqk = { 0.0h }; + + device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13)); + +#pragma unroll + for (short ii = 0; ii < D4; ii += NW) { + const short i = ii + tiisg; + + half4x4 mk; + mk[0] = pk4[i + 0*(nb11/8)]; + mk[1] = pk4[i + 1*(nb11/8)]; + mk[2] = pk4[i + 2*(nb11/8)]; + mk[3] = pk4[i + 3*(nb11/8)]; + + mqk += (float4) (mq[i] * mk); + } + + // reduce the results from the threads in the simdgroup + mqk += simd_shuffle_down(mqk, 16); + mqk += simd_shuffle_down(mqk, 8); + mqk += simd_shuffle_down(mqk, 4); + mqk += simd_shuffle_down(mqk, 2); + mqk += simd_shuffle_down(mqk, 1); + + // mqk = mqk*scale + mask + if (tiisg == 0) { + float4 mm = (float4) mp4[ic/4 + cc]; + mqk = mqk*scale + mm; + + ss4[cc] = mqk; + } + } + } + + // online softmax + { + const short p = tiisg; + + const float m = M; + const float s = ss[p]; + + M = simd_max(max(M, s)); + + const float ms = exp(m - M); + const float vs = exp(s - M); + + S = S*ms + simd_sum(vs); + + // the P matrix from the paper (Q rows, C columns) + ss[p] = vs; + + // O = diag(ms)*O +#pragma unroll + for (short ii = 0; ii < D4; ii += NW) { + const short i = ii + tiisg; + lo[i/NW] *= ms; + } + } + + // O = O + (Q*K^T)*V + { +#pragma unroll + for (short cc = 0; cc < C/4; ++cc) { + device const half4 * pv4 = (device const half4 *) ((device const char *) v + ((ic + 4*cc)*nb21 + iv2*nb22 + iv3*nb23)); + +#pragma unroll + for (short ii = 0; ii < D4; ii += NW) { + const short i = ii + tiisg; + + lo[i/NW] += pv4[i + 0*(nb21/8)] * ss[4*cc + 0]; + lo[i/NW] += pv4[i + 1*(nb21/8)] * ss[4*cc + 1]; + lo[i/NW] += pv4[i + 2*(nb21/8)] * ss[4*cc + 2]; + lo[i/NW] += pv4[i + 3*(nb21/8)] * ss[4*cc + 3]; + } + } + } + + } + + // these are needed for reducing the results from the simdgroups (reuse the ss buffer) + if (tiisg == 0) { + ss[0] = S; + ss[1] = M; + } + } + + // store results to shared memory + for (short ii = 0; ii < D4; ii += NW) { + short i = ii + tiisg; + sr4[i] = lo[ii/NW]; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // parallel reduce + for (short r = nsg/2; r > 0; r >>= 1) { + if (sgitg < r) { + const float S0 = ss[ 0]; + const float S1 = ss[r*SH + 0]; + + const float M0 = ss[ 1]; + const float M1 = ss[r*SH + 1]; + + const float M = max(M0, M1); + + const float ms0 = exp(M0 - M); + const float ms1 = exp(M1 - M); + + const float S = S0*ms0 + S1*ms1; + + if (tiisg == 0) { + ss[0] = S; + ss[1] = M; + } + + // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 + for (short ii = 0; ii < D4; ii += NW) { + short i = ii + tiisg; + sr4[i] = sr4[i]*ms0 + sr4[i + r*D4]*ms1; + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + device float4 * dst4 = (device float4 *) dst; + + // final rescale with 1/S and store to global memory + if (sgitg == 0) { + const float S = ss[0]; + + for (short ii = 0; ii < D4; ii += NW) { + short i = ii + tiisg; + dst4[(iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D4 + i] = (float4) sr4[i]/S; + } + } +} + +template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>; +template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>; + kernel void kernel_cpy_f16_f16( device const half * src0, device half * dst, diff --git a/ggml-sycl.cpp b/ggml-sycl.cpp index 2b76b3ebd..57fe4ea3d 100644 --- a/ggml-sycl.cpp +++ b/ggml-sycl.cpp @@ -14744,7 +14744,12 @@ inline void ggml_sycl_op_soft_max(const ggml_tensor *src0, GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); + const ggml_tensor * src2 = dst->src[2]; + +#pragma message("TODO: add ggml_sycl_op_soft_max() F16 src1 and src2 support") +#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021") GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional + GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32); // src2 contains positions and it is optional const int64_t ne00 = src0->ne[0]; const int64_t nrows_x = ggml_nrows(src0); @@ -14760,7 +14765,6 @@ inline void ggml_sycl_op_soft_max(const ggml_tensor *src0, float * src2_dd = nullptr; sycl_pool_alloc src2_f; - ggml_tensor * src2 = dst->src[2]; const bool use_src2 = src2 != nullptr; if (use_src2) { diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp index 1736ab736..f712cdd5a 100644 --- a/ggml-vulkan.cpp +++ b/ggml-vulkan.cpp @@ -3178,6 +3178,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const } return nullptr; case GGML_OP_SOFT_MAX: +#pragma message("TODO: add ggml_vk_soft_max() F16 src1 and src2 support") +#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021") + GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); + GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32); + if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && (src2 == nullptr || src2->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_soft_max_f32; } diff --git a/ggml.c b/ggml.c index cb273061c..74ecd5927 100644 --- a/ggml.c +++ b/ggml.c @@ -951,7 +951,7 @@ ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) { #define GGML_F16_VEC_ZERO GGML_F16x8_ZERO #define GGML_F16_VEC_SET1 GGML_F16x8_SET1 #define GGML_F16_VEC_LOAD(p, i) GGML_F16x8_LOAD(p) - #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE(p, r[i]) + #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE((ggml_fp16_internal_t *)(p), r[i]) #define GGML_F16_VEC_FMA GGML_F16x8_FMA #define GGML_F16_VEC_ADD GGML_F16x8_ADD #define GGML_F16_VEC_MUL GGML_F16x8_MUL @@ -977,7 +977,7 @@ ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) { #define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO #define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1 #define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p) - #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE(p, r[i]) + #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE((ggml_fp16_internal_t *)(p), r[i]) #define GGML_F16_VEC_FMA GGML_F32Cx4_FMA #define GGML_F16_VEC_ADD GGML_F32Cx4_ADD #define GGML_F16_VEC_MUL GGML_F32Cx4_MUL @@ -1046,7 +1046,7 @@ do { \ // unlike _mm256_cvt intrinsics that require F16C, _mm512_cvt is defined in AVX512F // so F16C guard isn't required -#define GGML_F32Cx16_LOAD(x) _mm512_cvtph_ps(_mm256_loadu_si256((__m256i *)(x))) +#define GGML_F32Cx16_LOAD(x) _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(x))) #define GGML_F32Cx16_STORE(x, y) _mm256_storeu_si256((__m256i *)(x), _mm512_cvtps_ph(y, 0)) #define GGML_F32Cx16_FMA(a, b, c) _mm512_fmadd_ps(b, c, a) @@ -1144,7 +1144,7 @@ do { \ #if defined(__F16C__) // the _mm256_cvt intrinsics require F16C -#define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((__m128i *)(x))) +#define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x))) #define GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0)) #else static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) { @@ -1662,6 +1662,37 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float #endif } +inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, const ggml_fp16_t * restrict x, const float v) { +#if defined(GGML_SIMD) + const int np = (n & ~(GGML_F16_STEP - 1)); + + GGML_F16_VEC vx = GGML_F16_VEC_SET1(v); + + GGML_F16_VEC ax[GGML_F16_ARR]; + GGML_F16_VEC ay[GGML_F16_ARR]; + + for (int i = 0; i < np; i += GGML_F16_STEP) { + for (int j = 0; j < GGML_F16_ARR; j++) { + ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j); + ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); + ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx); + + GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j); + } + } + + // leftovers + for (int i = np; i < n; ++i) { + y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v); + } +#else + // scalar + for (int i = 0; i < n; ++i) { + y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v); + } +#endif +} + // xs and vs are byte strides of x and v inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int vs, float * restrict y, const float * restrict xv, const float * restrict vv) { @@ -1746,6 +1777,35 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { #endif } +inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float v) { +#if defined(GGML_SIMD) + const int np = (n & ~(GGML_F16_STEP - 1)); + + GGML_F16_VEC vx = GGML_F16_VEC_SET1(v); + + GGML_F16_VEC ay[GGML_F16_ARR]; + + for (int i = 0; i < np; i += GGML_F16_STEP) { + for (int j = 0; j < GGML_F16_ARR; j++) { + ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); + ay[j] = GGML_F16_VEC_MUL(ay[j], vx); + + GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j); + } + } + + // leftovers + for (int i = np; i < n; ++i) { + y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v); + } +#else + // scalar + for (int i = 0; i < n; ++i) { + y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v); + } +#endif +} + inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, 0, x, 0, x, 0, 1); *s = sqrtf(*s); } inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; } inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); } @@ -2000,6 +2060,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "LEAKY_RELU", "FLASH_ATTN", + "FLASH_ATTN_EXT", "FLASH_FF", "FLASH_ATTN_BACK", "SSM_CONV", @@ -2026,7 +2087,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS_BACK", }; -static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76"); +static_assert(GGML_OP_COUNT == 77, "GGML_OP_COUNT != 77"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -2090,6 +2151,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "leaky_relu(x)", "flash_attn(x)", + "flash_attn_ext(x)", "flash_ff(x)", "flash_attn_back(x)", "ssm_conv(x)", @@ -2116,7 +2178,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cross_entropy_loss_back(x,y)", }; -static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76"); +static_assert(GGML_OP_COUNT == 77, "GGML_OP_COUNT != 77"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -4559,6 +4621,8 @@ struct ggml_tensor * ggml_mul_mat( void ggml_mul_mat_set_prec( struct ggml_tensor * a, enum ggml_prec prec) { + GGML_ASSERT(a->op == GGML_OP_MUL_MAT); + const int32_t prec_i32 = (int32_t) prec; ggml_set_op_params_i32(a, 0, prec_i32); @@ -5397,17 +5461,23 @@ static struct ggml_tensor * ggml_soft_max_impl( GGML_ASSERT(ggml_is_contiguous(a)); if (mask) { + GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32); GGML_ASSERT(ggml_is_contiguous(mask)); GGML_ASSERT(ggml_is_matrix(mask)); - GGML_ASSERT(ggml_can_repeat_rows(mask, a)); + GGML_ASSERT(mask->ne[0] == a->ne[0]); + GGML_ASSERT(mask->ne[1] >= a->ne[1]); } if (pos) { GGML_ASSERT(ggml_is_vector(pos)); - GGML_ASSERT(pos->type == GGML_TYPE_F32); + GGML_ASSERT(pos->type == GGML_TYPE_F16 || pos->type == GGML_TYPE_F32); GGML_ASSERT(pos->ne[0] == a->ne[0]); } + if (pos && mask) { + GGML_ASSERT(pos->type == mask->type); + } + if (max_bias > 0.0f) { GGML_ASSERT(pos); } @@ -6216,6 +6286,59 @@ struct ggml_tensor * ggml_flash_attn( return result; } +// ggml_flash_attn_ext + +struct ggml_tensor * ggml_flash_attn_ext( + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * mask, + float scale) { + GGML_ASSERT(ggml_can_mul_mat(k, q)); + // TODO: check if vT can be multiplied by (k*qT) + if (mask) { + GGML_ASSERT(ggml_is_contiguous(mask)); + GGML_ASSERT(mask->ne[2] == 1); + GGML_ASSERT(mask->ne[3] == 1); + GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) && + "the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big"); + //GGML_ASSERT(ggml_can_repeat_rows(mask, qk)); + } + + bool is_node = false; + + if (q->grad || k->grad || v->grad) { + is_node = true; + } + + // permute(0, 2, 1, 3) + int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + + float params[] = { scale }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_FLASH_ATTN_EXT; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = q; + result->src[1] = k; + result->src[2] = v; + result->src[3] = mask; + + return result; +} + +void ggml_flash_attn_ext_set_prec( + struct ggml_tensor * a, + enum ggml_prec prec) { + GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT); + + const int32_t prec_i32 = (int32_t) prec; + + ggml_set_op_params_i32(a, 1, prec_i32); // scale is on first pos +} + // ggml_flash_ff struct ggml_tensor * ggml_flash_ff( @@ -12255,7 +12378,7 @@ static void ggml_compute_forward_soft_max_f32( GGML_TENSOR_UNARY_OP_LOCALS - const int64_t ne11 = src1 ? src1->ne[1] : 1; + //const int64_t ne11 = src1 ? src1->ne[1] : 1; // TODO: is this supposed to be ceil instead of floor? // https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370 @@ -12278,19 +12401,31 @@ static void ggml_compute_forward_soft_max_f32( float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith; // when max_bias <= 0.0f, src2 is not used and we default it to src0 to avoid branching - float * pos = src2 ? (float *) src2->data : src0->data; + ggml_fp16_t * pos_f16 = src2 ? (ggml_fp16_t *) src2->data : src0->data; + float * pos_f32 = src2 ? (float *) src2->data : src0->data; + + const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16); for (int i1 = ir0; i1 < ir1; i1++) { float * sp = (float *)((char *) src0->data + i1*src0->nb[1]); float * dp = (float *)((char *) dst->data + i1*dst->nb[1]); // broadcast the mask across rows - float * mp = src1 ? (float *)((char *) src1->data + (i1%ne11)*src1->nb[1]) : NULL; + ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL; + float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL; ggml_vec_cpy_f32 (nc, wp, sp); ggml_vec_scale_f32(nc, wp, scale); - if (mp) { - ggml_vec_acc_f32(nc, wp, mp); + if (mp_f32) { + if (use_f16) { + for (int i = 0; i < nc; ++i) { + wp[i] += GGML_FP16_TO_FP32(mp_f16[i]); + } + } else { + for (int i = 0; i < nc; ++i) { + wp[i] += mp_f32[i]; + } + } } // ALiBi bias @@ -12298,8 +12433,14 @@ static void ggml_compute_forward_soft_max_f32( const uint32_t h = (i1/ne01)%ne02; // head const float slope = h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1); - for (int i = 0; i < nc; i++) { - wp[i] = wp[i] + slope*pos[i]; + if (use_f16) { + for (int i = 0; i < nc; ++i) { + wp[i] += slope*GGML_FP16_TO_FP32(pos_f16[i]); + } + } else { + for (int i = 0; i < nc; ++i) { + wp[i] += slope*pos_f32[i]; + } } } @@ -14569,6 +14710,198 @@ static void ggml_compute_forward_flash_attn( } } +// ggml_compute_forward_flash_attn_ext + +static void ggml_compute_forward_flash_attn_ext_f16( + const struct ggml_compute_params * params, + const struct ggml_tensor * q, + const struct ggml_tensor * k, + const struct ggml_tensor * v, + const struct ggml_tensor * mask, + struct ggml_tensor * dst) { + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + GGML_TENSOR_LOCALS(int64_t, neq, q, ne) + GGML_TENSOR_LOCALS(size_t, nbq, q, nb) + GGML_TENSOR_LOCALS(int64_t, nek, k, ne) + GGML_TENSOR_LOCALS(size_t, nbk, k, nb) + GGML_TENSOR_LOCALS(int64_t, nev, v, ne) + GGML_TENSOR_LOCALS(size_t, nbv, v, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t D = neq0; + const int64_t N = neq1; + + GGML_ASSERT(ne0 == D); + GGML_ASSERT(ne2 == N); + + GGML_ASSERT(nbq0 == sizeof(float)); + GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t)); + + GGML_ASSERT(neq0 == D); + GGML_ASSERT(nek0 == D); + GGML_ASSERT(nev0 == D); + + GGML_ASSERT(neq1 == N); + GGML_ASSERT(nev0 == D); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + // broadcast factors + const int64_t rk2 = neq2/nek2; + const int64_t rk3 = neq3/nek3; + + const int64_t rv2 = neq2/nev2; + const int64_t rv3 = neq3/nev3; + + if (params->type == GGML_TASK_TYPE_INIT) { + return; + } + + if (params->type == GGML_TASK_TYPE_FINALIZE) { + return; + } + + // parallelize by q rows using ggml_vec_dot_f32 + + // total rows in q + const int nr = neq1*neq2*neq3; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + float scale = 1.0f; + memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); + + // loop over n_batch and n_head + for (int ir = ir0; ir < ir1; ++ir) { + // q indices + const int iq3 = ir/(neq2*neq1); + const int iq2 = (ir - iq3*neq2*neq1)/neq1; + const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1); + + float S = 0.0f; + float M = -INFINITY; + + float * V32 = (float *) params->wdata + ith*(2*D + CACHE_LINE_SIZE_F32); + ggml_fp16_t * Q16 = (ggml_fp16_t *) (V32); // reuse memory + ggml_fp16_t * V16 = (ggml_fp16_t *) (V32 + D); + + memset(V16, 0, D*sizeof(ggml_fp16_t)); + + const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL; + + // k indices + const int ik3 = iq3 / rk3; + const int ik2 = iq2 / rk2; + + // v indices + const int iv3 = iq3 / rv3; + const int iv2 = iq2 / rv2; + + // online softmax / attention + // loop over n_kv and n_head_kv + // ref: https://arxiv.org/pdf/2112.05682.pdf + for (int64_t ic = 0; ic < nek1; ++ic) { + const float mv = mp ? GGML_FP16_TO_FP32(mp[ic]) : 0.0f; + if (mv == -INFINITY) { + continue; + } + + float s; + + // convert Q to F16 in V32 + { + const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)); + + for (int64_t d = 0; d < D; ++d) { + Q16[d] = GGML_FP32_TO_FP16(pq[d]); + } + } + + ggml_vec_dot_f16(D, + &s, 0, + (ggml_fp16_t *) ((char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3)), 0, + Q16, 0, 1); + + s = s*scale + mv; + + const float Mold = M; + + float ms = 1.0f; + float vs = 1.0f; + + if (s > M) { + M = s; + ms = expf(Mold - M); + + // V = V*expf(Mold - M) + ggml_vec_scale_f16(D, V16, ms); + } else { + vs = expf(s - M); + } + + const ggml_fp16_t * v16 = (const ggml_fp16_t *) ((char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3)); + + // V += v*expf(s - M) + ggml_vec_mad_f16(D, V16, v16, vs); + + S = S*ms + vs; + } + + // V /= S + for (int64_t d = 0; d < D; ++d) { + V32[d] = GGML_FP16_TO_FP32(V16[d])/S; + } + + // dst indices + const int i1 = iq1; + const int i2 = iq2; + const int i3 = iq3; + + // original + //memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float)); + + // permute(0, 2, 1, 3) + memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, V32, nb1); + } +} + +static void ggml_compute_forward_flash_attn_ext( + const struct ggml_compute_params * params, + const struct ggml_tensor * q, + const struct ggml_tensor * k, + const struct ggml_tensor * v, + const struct ggml_tensor * mask, + struct ggml_tensor * dst) { + switch (dst->op_params[1]) { + case GGML_PREC_DEFAULT: + case GGML_PREC_F32: + { + // uses F32 accumulators + ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + // ggml_compute_forward_flash_ff static void ggml_compute_forward_flash_ff_f16( @@ -16376,6 +16709,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm const bool masked = t != 0; ggml_compute_forward_flash_attn(params, masked, tensor); } break; + case GGML_OP_FLASH_ATTN_EXT: + { + ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor); + } break; case GGML_OP_FLASH_FF: { ggml_compute_forward_flash_ff(params, tensor); @@ -17388,6 +17725,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor GGML_ASSERT(false); // TODO: not implemented } break; case GGML_OP_FLASH_ATTN: + case GGML_OP_FLASH_ATTN_EXT: { struct ggml_tensor * flash_grad = NULL; if (src0->grad || src1->grad || tensor->src[2]->grad) { @@ -18160,6 +18498,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_ n_tasks = n_threads; } break; case GGML_OP_FLASH_ATTN: + case GGML_OP_FLASH_ATTN_EXT: { n_tasks = n_threads; } break; @@ -18563,6 +18902,12 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2 } } break; + case GGML_OP_FLASH_ATTN_EXT: + { + const int64_t ne00 = node->src[0]->ne[0]; // D + + cur = 2*sizeof(float)*ne00*n_tasks; // 2x head size + } break; case GGML_OP_FLASH_FF: { if (node->src[1]->type == GGML_TYPE_F32) { diff --git a/ggml.h b/ggml.h index 86e5a8dc5..a11795973 100644 --- a/ggml.h +++ b/ggml.h @@ -475,6 +475,7 @@ extern "C" { GGML_OP_LEAKY_RELU, GGML_OP_FLASH_ATTN, + GGML_OP_FLASH_ATTN_EXT, GGML_OP_FLASH_FF, GGML_OP_FLASH_ATTN_BACK, GGML_OP_SSM_CONV, @@ -1722,6 +1723,25 @@ extern "C" { struct ggml_tensor * v, bool masked); +#define GGML_KQ_MASK_PAD 32 + + // q: [n_embd, n_batch, n_head, 1] + // k: [n_embd, n_kv, n_head_kv, 1] + // v: [n_embd, n_kv, n_head_kv, 1] !! not transposed !! + // mask: [n_kv, n_batch_pad, 1, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !! + // res: [n_embd, n_head, n_batch, 1] !! permuted !! + GGML_API struct ggml_tensor * ggml_flash_attn_ext( + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * mask, + float scale); + + GGML_API void ggml_flash_attn_ext_set_prec( + struct ggml_tensor * a, + enum ggml_prec prec); + GGML_API struct ggml_tensor * ggml_flash_attn_back( struct ggml_context * ctx, struct ggml_tensor * q, diff --git a/llama.cpp b/llama.cpp index 72c10ffc2..18d6297ce 100644 --- a/llama.cpp +++ b/llama.cpp @@ -108,7 +108,6 @@ #define LLAMA_MAX_NODES 8192 #define LLAMA_MAX_EXPERTS 60 - // // logging // @@ -1846,7 +1845,7 @@ struct llama_hparams { float f_logit_scale = 0.0f; bool causal_attn = true; - bool need_kq_pos = false; + bool use_alibi = false; // currently, we need KQ_pos data for ALiBi-based models enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE; enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE; @@ -1936,6 +1935,7 @@ struct llama_cparams { bool embeddings; bool causal_attn; bool offload_kqv; + bool flash_attn; enum llama_pooling_type pooling_type; @@ -2039,8 +2039,8 @@ struct llama_kv_cache { bool has_shift = false; bool do_defrag = false; bool do_copy = false; - // with recurrent state models, a cell can hold the state for more than one past token - bool recurrent = false; + bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token + bool v_trans = true; // the value tensor is transposed // Note: The value of head isn't only used to optimize searching // for a free KV slot. llama_decode_internal also uses it, so it @@ -2339,11 +2339,14 @@ struct llama_context { static bool llama_kv_cache_init( struct llama_kv_cache & cache, - const llama_model & model, + const llama_context * ctx, ggml_type type_k, ggml_type type_v, uint32_t kv_size, bool offload) { + const llama_model & model = ctx->model; + const llama_cparams & cparams = ctx->cparams; + const struct llama_hparams & hparams = model.hparams; const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s(); @@ -2354,6 +2357,7 @@ static bool llama_kv_cache_init( // TODO: find a nicer way to add other recurrent model architectures cache.recurrent = model.arch == LLM_ARCH_MAMBA; + cache.v_trans = !cparams.flash_attn; // TODO: support mixed reccurent Transformer architectues // NOTE: (!a || b) is a logical implication (a -> b) @@ -2566,6 +2570,10 @@ static void llama_kv_cache_clear(struct llama_kv_cache & cache) { } cache.head = 0; cache.used = 0; + + for (auto & buf : cache.bufs) { + ggml_backend_buffer_clear(buf, 0); + } } static bool llama_kv_cache_seq_rm( @@ -4194,7 +4202,7 @@ static void llm_load_hparams( model.ftype = ml.ftype; if (hparams.f_max_alibi_bias > 0.0f) { - hparams.need_kq_pos = true; + hparams.use_alibi = true; } hparams.rope_type = llama_rope_type(&model); @@ -6203,37 +6211,47 @@ static struct ggml_tensor * llm_build_inp_embd( static void llm_build_kv_store( struct ggml_context * ctx, const llama_hparams & hparams, + const llama_cparams & cparams, const llama_kv_cache & kv, struct ggml_cgraph * graph, struct ggml_tensor * k_cur, struct ggml_tensor * v_cur, - int64_t n_ctx, int32_t n_tokens, int32_t kv_head, const llm_build_cb & cb, int64_t il) { + const int64_t n_ctx = cparams.n_ctx; + const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(); const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(kv.size == n_ctx); - // compute the transposed [n_tokens, n_embd] V matrix - assert(v_cur->ne[0] == n_embd_v_gqa && v_cur->ne[1] == n_tokens); - struct ggml_tensor * v_cur_t = ggml_transpose(ctx, v_cur); - cb(v_cur_t, "v_cur_t", il); - struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv.k_l[il], n_tokens*n_embd_k_gqa, (ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa))*kv_head); cb(k_cache_view, "k_cache_view", il); - struct ggml_tensor * v_cache_view = ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_v_gqa, - ( n_ctx)*ggml_element_size(kv.v_l[il]), - (kv_head)*ggml_element_size(kv.v_l[il])); + // note: storing RoPE-ed version of K in the KV cache + ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view)); + + assert(v_cur->ne[0] == n_embd_v_gqa && v_cur->ne[1] == n_tokens); + + struct ggml_tensor * v_cache_view = nullptr; + + if (cparams.flash_attn) { + v_cache_view = ggml_view_1d(ctx, kv.v_l[il], n_tokens*n_embd_v_gqa, + (kv_head)*ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa)); + } else { + // note: the V cache is transposed when not using flash attention + v_cache_view = ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_v_gqa, + ( n_ctx)*ggml_element_size(kv.v_l[il]), + (kv_head)*ggml_element_size(kv.v_l[il])); + + v_cur = ggml_transpose(ctx, v_cur); + } cb(v_cache_view, "v_cache_view", il); - // important: storing RoPE-ed version of K in the KV cache! - ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view)); - ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur_t, v_cache_view)); + ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur, v_cache_view)); } static struct ggml_tensor * llm_build_norm( @@ -6453,11 +6471,11 @@ static struct ggml_tensor * llm_build_moe_ffn( return moe_out; } -// if max_alibi_bias > 0 then apply ALiBi static struct ggml_tensor * llm_build_kqv( struct ggml_context * ctx, const llama_model & model, const llama_hparams & hparams, + const llama_cparams & cparams, const llama_kv_cache & kv, struct ggml_cgraph * graph, struct ggml_tensor * wo, @@ -6465,12 +6483,12 @@ static struct ggml_tensor * llm_build_kqv( struct ggml_tensor * q_cur, struct ggml_tensor * kq_mask, struct ggml_tensor * kq_pos, - int64_t n_ctx, int32_t n_tokens, int32_t n_kv, float kq_scale, const llm_build_cb & cb, int il) { + const int64_t n_ctx = cparams.n_ctx; const int64_t n_head = hparams.n_head; const int64_t n_head_kv = hparams.n_head_kv; const int64_t n_embd_head_k = hparams.n_embd_head_k; @@ -6488,72 +6506,100 @@ static struct ggml_tensor * llm_build_kqv( 0); cb(k, "k", il); - struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); - cb(kq, "kq", il); + struct ggml_tensor * cur; - if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3) { - // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs - // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847 - ggml_mul_mat_set_prec(kq, GGML_PREC_F32); - } + if (cparams.flash_attn) { + GGML_UNUSED(model); + GGML_UNUSED(n_ctx); - if (model.arch == LLM_ARCH_GROK) { - // need to do the following: - // multiply by attn_output_multiplyer of 0.08838834764831845 - // and then : - // kq = 30 * tanh(kq / 30) - // before the softmax below + // note: if this assert triggers, then some check has failed earlier + // the idea is to detect during context creation that ALiBi would be used and disable Flash Attention + GGML_ASSERT(kq_pos == nullptr && "ALiBi is not yet supported with Flash Attention"); - //try from phi2 - //ggml_mul_mat_set_prec(kq, GGML_PREC_F32); + // split cached v into n_head heads (not transposed) + struct ggml_tensor * v = + ggml_view_3d(ctx, kv.v_l[il], + n_embd_head_v, n_kv, n_head_kv, + ggml_row_size(kv.v_l[il]->type, n_embd_k_gqa), + ggml_row_size(kv.v_l[il]->type, n_embd_head_k), + 0); + cb(v, "v", il); - kq = ggml_tanh(ctx, ggml_scale(ctx, kq, 0.08838834764831845f/30.0f)); - kq = ggml_scale(ctx, kq, 30); - } + cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale); + + if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3) { + ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); + } + + cur = ggml_reshape_2d(ctx, cur, n_embd_head_k*n_head, n_tokens); + } else { + struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); + cb(kq, "kq", il); + + if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3) { + // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs + // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847 + ggml_mul_mat_set_prec(kq, GGML_PREC_F32); + } + + if (model.arch == LLM_ARCH_GROK) { + // need to do the following: + // multiply by attn_output_multiplyer of 0.08838834764831845 + // and then : + // kq = 30 * tanh(kq / 30) + // before the softmax below + + //try from phi2 + //ggml_mul_mat_set_prec(kq, GGML_PREC_F32); + + kq = ggml_tanh(ctx, ggml_scale(ctx, kq, 0.08838834764831845f/30.0f)); + kq = ggml_scale(ctx, kq, 30); + } #if defined(GGML_USE_KOMPUTE) #pragma message("TODO: ALiBi support in ggml_soft_max_ext is not implemented for Kompute") #pragma message(" Falling back to ggml_alibi(). Will become an error in Mar 2024") #pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5488") - if (hparams.f_max_alibi_bias > 0.0f) { - kq = ggml_scale(ctx, kq, kq_scale); - cb(kq, "kq_scaled", il); + if (hparams.use_alibi) { + kq = ggml_scale(ctx, kq, kq_scale); + cb(kq, "kq_scaled", il); - kq = ggml_alibi(ctx, kq, /*n_past*/ 0, n_head, hparams.f_max_alibi_bias); - cb(kq, "kq_scaled_alibi", il); + kq = ggml_alibi(ctx, kq, /*n_past*/ 0, n_head, hparams.f_max_alibi_bias); + cb(kq, "kq_scaled_alibi", il); - kq = ggml_add(ctx, kq, kq_mask); - cb(kq, "kq_masked", il); + kq = ggml_add(ctx, kq, kq_mask); + cb(kq, "kq_masked", il); - kq = ggml_soft_max(ctx, kq); - cb(kq, "kq_soft_max", il); - } else + kq = ggml_soft_max(ctx, kq); + cb(kq, "kq_soft_max", il); + } else #endif - { - kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_pos, kq_scale, hparams.f_max_alibi_bias); - cb(kq, "kq_soft_max_ext", il); + { + kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_pos, kq_scale, hparams.f_max_alibi_bias); + cb(kq, "kq_soft_max_ext", il); + } + + GGML_ASSERT(kv.size == n_ctx); + + // split cached v into n_head heads + struct ggml_tensor * v = + ggml_view_3d(ctx, kv.v_l[il], + n_kv, n_embd_head_v, n_head_kv, + ggml_element_size(kv.v_l[il])*n_ctx, + ggml_element_size(kv.v_l[il])*n_ctx*n_embd_head_v, + 0); + cb(v, "v", il); + + struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq); + cb(kqv, "kqv", il); + + struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3); + cb(kqv_merged, "kqv_merged", il); + + cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_k*n_head, n_tokens); + cb(cur, "kqv_merged_cont", il); } - GGML_ASSERT(kv.size == n_ctx); - - // split cached v into n_head heads - struct ggml_tensor * v = - ggml_view_3d(ctx, kv.v_l[il], - n_kv, n_embd_head_v, n_head_kv, - ggml_element_size(kv.v_l[il])*n_ctx, - ggml_element_size(kv.v_l[il])*n_ctx*n_embd_head_v, - 0); - cb(v, "v", il); - - struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq); - cb(kqv, "kqv", il); - - struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3); - cb(kqv_merged, "kqv_merged", il); - - struct ggml_tensor * cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_k*n_head, n_tokens); - cb(cur, "kqv_merged_cont", il); - ggml_build_forward_expand(graph, cur); cur = ggml_mul_mat(ctx, wo, cur); @@ -6572,6 +6618,7 @@ static struct ggml_tensor * llm_build_kv( struct ggml_context * ctx, const llama_model & model, const llama_hparams & hparams, + const llama_cparams & cparams, const llama_kv_cache & kv, struct ggml_cgraph * graph, struct ggml_tensor * wo, @@ -6581,7 +6628,6 @@ static struct ggml_tensor * llm_build_kv( struct ggml_tensor * q_cur, struct ggml_tensor * kq_mask, struct ggml_tensor * kq_pos, - int64_t n_ctx, int32_t n_tokens, int32_t kv_head, int32_t n_kv, @@ -6595,12 +6641,12 @@ static struct ggml_tensor * llm_build_kv( ggml_build_forward_expand(graph, k_cur); ggml_build_forward_expand(graph, v_cur); - llm_build_kv_store(ctx, hparams, kv, graph, k_cur, v_cur, n_ctx, n_tokens, kv_head, cb, il); + llm_build_kv_store(ctx, hparams, cparams, kv, graph, k_cur, v_cur, n_tokens, kv_head, cb, il); struct ggml_tensor * cur; - cur = llm_build_kqv(ctx, model, hparams, kv, graph, wo, wo_b, - q_cur, kq_mask, kq_pos, n_ctx, n_tokens, n_kv, kq_scale, cb, il); + cur = llm_build_kqv(ctx, model, hparams, cparams, kv, graph, wo, wo_b, + q_cur, kq_mask, kq_pos, n_tokens, n_kv, kq_scale, cb, il); cb(cur, "kqv_out", il); return cur; @@ -6642,6 +6688,8 @@ struct llm_build_context { const int32_t kv_head; // index of where we store new KV data in the cache const int32_t n_orig_ctx; + const bool flash_attn; + const enum llama_pooling_type pooling_type; const enum llama_rope_type rope_type; @@ -6688,6 +6736,7 @@ struct llm_build_context { n_outputs (worst_case ? n_tokens : lctx.n_outputs), kv_head (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head), n_orig_ctx (cparams.n_yarn_orig_ctx), + flash_attn (cparams.flash_attn), pooling_type (cparams.pooling_type), rope_type (hparams.rope_type), cb (cb), @@ -6802,15 +6851,31 @@ struct llm_build_context { ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa), ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*id)); - ggml_tensor * view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il], - nm, n_embd_v_gqa, - ggml_row_size(kv_self.v_l[il]->type, kv_self.size), - ggml_row_size(kv_self.v_l[il]->type, i)); + ggml_tensor * view_v_src; + ggml_tensor * view_v_dst; - ggml_tensor * view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il], - nm, n_embd_v_gqa, - ggml_row_size(kv_self.v_l[il]->type, kv_self.size), - ggml_row_size(kv_self.v_l[il]->type, id)); + if (flash_attn) { + // NOTE: the V cache is not transposed when using flash attention + view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il], + n_embd_v_gqa, nm, + ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa), + ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*i)); + + view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il], + n_embd_v_gqa, nm, + ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa), + ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*id)); + } else { + view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il], + nm, n_embd_v_gqa, + ggml_row_size(kv_self.v_l[il]->type, kv_self.size), + ggml_row_size(kv_self.v_l[il]->type, i)); + + view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il], + nm, n_embd_v_gqa, + ggml_row_size(kv_self.v_l[il]->type, kv_self.size), + ggml_row_size(kv_self.v_l[il]->type, id)); + } ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_k_src, view_k_dst)); ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_v_src, view_v_dst)); @@ -6840,20 +6905,26 @@ struct llm_build_context { struct ggml_tensor * build_inp_KQ_mask(bool causal = true) { if (causal) { - lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, n_tokens); + lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); } else { - lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens); + lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); } cb(lctx.inp_KQ_mask, "KQ_mask", -1); ggml_set_input(lctx.inp_KQ_mask); - return lctx.inp_KQ_mask; + return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_mask, GGML_TYPE_F16) : lctx.inp_KQ_mask; } - struct ggml_tensor * build_inp_KQ_pos() { - lctx.inp_KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n_kv); + struct ggml_tensor * build_inp_KQ_pos(bool causal = true) { + if (causal) { + lctx.inp_KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n_kv); + } else { + // TODO: this will be needed for ALiBi-based BERT models + // https://github.com/ggerganov/llama.cpp/pull/6826 + lctx.inp_KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n_tokens); + } cb(lctx.inp_KQ_pos, "KQ_pos", -1); ggml_set_input(lctx.inp_KQ_pos); - return lctx.inp_KQ_pos; + return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_pos, GGML_TYPE_F16) : lctx.inp_KQ_pos; } struct ggml_tensor * build_inp_mean() { @@ -6959,9 +7030,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -7099,9 +7170,9 @@ struct llm_build_context { cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -7206,9 +7277,9 @@ struct llm_build_context { ext_factor, attn_factor, beta_fast, beta_slow ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -7326,9 +7397,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -7451,9 +7522,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f, cb, il); } if (il == n_layer - 1) { @@ -7603,9 +7674,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, - model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + model.layers[il].wo, NULL, + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -7715,9 +7786,9 @@ struct llm_build_context { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -7919,9 +7990,9 @@ struct llm_build_context { ); cb(Vcur, "Vcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Q, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Q, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -8015,9 +8086,9 @@ struct llm_build_context { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); cb(Qcur, "Qcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -8308,9 +8379,9 @@ struct llm_build_context { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -8439,14 +8510,15 @@ struct llm_build_context { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, - model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + model.layers[il].wo, model.layers[il].bo, + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } else { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } } @@ -8588,9 +8660,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -8706,9 +8778,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -8819,9 +8891,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -8933,9 +9005,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -9088,9 +9160,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f, cb, il); } if (il == n_layer - 1) { @@ -9205,9 +9277,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, - model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il); + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + model.layers[il].wo, model.layers[il].bo, + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f, cb, il); } if (il == n_layer - 1) { @@ -9318,9 +9390,9 @@ struct llm_build_context { ext_factor, attn_factor, beta_fast, beta_slow); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } struct ggml_tensor * sa_out = cur; @@ -9421,9 +9493,9 @@ struct llm_build_context { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -9528,9 +9600,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -9644,9 +9716,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -9761,9 +9833,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -9891,9 +9963,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -10012,9 +10084,9 @@ struct llm_build_context { ext_factor, attn_factor, beta_fast, beta_slow); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f, cb, il); } if (il == n_layer - 1) { @@ -10131,9 +10203,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -10421,9 +10493,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -10552,9 +10624,9 @@ struct llm_build_context { ); cb(Kcur, "Kcur", il); - cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, nullptr, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -10981,7 +11053,9 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } } - if (hparams.need_kq_pos) { + // ALiBi requires the KQ_pos tensor to provide the sequence position of each token in the batch + // this allows to process multiple sequences in parallel with ALiBi-based models + if (hparams.use_alibi) { const int64_t n_kv = kv_self.n; GGML_ASSERT(lctx.inp_KQ_pos); @@ -11363,7 +11437,7 @@ static int llama_decode_internal( // a heuristic, to avoid attending the full cache if it is not yet utilized // after enough generations, the benefit from this heuristic disappears // if we start defragmenting the cache, the benefit from this will be more important - kv_self.n = std::min(kv_self.size, std::max(32u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32))); + kv_self.n = std::min(kv_self.size, std::max(256u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 256))); //kv_self.n = llama_kv_cache_cell_max(kv_self); } } @@ -11560,7 +11634,9 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) { // each move requires 6*n_layer tensors (see build_defrag) // - source view, destination view, copy operation // - x2 for keys and values - const uint32_t max_moves = LLAMA_MAX_NODES/(6*n_layer); + //const uint32_t max_moves = LLAMA_MAX_NODES/(6*n_layer); + // TODO: tmp fix https://github.com/ggerganov/llama.cpp/issues/6685#issuecomment-2057579516 + const uint32_t max_moves = (LLAMA_MAX_NODES - 2*n_layer)/(6*n_layer); // determine which KV cells to move where // @@ -15167,6 +15243,7 @@ struct llama_context_params llama_context_default_params() { /*.logits_all =*/ false, /*.embeddings =*/ false, /*.offload_kqv =*/ true, + /*.flash_attn =*/ false, /*.abort_callback =*/ nullptr, /*.abort_callback_data =*/ nullptr, }; @@ -15333,6 +15410,7 @@ struct llama_context * llama_new_context_with_model( cparams.defrag_thold = params.defrag_thold; cparams.embeddings = params.embeddings; cparams.offload_kqv = params.offload_kqv; + cparams.flash_attn = params.flash_attn; cparams.pooling_type = params.pooling_type; cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx; @@ -15340,12 +15418,20 @@ struct llama_context * llama_new_context_with_model( cparams.rope_freq_scale = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale; // this is necessary due to kv_self.n being padded later during inference - cparams.n_ctx = GGML_PAD(cparams.n_ctx, 32); + cparams.n_ctx = GGML_PAD(cparams.n_ctx, 256); // with causal attention, the batch size is limited by the context size cparams.n_batch = hparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch; - cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch); + // the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask + // this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext) + // ref: https://github.com/ggerganov/llama.cpp/pull/5021 + if (cparams.n_batch < GGML_KQ_MASK_PAD) { + LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD); + cparams.n_batch = GGML_KQ_MASK_PAD; + } + + cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch); cparams.n_yarn_orig_ctx = params.yarn_orig_ctx != 0 ? params.yarn_orig_ctx : hparams.n_yarn_orig_ctx != 0 ? hparams.n_yarn_orig_ctx : @@ -15377,6 +15463,23 @@ struct llama_context * llama_new_context_with_model( } } + if (cparams.flash_attn && hparams.use_alibi) { + LLAMA_LOG_WARN("%s: flash_attn is not yet compatible with ALiBi - forcing off\n", __func__); + cparams.flash_attn = false; + } + + if (cparams.flash_attn && model->arch == LLM_ARCH_GROK) { + LLAMA_LOG_WARN("%s: flash_attn is not compatible with Grok - forcing off\n", __func__); + cparams.flash_attn = false; + } + +#ifdef GGML_USE_HIPBLAS + if (cparams.flash_attn) { + LLAMA_LOG_WARN("%s: flash_attn is not yet compatible with HIPBLAS builds - forcing off\n", __func__); + cparams.flash_attn = false; + } +#endif + if (params.seed == LLAMA_DEFAULT_SEED) { params.seed = time(NULL); } @@ -15384,6 +15487,7 @@ struct llama_context * llama_new_context_with_model( LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx); LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch); LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch); + LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn); LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale); @@ -15512,7 +15616,7 @@ struct llama_context * llama_new_context_with_model( } ctx->backends.push_back(ctx->backend_cpu); - if (!llama_kv_cache_init(ctx->kv_self, ctx->model, type_k, type_v, kv_size, cparams.offload_kqv)) { + if (!llama_kv_cache_init(ctx->kv_self, ctx, type_k, type_v, kv_size, cparams.offload_kqv)) { LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__); llama_free(ctx); return nullptr; @@ -16111,6 +16215,7 @@ size_t llama_state_get_size(const struct llama_context * ctx) { const size_t s_kv_head = sizeof(uint32_t); const size_t s_kv_size = sizeof(uint32_t); const size_t s_kv_used = sizeof(uint32_t); + const size_t s_v_trans = sizeof(uint32_t); const size_t s_kv = ctx->kv_self.total_size(); const size_t s_kv_cell = sizeof(llama_pos) + sizeof(size_t) + cparams.n_seq_max*sizeof(llama_seq_id); const size_t s_kv_cells = ctx->kv_self.size * s_kv_cell; @@ -16128,10 +16233,14 @@ size_t llama_state_get_size(const struct llama_context * ctx) { + s_kv_head + s_kv_size + s_kv_used + + s_v_trans + s_kv + s_kv_cells ); + // on session change it is very likely that the state size has changed - so we need to update this function + static_assert(LLAMA_SESSION_VERSION == 6, "So you just bumped the session version - good. But did you remember to update llama_state_get_size?"); + return s_total; } @@ -16277,11 +16386,13 @@ static void llama_state_get_data_internal(struct llama_context * ctx, llama_data const uint32_t kv_size = kv_self.size; const size_t kv_buf_size = kv_self.total_size() / (kv_size ? kv_size : 1) * kv_head; const uint32_t kv_used = kv_self.used; + const uint32_t v_trans = kv_self.v_trans ? 1 : 0; data_ctx->write(&kv_buf_size, sizeof(kv_buf_size)); data_ctx->write(&kv_head, sizeof(kv_head)); data_ctx->write(&kv_size, sizeof(kv_size)); data_ctx->write(&kv_used, sizeof(kv_used)); + data_ctx->write(&v_trans, sizeof(v_trans)); if (kv_buf_size) { const size_t pre_kv_buf_size = data_ctx->get_size_written(); @@ -16294,7 +16405,7 @@ static void llama_state_get_data_internal(struct llama_context * ctx, llama_data ggml_backend_tensor_get(kv_self.k_l[il], tmp_buf.data(), 0, tmp_buf.size()); data_ctx->write(tmp_buf.data(), tmp_buf.size()); - if (kv_self.recurrent) { + if (kv_self.recurrent || !kv_self.v_trans) { // v is contiguous for recurrent models // TODO: use other tensors for state models than k and v const size_t v_size = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*kv_head); @@ -16427,11 +16538,15 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) { uint32_t kv_head; uint32_t kv_size; uint32_t kv_used; + uint32_t v_trans; memcpy(&kv_buf_size, inp, sizeof(kv_buf_size)); inp += sizeof(kv_buf_size); memcpy(&kv_head, inp, sizeof(kv_head)); inp += sizeof(kv_head); memcpy(&kv_size, inp, sizeof(kv_size)); inp += sizeof(kv_size); memcpy(&kv_used, inp, sizeof(kv_used)); inp += sizeof(kv_used); + memcpy(&v_trans, inp, sizeof(v_trans)); inp += sizeof(v_trans); + + GGML_ASSERT(kv_self.v_trans == (bool) v_trans); // incompatible V transposition if (kv_self.size != kv_size) { // the KV cache needs to be big enough to load all the KV cells from the saved state @@ -16441,6 +16556,8 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) { __func__, kv_head, kv_size, kv_self.size); } + llama_kv_cache_clear(ctx); + if (kv_buf_size) { const size_t pre_kv_buf_size = inp - src; @@ -16452,7 +16569,7 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) { ggml_backend_tensor_set(kv_self.k_l[il], inp, 0, k_size); inp += k_size; - if (kv_self.recurrent) { + if (kv_self.recurrent || !kv_self.v_trans) { // v is contiguous for recurrent models // TODO: use other tensors for state models than k and v const size_t v_size = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*kv_head); @@ -16474,8 +16591,6 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) { GGML_ASSERT(kv_buf_size == inp - src - pre_kv_buf_size); } - llama_kv_cache_clear(ctx); - ctx->kv_self.head = kv_head; ctx->kv_self.used = kv_used; @@ -16735,28 +16850,49 @@ static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llam } } - // For the values, they are transposed, so we also need the element size and get the element ranges from each row - const uint32_t kv_size = kv_self.size; - for (int il = 0; il < (int)n_layer; ++il) { - // Write value type - const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; - data_ctx.write(&v_type_i, sizeof(v_type_i)); + // TODO: simplify, reduce copy-paste + if (!kv_self.v_trans) { + for (int il = 0; il < (int)n_layer; ++il) { + // Write value type + const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; + data_ctx.write(&v_type_i, sizeof(v_type_i)); - // Write element size - const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type); - data_ctx.write(&v_size_el, sizeof(v_size_el)); + // Write row size of value + const size_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa); + data_ctx.write(&v_size_row, sizeof(v_size_row)); - // For each row, we get the element values of each cell - for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { - // Read each range of cells of v_size_el length each into tmp_buf and write out + // Read each range of cells of v_size length each into tmp_buf and write out for (const auto & range : cell_ranges) { const size_t range_size = range.second - range.first; - const size_t src_offset = (range.first + j * kv_size) * v_size_el; - tmp_buf.resize(range_size * v_size_el); - ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), src_offset, tmp_buf.size()); + tmp_buf.resize(range_size * v_size_row); + ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), range.first * v_size_row, range_size * v_size_row); data_ctx.write(tmp_buf.data(), tmp_buf.size()); } } + } else { + // For the values, they are transposed, so we also need the element size and get the element ranges from each row + const uint32_t kv_size = kv_self.size; + for (int il = 0; il < (int)n_layer; ++il) { + // Write value type + const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; + data_ctx.write(&v_type_i, sizeof(v_type_i)); + + // Write element size + const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type); + data_ctx.write(&v_size_el, sizeof(v_size_el)); + + // For each row, we get the element values of each cell + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + // Read each range of cells of v_size_el length each into tmp_buf and write out + for (const auto & range : cell_ranges) { + const size_t range_size = range.second - range.first; + const size_t src_offset = (range.first + j * kv_size) * v_size_el; + tmp_buf.resize(range_size * v_size_el); + ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), src_offset, tmp_buf.size()); + data_ctx.write(tmp_buf.data(), tmp_buf.size()); + } + } + } } return data_ctx.get_size_written(); @@ -16881,41 +17017,75 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, } } - // For each layer, read the values for each cell (transposed) - for (int il = 0; il < (int)n_layer; ++il) { - // Read type of value - int32_t v_type_i_ref; - memcpy(&v_type_i_ref, inp, sizeof(v_type_i_ref)); - inp += sizeof(v_type_i_ref); - const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; - if (v_type_i != v_type_i_ref) { - llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); - LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); - return 0; - } + // TODO: simplify, reduce copy-paste + if (!kv_self.v_trans) { + for (int il = 0; il < (int)n_layer; ++il) { + // Read type of value + int32_t v_type_i_ref; + memcpy(&v_type_i_ref, inp, sizeof(v_type_i_ref)); + inp += sizeof(v_type_i_ref); + const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; + if (v_type_i != v_type_i_ref) { + llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); + LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); + return 0; + } - // Read element size of value - size_t v_size_el_ref; - memcpy(&v_size_el_ref, inp, sizeof(v_size_el_ref)); - inp += sizeof(v_size_el_ref); - const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type); - if (v_size_el != v_size_el_ref) { - llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); - LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, v_size_el_ref, il); - return 0; - } + // Read row size of value + size_t v_size_row_ref; + memcpy(&v_size_row_ref, inp, sizeof(v_size_row_ref)); + inp += sizeof(v_size_row_ref); + const size_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa); + if (v_size_row != v_size_row_ref) { + llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); + LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, v_size_row_ref, il); + return 0; + } - if (cell_count) { - // For each row in the transposed matrix, read the values for the whole cell range - for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { - const size_t dst_offset = (kv_head + j * kv_size) * v_size_el; - ggml_backend_tensor_set(kv_self.v_l[il], inp, dst_offset, cell_count * v_size_el); - inp += cell_count * v_size_el; + if (cell_count) { + // Read and set the values for the whole cell range + ggml_backend_tensor_set(kv_self.v_l[il], inp, kv_head * v_size_row, cell_count * v_size_row); + inp += cell_count * v_size_row; + } + } + } else { + // For each layer, read the values for each cell (transposed) + for (int il = 0; il < (int)n_layer; ++il) { + // Read type of value + int32_t v_type_i_ref; + memcpy(&v_type_i_ref, inp, sizeof(v_type_i_ref)); + inp += sizeof(v_type_i_ref); + const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; + if (v_type_i != v_type_i_ref) { + llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); + LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); + return 0; + } + + // Read element size of value + size_t v_size_el_ref; + memcpy(&v_size_el_ref, inp, sizeof(v_size_el_ref)); + inp += sizeof(v_size_el_ref); + const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type); + if (v_size_el != v_size_el_ref) { + llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); + LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, v_size_el_ref, il); + return 0; + } + + if (cell_count) { + // For each row in the transposed matrix, read the values for the whole cell range + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + const size_t dst_offset = (kv_head + j * kv_size) * v_size_el; + ggml_backend_tensor_set(kv_self.v_l[il], inp, dst_offset, cell_count * v_size_el); + inp += cell_count * v_size_el; + } } } } const size_t nread = inp - src; + return nread; } diff --git a/llama.h b/llama.h index 30835de5f..059d78f11 100644 --- a/llama.h +++ b/llama.h @@ -40,7 +40,7 @@ #define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq' #define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN -#define LLAMA_SESSION_VERSION 5 +#define LLAMA_SESSION_VERSION 6 #define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ #define LLAMA_STATE_SEQ_VERSION 1 @@ -287,6 +287,7 @@ extern "C" { bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead) bool embeddings; // if true, extract embeddings (together with logits) bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU + bool flash_attn; // whether to use flash attention // Abort callback // if it returns true, execution of llama_decode() will be aborted @@ -542,7 +543,7 @@ extern "C" { // Returns the number of used KV cells (i.e. have at least one sequence assigned to them) LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx); - // Clear the KV cache + // Clear the KV cache - both cell info is erased and KV data is zeroed LLAMA_API void llama_kv_cache_clear( struct llama_context * ctx); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 02daad24b..b27c1291e 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1090,6 +1090,12 @@ struct test_soft_max : public test_case { return VARS_TO_STR5(type, ne, mask, scale, max_bias); } + // the 1024 test with bias occasionally fails: + // SOFT_MAX(type=f32,ne=[1024,16,1,1],mask=1,scale=1.000000,max_bias=8.000000): [SOFT_MAX] NMSE = 0.000000103 > 0.000000100 FAIL + virtual double max_nmse_err() override { + return 1e-6; + } + test_soft_max(ggml_type type = GGML_TYPE_F32, std::array ne = {10, 10, 10, 10}, bool mask = false, @@ -1101,7 +1107,7 @@ struct test_soft_max : public test_case { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); ggml_tensor * mask = nullptr; if (this->mask) { - mask = ggml_new_tensor_2d(ctx, type, ne[0], ne[1]); + mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, ne[0], ne[1]); } ggml_tensor * pos = nullptr; if (max_bias > 0.0f) { @@ -1475,6 +1481,34 @@ struct test_leaky_relu : public test_case { } }; +// GGML_OP_FLASH_ATTN_EXT +struct test_flash_attn_ext : public test_case { + const int64_t hs; // head size + const int64_t nh; // num heads + const int64_t kv; // kv size + const int64_t nb; // batch size + + std::string vars() override { + return VARS_TO_STR4(hs, nh, kv, nb); + } + + double max_nmse_err() override { + return 5e-4; + } + + test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8) + : hs(hs), nh(nh), kv(kv), nb(nb) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, hs, nb, nh, 1); + ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); + ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); + ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), 1, 1); + ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, mask, 1.0f/sqrtf(hs)); + return out; + } +}; + enum llm_norm_type { LLM_NORM, LLM_NORM_RMS, @@ -1661,7 +1695,7 @@ struct test_llama : public test_llm { struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, hp.n_tokens); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hp.n_kv, hp.n_tokens, 1); + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_F16, hp.n_kv, hp.n_tokens, 1); ggml_tensor * k_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400); ggml_tensor * v_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400); @@ -1783,7 +1817,7 @@ struct test_falcon : public test_llm { struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, hp.n_tokens); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hp.n_kv, hp.n_tokens, 1); + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_F16, hp.n_kv, hp.n_tokens, 1); ggml_tensor * k_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400); ggml_tensor * v_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400); @@ -2095,7 +2129,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op for (float scale : {1.0f, 0.1f}) { for (int64_t ne0 : {16, 1024}) { for (int64_t ne1 : {16, 1024}) { - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, mask, scale, max_bias)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, mask, scale, max_bias)); test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, scale, max_bias)); } } @@ -2139,6 +2173,16 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_timestep_embedding()); test_cases.emplace_back(new test_leaky_relu()); + for (int hs : { 64, 80, 128, 256, }) { + for (int nh : { 32, }) { + for (int kv : { 512, 1024, }) { + for (int nb : { 1, 2, 4, 8, }) { + test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb)); + } + } + } + } + // these tests are disabled to save execution time, but they can be handy for debugging #if 0 test_cases.emplace_back(new test_llama(1));