llama.cpp/common
Georgi Gerganov 9c67c2773d
ggml : add Flash Attention (#5021)
* ggml : add ggml_flash_attn_ext API

* ggml : fix GQA support in ggml_flash_attn_ext

* ggml : online attention (CPU)

* metal : initial implementation

* metal : f16 precision

* metal : reduce branches

* metal : specialize for head size

* wip : 8 rows per simd group

* wip : 4 rows per simd group

* wip : template for rows per warp

* metal : parallelize across KV size

* metal : parallel reduce across heads

* metal : efficient flash_attn_f16 implementation

* metal : avoid redundant loads of the attention

* metal : scale and mask in matrix form

* metal : fix comment

* llama : avoid ggml_cast, use F32 query

* metal : add parallel reduce version (disabled)

* metal : move output into local memory + optimize

- the result from each simdgroup now stays in the registers
- significantly reduced SRAM usage
- more efficient skipping of -INF blocks
- avoid simdgroup barrier in hot loop
- add comments

* metal : add tests, fix scaling, support C > 32

* metal : improve precision

* ggml : fix f16 mad

* metal : minor

* metal : support Q > 8

* tests : add ATTN tests

* metal : disable buffer allocation logs

* tests : more

* metal : faster inner loop for C == 32

* metal : fix array initialization

* tests : ifdef

* ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext

* ggml : fix ggml_soft_max mask requirement

* cuda : fix soft_max to use correct mask size

* cuda : add flash_attn kernel (wip)

* metal : optimize softmax for C > 32

* metal : optimize softmax

* tests : minor fix

* cuda : avoid zeroing fragments

* tests : update dims

* cuda : fix __hisinf() result check

* cuda : avoid warp_reduce for smax

* cuda : use int instead of int64_t

Noticeably improves performance (thanks to Johannes)

* cuda : make loops use the same loop values

Thanks Johannes again for the tip

* cuda : unroll some of the loops

* cuda : avoid __hisinf branches

* cuda : use half2 in softmax

* cuda : switch to 1 warp for bs > 16

* cuda : speed-up reduce part of the kernel

* cuda : unroll Q*K^T loop

* cuda : fix -INF block check

* cuda : simplify softmax

* cuda : fix matrix names

* cuda : minor

* llama : adapt to F16 KQ_pos

* llama : adapt new models to F16 KQ_mask

* ggml : fix F16 store (ARM NEON)

* llama : fix type of KQ_mask and KQ_pos

* ggml : fix CPU soft_max

* tests : add hs=256

* cuda : fix build

* metal : improve perf via smaller int registers

* cuda : adapt soft_max to F16 mask and pos

* CUDA: faster FlashAttention, kernel for bs == 1

* 16 cols for Phi-2

* no vec for hs, no hs==256 ncols==32 for Volta

* adjust kernel selection logic

* 4 warps, 256 stride for all D

* no ncols == 64

* Multiple parallel blocks for batch size 1

* fix compile warnings

* fix excessive KQ_b loads

* fix cmake build

* fix KV cache padding, NaN from INFINITY (#6438)

* llama : flash_attn cparam + fix defrag

* server: support flash_attn param

* server: bench: enable flash_attn param

* CUDA: refactor host code, dyn. par. blocks

* fix flash_attn_vec_f16 race condition

* flush softmax exp below threshold to 0

* store temp KQ in registers

* Calculate KQ as FP32 if KQV has GGML_PREC_F32

* Add __hgt2_mask implementation for CUDA 11

* fix KQ FP32 precision fpr parallel_blocks > 1

* llama-bench : add -fa,--flash-attn arg

* metal : add BS=1 kernel for flash attention (#6508)

* metal : add BS=1 kernel for flash attention (wip)

* metal : support more than 1 warps

* metal : opts

* metal : opt

* metal : switch to parallel reduce

* metal : reduce registers

* metal : simplify

* metal : initial FA vec kernel

* metal : use F32 attention accumulators

* batched-bench : add fattn arg

* llama : simplify llama_build_kv_store

ggml-ci

* llama : adapt build_olmo to changes

* ggml : fix arm fp16 store on windows

* metal : clean-up

* metal : clean-up kernel code

* metal : minor

* tests : remove benchmarks

ggml-ci

* ggml : fix avx512 const correctness

ggml-ci

* ggml : fix soft_max with bias on CPU

ggml-ci

* common : print --flash-attn in help

* ggml : fix num dimensions in ggml_flash_attn_ext

* llama : force disable flash attention for incompatible models

* ggml : ggml_soft_max support F16/F32 mask/pos

ggml-ci

* cuda : uint -> uint32_t

* cuda : "constexpr dim3" -> "const dim3"

ggml-ci

* cuda : try to fix __hgt2_mask

ggml-ci

* ggml : add TODO's for F16/F32 mask/pos support in other backends

* llama : replace bool need_kq_pos with use_alibi

* llama : prep ALiBi support for BERT models

ggml-ci

* llama : fix n_batch requirements

ggml-ci

* cont

* server : add help for --flash-attn arg

* llama : disable FA for AMD

* tests : remove TMP_ATTN_BENCH

ggml-ci

* llama : support save/load state with FA enabled

ggml-ci

* ci : add CUDA save-load-state tests

ggml-ci

* llama : llama_kv_cache_clear zeroes data + fix save-load seq

ggml-ci

* llama : fix copy-paste errors, add TODO

* llama : disallow incompatible states

* llama : update llama_state_get_size after v_trans field

* metal : remove tmp log

* llama : add static reminder for llama_state_get_size

* metal : fix max nsg

ggml-ci

* ci : fix arg order

ggml-ci

---------

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Co-authored-by: Pierrick HYMBERT <pierrick.hymbert@gmail.com>
2024-04-30 12:16:08 +03:00
..
base64.hpp llava : expose as a shared library for downstream projects (#3613) 2023-11-07 00:36:23 +03:00
build-info.cpp.in build : link against build info instead of compiling against it (#3879) 2023-11-02 08:50:16 +02:00
CMakeLists.txt main: add --json-schema / -j flag (#6659) 2024-04-15 18:35:21 +01:00
common.cpp ggml : add Flash Attention (#5021) 2024-04-30 12:16:08 +03:00
common.h ggml : add Flash Attention (#5021) 2024-04-30 12:16:08 +03:00
console.cpp check C++ code with -Wmissing-declarations (#3184) 2023-09-15 15:38:27 -04:00
console.h gguf : new file format with flexible meta data (beta) (#2398) 2023-08-21 23:07:43 +03:00
grammar-parser.cpp grammar : verify parsed state (#5950) 2024-03-10 17:17:43 +02:00
grammar-parser.h gguf : new file format with flexible meta data (beta) (#2398) 2023-08-21 23:07:43 +03:00
json-schema-to-grammar.cpp JSON schema conversion: ️ faster repetitions, min/maxLength for strings, cap number length (#6555) 2024-04-12 19:43:38 +01:00
json-schema-to-grammar.h json-schema-to-grammar : fix order of props + non-str const/enum (#6232) 2024-03-22 15:07:44 +02:00
json.hpp json-schema-to-grammar improvements (+ added to server) (#5978) 2024-03-21 11:50:43 +00:00
log.h Replace "alternative" boolean operator in conditional compilation directive (#6949) 2024-04-27 21:02:06 +02:00
ngram-cache.cpp Fixed lookup compilation issues on Windows (#6273) 2024-03-24 14:21:17 +01:00
ngram-cache.h lookup: complement data from context with general text statistics (#5479) 2024-03-23 01:24:36 +01:00
sampling.cpp sampling : use std::random_device{}() for default random seed (#6962) 2024-04-29 16:35:45 +03:00
sampling.h Server: fix seed for multiple slots (#6835) 2024-04-24 11:08:36 +02:00
stb_image.h examples: support LLaVA v1.5 (multimodal model) (#3436) 2023-10-12 18:23:18 +03:00
train.cpp code : normalize enum names (#5697) 2024-02-25 12:09:09 +02:00
train.h sync : ggml (backend v2) (#3912) 2023-11-13 14:16:23 +02:00