further addressed comments

This commit is contained in:
Alan Gray 2024-04-24 06:31:08 -07:00
parent d403b180a6
commit 408759687f

View File

@ -2460,7 +2460,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
int k=0;
for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_tensor * node = cgraph->nodes[i];
// Identify if the graph needs updated for this token due to the number of elements changing
// Identify if the graph needs to be updated for this token due to the number of elements changing
// (identified by inspecting soft max op parameters)
if(node->op == GGML_OP_SOFT_MAX) {
if(node->src[1]->ne[1] > 1){
@ -2489,10 +2489,10 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
#else
bool use_cuda_graph = false;
bool cuda_graph_update_required = false;
#endif
#endif // USE_CUDA_GRAPH
// Only perfom the graph exection if CUDA graphs are not enebled, or we are capturing the graph.
// With use of CUDA graphs, the execution will be performed by the graph launch.
// Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
// With the use of CUDA graphs, the execution will be performed by the graph launch.
if(!use_cuda_graph || cuda_graph_update_required) {
//temporarily avoid indenting here to make code review easier
for (int i = 0; i < cgraph->n_nodes; i++) {
@ -2519,7 +2519,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
}
}
#ifdef USE_CUDA_GRAPH
#ifdef USE_CUDA_GRAPH
if(use_cuda_graph && (cuda_graph_update_required)) { // End CUDA graph capture
CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_graph.graph));
}
@ -2541,7 +2541,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
// Subsequent call with non-null argument gets nodes
CUDA_CHECK(cudaGraphGetNodes(cuda_graph.graph, cuda_graph.nodes, &cuda_graph.num_nodes));
// Loop over nodes, and extract kernel parameters fro each node
// Loop over nodes, and extract kernel parameters from each node
for(size_t i=0; i<cuda_graph.num_nodes; i++) {
cudaGraphNodeType node_type;
CUDA_CHECK(cudaGraphNodeGetType(cuda_graph.nodes[i], &node_type));
@ -2588,7 +2588,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
CUDA_CHECK(cudaGraphLaunch(cuda_graph.instance, cuda_ctx->stream()));
}
cuda_graph.count++;
#endif
#endif // USE_CUDA_GRAPH
return GGML_STATUS_SUCCESS;
}