mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-27 03:44:35 +00:00
cuda : add F32 sgemm branch
This commit is contained in:
parent
52af782608
commit
16b60dd75c
38
ggml-cuda.cu
38
ggml-cuda.cu
@ -7252,7 +7252,8 @@ static void ggml_cuda_mul_mat_mat_deq_cublas(const ggml_tensor * src0, const ggm
|
|||||||
ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
|
ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
|
||||||
float * dst_ddf = (float *) dst_extra->data_device[g_main_device];
|
float * dst_ddf = (float *) dst_extra->data_device[g_main_device];
|
||||||
|
|
||||||
if (ggml_is_contiguous(src0)) {
|
#if 0
|
||||||
|
{
|
||||||
// convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
|
// convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
|
||||||
half * src0_as_f16 = nullptr;
|
half * src0_as_f16 = nullptr;
|
||||||
size_t src0_as = 0;
|
size_t src0_as = 0;
|
||||||
@ -7306,9 +7307,40 @@ static void ggml_cuda_mul_mat_mat_deq_cublas(const ggml_tensor * src0, const ggm
|
|||||||
if (src1_as != 0) {
|
if (src1_as != 0) {
|
||||||
ggml_cuda_pool_free(src1_as_f16, src1_as);
|
ggml_cuda_pool_free(src1_as_f16, src1_as);
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
GGML_ASSERT(false && "not implemented");
|
|
||||||
}
|
}
|
||||||
|
#else
|
||||||
|
{
|
||||||
|
// convert src0 to fp32, multiply as fp32
|
||||||
|
float * src0_as_f32 = nullptr;
|
||||||
|
size_t src0_as = 0;
|
||||||
|
if (src0->type != GGML_TYPE_F32) {
|
||||||
|
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
|
||||||
|
GGML_ASSERT(to_fp32_cuda != nullptr);
|
||||||
|
const size_t ne = ne01*ne00;
|
||||||
|
src0_as_f32 = (float *) ggml_cuda_pool_malloc(ne * sizeof(float), &src0_as);
|
||||||
|
to_fp32_cuda(src0_ddq, src0_as_f32, ne, main_stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
const float * src0_ptr = src0->type == GGML_TYPE_F32 ? (const float *) src0_ddq : src0_as_f32;
|
||||||
|
|
||||||
|
const float * src1_ptr = (const float *) src1_ddf;
|
||||||
|
|
||||||
|
const float alpha = 1.0f;
|
||||||
|
const float beta = 0.0f;
|
||||||
|
|
||||||
|
CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], main_stream));
|
||||||
|
CUBLAS_CHECK(
|
||||||
|
cublasSgemm(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
|
||||||
|
ne01, ne11, ne10,
|
||||||
|
&alpha, src0_ptr, ne00,
|
||||||
|
src1_ptr, ne10,
|
||||||
|
&beta, dst_ddf, ne01));
|
||||||
|
|
||||||
|
if (src0_as != 0) {
|
||||||
|
ggml_cuda_pool_free(src0_as_f32, src0_as);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
|
Loading…
Reference in New Issue
Block a user