mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-11-11 21:39:52 +00:00
ggml-cuda : perform cublas fp16 matrix multiplication as fp16 (#3370)
* ggml-cuda : perform cublas fp16 matrix multiplication as fp16 * try to fix rocm build * restrict fp16 mat mul to volta and up
This commit is contained in:
parent
e519621010
commit
da0400344b
104
ggml-cuda.cu
104
ggml-cuda.cu
@ -14,9 +14,11 @@
|
||||
// for rocblas_initialize()
|
||||
#include "rocblas/rocblas.h"
|
||||
#endif // __HIP_PLATFORM_AMD__
|
||||
#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F
|
||||
#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
|
||||
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
|
||||
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
|
||||
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
|
||||
#define CUBLAS_OP_N HIPBLAS_OP_N
|
||||
#define CUBLAS_OP_T HIPBLAS_OP_T
|
||||
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
|
||||
@ -235,8 +237,12 @@ static __device__ __forceinline__ int get_int_from_uint8_aligned(const uint8_t *
|
||||
return *((int *) (x8 + sizeof(int) * i32)); // assume at least 4 byte alignment
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
using to_t_cuda_t = void (*)(const void * __restrict__ x, T * __restrict__ y, int k, cudaStream_t stream);
|
||||
typedef to_t_cuda_t<float> to_fp32_cuda_t;
|
||||
typedef to_t_cuda_t<half> to_fp16_cuda_t;
|
||||
|
||||
typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v);
|
||||
typedef void (*to_fp32_cuda_t)(const void * __restrict__ x, float * __restrict__ y, int k, cudaStream_t stream);
|
||||
typedef void (*dot_kernel_k_t)(const void * __restrict__ vx, const int ib, const int iqs, const float * __restrict__ y, float & v);
|
||||
typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
|
||||
typedef void (*ggml_cuda_func_t)(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||
@ -1515,6 +1521,14 @@ static __device__ void convert_f16(const void * vx, const int ib, const int iqs,
|
||||
v.y = x[ib + iqs + 1];
|
||||
}
|
||||
|
||||
static __device__ void convert_f32(const void * vx, const int ib, const int iqs, dfloat2 & v){
|
||||
const float * x = (const float *) vx;
|
||||
|
||||
// automatic half -> float type cast if dfloat == float
|
||||
v.x = x[ib + iqs + 0];
|
||||
v.y = x[ib + iqs + 1];
|
||||
}
|
||||
|
||||
static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded) {
|
||||
const int ix = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
@ -1554,8 +1568,8 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest
|
||||
reinterpret_cast<half&>(y[ib].ds.y) = sum;
|
||||
}
|
||||
|
||||
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
|
||||
static __global__ void dequantize_block(const void * __restrict__ vx, float * __restrict__ y, const int k) {
|
||||
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
||||
static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int k) {
|
||||
const int i = blockDim.x*blockIdx.x + 2*threadIdx.x;
|
||||
|
||||
if (i >= k) {
|
||||
@ -4826,6 +4840,11 @@ static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, c
|
||||
dequantize_block<1, 1, convert_f16><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
|
||||
}
|
||||
|
||||
static void convert_fp32_to_fp16_cuda(const void * vx, half * y, const int k, cudaStream_t stream) {
|
||||
const int num_blocks = (k + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
|
||||
dequantize_block<1, 1, convert_f32><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
|
||||
}
|
||||
|
||||
static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
@ -4835,6 +4854,15 @@ static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, floa
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
|
||||
switch (type) {
|
||||
case GGML_TYPE_F32:
|
||||
return convert_fp32_to_fp16_cuda;
|
||||
default:
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
|
||||
switch (type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
@ -6016,8 +6044,6 @@ inline void ggml_cuda_op_mul_mat_cublas(
|
||||
GGML_ASSERT(src1_ddf_i != nullptr);
|
||||
GGML_ASSERT(dst_dd_i != nullptr);
|
||||
|
||||
const float alpha = 1.0f;
|
||||
const float beta = 0.0f;
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
|
||||
@ -6026,16 +6052,6 @@ inline void ggml_cuda_op_mul_mat_cublas(
|
||||
const int64_t ne0 = dst->ne[0];
|
||||
const int64_t row_diff = row_high - row_low;
|
||||
|
||||
float * src0_ddq_as_f32;
|
||||
size_t src0_as = 0;
|
||||
|
||||
if (src0->type != GGML_TYPE_F32) {
|
||||
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
|
||||
src0_ddq_as_f32 = (float *) ggml_cuda_pool_malloc(row_diff*ne00 * sizeof(float), &src0_as); // NOLINT
|
||||
to_fp32_cuda(src0_dd_i, src0_ddq_as_f32, row_diff*ne00, stream);
|
||||
}
|
||||
const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32;
|
||||
|
||||
int id;
|
||||
CUDA_CHECK(cudaGetDevice(&id));
|
||||
|
||||
@ -6043,6 +6059,61 @@ inline void ggml_cuda_op_mul_mat_cublas(
|
||||
// ldc == nrows of the matrix that cuBLAS writes into
|
||||
int ldc = dst->backend == GGML_BACKEND_GPU && id == g_main_device ? ne0 : row_diff;
|
||||
|
||||
const int compute_capability = g_compute_capabilities[id];
|
||||
|
||||
if (compute_capability >= CC_TURING && src0->type == GGML_TYPE_F16 && ggml_is_contiguous(src0) && ldc == row_diff) {
|
||||
// convert src1 to fp16, multiply as fp16, convert dst to fp32
|
||||
half * src1_as_f16 = nullptr;
|
||||
size_t src1_as = 0;
|
||||
if (src1->type != GGML_TYPE_F16) {
|
||||
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
|
||||
GGML_ASSERT(to_fp16_cuda != nullptr);
|
||||
size_t ne = src1_ncols*ne10;
|
||||
src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src1_as);
|
||||
to_fp16_cuda(src1_ddf_i, src1_as_f16, ne, stream);
|
||||
}
|
||||
const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddq_i : src1_as_f16;
|
||||
|
||||
size_t dst_as = 0;
|
||||
half * dst_f16 = (half *) ggml_cuda_pool_malloc(row_diff*src1_ncols * sizeof(half), &dst_as);
|
||||
|
||||
const half alpha_f16 = 1.0f;
|
||||
const half beta_f16 = 0.0f;
|
||||
|
||||
CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], stream));
|
||||
CUBLAS_CHECK(
|
||||
cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
|
||||
row_diff, src1_ncols, ne10,
|
||||
&alpha_f16, src0_dd_i, CUDA_R_16F, ne00,
|
||||
src1_ptr, CUDA_R_16F, ne10,
|
||||
&beta_f16, dst_f16, CUDA_R_16F, ldc,
|
||||
CUBLAS_COMPUTE_16F,
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||
|
||||
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
|
||||
to_fp32_cuda(dst_f16, dst_dd_i, row_diff*src1_ncols, stream);
|
||||
|
||||
ggml_cuda_pool_free(dst_f16, dst_as);
|
||||
|
||||
if (src1_as != 0) {
|
||||
ggml_cuda_pool_free(src1_as_f16, src1_as);
|
||||
}
|
||||
}
|
||||
else {
|
||||
float * src0_ddq_as_f32 = nullptr;
|
||||
size_t src0_as = 0;
|
||||
|
||||
if (src0->type != GGML_TYPE_F32) {
|
||||
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
|
||||
GGML_ASSERT(to_fp32_cuda != nullptr);
|
||||
src0_ddq_as_f32 = (float *) ggml_cuda_pool_malloc(row_diff*ne00 * sizeof(float), &src0_as); // NOLINT
|
||||
to_fp32_cuda(src0_dd_i, src0_ddq_as_f32, row_diff*ne00, stream);
|
||||
}
|
||||
const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32;
|
||||
|
||||
const float alpha = 1.0f;
|
||||
const float beta = 0.0f;
|
||||
|
||||
CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], stream));
|
||||
CUBLAS_CHECK(
|
||||
cublasSgemm(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
|
||||
@ -6051,9 +6122,10 @@ inline void ggml_cuda_op_mul_mat_cublas(
|
||||
src1_ddf_i, ne10,
|
||||
&beta, dst_dd_i, ldc));
|
||||
|
||||
if (src0_as > 0) {
|
||||
if (src0_as != 0) {
|
||||
ggml_cuda_pool_free(src0_ddq_as_f32, src0_as);
|
||||
}
|
||||
}
|
||||
|
||||
(void) dst;
|
||||
(void) src1_ddq_i;
|
||||
|
Loading…
Reference in New Issue
Block a user