cuda : prints wip

This commit is contained in:
Georgi Gerganov 2023-10-25 10:26:58 +03:00
parent 6961c4bd0b
commit 59d1232ea7
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -6304,6 +6304,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
const half alpha_f16 = 1.0f;
const half beta_f16 = 0.0f;
//printf("F16: row_diff: %ld, src1_ncols: %ld, ne10: %ld, ne00: %ld, ldc: %d\n", row_diff, src1_ncols, ne10, ne00, ldc);
CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], stream));
CUBLAS_CHECK(
cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
@ -7250,6 +7251,12 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
if (g_mul_mat_q && ggml_is_quantized(src0->type) && min_compute_capability >= MIN_CC_DP4A) {
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_q, true);
} else {
//printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]);
//printf(" %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);
//printf("src1: %8d %8d %8d %8d\n", src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3]);
//printf(" %8d %8d %8d %8d\n", src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3]);
//printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
//printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false);
}
}