mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 10:54:36 +00:00
Fix FlashAttention debug test, FP32 assert (#7684)
This commit is contained in:
parent
2e666832e6
commit
e141ce624a
@ -278,14 +278,10 @@ void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx,
|
|||||||
|
|
||||||
template <int D, ggml_type type_K, ggml_type type_V>
|
template <int D, ggml_type type_K, ggml_type type_V>
|
||||||
void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
ggml_tensor * KQV = dst;
|
|
||||||
ggml_tensor * Q = dst->src[0];
|
ggml_tensor * Q = dst->src[0];
|
||||||
ggml_tensor * K = dst->src[1];
|
ggml_tensor * K = dst->src[1];
|
||||||
ggml_tensor * V = dst->src[2];
|
ggml_tensor * V = dst->src[2];
|
||||||
|
|
||||||
const int32_t precision = KQV->op_params[2];
|
|
||||||
GGML_ASSERT(precision == GGML_PREC_DEFAULT);
|
|
||||||
|
|
||||||
GGML_ASSERT(K->type == type_K);
|
GGML_ASSERT(K->type == type_K);
|
||||||
GGML_ASSERT(V->type == type_V);
|
GGML_ASSERT(V->type == type_V);
|
||||||
|
|
||||||
|
@ -1584,9 +1584,11 @@ struct test_flash_attn_ext : public test_case {
|
|||||||
: hs(hs), nh(nh), kv(kv), nb(nb), mask(mask), max_bias(max_bias), type_KV(type_KV) {}
|
: hs(hs), nh(nh), kv(kv), nb(nb), mask(mask), max_bias(max_bias), type_KV(type_KV) {}
|
||||||
|
|
||||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||||
ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, hs, nb, nh, 1);
|
const int64_t hs_padded = GGML_PAD(hs, ggml_blck_size(type_KV));
|
||||||
ggml_tensor * k = ggml_new_tensor_4d(ctx, type_KV, hs, kv, nh, 1);
|
|
||||||
ggml_tensor * v = ggml_new_tensor_4d(ctx, type_KV, hs, kv, nh, 1);
|
ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, hs_padded, nb, nh, 1);
|
||||||
|
ggml_tensor * k = ggml_new_tensor_4d(ctx, type_KV, hs_padded, kv, nh, 1);
|
||||||
|
ggml_tensor * v = ggml_new_tensor_4d(ctx, type_KV, hs_padded, kv, nh, 1);
|
||||||
ggml_tensor * m = mask ? ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), 1, 1) : nullptr;
|
ggml_tensor * m = mask ? ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), 1, 1) : nullptr;
|
||||||
ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, m, 1.0f/sqrtf(hs), max_bias);
|
ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, m, 1.0f/sqrtf(hs), max_bias);
|
||||||
return out;
|
return out;
|
||||||
|
Loading…
Reference in New Issue
Block a user