This commit is contained in:
Jared Van Bortel 2023-11-14 15:54:26 -05:00
parent 9c4dfd06e8
commit 02c3309f6d

View File

@ -3506,6 +3506,10 @@ struct llm_build_context {
llama_buffer & buf_compute;
#if defined(GGML_USE_KOMPUTE)
ggml_kompute_context * ctx_kompute;
#endif
struct ggml_context * ctx0 = nullptr;
// TODO: consider making the entire interface noexcept
@ -3535,7 +3539,11 @@ struct llm_build_context {
kv_head (worst_case ? n_ctx - n_tokens : kv_self.head),
do_rope_shift (worst_case || kv_self.has_shift),
cb (cb),
buf_compute (lctx.buf_compute) {
buf_compute (lctx.buf_compute)
#if defined(GGML_USE_KOMPUTE)
, ctx_kompute (lctx.ctx_kompute)
#endif
{
GGML_ASSERT(!!kv_self.ctx);
// all initializations should be done in init()
@ -3662,15 +3670,15 @@ struct llm_build_context {
ggml_build_forward_expand(gf, cur);
#if defined(GGML_USE_KOMPUTE)
if (lctx.ctx_kompute) {
if (!ggml_vk_has_h2d_all(lctx.ctx_kompute)) {
ggml_vk_h2d_all(lctx.ctx_kompute);
if (ctx_kompute) {
if (!ggml_vk_has_h2d_all(ctx_kompute)) {
ggml_vk_h2d_all(ctx_kompute);
} else {
ggml_vk_h2d_tensor(lctx.ctx_kompute, to_device_tensor);
ggml_vk_h2d_tensor(lctx.ctx_kompute, inp_pos);
ggml_vk_h2d_tensor(lctx.ctx_kompute, KQ_mask);
ggml_vk_h2d_tensor(ctx_kompute, to_device_tensor);
ggml_vk_h2d_tensor(ctx_kompute, inp_pos);
ggml_vk_h2d_tensor(ctx_kompute, KQ_mask);
if (K_shift) {
ggml_vk_h2d_tensor(lctx.ctx_kompute, K_shift);
ggml_vk_h2d_tensor(ctx_kompute, K_shift);
}
}
}
@ -3907,15 +3915,15 @@ struct llm_build_context {
ggml_build_forward_expand(gf, cur);
#if defined(GGML_USE_KOMPUTE)
if (lctx.ctx_kompute) {
if (!ggml_vk_has_h2d_all(lctx.ctx_kompute)) {
ggml_vk_h2d_all(lctx.ctx_kompute);
if (ctx_kompute) {
if (!ggml_vk_has_h2d_all(ctx_kompute)) {
ggml_vk_h2d_all(ctx_kompute);
} else {
ggml_vk_h2d_tensor(lctx.ctx_kompute, to_device_tensor);
ggml_vk_h2d_tensor(lctx.ctx_kompute, inp_pos);
ggml_vk_h2d_tensor(lctx.ctx_kompute, KQ_mask);
ggml_vk_h2d_tensor(ctx_kompute, to_device_tensor);
ggml_vk_h2d_tensor(ctx_kompute, inp_pos);
ggml_vk_h2d_tensor(ctx_kompute, KQ_mask);
if (K_shift) {
ggml_vk_h2d_tensor(lctx.ctx_kompute, K_shift);
ggml_vk_h2d_tensor(ctx_kompute, K_shift);
}
}
}