mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-27 03:44:35 +00:00
Tidied to now only use CUDA runtime (not mixed with driver calls)
This commit is contained in:
parent
c8dd0e7c1c
commit
800f4fe48e
22
ggml-cuda.cu
22
ggml-cuda.cu
@ -2418,8 +2418,7 @@ struct ggml_cudaGraph {
|
|||||||
size_t numNodes = 0;
|
size_t numNodes = 0;
|
||||||
int softmax_ne0 = 0;
|
int softmax_ne0 = 0;
|
||||||
cudaGraphNode_t nodes[MAX_NODES_IN_CUDA_GRAPH];
|
cudaGraphNode_t nodes[MAX_NODES_IN_CUDA_GRAPH];
|
||||||
CUDA_KERNEL_NODE_PARAMS_v2 paramsDriver[MAX_NODES_IN_CUDA_GRAPH];
|
cudaKernelNodeParams params[MAX_NODES_IN_CUDA_GRAPH];
|
||||||
cudaKernelNodeParams paramsRuntime[MAX_NODES_IN_CUDA_GRAPH];
|
|
||||||
};
|
};
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
@ -2523,12 +2522,10 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
|
|||||||
|
|
||||||
// Loop over nodes, and extract kernel parameters fro each node
|
// Loop over nodes, and extract kernel parameters fro each node
|
||||||
for(size_t i=0; i<cudaGraph.numNodes; i++) {
|
for(size_t i=0; i<cudaGraph.numNodes; i++) {
|
||||||
CUgraphNodeType nodeType;
|
cudaGraphNodeType nodeType;
|
||||||
CU_CHECK(cuGraphNodeGetType(cudaGraph.nodes[i], &nodeType));
|
CUDA_CHECK(cudaGraphNodeGetType(cudaGraph.nodes[i], &nodeType));
|
||||||
if (nodeType == CU_GRAPH_NODE_TYPE_KERNEL) {
|
if (nodeType == cudaGraphNodeTypeKernel) {
|
||||||
// We currently get a set of params using both driver and runtime, to work around an issue (see below)
|
auto statRT = cudaGraphKernelNodeGetParams(cudaGraph.nodes[i], &cudaGraph.params[i]); // Get params using runtime
|
||||||
CU_CHECK(cuGraphKernelNodeGetParams(cudaGraph.nodes[i], &cudaGraph.paramsDriver[i])); // Get params using driver
|
|
||||||
auto statRT = cudaGraphKernelNodeGetParams(cudaGraph.nodes[i], &cudaGraph.paramsRuntime[i]); // Get params using runtime
|
|
||||||
if(statRT == cudaErrorInvalidDeviceFunction) {
|
if(statRT == cudaErrorInvalidDeviceFunction) {
|
||||||
// Fails due to incorrect handling by CUDA runtime of CUDA BLAS node.
|
// Fails due to incorrect handling by CUDA runtime of CUDA BLAS node.
|
||||||
// We don't need to update blas nodes, so clear error and move on.
|
// We don't need to update blas nodes, so clear error and move on.
|
||||||
@ -2539,16 +2536,13 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Update copy kernel param (required every token)
|
// Update copy kernel param (required every token)
|
||||||
// Currently uses runtime copy of params to identify copy function node,
|
|
||||||
// and driver copy of params to perform the update
|
|
||||||
// TO DO work out how to do it only using runtime copy.
|
|
||||||
if(!cudaGraphUpdateRequired) { // on update steps, the live parameters will already be captured
|
if(!cudaGraphUpdateRequired) { // on update steps, the live parameters will already be captured
|
||||||
int k=0;
|
int k=0;
|
||||||
for(size_t i=0; i<cudaGraph.numNodes; i++) {
|
for(size_t i=0; i<cudaGraph.numNodes; i++) {
|
||||||
if(cudaGraph.paramsRuntime[i].func == ggmlCudaCpyFn) {
|
if(cudaGraph.params[i].func == ggmlCudaCpyFn) {
|
||||||
char** updatedKernelArgPointer = updatedKernelArg[k++];
|
char** updatedKernelArgPointer = updatedKernelArg[k++];
|
||||||
cudaGraph.paramsDriver[i].kernelParams[1] = updatedKernelArgPointer;
|
cudaGraph.params[i].kernelParams[1] = updatedKernelArgPointer;
|
||||||
CU_CHECK(cuGraphKernelNodeSetParams(cudaGraph.nodes[i], &cudaGraph.paramsDriver[i]));
|
CUDA_CHECK(cudaGraphKernelNodeSetParams(cudaGraph.nodes[i], &cudaGraph.params[i]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user