cuda : fix dequantize kernel names (#4938)

This commit is contained in:
Georgi Gerganov 2024-01-15 13:27:00 +02:00
parent 2faaef3979
commit ddb008d845
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -6309,14 +6309,14 @@ static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int k, cu
} }
template<typename dst_t> template<typename dst_t>
static void dequantize_q4_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { static void dequantize_row_q4_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
const int nb32 = k / 32; const int nb32 = k / 32;
const int nb = (k + 255) / 256; const int nb = (k + 255) / 256;
dequantize_block_q4_0<<<nb, 32, 0, stream>>>(vx, y, nb32); dequantize_block_q4_0<<<nb, 32, 0, stream>>>(vx, y, nb32);
} }
template<typename dst_t> template<typename dst_t>
static void dequantize_q4_1_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { static void dequantize_row_q4_1_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
const int nb32 = k / 32; const int nb32 = k / 32;
const int nb = (k + 255) / 256; const int nb = (k + 255) / 256;
dequantize_block_q4_1<<<nb, 32, 0, stream>>>(vx, y, nb32); dequantize_block_q4_1<<<nb, 32, 0, stream>>>(vx, y, nb32);
@ -6370,9 +6370,9 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
int id; int id;
switch (type) { switch (type) {
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
return dequantize_q4_0_cuda; return dequantize_row_q4_0_cuda;
case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_1:
return dequantize_q4_1_cuda; return dequantize_row_q4_1_cuda;
case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_0:
return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>; return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
case GGML_TYPE_Q5_1: case GGML_TYPE_Q5_1:
@ -6407,9 +6407,9 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
switch (type) { switch (type) {
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
return dequantize_q4_0_cuda; return dequantize_row_q4_0_cuda;
case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_1:
return dequantize_q4_1_cuda; return dequantize_row_q4_1_cuda;
case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_0:
return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>; return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
case GGML_TYPE_Q5_1: case GGML_TYPE_Q5_1: