diff --git a/llama.cpp b/llama.cpp index cb0a1227a..a196b428f 100644 --- a/llama.cpp +++ b/llama.cpp @@ -3486,7 +3486,7 @@ static struct ggml_cgraph * llm_build_falcon( ggml_build_forward_expand(gf, cur); ggml_free(ctx0); - + #if defined(GGML_USE_KOMPUTE) if (lctx.ctx_kompute) { if (!ggml_vk_has_h2d_all(lctx.ctx_kompute)) { @@ -3870,11 +3870,19 @@ static bool llama_eval_internal( ggml_graph_compute_helper(lctx.work_buffer, gf, n_threads); } #elif defined(GGML_USE_KOMPUTE) - if (lctx.ctx_kompute) { + if (lctx.ctx_kompute && N == 1) { ggml_vk_graph_compute(lctx.ctx_kompute, gf); ggml_vk_d2h_tensor(lctx.ctx_kompute, res); } else { + if (lctx.ctx_kompute) { + ggml_vk_d2h_tensor(lctx.ctx_kompute, kv_self.k); + ggml_vk_d2h_tensor(lctx.ctx_kompute, kv_self.v); + } ggml_graph_compute_helper(lctx.work_buffer, gf, n_threads); + if (lctx.ctx_kompute) { + ggml_vk_h2d_tensor(lctx.ctx_kompute, kv_self.k); + ggml_vk_h2d_tensor(lctx.ctx_kompute, kv_self.v); + } } #else ggml_graph_compute_helper(lctx.work_buffer, gf, n_threads);