mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-26 11:24:35 +00:00
cuda : add F32 sgemm branch
This commit is contained in:
parent
52af782608
commit
16b60dd75c
38
ggml-cuda.cu
38
ggml-cuda.cu
@ -7252,7 +7252,8 @@ static void ggml_cuda_mul_mat_mat_deq_cublas(const ggml_tensor * src0, const ggm
|
||||
ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
|
||||
float * dst_ddf = (float *) dst_extra->data_device[g_main_device];
|
||||
|
||||
if (ggml_is_contiguous(src0)) {
|
||||
#if 0
|
||||
{
|
||||
// convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
|
||||
half * src0_as_f16 = nullptr;
|
||||
size_t src0_as = 0;
|
||||
@ -7306,9 +7307,40 @@ static void ggml_cuda_mul_mat_mat_deq_cublas(const ggml_tensor * src0, const ggm
|
||||
if (src1_as != 0) {
|
||||
ggml_cuda_pool_free(src1_as_f16, src1_as);
|
||||
}
|
||||
} else {
|
||||
GGML_ASSERT(false && "not implemented");
|
||||
}
|
||||
#else
|
||||
{
|
||||
// convert src0 to fp32, multiply as fp32
|
||||
float * src0_as_f32 = nullptr;
|
||||
size_t src0_as = 0;
|
||||
if (src0->type != GGML_TYPE_F32) {
|
||||
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
|
||||
GGML_ASSERT(to_fp32_cuda != nullptr);
|
||||
const size_t ne = ne01*ne00;
|
||||
src0_as_f32 = (float *) ggml_cuda_pool_malloc(ne * sizeof(float), &src0_as);
|
||||
to_fp32_cuda(src0_ddq, src0_as_f32, ne, main_stream);
|
||||
}
|
||||
|
||||
const float * src0_ptr = src0->type == GGML_TYPE_F32 ? (const float *) src0_ddq : src0_as_f32;
|
||||
|
||||
const float * src1_ptr = (const float *) src1_ddf;
|
||||
|
||||
const float alpha = 1.0f;
|
||||
const float beta = 0.0f;
|
||||
|
||||
CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], main_stream));
|
||||
CUBLAS_CHECK(
|
||||
cublasSgemm(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
|
||||
ne01, ne11, ne10,
|
||||
&alpha, src0_ptr, ne00,
|
||||
src1_ptr, ne10,
|
||||
&beta, dst_ddf, ne01));
|
||||
|
||||
if (src0_as != 0) {
|
||||
ggml_cuda_pool_free(src0_as_f32, src0_as);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
|
Loading…
Reference in New Issue
Block a user