cuda : add F32 sgemm branch

This commit is contained in:
Georgi Gerganov 2023-10-25 14:00:21 +03:00
parent 52af782608
commit 16b60dd75c
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -7252,7 +7252,8 @@ static void ggml_cuda_mul_mat_mat_deq_cublas(const ggml_tensor * src0, const ggm
ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
float * dst_ddf = (float *) dst_extra->data_device[g_main_device];
if (ggml_is_contiguous(src0)) {
#if 0
{
// convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
half * src0_as_f16 = nullptr;
size_t src0_as = 0;
@ -7306,9 +7307,40 @@ static void ggml_cuda_mul_mat_mat_deq_cublas(const ggml_tensor * src0, const ggm
if (src1_as != 0) {
ggml_cuda_pool_free(src1_as_f16, src1_as);
}
} else {
GGML_ASSERT(false && "not implemented");
}
#else
{
// convert src0 to fp32, multiply as fp32
float * src0_as_f32 = nullptr;
size_t src0_as = 0;
if (src0->type != GGML_TYPE_F32) {
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
GGML_ASSERT(to_fp32_cuda != nullptr);
const size_t ne = ne01*ne00;
src0_as_f32 = (float *) ggml_cuda_pool_malloc(ne * sizeof(float), &src0_as);
to_fp32_cuda(src0_ddq, src0_as_f32, ne, main_stream);
}
const float * src0_ptr = src0->type == GGML_TYPE_F32 ? (const float *) src0_ddq : src0_as_f32;
const float * src1_ptr = (const float *) src1_ddf;
const float alpha = 1.0f;
const float beta = 0.0f;
CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], main_stream));
CUBLAS_CHECK(
cublasSgemm(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
&alpha, src0_ptr, ne00,
src1_ptr, ne10,
&beta, dst_ddf, ne01));
if (src0_as != 0) {
ggml_cuda_pool_free(src0_as_f32, src0_as);
}
}
#endif
}
static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {