mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-27 03:44:35 +00:00
FIx issues raised in comments
This commit is contained in:
parent
cec409aa98
commit
c8dd0e7c1c
54
ggml-cuda.cu
54
ggml-cuda.cu
@ -2405,23 +2405,33 @@ GGML_CALL static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
|
|||||||
GGML_UNUSED(backend);
|
GGML_UNUSED(backend);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if (CUDART_VERSION >= 12000)
|
||||||
|
#define USE_CUDA_GRAPH
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef USE_CUDA_GRAPH
|
||||||
|
#define MAX_NODES_IN_CUDA_GRAPH 10000
|
||||||
struct ggml_cudaGraph {
|
struct ggml_cudaGraph {
|
||||||
int count=0;
|
int count=0;
|
||||||
cudaGraph_t graph = nullptr;
|
cudaGraph_t graph = nullptr;
|
||||||
cudaGraphExec_t instance = nullptr;
|
cudaGraphExec_t instance = nullptr;
|
||||||
size_t numNodes = 0;
|
size_t numNodes = 0;
|
||||||
int softmax_ne0 = 0;
|
int softmax_ne0 = 0;
|
||||||
|
cudaGraphNode_t nodes[MAX_NODES_IN_CUDA_GRAPH];
|
||||||
|
CUDA_KERNEL_NODE_PARAMS_v2 paramsDriver[MAX_NODES_IN_CUDA_GRAPH];
|
||||||
|
cudaKernelNodeParams paramsRuntime[MAX_NODES_IN_CUDA_GRAPH];
|
||||||
};
|
};
|
||||||
|
#endif
|
||||||
|
|
||||||
GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
|
GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
|
||||||
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
|
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
|
||||||
|
|
||||||
ggml_cuda_set_device(cuda_ctx->device);
|
ggml_cuda_set_device(cuda_ctx->device);
|
||||||
|
|
||||||
|
#ifdef USE_CUDA_GRAPH
|
||||||
// Objects required for CUDA Graph
|
// Objects required for CUDA Graph
|
||||||
#define MAX_NODES_IN_CUDA_GRAPH 10000
|
static ggml_cudaGraph cudaGraph;
|
||||||
static ggml_cudaGraph cudaGraph; //TO DO move this to a suitable persistant location (and avoid use of static memory)
|
bool useCudaGraph = (cudaGraph.count>=7); //avoid CUDA graphs on first few steps due to incompatible initialisations.
|
||||||
bool useCudaGraph = (cudaGraph.count>=2); //avoid CUDA graphs on first 2 steps due to incompatible initialisations.
|
|
||||||
char** updatedKernelArg[MAX_NODES_IN_CUDA_GRAPH];
|
char** updatedKernelArg[MAX_NODES_IN_CUDA_GRAPH];
|
||||||
bool cudaGraphUpdateRequired = false;
|
bool cudaGraphUpdateRequired = false;
|
||||||
// pointer to CUDA cpy kernel, which is required to identify
|
// pointer to CUDA cpy kernel, which is required to identify
|
||||||
@ -2458,6 +2468,11 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
|
|||||||
CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeGlobal));
|
CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeGlobal));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#else
|
||||||
|
bool useCudaGraph = false;
|
||||||
|
bool cudaGraphUpdateRequired = false;
|
||||||
|
#endif
|
||||||
|
|
||||||
// Only perfom the graph exection if CUDA graphs are not enebled, or we are capturing the graph.
|
// Only perfom the graph exection if CUDA graphs are not enebled, or we are capturing the graph.
|
||||||
// With use of CUDA graphs, the execution will be performed by the graph launch.
|
// With use of CUDA graphs, the execution will be performed by the graph launch.
|
||||||
if(!useCudaGraph || cudaGraphUpdateRequired) {
|
if(!useCudaGraph || cudaGraphUpdateRequired) {
|
||||||
@ -2486,6 +2501,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef USE_CUDA_GRAPH
|
||||||
if(useCudaGraph && (cudaGraphUpdateRequired)) { // End CUDA graph capture
|
if(useCudaGraph && (cudaGraphUpdateRequired)) { // End CUDA graph capture
|
||||||
CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cudaGraph.graph));
|
CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cudaGraph.graph));
|
||||||
}
|
}
|
||||||
@ -2498,29 +2514,29 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
|
|||||||
|
|
||||||
// Perform update to graph (if required for this token), and change copy parameter (required for every token)
|
// Perform update to graph (if required for this token), and change copy parameter (required for every token)
|
||||||
|
|
||||||
cudaGraphNode_t nodes[MAX_NODES_IN_CUDA_GRAPH];
|
|
||||||
CUDA_KERNEL_NODE_PARAMS_v2 paramsDriver[MAX_NODES_IN_CUDA_GRAPH];
|
|
||||||
cudaKernelNodeParams paramsRuntime[MAX_NODES_IN_CUDA_GRAPH];
|
|
||||||
|
|
||||||
if(cudaGraphUpdateRequired) {
|
if(cudaGraphUpdateRequired) {
|
||||||
// Extract nodes from graph
|
// Extract nodes from graph
|
||||||
if(cudaGraph.numNodes == 0) {
|
if(cudaGraph.numNodes == 0) {
|
||||||
CUDA_CHECK(cudaGraphGetNodes(cudaGraph.graph, nullptr, &cudaGraph.numNodes));
|
CUDA_CHECK(cudaGraphGetNodes(cudaGraph.graph, nullptr, &cudaGraph.numNodes));
|
||||||
}
|
}
|
||||||
CUDA_CHECK(cudaGraphGetNodes(cudaGraph.graph, nodes, &cudaGraph.numNodes));
|
CUDA_CHECK(cudaGraphGetNodes(cudaGraph.graph, cudaGraph.nodes, &cudaGraph.numNodes));
|
||||||
|
|
||||||
// Loop over nodes, and extract kernel parameters fro each node
|
// Loop over nodes, and extract kernel parameters fro each node
|
||||||
for(size_t i=0; i<cudaGraph.numNodes; i++) {
|
for(size_t i=0; i<cudaGraph.numNodes; i++) {
|
||||||
|
CUgraphNodeType nodeType;
|
||||||
|
CU_CHECK(cuGraphNodeGetType(cudaGraph.nodes[i], &nodeType));
|
||||||
|
if (nodeType == CU_GRAPH_NODE_TYPE_KERNEL) {
|
||||||
// We currently get a set of params using both driver and runtime, to work around an issue (see below)
|
// We currently get a set of params using both driver and runtime, to work around an issue (see below)
|
||||||
CU_CHECK(cuGraphKernelNodeGetParams(nodes[i], ¶msDriver[i])); // Get params using driver
|
CU_CHECK(cuGraphKernelNodeGetParams(cudaGraph.nodes[i], &cudaGraph.paramsDriver[i])); // Get params using driver
|
||||||
cudaError_t statRT = cudaGraphKernelNodeGetParams(nodes[i], ¶msRuntime[i]); // Get params using runtime
|
auto statRT = cudaGraphKernelNodeGetParams(cudaGraph.nodes[i], &cudaGraph.paramsRuntime[i]); // Get params using runtime
|
||||||
if(statRT == 98) {
|
if(statRT == cudaErrorInvalidDeviceFunction) {
|
||||||
// Fails due to incorrect handling by CUDA runtime of CUDA BLAS node.
|
// Fails due to incorrect handling by CUDA runtime of CUDA BLAS node.
|
||||||
// We don't need to update blas nodes, so clear error and move on.
|
// We don't need to update blas nodes, so clear error and move on.
|
||||||
cudaGetLastError();
|
cudaGetLastError();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Update copy kernel param (required every token)
|
// Update copy kernel param (required every token)
|
||||||
// Currently uses runtime copy of params to identify copy function node,
|
// Currently uses runtime copy of params to identify copy function node,
|
||||||
@ -2529,22 +2545,30 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
|
|||||||
if(!cudaGraphUpdateRequired) { // on update steps, the live parameters will already be captured
|
if(!cudaGraphUpdateRequired) { // on update steps, the live parameters will already be captured
|
||||||
int k=0;
|
int k=0;
|
||||||
for(size_t i=0; i<cudaGraph.numNodes; i++) {
|
for(size_t i=0; i<cudaGraph.numNodes; i++) {
|
||||||
if(paramsRuntime[i].func == ggmlCudaCpyFn) {
|
if(cudaGraph.paramsRuntime[i].func == ggmlCudaCpyFn) {
|
||||||
char** updatedKernelArgPointer = updatedKernelArg[k++];
|
char** updatedKernelArgPointer = updatedKernelArg[k++];
|
||||||
paramsDriver[i].kernelParams[1] = updatedKernelArgPointer;
|
cudaGraph.paramsDriver[i].kernelParams[1] = updatedKernelArgPointer;
|
||||||
CU_CHECK(cuGraphKernelNodeSetParams(nodes[i], ¶msDriver[i]));
|
CU_CHECK(cuGraphKernelNodeSetParams(cudaGraph.nodes[i], &cudaGraph.paramsDriver[i]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update graph executable
|
// Update graph executable
|
||||||
cudaGraphExecUpdateResultInfo resultInfo;
|
cudaGraphExecUpdateResultInfo resultInfo;
|
||||||
CUDA_CHECK(cudaGraphExecUpdate(cudaGraph.instance, cudaGraph.graph, &resultInfo));
|
auto stat = cudaGraphExecUpdate(cudaGraph.instance, cudaGraph.graph, &resultInfo);
|
||||||
|
if(stat == cudaErrorGraphExecUpdateFailure)
|
||||||
|
{
|
||||||
|
// The pre-existing graph exec cannot be updated due to violated constraints
|
||||||
|
// so instead clar error and re-instantiate
|
||||||
|
cudaGetLastError();
|
||||||
|
CUDA_CHECK(cudaGraphInstantiate(&cudaGraph.instance, cudaGraph.graph, NULL, NULL, 0));
|
||||||
|
}
|
||||||
|
|
||||||
// Launch graph
|
// Launch graph
|
||||||
CUDA_CHECK(cudaGraphLaunch(cudaGraph.instance, cuda_ctx->stream()));
|
CUDA_CHECK(cudaGraphLaunch(cudaGraph.instance, cuda_ctx->stream()));
|
||||||
}
|
}
|
||||||
cudaGraph.count++;
|
cudaGraph.count++;
|
||||||
|
#endif
|
||||||
return GGML_STATUS_SUCCESS;
|
return GGML_STATUS_SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user