mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-24 10:24:35 +00:00
falcon : fix CUDA inference by making K and Q contiguous (#2830)
* falcon : fix CUDA inference by making K and Q contiguous ggml-ci * cuda : add assert to guard from non-cont ropes
This commit is contained in:
parent
da7455d046
commit
eaa13a48ff
@ -6337,9 +6337,11 @@ void ggml_cuda_soft_max(const ggml_tensor * src0, const ggml_tensor * src1, ggml
|
||||
|
||||
void ggml_cuda_rope(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(ggml_is_contiguous(src0)); // TODO: this restriction is temporary until non-cont support is implemented
|
||||
|
||||
const int mode = ((int32_t *) dst->op_params)[2];
|
||||
const bool is_glm = mode & 4;
|
||||
|
||||
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rope, true, !is_glm); // flatten support not implemented for glm
|
||||
}
|
||||
|
||||
|
10
llama.cpp
10
llama.cpp
@ -2642,18 +2642,20 @@ static struct ggml_cgraph * llm_build_falcon(
|
||||
|
||||
const size_t wsize = ggml_type_size(cur->type);
|
||||
|
||||
struct ggml_tensor * tmpq = ggml_view_3d(
|
||||
// TODO: these 2 ggml_conts are technically not needed, but we add them until CUDA support for
|
||||
// non-contiguous views is added for the rope operator
|
||||
struct ggml_tensor * tmpq = ggml_cont(ctx0, ggml_view_3d(
|
||||
ctx0, cur, n_embd_head, n_head, N,
|
||||
wsize * n_embd_head,
|
||||
wsize * n_embd_head * (n_head + 2 * n_head_kv),
|
||||
0);
|
||||
0));
|
||||
offload_func_kq(tmpq);
|
||||
|
||||
struct ggml_tensor * tmpk = ggml_view_3d(
|
||||
struct ggml_tensor * tmpk = ggml_cont(ctx0, ggml_view_3d(
|
||||
ctx0, cur, n_embd_head, n_head_kv, N,
|
||||
wsize * n_embd_head,
|
||||
wsize * n_embd_head * (n_head + 2 * n_head_kv),
|
||||
wsize * n_embd_head * n_head);
|
||||
wsize * n_embd_head * n_head));
|
||||
offload_func_kq(tmpk);
|
||||
|
||||
struct ggml_tensor * tmpv = ggml_view_3d(
|
||||
|
Loading…
Reference in New Issue
Block a user