mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-26 11:24:35 +00:00
Tidied to now only use CUDA runtime (not mixed with driver calls)
This commit is contained in:
parent
c8dd0e7c1c
commit
800f4fe48e
22
ggml-cuda.cu
22
ggml-cuda.cu
@ -2418,8 +2418,7 @@ struct ggml_cudaGraph {
|
||||
size_t numNodes = 0;
|
||||
int softmax_ne0 = 0;
|
||||
cudaGraphNode_t nodes[MAX_NODES_IN_CUDA_GRAPH];
|
||||
CUDA_KERNEL_NODE_PARAMS_v2 paramsDriver[MAX_NODES_IN_CUDA_GRAPH];
|
||||
cudaKernelNodeParams paramsRuntime[MAX_NODES_IN_CUDA_GRAPH];
|
||||
cudaKernelNodeParams params[MAX_NODES_IN_CUDA_GRAPH];
|
||||
};
|
||||
#endif
|
||||
|
||||
@ -2523,12 +2522,10 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
|
||||
|
||||
// Loop over nodes, and extract kernel parameters fro each node
|
||||
for(size_t i=0; i<cudaGraph.numNodes; i++) {
|
||||
CUgraphNodeType nodeType;
|
||||
CU_CHECK(cuGraphNodeGetType(cudaGraph.nodes[i], &nodeType));
|
||||
if (nodeType == CU_GRAPH_NODE_TYPE_KERNEL) {
|
||||
// We currently get a set of params using both driver and runtime, to work around an issue (see below)
|
||||
CU_CHECK(cuGraphKernelNodeGetParams(cudaGraph.nodes[i], &cudaGraph.paramsDriver[i])); // Get params using driver
|
||||
auto statRT = cudaGraphKernelNodeGetParams(cudaGraph.nodes[i], &cudaGraph.paramsRuntime[i]); // Get params using runtime
|
||||
cudaGraphNodeType nodeType;
|
||||
CUDA_CHECK(cudaGraphNodeGetType(cudaGraph.nodes[i], &nodeType));
|
||||
if (nodeType == cudaGraphNodeTypeKernel) {
|
||||
auto statRT = cudaGraphKernelNodeGetParams(cudaGraph.nodes[i], &cudaGraph.params[i]); // Get params using runtime
|
||||
if(statRT == cudaErrorInvalidDeviceFunction) {
|
||||
// Fails due to incorrect handling by CUDA runtime of CUDA BLAS node.
|
||||
// We don't need to update blas nodes, so clear error and move on.
|
||||
@ -2539,16 +2536,13 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
|
||||
}
|
||||
|
||||
// Update copy kernel param (required every token)
|
||||
// Currently uses runtime copy of params to identify copy function node,
|
||||
// and driver copy of params to perform the update
|
||||
// TO DO work out how to do it only using runtime copy.
|
||||
if(!cudaGraphUpdateRequired) { // on update steps, the live parameters will already be captured
|
||||
int k=0;
|
||||
for(size_t i=0; i<cudaGraph.numNodes; i++) {
|
||||
if(cudaGraph.paramsRuntime[i].func == ggmlCudaCpyFn) {
|
||||
if(cudaGraph.params[i].func == ggmlCudaCpyFn) {
|
||||
char** updatedKernelArgPointer = updatedKernelArg[k++];
|
||||
cudaGraph.paramsDriver[i].kernelParams[1] = updatedKernelArgPointer;
|
||||
CU_CHECK(cuGraphKernelNodeSetParams(cudaGraph.nodes[i], &cudaGraph.paramsDriver[i]));
|
||||
cudaGraph.params[i].kernelParams[1] = updatedKernelArgPointer;
|
||||
CUDA_CHECK(cudaGraphKernelNodeSetParams(cudaGraph.nodes[i], &cudaGraph.params[i]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user