From 16b60dd75c8c89b726da5e9252454791fa1300b7 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 25 Oct 2023 14:00:21 +0300 Subject: [PATCH] cuda : add F32 sgemm branch --- ggml-cuda.cu | 38 +++++++++++++++++++++++++++++++++++--- 1 file changed, 35 insertions(+), 3 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index ca49d73bf..75e0dddf9 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -7252,7 +7252,8 @@ static void ggml_cuda_mul_mat_mat_deq_cublas(const ggml_tensor * src0, const ggm ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; float * dst_ddf = (float *) dst_extra->data_device[g_main_device]; - if (ggml_is_contiguous(src0)) { +#if 0 + { // convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32 half * src0_as_f16 = nullptr; size_t src0_as = 0; @@ -7306,9 +7307,40 @@ static void ggml_cuda_mul_mat_mat_deq_cublas(const ggml_tensor * src0, const ggm if (src1_as != 0) { ggml_cuda_pool_free(src1_as_f16, src1_as); } - } else { - GGML_ASSERT(false && "not implemented"); } +#else + { + // convert src0 to fp32, multiply as fp32 + float * src0_as_f32 = nullptr; + size_t src0_as = 0; + if (src0->type != GGML_TYPE_F32) { + const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type); + GGML_ASSERT(to_fp32_cuda != nullptr); + const size_t ne = ne01*ne00; + src0_as_f32 = (float *) ggml_cuda_pool_malloc(ne * sizeof(float), &src0_as); + to_fp32_cuda(src0_ddq, src0_as_f32, ne, main_stream); + } + + const float * src0_ptr = src0->type == GGML_TYPE_F32 ? (const float *) src0_ddq : src0_as_f32; + + const float * src1_ptr = (const float *) src1_ddf; + + const float alpha = 1.0f; + const float beta = 0.0f; + + CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], main_stream)); + CUBLAS_CHECK( + cublasSgemm(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N, + ne01, ne11, ne10, + &alpha, src0_ptr, ne00, + src1_ptr, ne10, + &beta, dst_ddf, ne01)); + + if (src0_as != 0) { + ggml_cuda_pool_free(src0_as_f32, src0_as); + } + } +#endif } static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {