mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-26 03:14:35 +00:00
cuda : fix dequantize kernel names (#4938)
This commit is contained in:
parent
2faaef3979
commit
ddb008d845
12
ggml-cuda.cu
12
ggml-cuda.cu
@ -6309,14 +6309,14 @@ static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int k, cu
|
|||||||
}
|
}
|
||||||
|
|
||||||
template<typename dst_t>
|
template<typename dst_t>
|
||||||
static void dequantize_q4_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
static void dequantize_row_q4_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
||||||
const int nb32 = k / 32;
|
const int nb32 = k / 32;
|
||||||
const int nb = (k + 255) / 256;
|
const int nb = (k + 255) / 256;
|
||||||
dequantize_block_q4_0<<<nb, 32, 0, stream>>>(vx, y, nb32);
|
dequantize_block_q4_0<<<nb, 32, 0, stream>>>(vx, y, nb32);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename dst_t>
|
template<typename dst_t>
|
||||||
static void dequantize_q4_1_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
static void dequantize_row_q4_1_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
||||||
const int nb32 = k / 32;
|
const int nb32 = k / 32;
|
||||||
const int nb = (k + 255) / 256;
|
const int nb = (k + 255) / 256;
|
||||||
dequantize_block_q4_1<<<nb, 32, 0, stream>>>(vx, y, nb32);
|
dequantize_block_q4_1<<<nb, 32, 0, stream>>>(vx, y, nb32);
|
||||||
@ -6370,9 +6370,9 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
|
|||||||
int id;
|
int id;
|
||||||
switch (type) {
|
switch (type) {
|
||||||
case GGML_TYPE_Q4_0:
|
case GGML_TYPE_Q4_0:
|
||||||
return dequantize_q4_0_cuda;
|
return dequantize_row_q4_0_cuda;
|
||||||
case GGML_TYPE_Q4_1:
|
case GGML_TYPE_Q4_1:
|
||||||
return dequantize_q4_1_cuda;
|
return dequantize_row_q4_1_cuda;
|
||||||
case GGML_TYPE_Q5_0:
|
case GGML_TYPE_Q5_0:
|
||||||
return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
|
return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
|
||||||
case GGML_TYPE_Q5_1:
|
case GGML_TYPE_Q5_1:
|
||||||
@ -6407,9 +6407,9 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
|
|||||||
static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
|
static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
|
||||||
switch (type) {
|
switch (type) {
|
||||||
case GGML_TYPE_Q4_0:
|
case GGML_TYPE_Q4_0:
|
||||||
return dequantize_q4_0_cuda;
|
return dequantize_row_q4_0_cuda;
|
||||||
case GGML_TYPE_Q4_1:
|
case GGML_TYPE_Q4_1:
|
||||||
return dequantize_q4_1_cuda;
|
return dequantize_row_q4_1_cuda;
|
||||||
case GGML_TYPE_Q5_0:
|
case GGML_TYPE_Q5_0:
|
||||||
return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
|
return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
|
||||||
case GGML_TYPE_Q5_1:
|
case GGML_TYPE_Q5_1:
|
||||||
|
Loading…
Reference in New Issue
Block a user