mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-26 11:24:35 +00:00
cann: support q4_0 model (#8822)
This commit is contained in:
parent
0d6fb52be0
commit
c02b0a8a4d
@ -627,7 +627,6 @@ GGML_CALL static void* ggml_backend_cann_buffer_get_base(
|
|||||||
GGML_CALL static void ggml_backend_cann_transform_q4_0(ggml_tensor* tensor,
|
GGML_CALL static void ggml_backend_cann_transform_q4_0(ggml_tensor* tensor,
|
||||||
const void* src,
|
const void* src,
|
||||||
void* dst) {
|
void* dst) {
|
||||||
GGML_ASSERT(tensor->op == GGML_OP_NONE);
|
|
||||||
|
|
||||||
int64_t n_elems = ggml_nelements(tensor);
|
int64_t n_elems = ggml_nelements(tensor);
|
||||||
int64_t groups = n_elems / QK4_0;
|
int64_t groups = n_elems / QK4_0;
|
||||||
@ -679,7 +678,6 @@ GGML_CALL static void ggml_backend_cann_transform_q4_0(ggml_tensor* tensor,
|
|||||||
*/
|
*/
|
||||||
GGML_CALL static void ggml_backend_cann_transform_back_q4_0(
|
GGML_CALL static void ggml_backend_cann_transform_back_q4_0(
|
||||||
const ggml_tensor* tensor, void* src, void* dst) {
|
const ggml_tensor* tensor, void* src, void* dst) {
|
||||||
GGML_ASSERT(tensor->op == GGML_OP_NONE);
|
|
||||||
|
|
||||||
int64_t n_elems = ggml_nelements(tensor);
|
int64_t n_elems = ggml_nelements(tensor);
|
||||||
int64_t groups = n_elems / QK4_0;
|
int64_t groups = n_elems / QK4_0;
|
||||||
@ -1666,10 +1664,17 @@ GGML_CALL static bool ggml_backend_cann_supports_op(ggml_backend_t backend,
|
|||||||
}
|
}
|
||||||
case GGML_OP_MUL_MAT: {
|
case GGML_OP_MUL_MAT: {
|
||||||
switch (op->src[0]->type) {
|
switch (op->src[0]->type) {
|
||||||
// case GGML_TYPE_Q4_0:
|
|
||||||
case GGML_TYPE_F16:
|
case GGML_TYPE_F16:
|
||||||
case GGML_TYPE_F32:
|
case GGML_TYPE_F32:
|
||||||
case GGML_TYPE_Q8_0:
|
case GGML_TYPE_Q8_0:
|
||||||
|
// TODO: fix me
|
||||||
|
// Current groupsize should not be greater than k-1 in
|
||||||
|
// aclnnWeightQuantBatchMatmulV2GetWorkspaceSize().
|
||||||
|
if (op->src[0]->ne[0]-1 > QK8_0) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
case GGML_TYPE_Q4_0:
|
||||||
return true;
|
return true;
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
@ -1694,6 +1699,7 @@ GGML_CALL static bool ggml_backend_cann_supports_op(ggml_backend_t backend,
|
|||||||
case GGML_TYPE_F32:
|
case GGML_TYPE_F32:
|
||||||
case GGML_TYPE_F16:
|
case GGML_TYPE_F16:
|
||||||
case GGML_TYPE_Q8_0:
|
case GGML_TYPE_Q8_0:
|
||||||
|
case GGML_TYPE_Q4_0:
|
||||||
return true;
|
return true;
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
|
@ -37,6 +37,10 @@ aclDataType ggml_cann_type_mapping(ggml_type type) {
|
|||||||
return ACL_INT16;
|
return ACL_INT16;
|
||||||
case GGML_TYPE_I32:
|
case GGML_TYPE_I32:
|
||||||
return ACL_INT32;
|
return ACL_INT32;
|
||||||
|
case GGML_TYPE_Q4_0:
|
||||||
|
return ACL_INT4;
|
||||||
|
case GGML_TYPE_Q8_0:
|
||||||
|
return ACL_INT8;
|
||||||
default:
|
default:
|
||||||
return ACL_DT_UNDEFINED;
|
return ACL_DT_UNDEFINED;
|
||||||
}
|
}
|
||||||
@ -89,33 +93,6 @@ bool ggml_cann_need_bcast(const ggml_tensor* t0, const ggml_tensor* t1) {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
aclTensor* ggml_cann_create_tensor(void* data_ptr, aclDataType dtype,
|
|
||||||
size_t type_size, int64_t* ne, size_t* nb,
|
|
||||||
int64_t dims, aclFormat format,
|
|
||||||
size_t offset) {
|
|
||||||
int64_t tmp_ne[GGML_MAX_DIMS * 2];
|
|
||||||
int64_t tmp_stride[GGML_MAX_DIMS * 2];
|
|
||||||
|
|
||||||
memcpy(tmp_ne, ne, dims * sizeof(int64_t));
|
|
||||||
for (int i = 0; i < dims; i++) {
|
|
||||||
tmp_stride[i] = nb[i] / type_size;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::reverse(tmp_ne, tmp_ne + dims);
|
|
||||||
std::reverse(tmp_stride, tmp_stride + dims);
|
|
||||||
|
|
||||||
int64_t acl_storage_len = 0;
|
|
||||||
for (int i = 0; i < dims; i++) {
|
|
||||||
acl_storage_len += (ne[i] - 1) * nb[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
aclTensor* acl_tensor =
|
|
||||||
aclCreateTensor(tmp_ne, dims, dtype, tmp_stride, offset / type_size,
|
|
||||||
format, &acl_storage_len, 1, data_ptr);
|
|
||||||
|
|
||||||
return acl_tensor;
|
|
||||||
}
|
|
||||||
|
|
||||||
int64_t ggml_cann_get_bcast_shape(const ggml_tensor* src0,
|
int64_t ggml_cann_get_bcast_shape(const ggml_tensor* src0,
|
||||||
const ggml_tensor* src1,
|
const ggml_tensor* src1,
|
||||||
int64_t* bcast_src0_ne,
|
int64_t* bcast_src0_ne,
|
||||||
|
@ -23,6 +23,9 @@
|
|||||||
#ifndef CANN_ACL_TENSOR_H
|
#ifndef CANN_ACL_TENSOR_H
|
||||||
#define CANN_ACL_TENSOR_H
|
#define CANN_ACL_TENSOR_H
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cstring>
|
||||||
|
|
||||||
#include <aclnn/aclnn_base.h>
|
#include <aclnn/aclnn_base.h>
|
||||||
#include "common.h"
|
#include "common.h"
|
||||||
|
|
||||||
@ -65,7 +68,8 @@ aclTensor* ggml_cann_create_tensor(const ggml_tensor* tensor, int64_t* ne = null
|
|||||||
size_t offset = 0);
|
size_t offset = 0);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Creates an ACL tensor from provided parameters.
|
* @brief Template for creating an ACL tensor from provided parameters. typename TYPE
|
||||||
|
* should be size_t or float.
|
||||||
*
|
*
|
||||||
* @details This function creates an ACL tensor using the provided data pointer,
|
* @details This function creates an ACL tensor using the provided data pointer,
|
||||||
* data type, dimensions, strides, format, offset, and additional parameters.
|
* data type, dimensions, strides, format, offset, and additional parameters.
|
||||||
@ -83,10 +87,34 @@ aclTensor* ggml_cann_create_tensor(const ggml_tensor* tensor, int64_t* ne = null
|
|||||||
* @param offset Offset in bytes for the ACL tensor data. Defaults to 0.
|
* @param offset Offset in bytes for the ACL tensor data. Defaults to 0.
|
||||||
* @return Pointer to the created ACL tensor.
|
* @return Pointer to the created ACL tensor.
|
||||||
*/
|
*/
|
||||||
|
template<typename TYPE>
|
||||||
aclTensor* ggml_cann_create_tensor(void* data_ptr, aclDataType dtype,
|
aclTensor* ggml_cann_create_tensor(void* data_ptr, aclDataType dtype,
|
||||||
size_t type_size, int64_t* ne, size_t* nb,
|
TYPE type_size, int64_t* ne, TYPE* nb,
|
||||||
int64_t dims, aclFormat format = ACL_FORMAT_ND,
|
int64_t dims,
|
||||||
size_t offset = 0);
|
aclFormat format = ACL_FORMAT_ND,
|
||||||
|
size_t offset = 0) {
|
||||||
|
int64_t tmp_ne[GGML_MAX_DIMS * 2];
|
||||||
|
int64_t tmp_stride[GGML_MAX_DIMS * 2];
|
||||||
|
|
||||||
|
memcpy(tmp_ne, ne, dims * sizeof(int64_t));
|
||||||
|
for (int i = 0; i < dims; i++) {
|
||||||
|
tmp_stride[i] = nb[i] / type_size;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::reverse(tmp_ne, tmp_ne + dims);
|
||||||
|
std::reverse(tmp_stride, tmp_stride + dims);
|
||||||
|
|
||||||
|
int64_t acl_storage_len = 0;
|
||||||
|
for (int i = 0; i < dims; i++) {
|
||||||
|
acl_storage_len += (ne[i] - 1) * nb[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
aclTensor* acl_tensor =
|
||||||
|
aclCreateTensor(tmp_ne, dims, dtype, tmp_stride, offset / type_size,
|
||||||
|
format, &acl_storage_len, 1, data_ptr);
|
||||||
|
|
||||||
|
return acl_tensor;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Checks if tensors require broadcasting based on their shapes.
|
* @brief Checks if tensors require broadcasting based on their shapes.
|
||||||
|
@ -910,6 +910,13 @@ void ggml_cann_dup(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
|||||||
((ggml_tensor*)dst->extra)->ne);
|
((ggml_tensor*)dst->extra)->ne);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
if (dst->type == GGML_TYPE_Q4_0) {
|
||||||
|
aclrtlaunch_ascendc_quantize_f16_to_q4_0(
|
||||||
|
24, ctx.stream(), src->data, dst->data,
|
||||||
|
((ggml_tensor*)src->extra)->ne, ((ggml_tensor*)src->extra)->nb,
|
||||||
|
((ggml_tensor*)dst->extra)->ne);
|
||||||
|
return;
|
||||||
|
}
|
||||||
if (dst->type == GGML_TYPE_F16) {
|
if (dst->type == GGML_TYPE_F16) {
|
||||||
if (ggml_are_same_shape(src, dst)) {
|
if (ggml_are_same_shape(src, dst)) {
|
||||||
cann_copy(ctx, acl_src, acl_dst);
|
cann_copy(ctx, acl_src, acl_dst);
|
||||||
@ -971,6 +978,13 @@ void ggml_cann_dup(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
|||||||
((ggml_tensor*)dst->extra)->ne);
|
((ggml_tensor*)dst->extra)->ne);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
if (dst->type == GGML_TYPE_Q4_0) {
|
||||||
|
aclrtlaunch_ascendc_quantize_f32_to_q4_0(
|
||||||
|
24, ctx.stream(), src->data, dst->data,
|
||||||
|
((ggml_tensor*)src->extra)->ne, ((ggml_tensor*)src->extra)->nb,
|
||||||
|
((ggml_tensor*)dst->extra)->ne);
|
||||||
|
return;
|
||||||
|
}
|
||||||
if (dst->type == GGML_TYPE_F32) {
|
if (dst->type == GGML_TYPE_F32) {
|
||||||
if (ggml_are_same_shape(src, dst)) {
|
if (ggml_are_same_shape(src, dst)) {
|
||||||
cann_copy(ctx, acl_src, acl_dst);
|
cann_copy(ctx, acl_src, acl_dst);
|
||||||
@ -2463,21 +2477,33 @@ static void ggml_cann_mat_mul_fp(ggml_backend_cann_context& ctx,
|
|||||||
* @param dst The destination tensor where the result of the matrix
|
* @param dst The destination tensor where the result of the matrix
|
||||||
* multiplication will be stored.
|
* multiplication will be stored.
|
||||||
*/
|
*/
|
||||||
static void ggml_cann_mul_mat_q8_0(ggml_backend_cann_context& ctx,
|
static void ggml_cann_mul_mat_quant(ggml_backend_cann_context& ctx,
|
||||||
ggml_tensor* dst) {
|
ggml_tensor* dst,
|
||||||
|
const enum ggml_type type) {
|
||||||
ggml_tensor* src0 = dst->src[0]; // weight
|
ggml_tensor* src0 = dst->src[0]; // weight
|
||||||
ggml_tensor* src1 = dst->src[1]; // input
|
ggml_tensor* src1 = dst->src[1]; // input
|
||||||
|
|
||||||
// The shape of the weight is NCHW. Matrix multiplication uses HW dims. HC
|
// The shape of the weight is NCHW. Matrix multiplication uses HW dims. HC
|
||||||
// is regarded as batch. weight need transpose.
|
// is regarded as batch. weight need transpose.
|
||||||
int64_t weight_ne[] = {src0->ne[1], src0->ne[0]};
|
int64_t weight_ne[] = {src0->ne[1], src0->ne[0]};
|
||||||
size_t weight_elem_size = sizeof(uint8_t);
|
float weight_elem_size;
|
||||||
size_t weight_nb[] = {weight_elem_size * src0->ne[0], weight_elem_size};
|
if (type == GGML_TYPE_Q4_0) {
|
||||||
|
weight_elem_size = float(sizeof(uint8_t)) / 2;
|
||||||
|
}
|
||||||
|
else if (type == GGML_TYPE_Q8_0) {
|
||||||
|
weight_elem_size = float(sizeof(uint8_t));
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
GGML_ABORT("Only support Q4_0 and Q8_0 MUL_MAT");
|
||||||
|
}
|
||||||
|
float weight_nb[] = {weight_elem_size * src0->ne[0], weight_elem_size};
|
||||||
|
|
||||||
// size of one matrix is element_size * height * width.
|
// size of one matrix is element_size * height * width.
|
||||||
size_t weight_stride = weight_elem_size * src0->ne[0] * src0->ne[1];
|
size_t weight_stride = weight_elem_size * src0->ne[0] * src0->ne[1];
|
||||||
size_t weight_size = weight_stride * src0->ne[2] * src0->ne[3];
|
size_t weight_size = weight_stride * src0->ne[2] * src0->ne[3];
|
||||||
|
|
||||||
// scale stored at the end of weight. Also need transpose.
|
// scale stored at the end of weight. Also need transpose.
|
||||||
|
GGML_ASSERT(QK4_0 == QK8_0);
|
||||||
int64_t scale_ne[] = {src0->ne[1], src0->ne[0] / QK8_0};
|
int64_t scale_ne[] = {src0->ne[1], src0->ne[0] / QK8_0};
|
||||||
size_t scale_elem_size = sizeof(uint16_t);
|
size_t scale_elem_size = sizeof(uint16_t);
|
||||||
size_t scale_nb[] = {src0->ne[0] / QK8_0 * scale_elem_size,
|
size_t scale_nb[] = {src0->ne[0] / QK8_0 * scale_elem_size,
|
||||||
@ -2541,8 +2567,9 @@ static void ggml_cann_mul_mat_q8_0(ggml_backend_cann_context& ctx,
|
|||||||
(char*)input_buffer + batch1 * input_stride, ACL_FLOAT16,
|
(char*)input_buffer + batch1 * input_stride, ACL_FLOAT16,
|
||||||
input_elem_size, input_ne, input_nb, 2);
|
input_elem_size, input_ne, input_nb, 2);
|
||||||
aclTensor* acl_weight_tensor = ggml_cann_create_tensor(
|
aclTensor* acl_weight_tensor = ggml_cann_create_tensor(
|
||||||
(char*)src0->data + batch0 * weight_stride, ACL_INT8,
|
(char*)src0->data + batch0 * weight_stride,
|
||||||
weight_elem_size, weight_ne, weight_nb, 2);
|
ggml_cann_type_mapping(type), weight_elem_size, weight_ne,
|
||||||
|
weight_nb, 2);
|
||||||
aclTensor* acl_scale_tensor = ggml_cann_create_tensor(
|
aclTensor* acl_scale_tensor = ggml_cann_create_tensor(
|
||||||
scale_offset + batch0 * scale_stride, ACL_FLOAT16,
|
scale_offset + batch0 * scale_stride, ACL_FLOAT16,
|
||||||
scale_elem_size, scale_ne, scale_nb, 2);
|
scale_elem_size, scale_ne, scale_nb, 2);
|
||||||
@ -2596,11 +2623,9 @@ void ggml_cann_mul_mat(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
|||||||
case GGML_TYPE_F16:
|
case GGML_TYPE_F16:
|
||||||
ggml_cann_mat_mul_fp(ctx, dst);
|
ggml_cann_mat_mul_fp(ctx, dst);
|
||||||
break;
|
break;
|
||||||
// case GGML_TYPE_Q4_0:
|
case GGML_TYPE_Q4_0:
|
||||||
// ggml_cann_mul_mat_q4_0(ctx, dst);
|
|
||||||
// break;
|
|
||||||
case GGML_TYPE_Q8_0:
|
case GGML_TYPE_Q8_0:
|
||||||
ggml_cann_mul_mat_q8_0(ctx, dst);
|
ggml_cann_mul_mat_quant(ctx, dst, type);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
|
@ -9,6 +9,7 @@ file(GLOB SRC_FILES
|
|||||||
get_row_q8_0.cpp
|
get_row_q8_0.cpp
|
||||||
quantize_f32_q8_0.cpp
|
quantize_f32_q8_0.cpp
|
||||||
quantize_f16_q8_0.cpp
|
quantize_f16_q8_0.cpp
|
||||||
|
quantize_float_to_q4_0.cpp
|
||||||
dup.cpp
|
dup.cpp
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -8,6 +8,8 @@
|
|||||||
|
|
||||||
#include "aclrtlaunch_ascendc_quantize_f32_q8_0.h"
|
#include "aclrtlaunch_ascendc_quantize_f32_q8_0.h"
|
||||||
#include "aclrtlaunch_ascendc_quantize_f16_q8_0.h"
|
#include "aclrtlaunch_ascendc_quantize_f16_q8_0.h"
|
||||||
|
#include "aclrtlaunch_ascendc_quantize_f16_to_q4_0.h"
|
||||||
|
#include "aclrtlaunch_ascendc_quantize_f32_to_q4_0.h"
|
||||||
|
|
||||||
#include "aclrtlaunch_ascendc_dup_by_rows_fp16.h"
|
#include "aclrtlaunch_ascendc_dup_by_rows_fp16.h"
|
||||||
#include "aclrtlaunch_ascendc_dup_by_rows_fp32.h"
|
#include "aclrtlaunch_ascendc_dup_by_rows_fp32.h"
|
||||||
|
273
ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp
Normal file
273
ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp
Normal file
@ -0,0 +1,273 @@
|
|||||||
|
#include "kernel_operator.h"
|
||||||
|
|
||||||
|
using namespace AscendC;
|
||||||
|
|
||||||
|
#define BUFFER_NUM 2
|
||||||
|
#define Group_Size 32
|
||||||
|
|
||||||
|
template <typename SRC_T>
|
||||||
|
class QUANTIZE_FLOAT_TO_Q4_0 {
|
||||||
|
public:
|
||||||
|
__aicore__ inline QUANTIZE_FLOAT_TO_Q4_0() {}
|
||||||
|
__aicore__ inline void init(GM_ADDR input, GM_ADDR output,
|
||||||
|
int64_t *input_ne_ub, size_t *input_nb_ub,
|
||||||
|
int64_t *output_ne_ub) {
|
||||||
|
int64_t op_block_num = GetBlockNum();
|
||||||
|
int64_t op_block_idx = GetBlockIdx();
|
||||||
|
|
||||||
|
// input stride of data elements
|
||||||
|
for (int i = 0; i < 4; i++) {
|
||||||
|
input_ne[i] = input_ne_ub[i];
|
||||||
|
input_stride[i] = input_nb_ub[i] / input_nb_ub[0];
|
||||||
|
output_ne[i] = output_ne_ub[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
// output stride of data elements
|
||||||
|
output_stride[0] = 1;
|
||||||
|
for (int i = 1; i < 4; i++) {
|
||||||
|
output_stride[i] = output_stride[i - 1] * output_ne[i - 1];
|
||||||
|
}
|
||||||
|
|
||||||
|
// scale saved one by one after data:. [group1_scale, group2_scale, ...]
|
||||||
|
scale_ne = input_ne;
|
||||||
|
scale_stride[0] = 1;
|
||||||
|
scale_stride[1] = input_ne[0] / Group_Size;
|
||||||
|
for (int i = 2; i < 4; i++) {
|
||||||
|
scale_stride[i] = scale_stride[i - 1] * scale_ne[i - 1];
|
||||||
|
}
|
||||||
|
|
||||||
|
// split input tensor by rows.
|
||||||
|
uint64_t nr = input_ne[1] * input_ne[2] * input_ne[3];
|
||||||
|
dr = nr / op_block_num;
|
||||||
|
|
||||||
|
uint64_t tails = nr % op_block_num;
|
||||||
|
if (op_block_idx < tails) {
|
||||||
|
dr += 1;
|
||||||
|
ir = dr * op_block_idx;
|
||||||
|
} else {
|
||||||
|
ir = dr * op_block_idx + tails;
|
||||||
|
}
|
||||||
|
|
||||||
|
group_size_in_row = scale_stride[1];
|
||||||
|
int64_t scale_offset = output_ne[0] * output_ne[1] * output_ne[2] *
|
||||||
|
output_ne[3] * sizeof(uint8_t) / 2;
|
||||||
|
|
||||||
|
input_gm.SetGlobalBuffer((__gm__ SRC_T *)input);
|
||||||
|
output_gm.SetGlobalBuffer((__gm__ int8_t *)output);
|
||||||
|
scale_gm.SetGlobalBuffer((__gm__ half *)(output + scale_offset + ir *
|
||||||
|
group_size_in_row *
|
||||||
|
sizeof(half)));
|
||||||
|
|
||||||
|
pipe.InitBuffer(input_queue, BUFFER_NUM, Group_Size * sizeof(SRC_T));
|
||||||
|
pipe.InitBuffer(output_queue, BUFFER_NUM,
|
||||||
|
Group_Size * sizeof(int8_t) / 2);
|
||||||
|
pipe.InitBuffer(cast_queue , BUFFER_NUM, Group_Size * sizeof(float));
|
||||||
|
pipe.InitBuffer(work_queue, BUFFER_NUM, Group_Size*sizeof(float));
|
||||||
|
pipe.InitBuffer(max_queue, BUFFER_NUM, Group_Size*sizeof(float));
|
||||||
|
pipe.InitBuffer(min_queue, BUFFER_NUM, Group_Size*sizeof(float));
|
||||||
|
pipe.InitBuffer(scale_queue, BUFFER_NUM, 16*sizeof(half));
|
||||||
|
pipe.InitBuffer(int8_queue, BUFFER_NUM, Group_Size * sizeof(int8_t));
|
||||||
|
pipe.InitBuffer(half_queue, BUFFER_NUM, Group_Size * sizeof(half));
|
||||||
|
}
|
||||||
|
|
||||||
|
__aicore__ inline void copy_in(uint32_t offset) {
|
||||||
|
LocalTensor<SRC_T> input_local = input_queue.AllocTensor<SRC_T>();
|
||||||
|
DataCopy(input_local, input_gm[offset], Group_Size);
|
||||||
|
input_queue.EnQue(input_local);
|
||||||
|
}
|
||||||
|
|
||||||
|
__aicore__ inline void copy_out(uint32_t offset) {
|
||||||
|
// reinterpretcast Group_Size(32) * int4b_t to Group_Size / 2 * int8_t,
|
||||||
|
// and using DataCopyPad to avoid 32 bits align.
|
||||||
|
LocalTensor<int4b_t> output_local = output_queue.DeQue<int4b_t>();
|
||||||
|
LocalTensor<int8_t> output_int8_local =
|
||||||
|
output_local.ReinterpretCast<int8_t>();
|
||||||
|
|
||||||
|
DataCopyExtParams dataCopyParams;
|
||||||
|
dataCopyParams.blockCount = 1;
|
||||||
|
dataCopyParams.blockLen = Group_Size / 2 * sizeof(int8_t);
|
||||||
|
DataCopyPad(output_gm[offset], output_int8_local, dataCopyParams);
|
||||||
|
|
||||||
|
output_queue.FreeTensor(output_local);
|
||||||
|
}
|
||||||
|
|
||||||
|
__aicore__ inline void input_to_cast(LocalTensor<float> cast_local,
|
||||||
|
LocalTensor<float> input_local) {
|
||||||
|
DataCopy(cast_local, input_local, Group_Size);
|
||||||
|
}
|
||||||
|
|
||||||
|
__aicore__ inline void input_to_cast(LocalTensor<float> cast_local,
|
||||||
|
LocalTensor<half> input_local) {
|
||||||
|
Cast(cast_local, input_local, RoundMode::CAST_NONE, Group_Size);
|
||||||
|
}
|
||||||
|
|
||||||
|
__aicore__ inline half calculate_group(int64_t row, int64_t group) {
|
||||||
|
const int64_t i3 = row / (input_ne[1] * input_ne[2]);
|
||||||
|
const int64_t i2 = (row - i3 * input_ne[1] * input_ne[2]) / input_ne[1];
|
||||||
|
const int64_t i1 =
|
||||||
|
row - i3 * input_ne[1] * input_ne[2] - i2 * input_ne[1];
|
||||||
|
|
||||||
|
const int64_t input_offset = i1 * input_stride[1] +
|
||||||
|
i2 * input_stride[2] +
|
||||||
|
i3 * input_stride[3] + Group_Size * group;
|
||||||
|
|
||||||
|
// output_offset is stride for output_gm which datatype is int8_t and
|
||||||
|
// divided by 2 is needed for int4b_t.
|
||||||
|
const int64_t output_offset = (i1 * output_stride[1] +
|
||||||
|
i2 * output_stride[2] +
|
||||||
|
i3 * output_stride[3] +
|
||||||
|
Group_Size * group) / 2;
|
||||||
|
copy_in(input_offset);
|
||||||
|
|
||||||
|
LocalTensor<SRC_T> input_local = input_queue.DeQue<SRC_T>();
|
||||||
|
LocalTensor<int4b_t> output_local = output_queue.AllocTensor<int4b_t>();
|
||||||
|
LocalTensor<float> cast_local = cast_queue.AllocTensor<float>();
|
||||||
|
LocalTensor<float> work_local = work_queue.AllocTensor<float>();
|
||||||
|
LocalTensor<float> max_local = max_queue.AllocTensor<float>();
|
||||||
|
LocalTensor<float> min_local = min_queue.AllocTensor<float>();
|
||||||
|
LocalTensor<int8_t> int8_local = int8_queue.AllocTensor<int8_t>();
|
||||||
|
LocalTensor<half> half_local = half_queue.AllocTensor<half>();
|
||||||
|
|
||||||
|
input_to_cast(cast_local, input_local);
|
||||||
|
|
||||||
|
ReduceMax(max_local, cast_local, work_local, Group_Size);
|
||||||
|
ReduceMin(min_local, cast_local, work_local, Group_Size);
|
||||||
|
const float max_value = max_local.GetValue(0);
|
||||||
|
const float min_value = min_local.GetValue(0);
|
||||||
|
float d = max_value;
|
||||||
|
if (min_value < 0 && (-1 * min_value) > max_value) {
|
||||||
|
d = min_value;
|
||||||
|
}
|
||||||
|
|
||||||
|
d = d / (-8);
|
||||||
|
if (d != 0) {
|
||||||
|
Muls(cast_local, cast_local, 1.0f / d, Group_Size);
|
||||||
|
}
|
||||||
|
|
||||||
|
// range: [-8,8] -> [0.5,16.5] -> [0,16] -> [0,15] -> [-8,7]
|
||||||
|
float scalar = 8.5f;
|
||||||
|
Adds(cast_local, cast_local, scalar, Group_Size);
|
||||||
|
Cast(cast_local, cast_local, RoundMode::CAST_FLOOR, Group_Size);
|
||||||
|
scalar = 15.0f;
|
||||||
|
Mins(cast_local, cast_local, scalar, Group_Size);
|
||||||
|
scalar = -8.0f;
|
||||||
|
Adds(cast_local, cast_local, scalar, Group_Size);
|
||||||
|
|
||||||
|
// float->half->int4b
|
||||||
|
Cast(half_local, cast_local, RoundMode::CAST_NONE, Group_Size);
|
||||||
|
Cast(output_local, half_local, RoundMode::CAST_NONE, Group_Size);
|
||||||
|
|
||||||
|
output_queue.EnQue(output_local);
|
||||||
|
copy_out(output_offset);
|
||||||
|
|
||||||
|
input_queue.FreeTensor(input_local);
|
||||||
|
work_queue.FreeTensor(work_local);
|
||||||
|
max_queue.FreeTensor(max_local);
|
||||||
|
min_queue.FreeTensor(min_local);
|
||||||
|
int8_queue.FreeTensor(int8_local);
|
||||||
|
half_queue.FreeTensor(half_local);
|
||||||
|
cast_queue.FreeTensor(cast_local);
|
||||||
|
return (half)d;
|
||||||
|
}
|
||||||
|
|
||||||
|
__aicore__ inline void calculate() {
|
||||||
|
LocalTensor<half> scale_local = scale_queue.AllocTensor<half>();
|
||||||
|
uint32_t scale_local_offset = 0;
|
||||||
|
uint32_t scale_global_offset = 0;
|
||||||
|
for (int64_t i = ir; i < ir + dr; i++) {
|
||||||
|
for (int64_t j = 0; j < group_size_in_row; j++) {
|
||||||
|
half scale = calculate_group(i, j);
|
||||||
|
scale_local.SetValue(scale_local_offset++, scale);
|
||||||
|
if (scale_local_offset == 16) {
|
||||||
|
scale_local_offset = 0;
|
||||||
|
// TODO: OPTIMIZE ME
|
||||||
|
pipe_barrier(PIPE_ALL);
|
||||||
|
DataCopy(scale_gm[scale_global_offset], scale_local, 16);
|
||||||
|
pipe_barrier(PIPE_ALL);
|
||||||
|
scale_global_offset += 16;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (scale_local_offset != 0) {
|
||||||
|
pipe_barrier(PIPE_ALL);
|
||||||
|
DataCopyExtParams dataCopyParams;
|
||||||
|
dataCopyParams.blockCount = 1;
|
||||||
|
dataCopyParams.blockLen = scale_local_offset * sizeof(half);
|
||||||
|
DataCopyPad(scale_gm[scale_global_offset], scale_local,
|
||||||
|
dataCopyParams);
|
||||||
|
pipe_barrier(PIPE_ALL);
|
||||||
|
}
|
||||||
|
scale_queue.FreeTensor(scale_local);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
int64_t input_ne[4];
|
||||||
|
size_t input_stride[4];
|
||||||
|
|
||||||
|
int64_t *scale_ne;
|
||||||
|
size_t scale_stride[4];
|
||||||
|
|
||||||
|
int64_t output_ne[4];
|
||||||
|
size_t output_stride[4];
|
||||||
|
|
||||||
|
int64_t group_size_in_row;
|
||||||
|
|
||||||
|
int64_t ir;
|
||||||
|
int64_t dr;
|
||||||
|
|
||||||
|
TPipe pipe;
|
||||||
|
GlobalTensor<SRC_T> input_gm;
|
||||||
|
GlobalTensor<half> scale_gm;
|
||||||
|
GlobalTensor<int8_t> output_gm;
|
||||||
|
TQue<QuePosition::VECIN, BUFFER_NUM> input_queue;
|
||||||
|
TQue<QuePosition::VECOUT, BUFFER_NUM> output_queue;
|
||||||
|
TQue<QuePosition::VECIN, BUFFER_NUM> work_queue;
|
||||||
|
TQue<QuePosition::VECOUT, BUFFER_NUM> max_queue;
|
||||||
|
TQue<QuePosition::VECOUT, BUFFER_NUM> min_queue;
|
||||||
|
TQue<QuePosition::VECOUT, BUFFER_NUM> scale_queue;
|
||||||
|
TQue<QuePosition::VECOUT, BUFFER_NUM> cast_queue;
|
||||||
|
TQue<QuePosition::VECOUT, BUFFER_NUM> int8_queue;
|
||||||
|
TQue<QuePosition::VECOUT, BUFFER_NUM> half_queue;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__aicore__ inline void copy_to_ub(GM_ADDR gm, T *ub, size_t size) {
|
||||||
|
auto gm_ptr = (__gm__ uint8_t *)gm;
|
||||||
|
auto ub_ptr = (uint8_t *)(ub);
|
||||||
|
for (int32_t i = 0; i < size; ++i, ++ub_ptr, ++gm_ptr) {
|
||||||
|
*ub_ptr = *gm_ptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" __global__ __aicore__ void ascendc_quantize_f16_to_q4_0(
|
||||||
|
GM_ADDR input_gm, GM_ADDR output_gm, GM_ADDR input_ne_gm,
|
||||||
|
GM_ADDR input_nb_gm, GM_ADDR output_ne_gm) {
|
||||||
|
int64_t input_ne_ub[4];
|
||||||
|
size_t input_nb_ub[4];
|
||||||
|
int64_t output_ne_ub[4];
|
||||||
|
|
||||||
|
copy_to_ub(input_ne_gm, input_ne_ub, 32);
|
||||||
|
copy_to_ub(input_nb_gm, input_nb_ub, 32);
|
||||||
|
copy_to_ub(output_ne_gm, output_ne_ub, 32);
|
||||||
|
|
||||||
|
QUANTIZE_FLOAT_TO_Q4_0<half> op;
|
||||||
|
op.init(input_gm, output_gm, input_ne_ub, input_nb_ub, output_ne_ub);
|
||||||
|
op.calculate();
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" __global__ __aicore__ void ascendc_quantize_f32_to_q4_0(
|
||||||
|
GM_ADDR input_gm, GM_ADDR output_gm, GM_ADDR input_ne_gm,
|
||||||
|
GM_ADDR input_nb_gm, GM_ADDR output_ne_gm) {
|
||||||
|
int64_t input_ne_ub[4];
|
||||||
|
size_t input_nb_ub[4];
|
||||||
|
int64_t output_ne_ub[4];
|
||||||
|
|
||||||
|
copy_to_ub(input_ne_gm, input_ne_ub, 32);
|
||||||
|
copy_to_ub(input_nb_gm, input_nb_ub, 32);
|
||||||
|
copy_to_ub(output_ne_gm, output_ne_ub, 32);
|
||||||
|
|
||||||
|
QUANTIZE_FLOAT_TO_Q4_0<float> op;
|
||||||
|
op.init(input_gm, output_gm, input_ne_ub, input_nb_ub, output_ne_ub);
|
||||||
|
op.calculate();
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user