From 9cb317f77e53067f7a138cc89ef7657148eae8e6 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 11 May 2024 10:32:41 +0300 Subject: [PATCH] ggml : full ALiBi support (#7192) * ggml : full ALiBi support * ggml : update ggml_soft_max_ext() CUDA, SYCL * ggml : ggml_flash_attn_ext() support ALiBi (CPU) * ggml : ggml_flash_attn_ext() support ALiBi (Metal) * ggml : fix warning * ggml : ggml_flash_attn_ext() support ALiBi (CUDA) ggml-ci * ggml : fix assert message * vulkan : add dev notes * ggml : require mask when using ALiBi ggml-ci * convert : fix convert for refact models --- convert-hf-to-gguf.py | 12 ++ ggml-cuda.cu | 5 - ggml-cuda/alibi.cu | 63 ------- ggml-cuda/alibi.cuh | 5 - ggml-cuda/fattn.cu | 72 ++++++-- ggml-cuda/softmax.cu | 55 +++--- ggml-kompute.cpp | 12 +- ggml-metal.m | 148 ++++++---------- ggml-metal.metal | 120 ++++++------- ggml-sycl.cpp | 138 ++------------- ggml-vulkan.cpp | 6 +- ggml.c | 309 +++++---------------------------- ggml.h | 18 +- gguf-py/gguf/tensor_mapping.py | 4 + llama.cpp | 178 +++++++------------ tests/test-backend-ops.cpp | 30 ++-- 16 files changed, 350 insertions(+), 825 deletions(-) delete mode 100644 ggml-cuda/alibi.cu delete mode 100644 ggml-cuda/alibi.cuh diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 1dc18b2a5..3315ca74b 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -1013,6 +1013,18 @@ class StarCoderModel(Model): class RefactModel(Model): model_arch = gguf.MODEL_ARCH.REFACT + def set_vocab(self): + super().set_vocab() + + # TODO: how to determine special FIM tokens automatically? + special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False, + special_token_types = ['prefix', 'suffix', 'middle', 'fsep', 'eot']) + special_vocab._set_special_token("prefix", 1) + special_vocab._set_special_token("suffix", 3) + special_vocab._set_special_token("middle", 2) + special_vocab._set_special_token("fsep", 4) # is this correct? + special_vocab.add_to_gguf(self.gguf_writer) + def set_gguf_parameters(self): hidden_dim = self.hparams["n_embd"] inner_dim = 4 * hidden_dim diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 6f89a7cc3..c5c778796 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -4,7 +4,6 @@ #include "ggml-cuda/common.cuh" #include "ggml-cuda/acc.cuh" -#include "ggml-cuda/alibi.cuh" #include "ggml-cuda/arange.cuh" #include "ggml-cuda/argsort.cuh" #include "ggml-cuda/binbcast.cuh" @@ -2277,9 +2276,6 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_ROPE: ggml_cuda_op_rope(ctx, dst); break; - case GGML_OP_ALIBI: - ggml_cuda_op_alibi(ctx, dst); - break; case GGML_OP_IM2COL: ggml_cuda_op_im2col(ctx, dst); break; @@ -2829,7 +2825,6 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: case GGML_OP_ROPE: - case GGML_OP_ALIBI: case GGML_OP_IM2COL: case GGML_OP_POOL_2D: case GGML_OP_SUM_ROWS: diff --git a/ggml-cuda/alibi.cu b/ggml-cuda/alibi.cu deleted file mode 100644 index 6c7f1fd95..000000000 --- a/ggml-cuda/alibi.cu +++ /dev/null @@ -1,63 +0,0 @@ -#include "alibi.cuh" - -static __global__ void alibi_f32(const float * x, float * dst, const int ncols, const int k_rows, - const int n_heads_log2_floor, const float m0, const float m1) { - const int col = blockDim.x*blockIdx.x + threadIdx.x; - - if (col >= ncols) { - return; - } - - const int row = blockDim.y*blockIdx.y + threadIdx.y; - const int i = row*ncols + col; - - const int k = row/k_rows; - - float m_k; - if (k < n_heads_log2_floor) { - m_k = powf(m0, k + 1); - } else { - m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1); - } - - dst[i] = col * m_k + x[i]; -} - -static void alibi_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, - const int k_rows, const int n_heads_log2_floor, const float m0, - const float m1, cudaStream_t stream) { - const dim3 block_dims(CUDA_ALIBI_BLOCK_SIZE, 1, 1); - const int num_blocks_x = (ncols + CUDA_ALIBI_BLOCK_SIZE - 1) / (CUDA_ALIBI_BLOCK_SIZE); - const dim3 block_nums(num_blocks_x, nrows, 1); - alibi_f32<<>>(x, dst, ncols, k_rows, n_heads_log2_floor, m0, m1); -} - -void ggml_cuda_op_alibi(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * src0 = dst->src[0]; - const float * src0_d = (const float *)src0->data; - float * dst_d = (float *)dst->data; - cudaStream_t stream = ctx.stream(); - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - const int64_t ne00 = src0->ne[0]; - const int64_t ne01 = src0->ne[1]; - const int64_t ne02 = src0->ne[2]; - const int64_t nrows = ggml_nrows(src0); - - //const int n_past = ((int32_t *) dst->op_params)[0]; - const int n_head = ((int32_t *) dst->op_params)[1]; - float max_bias; - memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float)); - - //GGML_ASSERT(ne01 + n_past == ne00); - GGML_ASSERT(n_head == ne02); - - const int n_heads_log2_floor = 1 << (int) floor(log2(n_head)); - - const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor); - const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor); - - alibi_f32_cuda(src0_d, dst_d, ne00, nrows, ne01, n_heads_log2_floor, m0, m1, stream); -} diff --git a/ggml-cuda/alibi.cuh b/ggml-cuda/alibi.cuh deleted file mode 100644 index 630adfc7f..000000000 --- a/ggml-cuda/alibi.cuh +++ /dev/null @@ -1,5 +0,0 @@ -#include "common.cuh" - -#define CUDA_ALIBI_BLOCK_SIZE 32 - -void ggml_cuda_op_alibi(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index 7c486f482..ac5d6672b 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -23,6 +23,10 @@ static __global__ void flash_attn_vec_ext_f16( float * __restrict__ dst, float2 * __restrict__ dst_meta, const float scale, + const float max_bias, + const float m0, + const float m1, + const uint32_t n_head_log2, const int ne00, const int ne01, const int ne02, @@ -58,6 +62,18 @@ static __global__ void flash_attn_vec_ext_f16( const int stride_KV = nb11 / sizeof(half); const int stride_KV2 = nb11 / sizeof(half2); + half slopeh = __float2half(1.0f); + + // ALiBi + if (max_bias > 0.0f) { + const int h = blockIdx.y; + + const float base = h < n_head_log2 ? m0 : m1; + const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + slopeh = __float2half(powf(base, exph)); + } + static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); constexpr int nwarps = D / WARP_SIZE; const int tid = WARP_SIZE*threadIdx.y + threadIdx.x; @@ -141,7 +157,7 @@ static __global__ void flash_attn_vec_ext_f16( for (int j = 0; j < ncols; ++j) { sum2[j] = warp_reduce_sum(sum2[j]); half sum = __low2half(sum2[j]) + __high2half(sum2[j]); - sum += mask ? maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f); + sum += mask ? slopeh*maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f); if (ncols == 1) { kqmax_new = ggml_cuda_hmax(kqmax_new, sum); @@ -249,6 +265,10 @@ static __global__ void flash_attn_ext_f16( float * __restrict__ dst, float2 * __restrict__ dst_meta, const float scale, + const float max_bias, + const float m0, + const float m1, + const uint32_t n_head_log2, const int ne00, const int ne01, const int ne02, @@ -305,6 +325,20 @@ static __global__ void flash_attn_ext_f16( const int stride_Q = nb01 / sizeof(float); const int stride_KV = nb11 / sizeof(half); + half slopeh = __float2half(1.0f); + half2 slope2 = make_half2(1.0f, 1.0f); + + // ALiBi + if (max_bias > 0.0f) { + const int h = blockIdx.y; + + const float base = h < n_head_log2 ? m0 : m1; + const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + slopeh = __float2half(powf(base, exph)); + slope2 = make_half2(slopeh, slopeh); + } + frag_b Q_b[D/16][ncols/frag_n]; // A single buffer for temporarily holding tiles of KQ and VKQ parts: @@ -421,7 +455,7 @@ static __global__ void flash_attn_ext_f16( for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) { const int k = k0 + threadIdx.x; - KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f; + KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f; KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]); } KQ_max_new = warp_reduce_max(KQ_max_new); @@ -464,7 +498,7 @@ static __global__ void flash_attn_ext_f16( for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { const int k = k0 + threadIdx.x; - KQ2_tmp[k0/WARP_SIZE] += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); + KQ2_tmp[k0/WARP_SIZE] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]); } KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new)))); @@ -710,8 +744,17 @@ template void launch_fattn_vec_ const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]); const int shmem = 0; - float scale; - memcpy(&scale, KQV->op_params, sizeof(float)); + float scale = 1.0f; + float max_bias = 0.0f; + + memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float)); + memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float)); + + const uint32_t n_head = Q->ne[2]; + const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); flash_attn_vec_ext_f16 <<>> ( @@ -720,7 +763,7 @@ template void launch_fattn_vec_ (const char *) V->data, mask ? ((const char *) mask->data) : nullptr, parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, - scale, + scale, max_bias, m0, m1, n_head_log2, Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], K->ne[0], K->ne[1], K->ne[2], K->ne[3], mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, @@ -761,8 +804,17 @@ template ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]); const int shmem = 0; - float scale; - memcpy(&scale, KQV->op_params, sizeof(float)); + float scale = 1.0f; + float max_bias = 0.0f; + + memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float)); + memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float)); + + const uint32_t n_head = Q->ne[2]; + const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); flash_attn_ext_f16 <<>> ( @@ -771,7 +823,7 @@ template data, mask ? ((const char *) mask->data) : nullptr, (parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, - scale, + scale, max_bias, m0, m1, n_head_log2, Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], K->ne[0], K->ne[1], K->ne[2], K->ne[3], mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, @@ -837,7 +889,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; - const int32_t precision = KQV->op_params[1]; + const int32_t precision = KQV->op_params[2]; if (!fp16_mma_available(cc)) { GGML_ASSERT(precision == GGML_PREC_DEFAULT); diff --git a/ggml-cuda/softmax.cu b/ggml-cuda/softmax.cu index 6ed225999..ca85285a3 100644 --- a/ggml-cuda/softmax.cu +++ b/ggml-cuda/softmax.cu @@ -11,7 +11,7 @@ __device__ float __forceinline__ t2f32(half val) { } template -static __global__ void soft_max_f32(const float * x, const T * mask, const T * pos, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) { +static __global__ void soft_max_f32(const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) { const int ncols = ncols_template == 0 ? ncols_par : ncols_template; const int tid = threadIdx.x; @@ -23,16 +23,16 @@ static __global__ void soft_max_f32(const float * x, const T * mask, const T * p const int warp_id = threadIdx.x / WARP_SIZE; const int lane_id = threadIdx.x % WARP_SIZE; - float slope = 0.0f; + float slope = 1.0f; // ALiBi if (max_bias > 0.0f) { const int h = rowx/nrows_y; // head index const float base = h < n_head_log2 ? m0 : m1; - const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; - slope = powf(base, exp); + slope = powf(base, exph); } extern __shared__ float data_soft_max_f32[]; @@ -53,7 +53,7 @@ static __global__ void soft_max_f32(const float * x, const T * mask, const T * p const int64_t ix = (int64_t)rowx*ncols + col; const int64_t iy = (int64_t)rowy*ncols + col; - const float val = x[ix]*scale + (mask ? t2f32(mask[iy]) : 0.0f) + (pos ? slope*t2f32(pos[col]) : 0.0f); + const float val = x[ix]*scale + (mask ? slope*t2f32(mask[iy]) : 0.0f); vals[col] = val; max_val = max(max_val, val); @@ -125,7 +125,7 @@ static __global__ void soft_max_f32(const float * x, const T * mask, const T * p } template -static void soft_max_f32_cuda(const float * x, const T * mask, const T * pos, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) { +static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) { int nth = WARP_SIZE; while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2; const dim3 block_dims(nth, 1, 1); @@ -133,8 +133,8 @@ static void soft_max_f32_cuda(const float * x, const T * mask, const T * pos, fl const size_t shmem = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float); static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted."); - const uint32_t n_head_kv = nrows_x/nrows_y; - const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv)); + const uint32_t n_head = nrows_x/nrows_y; + const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); @@ -142,43 +142,42 @@ static void soft_max_f32_cuda(const float * x, const T * mask, const T * pos, fl if (shmem < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) { switch (ncols_x) { case 32: - soft_max_f32<<>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; case 64: - soft_max_f32<<>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; case 128: - soft_max_f32<<>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; case 256: - soft_max_f32<<>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; case 512: - soft_max_f32<<>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; case 1024: - soft_max_f32<<>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; case 2048: - soft_max_f32<<>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; case 4096: - soft_max_f32<<>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; default: - soft_max_f32<<>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; } } else { const size_t shmem_low = WARP_SIZE*sizeof(float); - soft_max_f32<<>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); } } void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; - const ggml_tensor * src2 = dst->src[2]; const float * src0_d = (const float *)src0->data; const void * src1_d = src1 ? (const void *)src1->data : nullptr; @@ -190,7 +189,6 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { GGML_ASSERT( dst->type == GGML_TYPE_F32); GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional - GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F16 || src2->type == GGML_TYPE_F32); // src2 contains positions and it is optional const int64_t ne00 = src0->ne[0]; const int64_t nrows_x = ggml_nrows(src0); @@ -202,26 +200,15 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); - // positions tensor - void * src2_d = nullptr; - - const bool use_src2 = src2 != nullptr; - - if (use_src2) { - src2_d = (void *)src2->data; - } - - const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16); + const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); if (use_f16) { const half * src1_dd = (const half *)src1_d; - const half * src2_dd = (const half *)src2_d; - soft_max_f32_cuda(src0_d, src1_dd, src2_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream); + soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream); } else { const float * src1_dd = (const float *)src1_d; - const float * src2_dd = (const float *)src2_d; - soft_max_f32_cuda(src0_d, src1_dd, src2_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream); + soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream); } } diff --git a/ggml-kompute.cpp b/ggml-kompute.cpp index 9a469821d..3f033d58b 100644 --- a/ggml-kompute.cpp +++ b/ggml-kompute.cpp @@ -1559,12 +1559,18 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml case GGML_OP_SOFT_MAX: { float scale; - memcpy(&scale, dst->op_params, sizeof(float)); + float max_bias; -#pragma message("TODO: add ggml_vk_soft_max() F16/F32 src1 and src2 support") + memcpy(&scale, (float *)dst->op_params + 0, sizeof(float)); + memcpy(&max_bias, (float *)dst->op_params + 1, sizeof(float)); + +#pragma message("TODO: add ggml_vk_soft_max() F16 src1 support") #pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021") GGML_ASSERT(!src1 || src1t == GGML_TYPE_F32); - GGML_ASSERT(src2 == nullptr); + +#pragma message("TODO: add ALiBi support") +#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/7192") + GGML_ASSERT(max_bias == 0.0f); ggml_vk_soft_max(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale); } break; diff --git a/ggml-metal.m b/ggml-metal.m index 18ce5b88a..1bbb8fb4f 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -169,7 +169,6 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, GGML_METAL_KERNEL_TYPE_ROPE_F32, GGML_METAL_KERNEL_TYPE_ROPE_F16, - GGML_METAL_KERNEL_TYPE_ALIBI_F32, GGML_METAL_KERNEL_TYPE_IM2COL_F16, GGML_METAL_KERNEL_TYPE_IM2COL_F32, GGML_METAL_KERNEL_TYPE_UPSCALE_F32, @@ -623,7 +622,6 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true); @@ -759,7 +757,6 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const case GGML_OP_GROUP_NORM: return ctx->support_simdgroup_reduction; case GGML_OP_NORM: - case GGML_OP_ALIBI: case GGML_OP_ROPE: case GGML_OP_IM2COL: return true; @@ -1358,13 +1355,12 @@ static enum ggml_status ggml_metal_graph_compute( case GGML_OP_SOFT_MAX: { GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); - GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F16 || src2->type == GGML_TYPE_F32); int nth = 32; // SIMD width id pipeline = nil; - const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16); + const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); if (ne00%4 == 0) { while (nth < ne00/4 && nth < 256) { @@ -1395,8 +1391,8 @@ static enum ggml_status ggml_metal_graph_compute( const int64_t nrows_x = ggml_nrows(src0); const int64_t nrows_y = src0->ne[1]; - const uint32_t n_head_kv = nrows_x/nrows_y; - const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv)); + const uint32_t n_head = nrows_x/nrows_y; + const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); @@ -1408,20 +1404,15 @@ static enum ggml_status ggml_metal_graph_compute( } else { [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; } - if (id_src2) { - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; - } else { - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:2]; - } - [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:4]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:5]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:6]; - [encoder setBytes:&scale length:sizeof(scale) atIndex:7]; - [encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:8]; - [encoder setBytes:&m0 length:sizeof(m0) atIndex:9]; - [encoder setBytes:&m1 length:sizeof(m1) atIndex:10]; - [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:11]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; + [encoder setBytes:&scale length:sizeof(scale) atIndex:6]; + [encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:7]; + [encoder setBytes:&m0 length:sizeof(m0) atIndex:8]; + [encoder setBytes:&m1 length:sizeof(m1) atIndex:9]; + [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10]; [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; @@ -2226,49 +2217,6 @@ static enum ggml_status ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; - case GGML_OP_ALIBI: - { - GGML_ASSERT((src0t == GGML_TYPE_F32)); - - const int nth = MIN(1024, ne00); - - //const int n_past = ((int32_t *) dst->op_params)[0]; - const int n_head = ((int32_t *) dst->op_params)[1]; - - float max_bias; - memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float)); - - const int n_heads_log2_floor = 1 << (int) floor(log2(n_head)); - const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor); - const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor); - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ALIBI_F32].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; - [encoder setBytes:&m0 length:sizeof( float) atIndex:18]; - [encoder setBytes:&m1 length:sizeof( float) atIndex:19]; - [encoder setBytes:&n_heads_log2_floor length:sizeof(int) atIndex:20]; - - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; case GGML_OP_ROPE: { GGML_ASSERT(ne10 == ne02); @@ -2566,7 +2514,7 @@ static enum ggml_status ggml_metal_graph_compute( "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big"); const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30); - const int64_t ne31 = src3 ? src3->ne[1] : 0; + //const int64_t ne31 = src3 ? src3->ne[1] : 0; const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32); const int64_t ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33); @@ -2578,7 +2526,16 @@ static enum ggml_status ggml_metal_graph_compute( const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t); float scale; - memcpy(&scale, dst->op_params, sizeof(float)); + float max_bias; + + memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale)); + memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias)); + + const uint32_t n_head = src0->ne[2]; + const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); id pipeline = nil; @@ -2615,34 +2572,37 @@ static enum ggml_status ggml_metal_graph_compute( } [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; - [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:4]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:6]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:7]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:8]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:10]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:11]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:12]; - [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:14]; - [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:15]; - [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:16]; - [encoder setBytes:&nb10 length:sizeof(uint64_t) atIndex:17]; - [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:18]; - [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:19]; - [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:20]; - [encoder setBytes:&ne31 length:sizeof( int64_t) atIndex:21]; - [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:22]; - [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:23]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:24]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:25]; - [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; - [encoder setBytes:&scale length:sizeof( float) atIndex:27]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; + [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:4]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:5]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:6]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:7]; + [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:8]; + [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:10]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:11]; + [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:12]; + [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:14]; + [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:15]; + [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:16]; + [encoder setBytes:&nb10 length:sizeof(uint64_t) atIndex:17]; + [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:18]; + [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:19]; + [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:20]; + [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:21]; + [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:22]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:23]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:24]; + [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:25]; + [encoder setBytes:&scale length:sizeof( float) atIndex:26]; + [encoder setBytes:&max_bias length:sizeof( float) atIndex:27]; + [encoder setBytes:&m0 length:sizeof(m0) atIndex:28]; + [encoder setBytes:&m1 length:sizeof(m1) atIndex:29]; + [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:30]; if (!use_vec_kernel) { // half8x8 kernel diff --git a/ggml-metal.metal b/ggml-metal.metal index 46c7d5039..ee9de57a3 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -356,7 +356,6 @@ template kernel void kernel_soft_max( device const char * src0, device const char * src1, - device const char * src2, device char * dst, constant int64_t & ne00, constant int64_t & ne01, @@ -378,10 +377,9 @@ kernel void kernel_soft_max( device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00 : nullptr; - device const T * ppos = src2 != src0 ? (device const T *) src2 : nullptr; device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); - float slope = 0.0f; + float slope = 1.0f; // ALiBi if (max_bias > 0.0f) { @@ -397,7 +395,7 @@ kernel void kernel_soft_max( float lmax = -INFINITY; for (int i00 = tpitg; i00 < ne00; i00 += ntg) { - lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)); + lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)); } // find the max value in the block @@ -422,7 +420,7 @@ kernel void kernel_soft_max( // parallel sum float lsum = 0.0f; for (int i00 = tpitg; i00 < ne00; i00 += ntg) { - const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)) - max_val); + const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val); lsum += exp_psrc0; pdst[i00] = exp_psrc0; } @@ -461,7 +459,6 @@ template kernel void kernel_soft_max_4( device const char * src0, device const char * src1, - device const char * src2, device char * dst, constant int64_t & ne00, constant int64_t & ne01, @@ -483,10 +480,9 @@ kernel void kernel_soft_max_4( device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4; device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00/4 : nullptr; - device const T * ppos = src2 != src0 ? (device const T *) src2 : nullptr; device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4; - float slope = 0.0f; + float slope = 1.0f; if (max_bias > 0.0f) { const int64_t h = i02; @@ -501,7 +497,7 @@ kernel void kernel_soft_max_4( float4 lmax4 = -INFINITY; for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { - lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f))); + lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))); } const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3])); @@ -527,7 +523,7 @@ kernel void kernel_soft_max_4( // parallel sum float4 lsum4 = 0.0f; for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { - const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f))) - max_val); + const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val); lsum4 += exp_psrc4; pdst4[i00] = exp_psrc4; } @@ -1595,60 +1591,6 @@ kernel void kernel_mul_mv_f16_f32_l4( } } -kernel void kernel_alibi_f32( - device const float * src0, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - constant float & m0, - constant float & m1, - constant int & n_heads_log2_floor, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig[2]; - const int64_t i02 = tgpig[1]; - const int64_t i01 = tgpig[0]; - - const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - const int64_t i3 = n / (ne2*ne1*ne0); - const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - //const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); - - const int64_t k = i3*ne3 + i2; - - float m_k; - if (k < n_heads_log2_floor) { - m_k = pow(m0, k + 1); - } else { - m_k = pow(m1, 2 * (k - n_heads_log2_floor) + 1); - } - - device char * dst_row = (device char *) dst + i3*nb3 + i2*nb2 + i1*nb1; - device const char * src_row = (device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01; - for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { - const float src_v = *(device float *)(src_row + i00*nb00); - device float * dst_v = (device float *)(dst_row + i00*nb0); - *dst_v = i00 * m_k + src_v; - } -} - static float rope_yarn_ramp(const float low, const float high, const int i0) { const float y = (i0 / 2 - low) / max(0.001f, high - low); return 1.0f - min(1.0f, max(0.0f, y)); @@ -2116,13 +2058,16 @@ typedef void (flash_attn_ext_f16_t)( constant uint64_t & nb11, constant uint64_t & nb12, constant uint64_t & nb13, - constant int64_t & ne31, constant uint64_t & nb31, constant int64_t & ne0, constant int64_t & ne1, constant int64_t & ne2, constant int64_t & ne3, constant float & scale, + constant float & max_bias, + constant float & m0, + constant float & m1, + constant uint32_t & n_head_log2, threadgroup half * shared, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], @@ -2154,13 +2099,16 @@ kernel void kernel_flash_attn_ext_f16( constant uint64_t & nb11, constant uint64_t & nb12, constant uint64_t & nb13, - constant int64_t & ne31, constant uint64_t & nb31, constant int64_t & ne0, constant int64_t & ne1, constant int64_t & ne2, constant int64_t & ne3, constant float & scale, + constant float & max_bias, + constant float & m0, + constant float & m1, + constant uint32_t & n_head_log2, threadgroup half * shared [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], @@ -2257,6 +2205,19 @@ kernel void kernel_flash_attn_ext_f16( // prepare diagonal scale matrix simdgroup_float8x8 mscale(scale); + // prepare diagonal slope matrix + simdgroup_float8x8 mslope(1.0f); + + // ALiBi + if (max_bias > 0.0f) { + const short h = iq2; + + const float base = h < n_head_log2 ? m0 : m1; + const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + mslope = simdgroup_float8x8(pow(base, exph)); + } + // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) { @@ -2279,9 +2240,10 @@ kernel void kernel_flash_attn_ext_f16( simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk); } - // mqk = mqk*scale + mask + // mqk = mqk*scale + mask*slope simdgroup_half8x8 mm; simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(half), 0, false); + simdgroup_multiply(mm, mslope, mm); simdgroup_multiply_accumulate(mqk, mqk, mscale, mm); simdgroup_store(mqk, ss + 8*cc, TF, 0, false); @@ -2472,13 +2434,16 @@ kernel void kernel_flash_attn_ext_vec_f16( constant uint64_t & nb11, constant uint64_t & nb12, constant uint64_t & nb13, - constant int64_t & ne31, constant uint64_t & nb31, constant int64_t & ne0, constant int64_t & ne1, constant int64_t & ne2, constant int64_t & ne3, constant float & scale, + constant float & max_bias, + constant float & m0, + constant float & m1, + constant uint32_t & n_head_log2, threadgroup half * shared [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], @@ -2497,6 +2462,18 @@ kernel void kernel_flash_attn_ext_vec_f16( const short T = D + 2*nsg*SH; // shared memory size per query in (half) + float slope = 1.0f; + + // ALiBi + if (max_bias > 0.0f) { + const short h = iq2; + + const float base = h < n_head_log2 ? m0 : m1; + const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + slope = pow(base, exp); + } + //threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4 threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix @@ -2603,10 +2580,10 @@ kernel void kernel_flash_attn_ext_vec_f16( mqk += simd_shuffle_down(mqk, 2); mqk += simd_shuffle_down(mqk, 1); - // mqk = mqk*scale + mask + // mqk = mqk*scale + mask*slope if (tiisg == 0) { float4 mm = (float4) mp4[ic/4 + cc]; - mqk = mqk*scale + mm; + mqk = mqk*scale + mm*slope; ss4[cc] = mqk; } @@ -2840,7 +2817,8 @@ kernel void kernel_cpy_f32_f16( for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - dst_data[i00] = src[0]; + // TODO: is there a better way to handle -INFINITY? + dst_data[i00] = src[0] == -INFINITY ? -MAXHALF : src[0]; } } diff --git a/ggml-sycl.cpp b/ggml-sycl.cpp index 79aec4d9f..e93d2af63 100644 --- a/ggml-sycl.cpp +++ b/ggml-sycl.cpp @@ -3154,7 +3154,6 @@ typedef float (*vec_dot_q_mul_mat_sycl_t)( #define SYCL_SCALE_BLOCK_SIZE 256 #define SYCL_CLAMP_BLOCK_SIZE 256 #define SYCL_ROPE_BLOCK_SIZE 256 -#define SYCL_ALIBI_BLOCK_SIZE 32 #define SYCL_DIAG_MASK_INF_BLOCK_SIZE 32 #define SYCL_QUANTIZE_BLOCK_SIZE 256 #define SYCL_DEQUANTIZE_BLOCK_SIZE 256 @@ -9316,32 +9315,6 @@ static void rope_glm_f32( dst[i + half_n_dims * 3] = x2*sin_block_theta + x3*cos_block_theta; } -static void alibi_f32(const float * x, float * dst, const int ncols, const int k_rows, - const int n_heads_log2_floor, const float m0, const float m1, - const sycl::nd_item<3> &item_ct1) { - const int col = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - - if (col >= ncols) { - return; - } - - const int row = item_ct1.get_local_range(1) * item_ct1.get_group(1) + - item_ct1.get_local_id(1); - const int i = row*ncols + col; - - const int k = row/k_rows; - - float m_k; - if (k < n_heads_log2_floor) { - m_k = dpct::pow(m0, k + 1); - } else { - m_k = dpct::pow(m1, 2 * (k - n_heads_log2_floor) + 1); - } - - dst[i] = col * m_k + x[i]; -} - static void k_sum_rows_f32(const float * x, float * dst, const int ncols, const sycl::nd_item<3> &item_ct1) { const int row = item_ct1.get_group(1); @@ -9443,7 +9416,7 @@ static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, con template -static void soft_max_f32(const float * x, const float * mask, const float *pos, float * dst, const int ncols_par, +static void soft_max_f32(const float * x, const float * mask, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2, const sycl::nd_item<3> &item_ct1, float *buf) { const int ncols = ncols_template == 0 ? ncols_par : ncols_template; @@ -9457,7 +9430,7 @@ static void soft_max_f32(const float * x, const float * mask, const float *pos, const int warp_id = item_ct1.get_local_id(2) / WARP_SIZE; const int lane_id = item_ct1.get_local_id(2) % WARP_SIZE; - float slope = 0.0f; + float slope = 1.0f; // ALiBi if (max_bias > 0.0f) { @@ -9482,7 +9455,7 @@ static void soft_max_f32(const float * x, const float * mask, const float *pos, const int ix = rowx*ncols + col; const int iy = rowy*ncols + col; - const float val = x[ix]*scale + (mask ? mask[iy] : 0.0f) + (pos ? slope*pos[col] : 0.0f); + const float val = x[ix]*scale + (mask ? slope*mask[iy] : 0.0f); vals[col] = val; max_val = sycl::max(max_val, val); @@ -12964,20 +12937,6 @@ static void rope_glm_f32_sycl(const float *x, float *dst, int ncols, int nrows, }); } -static void alibi_f32_sycl(const float *x, float *dst, const int ncols, - const int nrows, const int k_rows, - const int n_heads_log2_floor, const float m0, - const float m1, dpct::queue_ptr stream) { - const sycl::range<3> block_dims(1, 1, SYCL_ALIBI_BLOCK_SIZE); - const int num_blocks_x = (ncols + SYCL_ALIBI_BLOCK_SIZE - 1) / (SYCL_ALIBI_BLOCK_SIZE); - const sycl::range<3> block_nums(1, nrows, num_blocks_x); - stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - alibi_f32(x, dst, ncols, k_rows, - n_heads_log2_floor, m0, m1, item_ct1); - }); -} - static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols, const int nrows, dpct::queue_ptr stream) { const sycl::range<3> block_dims(1, 1, WARP_SIZE); @@ -13058,7 +13017,7 @@ static void diag_mask_inf_f32_sycl(const float *x, float *dst, } template -static void soft_max_f32_submitter(const float * x, const float * mask, const float *pos, float * dst, const int ncols_par, +static void soft_max_f32_submitter(const float * x, const float * mask, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2, sycl::range<3> block_nums, sycl::range<3> block_dims, const size_t n_local_scratch, dpct::queue_ptr stream) { @@ -13068,7 +13027,7 @@ static void soft_max_f32_submitter(const float * x, const float * mask, const fl cgh.parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] { - soft_max_f32(x, mask, pos, dst, ncols_par, + soft_max_f32(x, mask, dst, ncols_par, nrows_y, scale, max_bias, m0, m1, n_head_log2, item_ct1, local_buf_acc.get_pointer()); @@ -13076,7 +13035,7 @@ static void soft_max_f32_submitter(const float * x, const float * mask, const fl }); } -static void soft_max_f32_sycl(const float * x, const float * mask, const float * pos, +static void soft_max_f32_sycl(const float * x, const float * mask, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, dpct::queue_ptr stream) { @@ -13098,60 +13057,60 @@ static void soft_max_f32_sycl(const float * x, const float * mask, const float * const size_t local_mem_size = stream->get_device().get_info(); if (n_local_scratch*sizeof(float) < local_mem_size) { if (ncols_x > max_block_size) { - soft_max_f32_submitter(x, mask, pos, dst, ncols_x, nrows_y, scale, + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, block_nums, block_dims, n_local_scratch, stream); return; } switch (ncols_x) { case 32: - soft_max_f32_submitter(x, mask, pos, dst, ncols_x, nrows_y, scale, + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, block_nums, block_dims, n_local_scratch, stream); break; case 64: - soft_max_f32_submitter(x, mask, pos, dst, ncols_x, nrows_y, scale, + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, block_nums, block_dims, n_local_scratch, stream); break; case 128: - soft_max_f32_submitter(x, mask, pos, dst, ncols_x, nrows_y, scale, + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, block_nums, block_dims, n_local_scratch, stream); break; case 256: - soft_max_f32_submitter(x, mask, pos, dst, ncols_x, nrows_y, scale, + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, block_nums, block_dims, n_local_scratch, stream); break; case 512: - soft_max_f32_submitter(x, mask, pos, dst, ncols_x, nrows_y, scale, + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, block_nums, block_dims, n_local_scratch, stream); break; case 1024: - soft_max_f32_submitter(x, mask, pos, dst, ncols_x, nrows_y, scale, + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, block_nums, block_dims, n_local_scratch, stream); break; case 2048: - soft_max_f32_submitter(x, mask, pos, dst, ncols_x, nrows_y, scale, + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, block_nums, block_dims, n_local_scratch, stream); break; case 4096: - soft_max_f32_submitter(x, mask, pos, dst, ncols_x, nrows_y, scale, + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, block_nums, block_dims, n_local_scratch, stream); break; default: - soft_max_f32_submitter(x, mask, pos, dst, ncols_x, nrows_y, scale, + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, block_nums, block_dims, n_local_scratch, stream); break; } } else { - soft_max_f32_submitter(x, mask, pos, dst, ncols_x, nrows_y, scale, + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, block_nums, block_dims, WARP_SIZE, stream); } @@ -14562,36 +14521,6 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1, (void) src1_dd; } -inline void ggml_sycl_op_alibi(const ggml_tensor *src0, const ggml_tensor *src1, - ggml_tensor *dst, const float *src0_dd, - const float *src1_dd, float *dst_dd, - const dpct::queue_ptr &main_stream) { - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - GGML_TENSOR_LOCALS_3(int64_t, ne0, src0, ne); - const int64_t nrows = ggml_nrows(src0); - - //const int n_past = ((int32_t *) dst->op_params)[0]; - const int n_head = ((int32_t *) dst->op_params)[1]; - float max_bias; - memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float)); - - //GGML_ASSERT(ne01 + n_past == ne00); - GGML_ASSERT(n_head == ne02); - - const int n_heads_log2_floor = 1 << (int) floor(log2(n_head)); - - const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor); - const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor); - - alibi_f32_sycl(src0_dd, dst_dd, ne00, nrows, ne01, n_heads_log2_floor, m0, m1, main_stream); - - (void) src1; - (void) src1_dd; -} - static void ggml_sycl_op_pool2d(const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const float *src0_dd, const float *src1_dd, @@ -14746,12 +14675,9 @@ inline void ggml_sycl_op_soft_max(const ggml_tensor *src0, GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); - const ggml_tensor * src2 = dst->src[2]; - -#pragma message("TODO: add ggml_sycl_op_soft_max() F16 src1 and src2 support") +#pragma message("TODO: add ggml_sycl_op_soft_max() F16 src1 support") #pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021") GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional - GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32); // src2 contains positions and it is optional const int64_t ne00 = src0->ne[0]; const int64_t nrows_x = ggml_nrows(src0); @@ -14763,25 +14689,7 @@ inline void ggml_sycl_op_soft_max(const ggml_tensor *src0, memcpy(&scale, dst->op_params + 0, sizeof(float)); memcpy(&max_bias, dst->op_params + 1, sizeof(float)); - // positions tensor - float * src2_dd = nullptr; - sycl_pool_alloc src2_f; - - const bool use_src2 = src2 != nullptr; - - if (use_src2) { - const bool src2_on_device = src2->backend == GGML_BACKEND_TYPE_GPU; - - if (src2_on_device) { - ggml_tensor_extra_gpu * src2_extra = (ggml_tensor_extra_gpu *) src2->extra; - src2_dd = (float *) src2_extra->data_device[g_main_device]; - } else { - src2_dd = src2_f.alloc(ggml_nelements(src2)); - SYCL_CHECK(ggml_sycl_cpy_tensor_2d(src2_dd, src2, 0, 0, 0, 1, main_stream)); - } - } - - soft_max_f32_sycl(src0_dd, src1 ? src1_dd : nullptr, src2_dd, dst_dd, ne00, + soft_max_f32_sycl(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream); } @@ -16232,10 +16140,6 @@ static void ggml_sycl_rope(const ggml_tensor * src0, const ggml_tensor * src1, g ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_rope); } -static void ggml_sycl_alibi(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_alibi); -} - static void ggml_sycl_pool2d(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_pool2d); } @@ -16612,9 +16516,6 @@ bool ggml_sycl_compute_forward(struct ggml_compute_params * params, struct ggml_ case GGML_OP_ROPE: func = ggml_sycl_rope; break; - case GGML_OP_ALIBI: - func = ggml_sycl_alibi; - break; case GGML_OP_IM2COL: func = ggml_sycl_im2col; break; @@ -17744,7 +17645,6 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: case GGML_OP_ROPE: - case GGML_OP_ALIBI: case GGML_OP_IM2COL: case GGML_OP_POOL_2D: case GGML_OP_SUM_ROWS: diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp index 95f718974..b9449be03 100644 --- a/ggml-vulkan.cpp +++ b/ggml-vulkan.cpp @@ -3830,9 +3830,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return nullptr; case GGML_OP_SOFT_MAX: GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); - GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32 || src2->type == GGML_TYPE_F16); - if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && (src2 == nullptr || src2->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) { + if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_soft_max_f32; } if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16 && src2->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) { @@ -4286,6 +4285,9 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context * subctx, const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); +#pragma message("TODO: src2 is no longer used in soft_max - should be removed and ALiBi calculation should be updated") +#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/7192") + ggml_vk_op_f32(ctx, subctx, src0, src1, src2, dst, GGML_OP_SOFT_MAX, { ncols, src1 != nullptr ? nrows_y : (uint32_t)0, diff --git a/ggml.c b/ggml.c index 093d38d00..4ee5d24af 100644 --- a/ggml.c +++ b/ggml.c @@ -2185,7 +2185,6 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "SOFT_MAX_BACK", "ROPE", "ROPE_BACK", - "ALIBI", "CLAMP", "CONV_TRANSPOSE_1D", "IM2COL", @@ -2227,7 +2226,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS_BACK", }; -static_assert(GGML_OP_COUNT == 77, "GGML_OP_COUNT != 77"); +static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -2276,7 +2275,6 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "soft_max_back(x)", "rope(x)", "rope_back(x)", - "alibi(x)", "clamp(x)", "conv_transpose_1d(x)", "im2col(x)", @@ -2318,7 +2316,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cross_entropy_loss_back(x,y)", }; -static_assert(GGML_OP_COUNT == 77, "GGML_OP_COUNT != 77"); +static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -5646,7 +5644,6 @@ static struct ggml_tensor * ggml_soft_max_impl( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * mask, - struct ggml_tensor * pos, float scale, float max_bias, bool inplace) { @@ -5660,18 +5657,8 @@ static struct ggml_tensor * ggml_soft_max_impl( GGML_ASSERT(mask->ne[1] >= a->ne[1]); } - if (pos) { - GGML_ASSERT(ggml_is_vector(pos)); - GGML_ASSERT(pos->type == GGML_TYPE_F16 || pos->type == GGML_TYPE_F32); - GGML_ASSERT(pos->ne[0] == a->ne[0]); - } - - if (pos && mask) { - GGML_ASSERT(pos->type == mask->type); - } - if (max_bias > 0.0f) { - GGML_ASSERT(pos); + GGML_ASSERT(mask); } bool is_node = false; @@ -5689,7 +5676,6 @@ static struct ggml_tensor * ggml_soft_max_impl( result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; result->src[1] = mask; - result->src[2] = pos; return result; } @@ -5697,23 +5683,22 @@ static struct ggml_tensor * ggml_soft_max_impl( struct ggml_tensor * ggml_soft_max( struct ggml_context * ctx, struct ggml_tensor * a) { - return ggml_soft_max_impl(ctx, a, NULL, NULL, 1.0f, 0.0f, false); + return ggml_soft_max_impl(ctx, a, NULL, 1.0f, 0.0f, false); } struct ggml_tensor * ggml_soft_max_inplace( struct ggml_context * ctx, struct ggml_tensor * a) { - return ggml_soft_max_impl(ctx, a, NULL, NULL, 1.0f, 0.0f, true); + return ggml_soft_max_impl(ctx, a, NULL, 1.0f, 0.0f, true); } struct ggml_tensor * ggml_soft_max_ext( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * mask, - struct ggml_tensor * pos, float scale, float max_bias) { - return ggml_soft_max_impl(ctx, a, mask, pos, scale, max_bias, false); + return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, false); } // ggml_soft_max_back @@ -5928,37 +5913,6 @@ struct ggml_tensor * ggml_rope_back( return result; } -// ggml_alibi - -struct ggml_tensor * ggml_alibi( - struct ggml_context * ctx, - struct ggml_tensor * a, - int n_past, - int n_head, - float bias_max) { - GGML_ASSERT(n_past >= 0); - bool is_node = false; - - if (a->grad) { - GGML_ASSERT(false); // TODO: implement backward - is_node = true; - } - - // TODO: when implement backward, fix this: - //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - struct ggml_tensor * result = ggml_view_tensor(ctx, a); - - int32_t op_params[3] = { n_past, n_head }; - memcpy(op_params + 2, &bias_max, sizeof(float)); - ggml_set_op_params(result, op_params, sizeof(op_params)); - - result->op = GGML_OP_ALIBI; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - - return result; -} - // ggml_clamp struct ggml_tensor * ggml_clamp( @@ -6486,9 +6440,11 @@ struct ggml_tensor * ggml_flash_attn_ext( struct ggml_tensor * k, struct ggml_tensor * v, struct ggml_tensor * mask, - float scale) { + float scale, + float max_bias) { GGML_ASSERT(ggml_can_mul_mat(k, q)); // TODO: check if vT can be multiplied by (k*qT) + if (mask) { GGML_ASSERT(ggml_is_contiguous(mask)); GGML_ASSERT(mask->ne[2] == 1); @@ -6498,6 +6454,10 @@ struct ggml_tensor * ggml_flash_attn_ext( //GGML_ASSERT(ggml_can_repeat_rows(mask, qk)); } + if (max_bias > 0.0f) { + GGML_ASSERT(mask); + } + bool is_node = false; if (q->grad || k->grad || v->grad) { @@ -6508,7 +6468,7 @@ struct ggml_tensor * ggml_flash_attn_ext( int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); - float params[] = { scale }; + float params[] = { scale, max_bias }; ggml_set_op_params(result, params, sizeof(params)); result->op = GGML_OP_FLASH_ATTN_EXT; @@ -6528,7 +6488,7 @@ void ggml_flash_attn_ext_set_prec( const int32_t prec_i32 = (int32_t) prec; - ggml_set_op_params_i32(a, 1, prec_i32); // scale is on first pos + ggml_set_op_params_i32(a, 2, prec_i32); // scale is on first pos, max_bias on second } // ggml_flash_ff @@ -13333,7 +13293,6 @@ static void ggml_compute_forward_soft_max_f32( const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src1 = dst->src[1]; - const struct ggml_tensor * src2 = dst->src[2]; assert(ggml_is_contiguous(dst)); assert(ggml_are_same_shape(src0, dst)); @@ -13359,8 +13318,8 @@ static void ggml_compute_forward_soft_max_f32( // TODO: is this supposed to be ceil instead of floor? // https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370 - const uint32_t n_head_kv = ne02; - const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head_kv)); + const uint32_t n_head = ne02; + const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); @@ -13377,13 +13336,13 @@ static void ggml_compute_forward_soft_max_f32( float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith; - // when max_bias <= 0.0f, src2 is not used and we default it to src0 to avoid branching - ggml_fp16_t * pos_f16 = src2 ? (ggml_fp16_t *) src2->data : src0->data; - float * pos_f32 = src2 ? (float *) src2->data : src0->data; - - const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16); + const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); for (int i1 = ir0; i1 < ir1; i1++) { + // ALiBi + const uint32_t h = (i1/ne01)%ne02; // head + const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f; + float * sp = (float *)((char *) src0->data + i1*src0->nb[1]); float * dp = (float *)((char *) dst->data + i1*dst->nb[1]); @@ -13396,27 +13355,11 @@ static void ggml_compute_forward_soft_max_f32( if (mp_f32) { if (use_f16) { for (int i = 0; i < nc; ++i) { - wp[i] += GGML_FP16_TO_FP32(mp_f16[i]); + wp[i] += slope*GGML_FP16_TO_FP32(mp_f16[i]); } } else { for (int i = 0; i < nc; ++i) { - wp[i] += mp_f32[i]; - } - } - } - - // ALiBi bias - if (max_bias > 0.0f) { - const uint32_t h = (i1/ne01)%ne02; // head - const float slope = h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1); - - if (use_f16) { - for (int i = 0; i < nc; ++i) { - wp[i] += slope*GGML_FP16_TO_FP32(pos_f16[i]); - } - } else { - for (int i = 0; i < nc; ++i) { - wp[i] += slope*pos_f32[i]; + wp[i] += slope*mp_f32[i]; } } } @@ -13578,178 +13521,6 @@ static void ggml_compute_forward_soft_max_back( } } -// ggml_compute_forward_alibi - -static void ggml_compute_forward_alibi_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - assert(params->ith == 0); - - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } - - //const int n_past = ((int32_t *) dst->op_params)[0]; - const int n_head = ((int32_t *) dst->op_params)[1]; - float max_bias; - memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float)); - - const int64_t ne0 = src0->ne[0]; // all_seq_len = n_past + ne1 - const int64_t ne1 = src0->ne[1]; // seq_len_without_past - const int64_t ne2 = src0->ne[2]; // n_head -> this is k - //const int64_t ne3 = src0->ne[3]; // 1 -> bsz - - const int64_t n = ggml_nrows(src0); - const int64_t ne2_ne3 = n/ne1; // ne2*ne3 - - const size_t nb0 = src0->nb[0]; - const size_t nb1 = src0->nb[1]; - const size_t nb2 = src0->nb[2]; - //const int nb3 = src0->nb[3]; - - GGML_ASSERT(nb0 == sizeof(float)); - GGML_ASSERT(n_head == ne2); - - // add alibi to src0 (KQ_scaled) - const int n_heads_log2_floor = 1 << (int) floor(log2(n_head)); - - const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor); - const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor); - - for (int64_t k = 0; k < ne2_ne3; k++) { - // TODO: k*nb2 or k*nb3 - float m_k; - - if (k < n_heads_log2_floor) { - m_k = powf(m0, k + 1); - } else { - m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1); - } - - for (int64_t i = 0; i < ne0; i++) { - for (int64_t j = 0; j < ne1; j++) { - float * const src = (float *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2); - float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2); - pdst[0] = i * m_k + src[0]; - } - } - } -} - -static void ggml_compute_forward_alibi_f16( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - assert(params->ith == 0); - - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } - - //const int n_past = ((int32_t *) dst->op_params)[0]; - const int n_head = ((int32_t *) dst->op_params)[1]; - float max_bias; - memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float)); - - const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1 - const int ne1 = src0->ne[1]; // seq_len_without_past - const int ne2 = src0->ne[2]; // n_head -> this is k - //const int ne3 = src0->ne[3]; // 1 -> bsz - - const int n = ggml_nrows(src0); - const int ne2_ne3 = n/ne1; // ne2*ne3 - - const int nb0 = src0->nb[0]; - const int nb1 = src0->nb[1]; - const int nb2 = src0->nb[2]; - //const int nb3 = src0->nb[3]; - - GGML_ASSERT(nb0 == sizeof(ggml_fp16_t)); - //GGML_ASSERT(ne1 + n_past == ne0); (void) n_past; - GGML_ASSERT(n_head == ne2); - - // add alibi to src0 (KQ_scaled) - const int n_heads_log2_floor = 1 << (int) floor(log2(n_head)); - - const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor); - const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor); - - for (int k = 0; k < ne2_ne3; k++) { - // TODO: k*nb2 or k*nb3 - float m_k; - - if (k < n_heads_log2_floor) { - m_k = powf(m0, k + 1); - } else { - m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1); - } - - for (int i = 0; i < ne0; i++) { - for (int j = 0; j < ne1; j++) { - ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2); - float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2); - - // we return F32 - pdst[0] = i * m_k + GGML_FP16_TO_FP32(src[0]); - } - } - } -} - -static void ggml_compute_forward_alibi( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F16: - { - ggml_compute_forward_alibi_f16(params, dst); - } break; - case GGML_TYPE_F32: - { - ggml_compute_forward_alibi_f32(params, dst); - } break; - case GGML_TYPE_BF16: - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q5_0: - case GGML_TYPE_Q5_1: - case GGML_TYPE_Q8_0: - case GGML_TYPE_Q8_1: - case GGML_TYPE_Q2_K: - case GGML_TYPE_Q3_K: - case GGML_TYPE_Q4_K: - case GGML_TYPE_Q5_K: - case GGML_TYPE_Q6_K: - case GGML_TYPE_IQ2_XXS: - case GGML_TYPE_IQ2_XS: - case GGML_TYPE_IQ3_XXS: - case GGML_TYPE_IQ1_S: - case GGML_TYPE_IQ1_M: - case GGML_TYPE_IQ4_NL: - case GGML_TYPE_IQ4_XS: - case GGML_TYPE_IQ3_S: - case GGML_TYPE_IQ2_S: - case GGML_TYPE_Q8_K: - case GGML_TYPE_I8: - case GGML_TYPE_I16: - case GGML_TYPE_I32: - case GGML_TYPE_I64: - case GGML_TYPE_F64: - case GGML_TYPE_COUNT: - { - GGML_ASSERT(false); - } break; - } -} - // ggml_compute_forward_clamp static void ggml_compute_forward_clamp_f32( @@ -15763,8 +15534,17 @@ static void ggml_compute_forward_flash_attn_ext_f16( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - float scale = 1.0f; - memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); + float scale = 1.0f; + float max_bias = 0.0f; + + memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); + memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); + + const uint32_t n_head = neq2; + const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); // loop over n_batch and n_head for (int ir = ir0; ir < ir1; ++ir) { @@ -15773,6 +15553,9 @@ static void ggml_compute_forward_flash_attn_ext_f16( const int iq2 = (ir - iq3*neq2*neq1)/neq1; const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1); + const uint32_t h = iq2; // head + const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f; + float S = 0.0f; float M = -INFINITY; @@ -15796,7 +15579,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( // loop over n_kv and n_head_kv // ref: https://arxiv.org/pdf/2112.05682.pdf for (int64_t ic = 0; ic < nek1; ++ic) { - const float mv = mp ? GGML_FP16_TO_FP32(mp[ic]) : 0.0f; + const float mv = mp ? slope*GGML_FP16_TO_FP32(mp[ic]) : 0.0f; if (mv == -INFINITY) { continue; } @@ -15867,7 +15650,7 @@ static void ggml_compute_forward_flash_attn_ext( const struct ggml_tensor * v, const struct ggml_tensor * mask, struct ggml_tensor * dst) { - switch (dst->op_params[1]) { + switch (dst->op_params[2]) { case GGML_PREC_DEFAULT: case GGML_PREC_F32: { @@ -17630,10 +17413,6 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_rope_back(params, tensor); } break; - case GGML_OP_ALIBI: - { - ggml_compute_forward_alibi(params, tensor); - } break; case GGML_OP_CLAMP: { ggml_compute_forward_clamp(params, tensor); @@ -18652,10 +18431,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor zero_table); } } break; - case GGML_OP_ALIBI: - { - GGML_ASSERT(false); // TODO: not implemented - } break; case GGML_OP_CLAMP: { GGML_ASSERT(false); // TODO: not implemented @@ -19428,10 +19203,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_ { n_tasks = n_threads; } break; - case GGML_OP_ALIBI: - { - n_tasks = 1; //TODO - } break; case GGML_OP_CLAMP: { n_tasks = 1; //TODO diff --git a/ggml.h b/ggml.h index fe6053822..76c332831 100644 --- a/ggml.h +++ b/ggml.h @@ -468,7 +468,6 @@ extern "C" { GGML_OP_SOFT_MAX_BACK, GGML_OP_ROPE, GGML_OP_ROPE_BACK, - GGML_OP_ALIBI, GGML_OP_CLAMP, GGML_OP_CONV_TRANSPOSE_1D, GGML_OP_IM2COL, @@ -1428,15 +1427,13 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); - // fused soft_max(a*scale + mask + pos[i]*(ALiBi slope)) + // fused soft_max(a*scale + mask*(ALiBi slope)) // mask is optional - // pos is required when max_bias > 0.0f // max_bias = 0.0f for no ALiBi GGML_API struct ggml_tensor * ggml_soft_max_ext( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * mask, - struct ggml_tensor * pos, float scale, float max_bias); @@ -1538,16 +1535,6 @@ extern "C" { float xpos_base, bool xpos_down); - // alibi position embedding - // in-place, returns view(a) - GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_alibi( - struct ggml_context * ctx, - struct ggml_tensor * a, - int n_past, - int n_head, - float bias_max), - "use ggml_soft_max_ext instead (will be removed in Mar 2024)"); - // clamp // in-place, returns view(a) GGML_API struct ggml_tensor * ggml_clamp( @@ -1744,7 +1731,8 @@ extern "C" { struct ggml_tensor * k, struct ggml_tensor * v, struct ggml_tensor * mask, - float scale); + float scale, + float max_bias); GGML_API void ggml_flash_attn_ext_set_prec( struct ggml_tensor * a, diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index e5750d419..990fe63c2 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -137,6 +137,7 @@ class TensorNameMap: "layers.{bid}.attention.wk", # llama-pth "encoder.layer.{bid}.attention.self.key", # bert "transformer.h.{bid}.attn.k_proj", # gpt-j + "transformer.h.{bid}.attn.k", # refact "model.layers.layers.{bid}.self_attn.k_proj", # plamo "model.layers.{bid}.attention.wk", # internlm2 "transformer.decoder_layer.{bid}.multi_head_attention.key" # Grok @@ -148,6 +149,7 @@ class TensorNameMap: "layers.{bid}.attention.wv", # llama-pth "encoder.layer.{bid}.attention.self.value", # bert "transformer.h.{bid}.attn.v_proj", # gpt-j + "transformer.h.{bid}.attn.v", # refact "model.layers.layers.{bid}.self_attn.v_proj", # plamo "model.layers.{bid}.attention.wv", # internlm2 "transformer.decoder_layer.{bid}.multi_head_attention.value" # Grok @@ -229,6 +231,7 @@ class TensorNameMap: "layers.{bid}.feed_forward.w3", # llama-pth "encoder.layer.{bid}.intermediate.dense", # bert "transformer.h.{bid}.mlp.fc_in", # gpt-j + "transformer.h.{bid}.mlp.linear_3", # refact "language_model.encoder.layers.{bid}.mlp.dense_h_to_4h", # persimmon "model.layers.{bid}.mlp.dense_h_to_4h", # persimmon "transformer.h.{bid}.mlp.w1", # qwen @@ -266,6 +269,7 @@ class TensorNameMap: "model.layers.layers.{bid}.mlp.gate_proj", # plamo "model.layers.{bid}.feed_forward.w1", # internlm2 "encoder.layers.{bid}.mlp.fc12", # nomic-bert + "transformer.h.{bid}.mlp.linear_1", # refact ), MODEL_TENSOR.FFN_GATE_EXP: ( diff --git a/llama.cpp b/llama.cpp index 2f1123d4e..dede68cb5 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1845,7 +1845,7 @@ struct llama_hparams { float f_logit_scale = 0.0f; bool causal_attn = true; - bool use_alibi = false; // currently, we need KQ_pos data for ALiBi-based models + bool use_alibi = false; enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE; enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE; @@ -2317,7 +2317,6 @@ struct llama_context { struct ggml_tensor * inp_pos; // I32 [n_batch] struct ggml_tensor * inp_out_ids; // I32 [n_outputs] struct ggml_tensor * inp_KQ_mask; // F32 [kv_size, n_batch] - struct ggml_tensor * inp_KQ_pos; // F32 [n_kv] struct ggml_tensor * inp_K_shift; // I32 [kv_size] struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch] struct ggml_tensor * inp_cls; // I32 [n_batch] @@ -6500,7 +6499,6 @@ static struct ggml_tensor * llm_build_kqv( struct ggml_tensor * wo_b, struct ggml_tensor * q_cur, struct ggml_tensor * kq_mask, - struct ggml_tensor * kq_pos, int32_t n_tokens, int32_t n_kv, float kq_scale, @@ -6530,10 +6528,6 @@ static struct ggml_tensor * llm_build_kqv( GGML_UNUSED(model); GGML_UNUSED(n_ctx); - // note: if this assert triggers, then some check has failed earlier - // the idea is to detect during context creation that ALiBi would be used and disable Flash Attention - GGML_ASSERT(kq_pos == nullptr && "ALiBi is not yet supported with Flash Attention"); - // split cached v into n_head heads (not transposed) struct ggml_tensor * v = ggml_view_3d(ctx, kv.v_l[il], @@ -6543,7 +6537,7 @@ static struct ggml_tensor * llm_build_kqv( 0); cb(v, "v", il); - cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale); + cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias); if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3) { ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); @@ -6574,28 +6568,8 @@ static struct ggml_tensor * llm_build_kqv( kq = ggml_scale(ctx, kq, 30); } -#if defined(GGML_USE_KOMPUTE) -#pragma message("TODO: ALiBi support in ggml_soft_max_ext is not implemented for Kompute") -#pragma message(" Falling back to ggml_alibi(). Will become an error in Mar 2024") -#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5488") - if (hparams.use_alibi) { - kq = ggml_scale(ctx, kq, kq_scale); - cb(kq, "kq_scaled", il); - - kq = ggml_alibi(ctx, kq, /*n_past*/ 0, n_head, hparams.f_max_alibi_bias); - cb(kq, "kq_scaled_alibi", il); - - kq = ggml_add(ctx, kq, kq_mask); - cb(kq, "kq_masked", il); - - kq = ggml_soft_max(ctx, kq); - cb(kq, "kq_soft_max", il); - } else -#endif - { - kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_pos, kq_scale, hparams.f_max_alibi_bias); - cb(kq, "kq_soft_max_ext", il); - } + kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias); + cb(kq, "kq_soft_max_ext", il); GGML_ASSERT(kv.size == n_ctx); @@ -6645,7 +6619,6 @@ static struct ggml_tensor * llm_build_kv( struct ggml_tensor * v_cur, struct ggml_tensor * q_cur, struct ggml_tensor * kq_mask, - struct ggml_tensor * kq_pos, int32_t n_tokens, int32_t kv_head, int32_t n_kv, @@ -6664,7 +6637,7 @@ static struct ggml_tensor * llm_build_kv( struct ggml_tensor * cur; cur = llm_build_kqv(ctx, model, hparams, cparams, kv, graph, wo, wo_b, - q_cur, kq_mask, kq_pos, n_tokens, n_kv, kq_scale, cb, il); + q_cur, kq_mask, n_tokens, n_kv, kq_scale, cb, il); cb(cur, "kqv_out", il); return cur; @@ -6771,18 +6744,17 @@ struct llm_build_context { ctx0 = ggml_init(params); - lctx.inp_tokens = nullptr; - lctx.inp_embd = nullptr; - lctx.inp_pos = nullptr; + lctx.inp_tokens = nullptr; + lctx.inp_embd = nullptr; + lctx.inp_pos = nullptr; lctx.inp_out_ids = nullptr; lctx.inp_KQ_mask = nullptr; - lctx.inp_KQ_pos = nullptr; lctx.inp_K_shift = nullptr; - lctx.inp_mean = nullptr; - lctx.inp_cls = nullptr; - lctx.inp_s_copy = nullptr; - lctx.inp_s_mask = nullptr; - lctx.inp_s_seq = nullptr; + lctx.inp_mean = nullptr; + lctx.inp_cls = nullptr; + lctx.inp_s_copy = nullptr; + lctx.inp_s_mask = nullptr; + lctx.inp_s_seq = nullptr; } void free() { @@ -6932,19 +6904,6 @@ struct llm_build_context { return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_mask, GGML_TYPE_F16) : lctx.inp_KQ_mask; } - struct ggml_tensor * build_inp_KQ_pos(bool causal = true) { - if (causal) { - lctx.inp_KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n_kv); - } else { - // TODO: this will be needed for ALiBi-based BERT models - // https://github.com/ggerganov/llama.cpp/pull/6826 - lctx.inp_KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n_tokens); - } - cb(lctx.inp_KQ_pos, "KQ_pos", -1); - ggml_set_input(lctx.inp_KQ_pos); - return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_pos, GGML_TYPE_F16) : lctx.inp_KQ_pos; - } - struct ggml_tensor * build_inp_mean() { lctx.inp_mean = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens); cb(lctx.inp_mean, "inp_mean", -1); @@ -7050,7 +7009,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -7143,9 +7102,6 @@ struct llm_build_context { // KQ_mask (mask for 1 head, it will be broadcasted to all heads) struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); - // positions of the tokens in the KV cache - struct ggml_tensor * KQ_pos = build_inp_KQ_pos(); - for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; @@ -7190,7 +7146,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -7260,9 +7216,6 @@ struct llm_build_context { // KQ_mask (mask for 1 head, it will be broadcasted to all heads) struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); - // positions of the tokens in the KV cache - struct ggml_tensor * KQ_pos = build_inp_KQ_pos(); - for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; @@ -7297,7 +7250,7 @@ struct llm_build_context { cb(Kcur, "Kcur", il); cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -7417,7 +7370,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -7542,7 +7495,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f, cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il); } if (il == n_layer - 1) { @@ -7694,7 +7647,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -7806,7 +7759,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -8010,7 +7963,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Q, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Q, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -8076,9 +8029,6 @@ struct llm_build_context { // KQ_mask (mask for 1 head, it will be broadcasted to all heads) struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); - // positions of the tokens in the KV cache - struct ggml_tensor * KQ_pos = build_inp_KQ_pos(); - for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; @@ -8106,7 +8056,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -8246,7 +8196,7 @@ struct llm_build_context { struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); cb(kq, "kq", il); - kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, nullptr, 1.0f/sqrtf(float(n_embd_head)), hparams.f_max_alibi_bias); + kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, 1.0f/sqrtf(float(n_embd_head)), hparams.f_max_alibi_bias); cb(kq, "kq_soft_max_ext", il); struct ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_tokens))); @@ -8363,9 +8313,6 @@ struct llm_build_context { // KQ_mask (mask for 1 head, it will be broadcasted to all heads) struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); - // positions of the tokens in the KV cache - struct ggml_tensor * KQ_pos = build_inp_KQ_pos(); - inpL = llm_build_norm(ctx0, inpL, hparams, model.tok_norm, model.tok_norm_b, @@ -8399,7 +8346,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -8464,9 +8411,6 @@ struct llm_build_context { // KQ_mask (mask for 1 head, it will be broadcasted to all heads) struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); - // positions of the tokens in the KV cache - struct ggml_tensor * KQ_pos = build_inp_KQ_pos(); - if (model.pos_embd) { // inp_pos - contains the positions struct ggml_tensor * inp_pos = build_inp_pos(); @@ -8530,13 +8474,13 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } else { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } } @@ -8680,7 +8624,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -8798,7 +8742,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -8911,7 +8855,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -9025,7 +8969,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -9180,7 +9124,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f, cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il); } if (il == n_layer - 1) { @@ -9297,7 +9241,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f, cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il); } if (il == n_layer - 1) { @@ -9410,7 +9354,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } struct ggml_tensor * sa_out = cur; @@ -9513,7 +9457,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -9620,7 +9564,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -9736,7 +9680,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -9853,7 +9797,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -9983,7 +9927,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -10104,7 +10048,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, NULL, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f, cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il); } if (il == n_layer - 1) { @@ -10223,7 +10167,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -10513,7 +10457,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -10644,7 +10588,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, nullptr, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { @@ -11032,11 +10976,21 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) { f = -INFINITY; } else { - f = 0.0f; + if (hparams.use_alibi) { + f = -fabs(lctx.kv_self.cells[i].pos - pos); + } else { + f = 0.0f; + } } data[h*(n_kv*n_tokens) + j*n_kv + i] = f; } } + + for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { + for (int j = 0; j < n_kv; ++j) { + data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; + } + } } } else { // when using kv cache, the mask needs to match the kv cache size @@ -11055,7 +11009,11 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { float f = -INFINITY; for (int s = 0; s < batch.n_seq_id[i]; ++s) { if (batch.seq_id[i][s] == seq_id) { - f = 0.0f; + if (hparams.use_alibi) { + f = -fabs(batch.pos[i] - batch.pos[j]); + } else { + f = 0.0f; + } break; } } @@ -11071,21 +11029,6 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } } - // ALiBi requires the KQ_pos tensor to provide the sequence position of each token in the batch - // this allows to process multiple sequences in parallel with ALiBi-based models - if (hparams.use_alibi) { - const int64_t n_kv = kv_self.n; - - GGML_ASSERT(lctx.inp_KQ_pos); - GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_pos->buffer)); - - float * data = (float *) lctx.inp_KQ_pos->data; - - for (int i = 0; i < n_kv; ++i) { - data[i] = float(lctx.kv_self.cells[i].pos); - } - } - if (cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) { const int64_t n_tokens = batch.n_tokens; @@ -15509,11 +15452,6 @@ struct llama_context * llama_new_context_with_model( } } - if (cparams.flash_attn && hparams.use_alibi) { - LLAMA_LOG_WARN("%s: flash_attn is not yet compatible with ALiBi - forcing off\n", __func__); - cparams.flash_attn = false; - } - if (cparams.flash_attn && model->arch == LLM_ARCH_GROK) { LLAMA_LOG_WARN("%s: flash_attn is not compatible with Grok - forcing off\n", __func__); cparams.flash_attn = false; diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 0d66de5d9..731788b95 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1111,11 +1111,7 @@ struct test_soft_max : public test_case { if (this->mask) { mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, ne[0], ne[1]); } - ggml_tensor * pos = nullptr; - if (max_bias > 0.0f) { - pos = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ne[0]); - } - ggml_tensor * out = ggml_soft_max_ext(ctx, a, mask, pos, scale, max_bias); + ggml_tensor * out = ggml_soft_max_ext(ctx, a, mask, scale, max_bias); return out; } }; @@ -1490,23 +1486,25 @@ struct test_flash_attn_ext : public test_case { const int64_t kv; // kv size const int64_t nb; // batch size + const float max_bias; // ALiBi + std::string vars() override { - return VARS_TO_STR4(hs, nh, kv, nb); + return VARS_TO_STR5(hs, nh, kv, nb, max_bias); } double max_nmse_err() override { return 5e-4; } - test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8) - : hs(hs), nh(nh), kv(kv), nb(nb) {} + test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8, float max_bias = 0.0f) + : hs(hs), nh(nh), kv(kv), nb(nb), max_bias(max_bias) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, hs, nb, nh, 1); ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), 1, 1); - ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, mask, 1.0f/sqrtf(hs)); + ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, mask, 1.0f/sqrtf(hs), max_bias); return out; } }; @@ -1611,7 +1609,7 @@ public: struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); - kq = ggml_soft_max_ext(ctx, kq, kq_mask, nullptr, kq_scale, 0.0f); + kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, 0.0f); // split cached v into n_head heads struct ggml_tensor * v = @@ -2128,6 +2126,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op #endif for (bool mask : {false, true}) { for (float max_bias : {0.0f, 8.0f}) { + if (!mask && max_bias > 0.0f) continue; for (float scale : {1.0f, 0.1f}) { for (int64_t ne0 : {16, 1024}) { for (int64_t ne1 : {16, 1024}) { @@ -2141,7 +2140,6 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, 0.1f, 0.0f)); test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 0.0f)); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, 0.1f, 8.0f)); test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 8.0f)); for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) { @@ -2180,10 +2178,12 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op #else for (int hs : { 64, 80, 128, 256, }) { #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) - for (int nh : { 32, }) { - for (int kv : { 512, 1024, }) { - for (int nb : { 1, 2, 4, 8, }) { - test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb)); + for (float max_bias : {0.0f, 8.0f}) { + for (int nh : { 32, }) { + for (int kv : { 512, 1024, }) { + for (int nb : { 1, 2, 4, 8, }) { + test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, max_bias)); + } } } }