From 8564f79036c724615f1677138d5e6ed5f61075ae Mon Sep 17 00:00:00 2001 From: Aaron Miller Date: Wed, 4 Oct 2023 21:03:27 -0700 Subject: [PATCH] falcon h2d + reenable vulkan --- llama.cpp | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/llama.cpp b/llama.cpp index e79251194..858494244 100644 --- a/llama.cpp +++ b/llama.cpp @@ -3210,6 +3210,9 @@ static struct ggml_cgraph * llm_build_falcon( struct ggml_tensor * cur; struct ggml_tensor * inpL; +#if defined(GGML_USE_KOMPUTE) + struct ggml_tensor * toDeviceTensor = nullptr; +#endif if (tokens) { struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); @@ -3219,7 +3222,9 @@ static struct ggml_cgraph * llm_build_falcon( memcpy(inp_tokens->data, tokens, N*ggml_element_size(inp_tokens)); } ggml_set_name(inp_tokens, "inp_tokens"); - +#if defined(GGML_USE_KOMPUTE) + toDeviceTensor = inp_tokens; +#endif inpL = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens); } else { #ifdef GGML_USE_MPI @@ -3232,6 +3237,9 @@ static struct ggml_cgraph * llm_build_falcon( if (!ggml_allocr_is_measure(lctx.alloc)) { memcpy(inpL->data, embd, N * n_embd * ggml_element_size(inpL)); } +#if defined(GGML_USE_KOMPUTE) + toDeviceTensor = inpL; +#endif } const int i_gpu_start = n_layer - n_gpu_layers; @@ -3463,6 +3471,16 @@ static struct ggml_cgraph * llm_build_falcon( ggml_build_forward_expand(gf, cur); ggml_free(ctx0); + +#if defined(GGML_USE_KOMPUTE) + if (lctx.ctx_kompute) { + if (!ggml_vk_has_h2d_all(lctx.ctx_kompute)) { + ggml_vk_h2d_all(lctx.ctx_kompute); + } else { + ggml_vk_h2d_tensor(lctx.ctx_kompute, toDeviceTensor); + } + } +#endif return gf; } @@ -6494,7 +6512,7 @@ struct llama_context * llama_new_context_with_model( #elif defined(GGML_USE_KOMPUTE) // TODO(cebtenzzre): we need to check the type of each tensor because Q8_0 is not currently supported if (ggml_vk_has_device() && params.n_gpu_layers > 0 - && model->arch == LLM_ARCH_LLAMA + && (model->arch == LLM_ARCH_LLAMA || model->arch == LLM_ARCH_FALCON) && (model->ftype == LLAMA_FTYPE_ALL_F32 || model->ftype == LLAMA_FTYPE_MOSTLY_F16 || model->ftype == LLAMA_FTYPE_MOSTLY_Q4_0