mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 02:44:36 +00:00
cuda : add batched cuBLAS GEMM for faster attention (#3749)
* cmake : add helper for faster CUDA builds * batched : add NGL arg * ggml : skip nops in compute_forward * cuda : minor indentation * cuda : batched cuBLAS GEMMs for src0 F16 and src1 F32 (attention ops) * Apply suggestions from code review These changes plus: ```c++ #define cublasGemmBatchedEx hipblasGemmBatchedEx ``` are needed to compile with ROCM. I haven't done performance testing, but it seems to work. I couldn't figure out how to propose a change for lines outside what the pull changed, also this is the first time trying to create a multi-part review so please forgive me if I mess something up. * cuda : add ROCm / hipBLAS cublasGemmBatchedEx define * cuda : add cublasGemmStridedBatchedEx for non-broadcasted cases * cuda : reduce mallocs in cublasGemmBatchedEx branch * cuda : add TODO for calling cublas from kernel + using mem pool --------- Co-authored-by: Kerfuffle <44031344+KerfuffleV2@users.noreply.github.com>
This commit is contained in:
parent
daab3d7f45
commit
2b4ea35e56
@ -331,6 +331,7 @@ if (LLAMA_CUBLAS)
|
||||
set(CMAKE_CUDA_ARCHITECTURES "60;61;70") # needed for f16 CUDA intrinsics
|
||||
else()
|
||||
set(CMAKE_CUDA_ARCHITECTURES "52;61;70") # lowest CUDA 12 standard + lowest for integer intrinsics
|
||||
#set(CMAKE_CUDA_ARCHITECTURES "") # use this to compile much faster, but only F16 models work
|
||||
endif()
|
||||
endif()
|
||||
message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}")
|
||||
|
@ -11,7 +11,7 @@ int main(int argc, char ** argv) {
|
||||
gpt_params params;
|
||||
|
||||
if (argc == 1 || argv[1][0] == '-') {
|
||||
printf("usage: %s MODEL_PATH [PROMPT] [PARALLEL] [LEN]\n" , argv[0]);
|
||||
printf("usage: %s MODEL_PATH [PROMPT] [PARALLEL] [LEN] [NGL]\n" , argv[0]);
|
||||
return 1 ;
|
||||
}
|
||||
|
||||
@ -21,6 +21,9 @@ int main(int argc, char ** argv) {
|
||||
// total length of the sequences including the prompt
|
||||
int n_len = 32;
|
||||
|
||||
// number of layers to offload to the GPU
|
||||
int n_gpu_layers = 0;
|
||||
|
||||
if (argc >= 2) {
|
||||
params.model = argv[1];
|
||||
}
|
||||
@ -37,6 +40,10 @@ int main(int argc, char ** argv) {
|
||||
n_len = std::atoi(argv[4]);
|
||||
}
|
||||
|
||||
if (argc >= 6) {
|
||||
n_gpu_layers = std::atoi(argv[5]);
|
||||
}
|
||||
|
||||
if (params.prompt.empty()) {
|
||||
params.prompt = "Hello my name is";
|
||||
}
|
||||
@ -49,7 +56,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
llama_model_params model_params = llama_model_default_params();
|
||||
|
||||
// model_params.n_gpu_layers = 99; // offload all layers to the GPU
|
||||
model_params.n_gpu_layers = n_gpu_layers;
|
||||
|
||||
llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params);
|
||||
|
||||
|
190
ggml-cuda.cu
190
ggml-cuda.cu
@ -29,6 +29,8 @@
|
||||
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
|
||||
#define cublasCreate hipblasCreate
|
||||
#define cublasGemmEx hipblasGemmEx
|
||||
#define cublasGemmBatchedEx hipblasGemmBatchedEx
|
||||
#define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx
|
||||
#define cublasHandle_t hipblasHandle_t
|
||||
#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
|
||||
#define cublasSetStream hipblasSetStream
|
||||
@ -4326,13 +4328,13 @@ static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
|
||||
|
||||
const half * x = (const half *) vx;
|
||||
|
||||
const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
|
||||
const int channel = blockDim.z*blockIdx.z + threadIdx.z;
|
||||
const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
|
||||
const int channel = blockDim.z*blockIdx.z + threadIdx.z;
|
||||
const int channel_x = channel / channel_x_divisor;
|
||||
|
||||
const int nrows_y = ncols_x;
|
||||
const int nrows_y = ncols_x;
|
||||
const int nrows_dst = nrows_x;
|
||||
const int row_dst = row_x;
|
||||
const int row_dst = row_x;
|
||||
|
||||
const int idst = channel*nrows_dst + row_dst;
|
||||
|
||||
@ -4345,13 +4347,13 @@ static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
|
||||
break;
|
||||
}
|
||||
|
||||
const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x;
|
||||
const float xi = __half2float(x[ix]);
|
||||
|
||||
const int row_y = col_x;
|
||||
|
||||
const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x;
|
||||
const int iy = channel*nrows_y + row_y;
|
||||
|
||||
const float xi = __half2float(x[ix]);
|
||||
|
||||
tmp += xi * y[iy];
|
||||
}
|
||||
|
||||
@ -7013,7 +7015,8 @@ static void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tens
|
||||
}
|
||||
|
||||
static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
|
||||
GGML_ASSERT(!ggml_is_contiguous(src0) && ggml_is_contiguous(src1));
|
||||
GGML_ASSERT(!ggml_is_transposed(src0));
|
||||
GGML_ASSERT(!ggml_is_transposed(src1));
|
||||
GGML_ASSERT(!ggml_is_permuted(src0));
|
||||
GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT);
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
||||
@ -7023,11 +7026,11 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor
|
||||
const int64_t ne01 = src0->ne[1];
|
||||
const int64_t ne02 = src0->ne[2];
|
||||
|
||||
const int64_t ne12 = src1->ne[2];
|
||||
|
||||
const int64_t nb01 = src0->nb[1];
|
||||
const int64_t nb02 = src0->nb[2];
|
||||
|
||||
const int64_t ne12 = src1->ne[2];
|
||||
|
||||
CUDA_CHECK(ggml_cuda_set_device(g_main_device));
|
||||
cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
|
||||
|
||||
@ -7046,6 +7049,159 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor
|
||||
ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream);
|
||||
}
|
||||
|
||||
static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
|
||||
GGML_ASSERT(!ggml_is_transposed(src0));
|
||||
GGML_ASSERT(!ggml_is_transposed(src1));
|
||||
GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT);
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
|
||||
const int64_t ne00 = src0->ne[0]; GGML_UNUSED(ne00);
|
||||
const int64_t ne01 = src0->ne[1];
|
||||
const int64_t ne02 = src0->ne[2];
|
||||
const int64_t ne03 = src0->ne[3];
|
||||
|
||||
const int64_t nb01 = src0->nb[1];
|
||||
const int64_t nb02 = src0->nb[2]; GGML_UNUSED(nb02);
|
||||
const int64_t nb03 = src0->nb[3]; GGML_UNUSED(nb03);
|
||||
|
||||
const int64_t ne10 = src1->ne[0];
|
||||
const int64_t ne11 = src1->ne[1];
|
||||
const int64_t ne12 = src1->ne[2];
|
||||
const int64_t ne13 = src1->ne[3];
|
||||
|
||||
const int64_t nb11 = src1->nb[1];
|
||||
const int64_t nb12 = src1->nb[2]; GGML_UNUSED(nb12);
|
||||
const int64_t nb13 = src1->nb[3]; GGML_UNUSED(nb13);
|
||||
|
||||
const int64_t ne1 = ggml_nelements(src1);
|
||||
const int64_t ne = ggml_nelements(dst);
|
||||
|
||||
CUDA_CHECK(ggml_cuda_set_device(g_main_device));
|
||||
cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
|
||||
|
||||
int id;
|
||||
CUDA_CHECK(cudaGetDevice(&id));
|
||||
CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], main_stream));
|
||||
|
||||
ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
|
||||
void * src0_ddq = src0_extra->data_device[g_main_device];
|
||||
half * src0_as_f16 = (half *) src0_ddq;
|
||||
|
||||
ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
|
||||
float * src1_ddf = (float *) src1_extra->data_device[g_main_device];
|
||||
|
||||
ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
|
||||
float * dst_ddf = (float *) dst_extra->data_device[g_main_device];
|
||||
|
||||
// convert src1 to fp16
|
||||
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
|
||||
GGML_ASSERT(to_fp16_cuda != nullptr);
|
||||
|
||||
size_t src1_as = 0;
|
||||
half * src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne1 * sizeof(half), &src1_as);
|
||||
to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream);
|
||||
|
||||
size_t dst_as = 0;
|
||||
half * dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as);
|
||||
|
||||
GGML_ASSERT(ne12 % ne02 == 0);
|
||||
GGML_ASSERT(ne13 % ne03 == 0);
|
||||
|
||||
// broadcast factors
|
||||
const int64_t r2 = ne12/ne02;
|
||||
const int64_t r3 = ne13/ne03;
|
||||
|
||||
const half alpha_f16 = 1.0f;
|
||||
const half beta_f16 = 0.0f;
|
||||
|
||||
#if 0
|
||||
// use cublasGemmEx
|
||||
{
|
||||
for (int i13 = 0; i13 < ne13; ++i13) {
|
||||
for (int i12 = 0; i12 < ne12; ++i12) {
|
||||
int i03 = i13 / r3;
|
||||
int i02 = i12 / r2;
|
||||
|
||||
CUBLAS_CHECK(
|
||||
cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
|
||||
ne01, ne11, ne10,
|
||||
&alpha_f16, (const char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3] , CUDA_R_16F, nb01/sizeof(half),
|
||||
(const char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F, nb11/sizeof(float),
|
||||
&beta_f16, ( char *) dst_f16 + i12* dst->nb[2]/2 + i13* dst->nb[3]/2, CUDA_R_16F, ne01,
|
||||
CUBLAS_COMPUTE_16F,
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
if (r2 == 1 && r3 == 1 && src0->nb[2]*src0->ne[2] == src0->nb[3] && src1->nb[2]*src1->ne[2] == src1->nb[3]) {
|
||||
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
|
||||
// use cublasGemmStridedBatchedEx
|
||||
CUBLAS_CHECK(
|
||||
cublasGemmStridedBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
|
||||
ne01, ne11, ne10,
|
||||
&alpha_f16, (const char *) src0_as_f16, CUDA_R_16F, nb01/sizeof(half), src0->nb[2]/sizeof(half), // strideA
|
||||
(const char *) src1_as_f16, CUDA_R_16F, nb11/sizeof(float), src1->nb[2]/sizeof(float), // strideB
|
||||
&beta_f16, ( char *) dst_f16, CUDA_R_16F, ne01, dst->nb[2]/sizeof(float), // strideC
|
||||
ne12*ne13,
|
||||
CUBLAS_COMPUTE_16F,
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||
} else {
|
||||
// use cublasGemmBatchedEx
|
||||
// TODO: https://github.com/ggerganov/llama.cpp/pull/3749#discussion_r1369997000
|
||||
const int ne23 = ne12*ne13;
|
||||
|
||||
// TODO: avoid this alloc
|
||||
void ** ptrs = (void **) malloc(3*ne23*sizeof(void *));
|
||||
|
||||
for (int i13 = 0; i13 < ne13; ++i13) {
|
||||
for (int i12 = 0; i12 < ne12; ++i12) {
|
||||
int i03 = i13 / r3;
|
||||
int i02 = i12 / r2;
|
||||
|
||||
ptrs[0*ne23 + i12 + i13*ne12] = (char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3];
|
||||
ptrs[1*ne23 + i12 + i13*ne12] = (char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2;
|
||||
ptrs[2*ne23 + i12 + i13*ne12] = (char *) dst_f16 + i12* dst->nb[2]/2 + i13* dst->nb[3]/2;
|
||||
}
|
||||
}
|
||||
|
||||
// allocate device memory for pointers
|
||||
void ** ptrs_as = nullptr;
|
||||
CUDA_CHECK(cudaMalloc(&ptrs_as, 3*ne23*sizeof(void *)));
|
||||
|
||||
// TODO: this does not work for some reason -- not sure why?
|
||||
//size_t ptrs_s = 0;
|
||||
//ptrs_as = (void **) ggml_cuda_pool_malloc(3*ne23*sizeof(void *), &ptrs_s);
|
||||
|
||||
// copy pointers to device
|
||||
CUDA_CHECK(cudaMemcpy(ptrs_as, ptrs, 3*ne23*sizeof(void *), cudaMemcpyHostToDevice));
|
||||
|
||||
free(ptrs);
|
||||
|
||||
CUBLAS_CHECK(
|
||||
cublasGemmBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
|
||||
ne01, ne11, ne10,
|
||||
&alpha_f16, (const void **) (ptrs_as + 0*ne23), CUDA_R_16F, nb01/sizeof(half),
|
||||
(const void **) (ptrs_as + 1*ne23), CUDA_R_16F, nb11/sizeof(float),
|
||||
&beta_f16, ( void **) (ptrs_as + 2*ne23), CUDA_R_16F, ne01,
|
||||
ne23,
|
||||
CUBLAS_COMPUTE_16F,
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||
|
||||
// free device memory for pointers
|
||||
CUDA_CHECK(cudaFree(ptrs_as));
|
||||
//ggml_cuda_pool_free(ptrs_as, ptrs_s);
|
||||
}
|
||||
#endif
|
||||
|
||||
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
|
||||
to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream);
|
||||
|
||||
ggml_cuda_pool_free(src1_as_f16, src1_as);
|
||||
ggml_cuda_pool_free(dst_f16, dst_as);
|
||||
}
|
||||
|
||||
static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
bool all_on_device = (src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT) &&
|
||||
src1->backend == GGML_BACKEND_GPU && dst->backend == GGML_BACKEND_GPU;
|
||||
@ -7058,10 +7214,22 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
|
||||
}
|
||||
}
|
||||
|
||||
// debug helpers
|
||||
//printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]);
|
||||
//printf(" %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);
|
||||
//printf("src1: %8d %8d %8d %8d\n", src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3]);
|
||||
//printf(" %8d %8d %8d %8d\n", src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3]);
|
||||
//printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
|
||||
//printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
|
||||
|
||||
if (all_on_device && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
|
||||
// KQ
|
||||
ggml_cuda_mul_mat_vec_p021(src0, src1, dst);
|
||||
} else if (all_on_device && !ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && src1->ne[1] == 1) {
|
||||
} else if (all_on_device && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
|
||||
// KQV
|
||||
ggml_cuda_mul_mat_vec_nc(src0, src1, dst);
|
||||
} else if (all_on_device && src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
|
||||
ggml_cuda_mul_mat_mat_batched_cublas(src0, src1, dst);
|
||||
} else if (src0->type == GGML_TYPE_F32) {
|
||||
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false);
|
||||
} else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) {
|
||||
|
4
ggml.c
4
ggml.c
@ -16602,6 +16602,10 @@ static void ggml_compute_forward_cross_entropy_loss_back(
|
||||
static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
|
||||
GGML_ASSERT(params);
|
||||
|
||||
if (tensor->op == GGML_OP_NONE) {
|
||||
return;
|
||||
}
|
||||
|
||||
#ifdef GGML_USE_CUBLAS
|
||||
bool skip_cpu = ggml_cuda_compute_forward(params, tensor);
|
||||
if (skip_cpu) {
|
||||
|
Loading…
Reference in New Issue
Block a user