From 41f477879fd5ccc31211634292e8e293ec700e85 Mon Sep 17 00:00:00 2001 From: agray3 Date: Sat, 21 Sep 2024 01:41:07 +0100 Subject: [PATCH] Update CUDA graph on scale change plus clear nodes/params (#9550) * Avoid using saved CUDA graph if scale changes and reset nodes/params on update Fixes https://github.com/ggerganov/llama.cpp/issues/9451 * clear before resize --- ggml/src/ggml-cuda.cu | 9 +++++++++ ggml/src/ggml-cuda/common.cuh | 1 + 2 files changed, 10 insertions(+) diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index b0843dc62..895ba4794 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -2478,6 +2478,7 @@ static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_p for (int i = 0; i < GGML_MAX_SRC; i++) { graph_node_properties->src_address[i] = node->src[i] ? node->src[i]->data : nullptr; } + memcpy(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS); } static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) { @@ -2509,6 +2510,12 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra return false; } } + + if (node->op == GGML_OP_SCALE && + memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) { + return false; + } + return true; } @@ -2720,7 +2727,9 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t // First call with null argument gets number of nodes in graph CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes)); // Subsequent call with non-null argument gets nodes + cuda_ctx->cuda_graph->nodes.clear(); cuda_ctx->cuda_graph->nodes.resize(cuda_ctx->cuda_graph->num_nodes); + cuda_ctx->cuda_graph->params.clear(); cuda_ctx->cuda_graph->params.resize(cuda_ctx->cuda_graph->num_nodes); if (cuda_ctx->cuda_graph->num_nodes > 0) { CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, cuda_ctx->cuda_graph->nodes.data(), &cuda_ctx->cuda_graph->num_nodes)); diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index eb39b6d23..85eb200f0 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -569,6 +569,7 @@ struct ggml_graph_node_properties { int64_t ne[GGML_MAX_DIMS]; size_t nb[GGML_MAX_DIMS]; void * src_address[GGML_MAX_SRC]; + int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)]; }; struct ggml_cuda_graph {