Tidied to now only use CUDA runtime (not mixed with driver calls)

This commit is contained in:
Alan Gray 2024-04-22 04:50:39 -07:00
parent c8dd0e7c1c
commit 800f4fe48e

View File

@ -2418,8 +2418,7 @@ struct ggml_cudaGraph {
size_t numNodes = 0; size_t numNodes = 0;
int softmax_ne0 = 0; int softmax_ne0 = 0;
cudaGraphNode_t nodes[MAX_NODES_IN_CUDA_GRAPH]; cudaGraphNode_t nodes[MAX_NODES_IN_CUDA_GRAPH];
CUDA_KERNEL_NODE_PARAMS_v2 paramsDriver[MAX_NODES_IN_CUDA_GRAPH]; cudaKernelNodeParams params[MAX_NODES_IN_CUDA_GRAPH];
cudaKernelNodeParams paramsRuntime[MAX_NODES_IN_CUDA_GRAPH];
}; };
#endif #endif
@ -2523,12 +2522,10 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
// Loop over nodes, and extract kernel parameters fro each node // Loop over nodes, and extract kernel parameters fro each node
for(size_t i=0; i<cudaGraph.numNodes; i++) { for(size_t i=0; i<cudaGraph.numNodes; i++) {
CUgraphNodeType nodeType; cudaGraphNodeType nodeType;
CU_CHECK(cuGraphNodeGetType(cudaGraph.nodes[i], &nodeType)); CUDA_CHECK(cudaGraphNodeGetType(cudaGraph.nodes[i], &nodeType));
if (nodeType == CU_GRAPH_NODE_TYPE_KERNEL) { if (nodeType == cudaGraphNodeTypeKernel) {
// We currently get a set of params using both driver and runtime, to work around an issue (see below) auto statRT = cudaGraphKernelNodeGetParams(cudaGraph.nodes[i], &cudaGraph.params[i]); // Get params using runtime
CU_CHECK(cuGraphKernelNodeGetParams(cudaGraph.nodes[i], &cudaGraph.paramsDriver[i])); // Get params using driver
auto statRT = cudaGraphKernelNodeGetParams(cudaGraph.nodes[i], &cudaGraph.paramsRuntime[i]); // Get params using runtime
if(statRT == cudaErrorInvalidDeviceFunction) { if(statRT == cudaErrorInvalidDeviceFunction) {
// Fails due to incorrect handling by CUDA runtime of CUDA BLAS node. // Fails due to incorrect handling by CUDA runtime of CUDA BLAS node.
// We don't need to update blas nodes, so clear error and move on. // We don't need to update blas nodes, so clear error and move on.
@ -2539,16 +2536,13 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
} }
// Update copy kernel param (required every token) // Update copy kernel param (required every token)
// Currently uses runtime copy of params to identify copy function node,
// and driver copy of params to perform the update
// TO DO work out how to do it only using runtime copy.
if(!cudaGraphUpdateRequired) { // on update steps, the live parameters will already be captured if(!cudaGraphUpdateRequired) { // on update steps, the live parameters will already be captured
int k=0; int k=0;
for(size_t i=0; i<cudaGraph.numNodes; i++) { for(size_t i=0; i<cudaGraph.numNodes; i++) {
if(cudaGraph.paramsRuntime[i].func == ggmlCudaCpyFn) { if(cudaGraph.params[i].func == ggmlCudaCpyFn) {
char** updatedKernelArgPointer = updatedKernelArg[k++]; char** updatedKernelArgPointer = updatedKernelArg[k++];
cudaGraph.paramsDriver[i].kernelParams[1] = updatedKernelArgPointer; cudaGraph.params[i].kernelParams[1] = updatedKernelArgPointer;
CU_CHECK(cuGraphKernelNodeSetParams(cudaGraph.nodes[i], &cudaGraph.paramsDriver[i])); CUDA_CHECK(cudaGraphKernelNodeSetParams(cudaGraph.nodes[i], &cudaGraph.params[i]));
} }
} }
} }