mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-12 19:50:17 +00:00
DRAFT: Introduction of CUDA Graphs to LLama.cpp
This commit is contained in:
parent
637e9a86c2
commit
cec409aa98
112
ggml-cuda.cu
112
ggml-cuda.cu
@ -2405,11 +2405,63 @@ GGML_CALL static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
|
||||
GGML_UNUSED(backend);
|
||||
}
|
||||
|
||||
struct ggml_cudaGraph {
|
||||
int count=0;
|
||||
cudaGraph_t graph = nullptr;
|
||||
cudaGraphExec_t instance = nullptr;
|
||||
size_t numNodes = 0;
|
||||
int softmax_ne0 = 0;
|
||||
};
|
||||
|
||||
GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
|
||||
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
|
||||
|
||||
ggml_cuda_set_device(cuda_ctx->device);
|
||||
|
||||
// Objects required for CUDA Graph
|
||||
#define MAX_NODES_IN_CUDA_GRAPH 10000
|
||||
static ggml_cudaGraph cudaGraph; //TO DO move this to a suitable persistant location (and avoid use of static memory)
|
||||
bool useCudaGraph = (cudaGraph.count>=2); //avoid CUDA graphs on first 2 steps due to incompatible initialisations.
|
||||
char** updatedKernelArg[MAX_NODES_IN_CUDA_GRAPH];
|
||||
bool cudaGraphUpdateRequired = false;
|
||||
// pointer to CUDA cpy kernel, which is required to identify
|
||||
// kernel parameters which need updated in the graph for each token
|
||||
void* ggmlCudaCpyFn = nullptr;
|
||||
if(useCudaGraph) {
|
||||
|
||||
if(cudaGraph.instance == nullptr) cudaGraphUpdateRequired=true;
|
||||
|
||||
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
|
||||
int k=0;
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
ggml_tensor * node = cgraph->nodes[i];
|
||||
// Identify if the graph needs updated for this token due to the number of elements changing
|
||||
// (identified by inspecting soft max op parameters)
|
||||
if(node->op == GGML_OP_SOFT_MAX) {
|
||||
if(node->src[0]->ne[0] != cudaGraph.softmax_ne0) {
|
||||
cudaGraphUpdateRequired = true;
|
||||
cudaGraph.softmax_ne0 = node->src[0]->ne[0];
|
||||
}
|
||||
}
|
||||
if(node->op == GGML_OP_CPY) {
|
||||
// store the copy op parameter which changes with each token.
|
||||
updatedKernelArg[k++]=(char**) &(node->src[1]->data);
|
||||
if(ggmlCudaCpyFn == nullptr){
|
||||
// store a pointer to the copy op CUDA kernel to identify it later
|
||||
ggmlCudaCpyFn = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if(useCudaGraph && cudaGraphUpdateRequired) { // Start CUDA graph capture
|
||||
CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeGlobal));
|
||||
}
|
||||
|
||||
// Only perfom the graph exection if CUDA graphs are not enebled, or we are capturing the graph.
|
||||
// With use of CUDA graphs, the execution will be performed by the graph launch.
|
||||
if(!useCudaGraph || cudaGraphUpdateRequired) {
|
||||
//temporarily avoid indenting here to make code review easier
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
ggml_tensor * node = cgraph->nodes[i];
|
||||
|
||||
@ -2432,7 +2484,67 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
|
||||
}
|
||||
GGML_ASSERT(ok);
|
||||
}
|
||||
}
|
||||
|
||||
if(useCudaGraph && (cudaGraphUpdateRequired)) { // End CUDA graph capture
|
||||
CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cudaGraph.graph));
|
||||
}
|
||||
if(useCudaGraph){
|
||||
|
||||
if(cudaGraph.instance == nullptr) { // Create executable graph from captured graph.
|
||||
CUDA_CHECK(cudaGraphInstantiate(&cudaGraph.instance, cudaGraph.graph, NULL, NULL, 0));
|
||||
}
|
||||
|
||||
|
||||
// Perform update to graph (if required for this token), and change copy parameter (required for every token)
|
||||
|
||||
cudaGraphNode_t nodes[MAX_NODES_IN_CUDA_GRAPH];
|
||||
CUDA_KERNEL_NODE_PARAMS_v2 paramsDriver[MAX_NODES_IN_CUDA_GRAPH];
|
||||
cudaKernelNodeParams paramsRuntime[MAX_NODES_IN_CUDA_GRAPH];
|
||||
|
||||
if(cudaGraphUpdateRequired) {
|
||||
// Extract nodes from graph
|
||||
if(cudaGraph.numNodes == 0) {
|
||||
CUDA_CHECK(cudaGraphGetNodes(cudaGraph.graph, nullptr, &cudaGraph.numNodes));
|
||||
}
|
||||
CUDA_CHECK(cudaGraphGetNodes(cudaGraph.graph, nodes, &cudaGraph.numNodes));
|
||||
|
||||
// Loop over nodes, and extract kernel parameters fro each node
|
||||
for(size_t i=0; i<cudaGraph.numNodes; i++) {
|
||||
// We currently get a set of params using both driver and runtime, to work around an issue (see below)
|
||||
CU_CHECK(cuGraphKernelNodeGetParams(nodes[i], ¶msDriver[i])); // Get params using driver
|
||||
cudaError_t statRT = cudaGraphKernelNodeGetParams(nodes[i], ¶msRuntime[i]); // Get params using runtime
|
||||
if(statRT == 98) {
|
||||
// Fails due to incorrect handling by CUDA runtime of CUDA BLAS node.
|
||||
// We don't need to update blas nodes, so clear error and move on.
|
||||
cudaGetLastError();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update copy kernel param (required every token)
|
||||
// Currently uses runtime copy of params to identify copy function node,
|
||||
// and driver copy of params to perform the update
|
||||
// TO DO work out how to do it only using runtime copy.
|
||||
if(!cudaGraphUpdateRequired) { // on update steps, the live parameters will already be captured
|
||||
int k=0;
|
||||
for(size_t i=0; i<cudaGraph.numNodes; i++) {
|
||||
if(paramsRuntime[i].func == ggmlCudaCpyFn) {
|
||||
char** updatedKernelArgPointer = updatedKernelArg[k++];
|
||||
paramsDriver[i].kernelParams[1] = updatedKernelArgPointer;
|
||||
CU_CHECK(cuGraphKernelNodeSetParams(nodes[i], ¶msDriver[i]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update graph executable
|
||||
cudaGraphExecUpdateResultInfo resultInfo;
|
||||
CUDA_CHECK(cudaGraphExecUpdate(cudaGraph.instance, cudaGraph.graph, &resultInfo));
|
||||
|
||||
// Launch graph
|
||||
CUDA_CHECK(cudaGraphLaunch(cudaGraph.instance, cuda_ctx->stream()));
|
||||
}
|
||||
cudaGraph.count++;
|
||||
return GGML_STATUS_SUCCESS;
|
||||
}
|
||||
|
||||
|
@ -459,3 +459,32 @@ void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
ggml_cuda_cpy(ctx, src0, dst);
|
||||
}
|
||||
|
||||
void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
|
||||
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
|
||||
return (void*) cpy_f32_f16<cpy_1_f32_f32>;
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
|
||||
return (void*) cpy_f32_f16<cpy_1_f32_f16>;
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
|
||||
return (void*) cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>;
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
|
||||
return (void*) cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>;
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
|
||||
return (void*) cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>;
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
|
||||
return (void*) cpy_f32_q<cpy_blck_f32_q5_0, QK5_0>;
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
|
||||
return (void*) cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>;
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
|
||||
return (void*) cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>;
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
|
||||
return (void*) cpy_f32_f16<cpy_1_f32_f16>;
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
|
||||
return (void*) cpy_f32_f16<cpy_1_f16_f32>;
|
||||
} else {
|
||||
fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
|
||||
ggml_type_name(src0->type), ggml_type_name(src1->type));
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -5,3 +5,5 @@
|
||||
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1);
|
||||
|
||||
void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1);
|
||||
|
Loading…
Reference in New Issue
Block a user