mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 10:54:36 +00:00
ggml : fix some mul mat cases + add tests for src1 F16 (ggml/669)
* fixed mul-mat error for old GPUs * style fixes * add mul mat src1 f16 test cases, fix more cases ggml-ci --------- Co-authored-by: bssrdf <bssrdf@gmail.com> Co-authored-by: slaren <slarengh@gmail.com>
This commit is contained in:
parent
ca38b8d334
commit
afc8c19291
@ -614,10 +614,14 @@ static void ggml_backend_cpu_graph_compute(ggml_backend_t backend, struct ggml_c
|
||||
}
|
||||
|
||||
static bool ggml_backend_cpu_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
|
||||
return true;
|
||||
switch (op->op) {
|
||||
case GGML_OP_MUL_MAT:
|
||||
return op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == ggml_internal_get_type_traits(op->src[0]->type).vec_dot_type;
|
||||
default:
|
||||
return true;
|
||||
}
|
||||
|
||||
GGML_UNUSED(backend);
|
||||
GGML_UNUSED(op);
|
||||
}
|
||||
|
||||
static struct ggml_backend_i cpu_backend_i = {
|
||||
|
89
ggml-cuda.cu
89
ggml-cuda.cu
@ -7485,6 +7485,8 @@ static void ggml_cuda_op_dequantize_mul_mat_vec(
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t row_diff = row_high - row_low;
|
||||
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
|
||||
// on some GPUs it is faster to convert src1 to half and to use half precision intrinsics
|
||||
#ifdef GGML_CUDA_F16
|
||||
cuda_pool_alloc<half> src1_dfloat_a;
|
||||
@ -7577,6 +7579,7 @@ static void ggml_cuda_op_mul_mat_cublas(
|
||||
const int compute_capability = g_device_caps[id].cc;
|
||||
|
||||
if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) {
|
||||
//printf("this branch\n");
|
||||
// convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
|
||||
cuda_pool_alloc<half> src0_as_f16;
|
||||
if (src0->type != GGML_TYPE_F16) {
|
||||
@ -7614,9 +7617,9 @@ static void ggml_cuda_op_mul_mat_cublas(
|
||||
|
||||
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
|
||||
to_fp32_cuda(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
|
||||
}
|
||||
else {
|
||||
} else {
|
||||
cuda_pool_alloc<float> src0_ddq_as_f32;
|
||||
cuda_pool_alloc<float> src1_ddq_as_f32;
|
||||
|
||||
if (src0->type != GGML_TYPE_F32) {
|
||||
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
|
||||
@ -7624,7 +7627,15 @@ static void ggml_cuda_op_mul_mat_cublas(
|
||||
src0_ddq_as_f32.alloc(row_diff*ne00);
|
||||
to_fp32_cuda(src0_dd_i, src0_ddq_as_f32.get(), row_diff*ne00, stream);
|
||||
}
|
||||
if (src1->type != GGML_TYPE_F32) {
|
||||
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src1->type);
|
||||
GGML_ASSERT(to_fp32_cuda != nullptr);
|
||||
src1_ddq_as_f32.alloc(src1_ncols*ne10);
|
||||
to_fp32_cuda(src1_ddf_i, src1_ddq_as_f32.get(), src1_ncols*ne10, stream);
|
||||
}
|
||||
|
||||
const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32.get();
|
||||
const float * src1_ddf1_i = src1->type == GGML_TYPE_F32 ? (const float *) src1_ddf_i : src1_ddq_as_f32.get();
|
||||
|
||||
const float alpha = 1.0f;
|
||||
const float beta = 0.0f;
|
||||
@ -7633,9 +7644,9 @@ static void ggml_cuda_op_mul_mat_cublas(
|
||||
CUBLAS_CHECK(
|
||||
cublasSgemm(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
|
||||
row_diff, src1_ncols, ne10,
|
||||
&alpha, src0_ddf_i, ne00,
|
||||
src1_ddf_i, ne10,
|
||||
&beta, dst_dd_i, ldc));
|
||||
&alpha, src0_ddf_i, ne00,
|
||||
src1_ddf1_i, ne10,
|
||||
&beta, dst_dd_i, ldc));
|
||||
}
|
||||
|
||||
(void) dst;
|
||||
@ -8035,6 +8046,7 @@ static void ggml_cuda_op_mul_mat(
|
||||
|
||||
GGML_ASSERT(dst->backend != GGML_BACKEND_GPU_SPLIT);
|
||||
GGML_ASSERT(src1->backend != GGML_BACKEND_GPU_SPLIT);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32 || (src1->ne[2] == 1 && src1->ne[3] == 1));
|
||||
|
||||
GGML_ASSERT(ne12 >= ne02 && ne12 % ne02 == 0);
|
||||
|
||||
@ -8481,9 +8493,9 @@ static __global__ void k_compute_batched_ptrs(
|
||||
int64_t i03 = i13 / r3;
|
||||
int64_t i02 = i12 / r2;
|
||||
|
||||
ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02 + i03*nb03;
|
||||
ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12/2 + i13*nb13/2;
|
||||
ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst + i12*nbd2 + i13*nbd3;
|
||||
ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02 + i03*nb03;
|
||||
ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12 + i13*nb13;
|
||||
ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst + i12*nbd2 + i13*nbd3;
|
||||
}
|
||||
|
||||
static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
@ -8492,28 +8504,10 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
|
||||
|
||||
GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT);
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
|
||||
const int64_t ne00 = src0->ne[0]; GGML_UNUSED(ne00);
|
||||
const int64_t ne01 = src0->ne[1];
|
||||
const int64_t ne02 = src0->ne[2];
|
||||
const int64_t ne03 = src0->ne[3];
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
const int64_t nb01 = src0->nb[1];
|
||||
const int64_t nb02 = src0->nb[2]; GGML_UNUSED(nb02);
|
||||
const int64_t nb03 = src0->nb[3]; GGML_UNUSED(nb03);
|
||||
|
||||
const int64_t ne10 = src1->ne[0];
|
||||
const int64_t ne11 = src1->ne[1];
|
||||
const int64_t ne12 = src1->ne[2];
|
||||
const int64_t ne13 = src1->ne[3];
|
||||
|
||||
const int64_t nb11 = src1->nb[1];
|
||||
const int64_t nb12 = src1->nb[2]; GGML_UNUSED(nb12);
|
||||
const int64_t nb13 = src1->nb[3]; GGML_UNUSED(nb13);
|
||||
|
||||
const int64_t ne1 = ggml_nelements(src1);
|
||||
const int64_t ne = ggml_nelements(dst);
|
||||
const int64_t ne_dst = ggml_nelements(dst);
|
||||
|
||||
ggml_cuda_set_device(g_main_device);
|
||||
cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
|
||||
@ -8522,7 +8516,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
|
||||
|
||||
ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
|
||||
void * src0_ddq = src0_extra->data_device[g_main_device];
|
||||
half * src0_as_f16 = (half *) src0_ddq;
|
||||
half * src0_f16 = (half *) src0_ddq;
|
||||
|
||||
ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
|
||||
float * src1_ddf = (float *) src1_extra->data_device[g_main_device];
|
||||
@ -8531,11 +8525,15 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
|
||||
float * dst_ddf = (float *) dst_extra->data_device[g_main_device];
|
||||
|
||||
// convert src1 to fp16
|
||||
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
|
||||
GGML_ASSERT(to_fp16_cuda != nullptr);
|
||||
|
||||
cuda_pool_alloc<half> src1_as_f16(ne1);
|
||||
to_fp16_cuda(src1_ddf, src1_as_f16.get(), ne1, main_stream);
|
||||
cuda_pool_alloc<half> src1_f16_alloc;
|
||||
if (src1->type != GGML_TYPE_F16) {
|
||||
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
|
||||
const int64_t ne_src1 = ggml_nelements(src1);
|
||||
src1_f16_alloc.alloc(ne_src1);
|
||||
GGML_ASSERT(to_fp16_cuda != nullptr);
|
||||
to_fp16_cuda(src1_ddf, src1_f16_alloc.get(), ne_src1, main_stream);
|
||||
}
|
||||
half * src1_f16 = src1->type == GGML_TYPE_F16 ? (half *) src1_ddf : src1_f16_alloc.get();
|
||||
|
||||
cuda_pool_alloc<half> dst_f16;
|
||||
char * dst_t;
|
||||
@ -8557,7 +8555,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
|
||||
const void * beta = &beta_f16;
|
||||
|
||||
if (dst->op_params[0] == GGML_PREC_DEFAULT) {
|
||||
dst_t = (char *) dst_f16.alloc(ne);
|
||||
dst_t = (char *) dst_f16.alloc(ne_dst);
|
||||
|
||||
nbd2 /= sizeof(float) / sizeof(half);
|
||||
nbd3 /= sizeof(float) / sizeof(half);
|
||||
@ -8604,9 +8602,9 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
|
||||
CUBLAS_CHECK(
|
||||
cublasGemmStridedBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
|
||||
ne01, ne11, ne10,
|
||||
alpha, (const char *) src0_as_f16, CUDA_R_16F, nb01/sizeof(half), src0->nb[2]/sizeof(half), // strideA
|
||||
(const char *) src1_as_f16.get(), CUDA_R_16F, nb11/sizeof(float), src1->nb[2]/sizeof(float), // strideB
|
||||
beta, ( char *) dst_t, cu_data_type, ne01, dst->nb[2]/sizeof(float), // strideC
|
||||
alpha, (const char *) src0_f16, CUDA_R_16F, nb01/nb00, nb02/nb00, // strideA
|
||||
(const char *) src1_f16, CUDA_R_16F, nb11/nb10, nb12/nb10, // strideB
|
||||
beta, ( char *) dst_t, cu_data_type, ne01, nb2/nb0, // strideC
|
||||
ne12*ne13,
|
||||
cu_compute_type,
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||
@ -8619,12 +8617,13 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
|
||||
|
||||
dim3 block_dims(ne13, ne12);
|
||||
k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
|
||||
src0_as_f16, src1_as_f16.get(), dst_t,
|
||||
src0_f16, src1_f16, dst_t,
|
||||
ptrs_src.get(), ptrs_dst.get(),
|
||||
ne12, ne13,
|
||||
ne23,
|
||||
nb02, nb03,
|
||||
nb12, nb13,
|
||||
src1->type == GGML_TYPE_F16 ? nb12 : nb12/2,
|
||||
src1->type == GGML_TYPE_F16 ? nb13 : nb13/2,
|
||||
nbd2, nbd3,
|
||||
r2, r3);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
@ -8632,8 +8631,8 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
|
||||
CUBLAS_CHECK(
|
||||
cublasGemmBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
|
||||
ne01, ne11, ne10,
|
||||
alpha, (const void **) (ptrs_src.get() + 0*ne23), CUDA_R_16F, nb01/sizeof(half),
|
||||
(const void **) (ptrs_src.get() + 1*ne23), CUDA_R_16F, nb11/sizeof(float),
|
||||
alpha, (const void **) (ptrs_src.get() + 0*ne23), CUDA_R_16F, nb01/nb00,
|
||||
(const void **) (ptrs_src.get() + 1*ne23), CUDA_R_16F, nb11/nb10,
|
||||
beta, ( void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne01,
|
||||
ne23,
|
||||
cu_compute_type,
|
||||
@ -8643,7 +8642,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
|
||||
|
||||
if (dst->op_params[0] == GGML_PREC_DEFAULT) {
|
||||
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
|
||||
to_fp32_cuda(dst_f16.get(), dst_ddf, ne, main_stream);
|
||||
to_fp32_cuda(dst_f16.get(), dst_ddf, ne_dst, main_stream);
|
||||
}
|
||||
}
|
||||
|
||||
@ -8682,13 +8681,13 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
|
||||
} else if (!split && all_on_device && !use_tensor_cores && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
|
||||
// KQV single-batch
|
||||
ggml_cuda_mul_mat_vec_nc(src0, src1, dst);
|
||||
} else if (!split && all_on_device && use_tensor_cores && src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1)) {
|
||||
} else if (!split && all_on_device && use_tensor_cores && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1)) {
|
||||
// KQ + KQV multi-batch
|
||||
ggml_cuda_mul_mat_mat_batched_cublas(src0, src1, dst);
|
||||
} else if (src0->type == GGML_TYPE_F32) {
|
||||
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false);
|
||||
} else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) {
|
||||
if (src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0) {
|
||||
if (src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0 && src1->type == GGML_TYPE_F32) {
|
||||
#ifdef GGML_CUDA_FORCE_DMMV
|
||||
const bool use_mul_mat_vec_q = false;
|
||||
#else
|
||||
|
2
ggml.c
2
ggml.c
@ -9687,7 +9687,7 @@ static void ggml_compute_forward_mul_mat(
|
||||
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
|
||||
|
||||
assert(params->wsize >= ne11*ne12*ne13*row_size);
|
||||
assert(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
|
||||
for (int64_t i13 = 0; i13 < ne13; ++i13) {
|
||||
for (int64_t i12 = 0; i12 < ne12; ++i12) {
|
||||
|
@ -350,13 +350,18 @@ struct test_case {
|
||||
fflush(stdout);
|
||||
|
||||
// check if backends support op
|
||||
bool supported = true;
|
||||
for (ggml_backend_t backend : {backend1, backend2}) {
|
||||
if (!ggml_backend_supports_op(backend, out)) {
|
||||
printf("not supported\n");
|
||||
ggml_free(ctx);
|
||||
return true;
|
||||
printf("not supported [%s] ", ggml_backend_name(backend));
|
||||
supported = false;
|
||||
}
|
||||
}
|
||||
if (!supported) {
|
||||
printf("\n");
|
||||
ggml_free(ctx);
|
||||
return true;
|
||||
}
|
||||
|
||||
// post-graph sentinel
|
||||
add_sentinel(ctx);
|
||||
@ -1505,8 +1510,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
||||
}
|
||||
|
||||
for (ggml_type type_a : all_types) {
|
||||
for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) {
|
||||
// FIXME: CPU crashes on f16xf16
|
||||
for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, { 1, 1}, {1, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 1}, {1, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 1}, {2, 1}));
|
||||
|
Loading…
Reference in New Issue
Block a user