mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 19:04:35 +00:00
cuda : try to fix main device write
This commit is contained in:
parent
1a0843c493
commit
706ff4c2e0
51
ggml-cuda.cu
51
ggml-cuda.cu
@ -6355,11 +6355,13 @@ inline void ggml_cuda_op_mul_mat_cublas(
|
||||
int id;
|
||||
CUDA_CHECK(cudaGetDevice(&id));
|
||||
|
||||
const int compute_capability = g_compute_capabilities[id];
|
||||
|
||||
// the main device has a larger memory buffer to hold the results from all GPUs
|
||||
// ldc == nrows of the matrix that cuBLAS writes into
|
||||
int ldc = dst->backend == GGML_BACKEND_GPU && id == g_main_device ? ne0 : row_diff;
|
||||
const int ldc = dst->backend == GGML_BACKEND_GPU && id == g_main_device ? ne0 : row_diff;
|
||||
|
||||
const int compute_capability = g_compute_capabilities[id];
|
||||
const bool is_split = row_diff != src0->ne[1];
|
||||
|
||||
if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0)) {
|
||||
// convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
|
||||
@ -6385,26 +6387,41 @@ inline void ggml_cuda_op_mul_mat_cublas(
|
||||
}
|
||||
const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddq_i : src1_as_f16;
|
||||
|
||||
size_t dst_as = 0;
|
||||
half * dst_f16 = (half *) ggml_cuda_pool_malloc(ldc*src1_ncols * sizeof(half), &dst_as);
|
||||
if (!is_split) {
|
||||
const half alpha = 1.0f;
|
||||
const half beta = 0.0f;
|
||||
|
||||
const half alpha_f16 = 1.0f;
|
||||
const half beta_f16 = 0.0f;
|
||||
size_t dst_as = 0;
|
||||
half * dst_f16 = (half *) ggml_cuda_pool_malloc(row_diff*src1_ncols * sizeof(half), &dst_as);
|
||||
|
||||
CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], stream));
|
||||
CUBLAS_CHECK(
|
||||
cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
|
||||
row_diff, src1_ncols, ne10,
|
||||
&alpha_f16, src0_ptr, CUDA_R_16F, ne00,
|
||||
CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], stream));
|
||||
CUBLAS_CHECK(
|
||||
cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
|
||||
row_diff, src1_ncols, ne10,
|
||||
&alpha, src0_ptr, CUDA_R_16F, ne00,
|
||||
src1_ptr, CUDA_R_16F, ne10,
|
||||
&beta_f16, dst_f16, CUDA_R_16F, ldc,
|
||||
CUBLAS_COMPUTE_16F,
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||
&beta, dst_f16, CUDA_R_16F, ldc,
|
||||
CUBLAS_COMPUTE_16F,
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||
|
||||
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
|
||||
to_fp32_cuda(dst_f16, dst_dd_i, ldc*src1_ncols, stream);
|
||||
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
|
||||
to_fp32_cuda(dst_f16, dst_dd_i, row_diff*src1_ncols, stream);
|
||||
|
||||
ggml_cuda_pool_free(dst_f16, dst_as);
|
||||
ggml_cuda_pool_free(dst_f16, dst_as);
|
||||
} else {
|
||||
const float alpha = 1.0f;
|
||||
const float beta = 0.0f;
|
||||
|
||||
CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], stream));
|
||||
CUBLAS_CHECK(
|
||||
cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
|
||||
row_diff, src1_ncols, ne10,
|
||||
&alpha, src0_ptr, CUDA_R_16F, ne00,
|
||||
src1_ptr, CUDA_R_16F, ne10,
|
||||
&beta, dst_dd_i, CUDA_R_32F, ldc,
|
||||
CUBLAS_COMPUTE_32F,
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||
}
|
||||
|
||||
if (src0_as != 0) {
|
||||
ggml_cuda_pool_free(src0_as_f16, src0_as);
|
||||
|
Loading…
Reference in New Issue
Block a user