mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-13 12:10:18 +00:00
ggml : force F32 precision for ggml_mul_mat
This commit is contained in:
parent
0ef3ca2ac6
commit
4cc78d3873
38
ggml-cuda.cu
38
ggml-cuda.cu
@ -7579,8 +7579,7 @@ static void ggml_cuda_op_mul_mat_cublas(
|
||||
|
||||
const int compute_capability = g_device_caps[id].cc;
|
||||
|
||||
if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) {
|
||||
//printf("this branch\n");
|
||||
if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) {
|
||||
// convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
|
||||
cuda_pool_alloc<half> src0_as_f16;
|
||||
if (src0->type != GGML_TYPE_F16) {
|
||||
@ -7601,6 +7600,10 @@ static void ggml_cuda_op_mul_mat_cublas(
|
||||
to_fp16_cuda(src1_ddf_i, src1_as_f16.get(), ne, stream);
|
||||
}
|
||||
const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16.get();
|
||||
|
||||
switch (dst->op_params[0]) {
|
||||
case GGML_PREC_DEFAULT:
|
||||
{
|
||||
cuda_pool_alloc<half> dst_f16(row_diff*src1_ncols);
|
||||
|
||||
const half alpha_f16 = 1.0f;
|
||||
@ -7618,6 +7621,23 @@ static void ggml_cuda_op_mul_mat_cublas(
|
||||
|
||||
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
|
||||
to_fp32_cuda(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
|
||||
} break;
|
||||
case GGML_PREC_F32:
|
||||
{
|
||||
const float alpha_f32 = 1.0f;
|
||||
const float beta_f32 = 0.0f;
|
||||
|
||||
CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], stream));
|
||||
CUBLAS_CHECK(
|
||||
cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
|
||||
row_diff, src1_ncols, ne10,
|
||||
&alpha_f32, src0_ptr, CUDA_R_16F, ne00,
|
||||
src1_ptr, CUDA_R_16F, ne10,
|
||||
&beta_f32, dst_dd_i, CUDA_R_32F, ldc,
|
||||
CUBLAS_COMPUTE_32F,
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||
} break;
|
||||
}
|
||||
} else {
|
||||
cuda_pool_alloc<float> src0_ddq_as_f32;
|
||||
cuda_pool_alloc<float> src1_ddq_as_f32;
|
||||
@ -9234,6 +9254,20 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) {
|
||||
}
|
||||
|
||||
void ggml_cuda_free_data(struct ggml_tensor * tensor) {
|
||||
// print current mem usage using cudaMemGetInfo
|
||||
// TODO: this is a hack - need better solution
|
||||
{
|
||||
size_t free;
|
||||
size_t total;
|
||||
CUDA_CHECK(cudaMemGetInfo(&free, &total));
|
||||
|
||||
static size_t used = 0;
|
||||
if (used < total - free) {
|
||||
printf("CUDA: used %zu MB, free %zu MB\n", (total - free)/1024/1024, free/1024/1024);
|
||||
used = total - free;
|
||||
}
|
||||
}
|
||||
|
||||
if (!tensor || !tensor->extra || (tensor->backend != GGML_BACKEND_GPU && tensor->backend != GGML_BACKEND_GPU_SPLIT) ) {
|
||||
return;
|
||||
}
|
||||
|
6
ggml.c
6
ggml.c
@ -4077,6 +4077,12 @@ struct ggml_tensor * ggml_mul_mat(
|
||||
const int64_t ne[4] = { a->ne[1], b->ne[1], b->ne[2], b->ne[3] };
|
||||
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
||||
|
||||
// TMP: force f32 precision
|
||||
{
|
||||
const int32_t prec_i32 = GGML_PREC_F32;
|
||||
ggml_set_op_params_i32(result, 0, prec_i32);
|
||||
}
|
||||
|
||||
result->op = GGML_OP_MUL_MAT;
|
||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||
result->src[0] = a;
|
||||
|
Loading…
Reference in New Issue
Block a user