From 7066b4cce2898993e943ad6af5d8f1de5840c8e9 Mon Sep 17 00:00:00 2001 From: Chenguang Li <87689256+noemotiovon@users.noreply.github.com> Date: Tue, 26 Nov 2024 17:31:05 +0800 Subject: [PATCH] CANN: RoPE and CANCAT operator optimization (#10488) Co-authored-by: noemotiovon --- ggml/src/ggml-cann/aclnn_ops.cpp | 225 +++++++------------------------ ggml/src/ggml-cann/ggml-cann.cpp | 68 ++++++++-- 2 files changed, 106 insertions(+), 187 deletions(-) diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index 1f4ee986c..6113b59f4 100644 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -21,6 +21,7 @@ */ #include "aclnn_ops.h" +#include "ggml-impl.h" #include #include @@ -241,10 +242,14 @@ void ggml_cann_concat(ggml_backend_cann_context& ctx, ggml_tensor* dst) { aclTensor* acl_src1 = ggml_cann_create_tensor(src1); aclTensor* acl_dst = ggml_cann_create_tensor(dst); - int64_t concat_dim = 1; + const int32_t dim = ggml_get_op_params_i32(dst, 0); + + GGML_ASSERT(dim >= 0 && dim < 4); + int32_t acl_dim = 3 - dim; + aclTensor* tensors[] = {acl_src0, acl_src1}; aclTensorList* tensorList = aclCreateTensorList(tensors, 2); - aclnn_concat(ctx, tensorList, acl_dst, concat_dim); + aclnn_concat(ctx, tensorList, acl_dst, acl_dim); ACL_CHECK(aclDestroyTensorList(tensorList)); ACL_CHECK(aclDestroyTensor(acl_dst)); @@ -1437,10 +1442,6 @@ void ggml_cann_im2col(ggml_backend_cann_context& ctx, ggml_tensor* dst) { ggml_tensor* src0 = dst->src[0]; // kernel ggml_tensor* src1 = dst->src[1]; // input - GGML_ASSERT(src0->type == GGML_TYPE_F16); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32); - GGML_TENSOR_BINARY_OP_LOCALS; // aclnnIm2col only works on 2D. set s1, p1, d1 to 1 to perform 2D @@ -1462,9 +1463,6 @@ void ggml_cann_im2col(ggml_backend_cann_context& ctx, ggml_tensor* dst) { const int64_t OH = is_2D ? ne2 : 1; const int64_t OW = ne1; - GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); - GGML_ASSERT(nb10 == sizeof(float)); - // memory allocated increased to 3x when is_2D == false const int64_t n_bytes_factor = is_2D ? 1 : 3; @@ -2859,15 +2857,27 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst, ACL_CHECK(aclDestroyTensor(acl_cos_tensor)); } +#ifdef __cplusplus +extern "C" { +#endif +aclnnStatus aclnnRotaryPositionEmbeddingGetWorkspaceSize( + const aclTensor* x, const aclTensor* cos, const aclTensor* sin, + int64_t mode, const aclTensor* yOut, uint64_t* workspaceSize, + aclOpExecutor** executor); +aclnnStatus aclnnRotaryPositionEmbedding(void* workspace, + uint64_t workspaceSize, + aclOpExecutor* executor, + aclrtStream stream); +#ifdef __cplusplus +} +#endif + void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { // TODO: use ascendc // Only test with LLAMA model. ggml_tensor* src0 = dst->src[0]; // input ggml_tensor* src2 = dst->src[2]; // freq_factors - // TODO: with freq_factors - GGML_ASSERT(src2 == NULL); - // param float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; // const int n_past = ((int32_t *) dst->op_params)[0]; @@ -2885,13 +2895,19 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { memcpy(&beta_fast, (int32_t*)dst->op_params + 9, sizeof(float)); memcpy(&beta_slow, (int32_t*)dst->op_params + 10, sizeof(float)); - GGML_ASSERT(n_dims <= ne0); + // TODO: with freq_factors + GGML_ASSERT(src2 == NULL); + // TODO: attn_factor != 1 + GGML_ASSERT(attn_factor == 1); + // TODO: n_dims <= ne0 + GGML_ASSERT(n_dims == ne0); GGML_ASSERT(n_dims % 2 == 0); - // TODO: ext_factor != 0 GGML_ASSERT(ext_factor == 0); // TODO: freq_scale != 1 GGML_ASSERT(freq_scale == 1); + // TODO: type == GGML_TYPE_F16 + GGML_ASSERT(src0->type == GGML_TYPE_F32); const float theta_scale = powf(freq_base, -2.0f / n_dims); @@ -2924,177 +2940,30 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { aclnn_cache_init(ctx, dst, acl_cos_reshape_tensor, acl_sin_reshape_tensor, theta_scale, is_neox); - // roll input - void* input_roll_buffer; - aclTensor* acl_minus_one_tensor; - void* minus_one_scale_buffer = nullptr; - ggml_cann_pool_alloc roll_allocator(ctx.pool(), ggml_nbytes(src0)); - ggml_cann_pool_alloc minus_one_scale_allocator( - ctx.pool(), sizeof(float_t) * src0->ne[0]); - if (!is_neox) { - // roll input: [q0,q1,q2,q3,...] -> [q1,q0,q3,q2,...] - input_roll_buffer = roll_allocator.get(); - int64_t input_roll_ne[4] = {2, src0->ne[1] * (src0->ne[0] / 2), - src0->ne[2], src0->ne[3]}; - size_t input_roll_nb[GGML_MAX_DIMS]; - input_roll_nb[0] = ggml_type_size(src0->type); - for (int i = 1; i < GGML_MAX_DIMS; i++) { - input_roll_nb[i] = input_roll_nb[i - 1] * input_roll_ne[i - 1]; - } - aclTensor* acl_input_roll_tensor = ggml_cann_create_tensor( - input_roll_buffer, ggml_cann_type_mapping(src0->type), - ggml_type_size(src0->type), input_roll_ne, input_roll_nb, - GGML_MAX_DIMS); - aclTensor* acl_input_tensor = ggml_cann_create_tensor( - src0->data, ggml_cann_type_mapping(src0->type), - ggml_type_size(src0->type), input_roll_ne, input_roll_nb, - GGML_MAX_DIMS); + uint64_t workspaceSize = 0; + aclOpExecutor* executor; - int64_t shifts[] = {1}; - int64_t dims[] = {3}; - aclnn_roll(ctx, acl_input_tensor, acl_input_roll_tensor, shifts, dims); - ACL_CHECK(aclDestroyTensor(acl_input_roll_tensor)); - ACL_CHECK(aclDestroyTensor(acl_input_tensor)); + void* workspaceAddr = nullptr; - // init [-1, 1, -1, 1, ...] - minus_one_scale_buffer = minus_one_scale_allocator.get(); - - int64_t minus_one_ne[4] = {src0->ne[0], 1, 1, 1}; - size_t minus_one_nb[GGML_MAX_DIMS]; - minus_one_nb[0] = sizeof(float_t); - for (int i = 1; i < GGML_MAX_DIMS; i++) { - minus_one_nb[i] = minus_one_nb[i - 1] * minus_one_ne[i - 1]; - } - acl_minus_one_tensor = aclnn_ones( - ctx, minus_one_scale_buffer, sizeof(float_t) * src0->ne[0], - minus_one_ne, GGML_MAX_DIMS, ACL_FLOAT, sizeof(float_t), 1); - int64_t dim = 3; - int64_t* index = new int64_t[src0->ne[0]]; - for (int i = 0; i < src0->ne[0]; i++) { - index[i] = i / 2 * 2; - } - int64_t index_num = src0->ne[0]; - float value = -1; - aclnn_index_fill_tensor(ctx, acl_minus_one_tensor, dim, index, - index_num, value); - } else { - // roll input: [q0,q1,q2,...] -> - // [q_half,q_half+1,...,q_end,q0,q1,...q_half-1] - input_roll_buffer = roll_allocator.get(); - aclTensor* acl_input_roll_tensor = ggml_cann_create_tensor( - input_roll_buffer, ggml_cann_type_mapping(src0->type), - ggml_type_size(src0->type), src0->ne, src0->nb, GGML_MAX_DIMS); - aclTensor* acl_input_tensor = ggml_cann_create_tensor(src0); - - int64_t shifts[] = {src0->ne[0] / 2}; - int64_t dims[] = {3}; - aclnn_roll(ctx, acl_input_tensor, acl_input_roll_tensor, shifts, dims); - - ACL_CHECK(aclDestroyTensor(acl_input_roll_tensor)); - ACL_CHECK(aclDestroyTensor(acl_input_tensor)); - - // init [-1, -1, -1, 1, 1,1,...] - minus_one_scale_buffer = minus_one_scale_allocator.get(); - - int64_t minus_one_ne[4] = {src0->ne[0], 1, 1, 1}; - size_t minus_one_nb[GGML_MAX_DIMS]; - minus_one_nb[0] = sizeof(float_t); - for (int i = 1; i < GGML_MAX_DIMS; i++) { - minus_one_nb[i] = minus_one_nb[i - 1] * minus_one_ne[i - 1]; - } - acl_minus_one_tensor = aclnn_ones( - ctx, minus_one_scale_buffer, sizeof(float_t) * src0->ne[0], - minus_one_ne, GGML_MAX_DIMS, ACL_FLOAT, sizeof(float_t), 1); - // -1 * first half - int64_t first_half_ne[4] = {src0->ne[0] / 2, 1, 1, 1}; - size_t first_half_nb[GGML_MAX_DIMS]; - first_half_nb[0] = sizeof(float_t); - for (int i = 1; i < GGML_MAX_DIMS; i++) { - first_half_nb[i] = first_half_nb[i - 1] * first_half_ne[i - 1]; - } - aclTensor* acl_first_half_tensor = ggml_cann_create_tensor( - minus_one_scale_buffer, ACL_FLOAT, sizeof(float_t), first_half_ne, - first_half_nb, GGML_MAX_DIMS); - bool inplace = true; - float scale = -1; - aclnn_muls(ctx, acl_first_half_tensor, scale, nullptr, inplace); - ACL_CHECK(aclDestroyTensor(acl_first_half_tensor)); + int acl_mode = mode; + if (mode == 0) { + acl_mode = 1; } - // TODO: n_dims < ne0 - GGML_ASSERT(n_dims == src0->ne[0]); - - // input * scale - ggml_cann_pool_alloc roll_mul_scale_allocator(ctx.pool(), - ggml_nbytes(src0)); - void* input_roll_mul_scale_buffer = roll_mul_scale_allocator.get(); - size_t input_nb[GGML_MAX_DIMS]; - input_nb[0] = ggml_type_size(src0->type); - for (int i = 1; i < GGML_MAX_DIMS; i++) { - input_nb[i] = input_nb[i - 1] * src0->ne[i - 1]; - } - aclTensor* acl_input_roll_mul_scale_tensor = ggml_cann_create_tensor( - input_roll_mul_scale_buffer, ggml_cann_type_mapping(src0->type), - ggml_type_size(src0->type), src0->ne, input_nb, GGML_MAX_DIMS); - aclTensor* acl_input_roll_reshape_tensor = ggml_cann_create_tensor( - input_roll_buffer, ggml_cann_type_mapping(src0->type), - ggml_type_size(src0->type), src0->ne, input_nb, GGML_MAX_DIMS); - - aclnn_mul(ctx, acl_input_roll_reshape_tensor, acl_minus_one_tensor, - acl_input_roll_mul_scale_tensor); - - // output - aclTensor* acl_src0 = ggml_cann_create_tensor(src0); + aclTensor* acl_x = ggml_cann_create_tensor(src0); aclTensor* acl_dst = ggml_cann_create_tensor(dst); - void* output_fp32_buffer; - if (src0->type == GGML_TYPE_F32) { - aclnn_inplace_mul(ctx, acl_src0, acl_cos_reshape_tensor); - aclnn_inplace_mul(ctx, acl_input_roll_mul_scale_tensor, - acl_sin_reshape_tensor); - aclnn_add(ctx, acl_src0, acl_input_roll_mul_scale_tensor, acl_dst); - // TODO: ne0 != n_dims in mode2 - } else if (src0->type == GGML_TYPE_F16) { - size_t input_fp32_nb[GGML_MAX_DIMS]; - input_fp32_nb[0] = sizeof(float_t); - for (int i = 1; i < GGML_MAX_DIMS; i++) { - input_fp32_nb[i] = input_fp32_nb[i - 1] * dst->ne[i - 1]; - } - ggml_cann_pool_alloc fp32_allocator1( - ctx.pool(), ggml_nelements(dst) * sizeof(float_t)); - void* input_fp32_buffer1 = fp32_allocator1.get(); - aclTensor* input_fp32_tensor1 = ggml_cann_create_tensor( - input_fp32_buffer1, ACL_FLOAT, sizeof(float_t), dst->ne, - input_fp32_nb, GGML_MAX_DIMS); - ggml_cann_pool_alloc fp32_allocator2( - ctx.pool(), ggml_nelements(dst) * sizeof(float_t)); - void* input_fp32_buffer2 = fp32_allocator2.get(); - aclTensor* input_fp32_tensor2 = ggml_cann_create_tensor( - input_fp32_buffer2, ACL_FLOAT, sizeof(float_t), dst->ne, - input_fp32_nb, GGML_MAX_DIMS); - - ggml_cann_pool_alloc fp32_allocator( - ctx.pool(), ggml_nelements(dst) * sizeof(float_t)); - output_fp32_buffer = fp32_allocator.get(); - aclTensor* output_fp32_tensor = ggml_cann_create_tensor( - output_fp32_buffer, ACL_FLOAT, sizeof(float_t), dst->ne, - input_fp32_nb, GGML_MAX_DIMS); - aclnn_mul(ctx, acl_src0, acl_cos_reshape_tensor, input_fp32_tensor1); - aclnn_mul(ctx, acl_input_roll_mul_scale_tensor, acl_sin_reshape_tensor, - input_fp32_tensor2); - aclnn_add(ctx, input_fp32_tensor1, input_fp32_tensor2, - output_fp32_tensor); - aclnn_cast(ctx, output_fp32_tensor, acl_dst, ACL_FLOAT16); - - ACL_CHECK(aclDestroyTensor(input_fp32_tensor1)); - ACL_CHECK(aclDestroyTensor(input_fp32_tensor2)); - ACL_CHECK(aclDestroyTensor(output_fp32_tensor)); + ACL_CHECK(aclnnRotaryPositionEmbeddingGetWorkspaceSize( + acl_x, acl_cos_reshape_tensor, acl_sin_reshape_tensor, acl_mode, acl_dst, &workspaceSize, &executor)); + if (workspaceSize > 0) { + ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); + workspaceAddr = workspace_allocator.get(); } - ACL_CHECK(aclDestroyTensor(acl_sin_reshape_tensor)); + ACL_CHECK(aclnnRotaryPositionEmbedding(workspaceAddr, workspaceSize, + executor, ctx.stream())); + + ACL_CHECK(aclDestroyTensor(acl_x)); ACL_CHECK(aclDestroyTensor(acl_cos_reshape_tensor)); - ACL_CHECK(aclDestroyTensor(acl_minus_one_tensor)); - ACL_CHECK(aclDestroyTensor(acl_input_roll_mul_scale_tensor)); - ACL_CHECK(aclDestroyTensor(acl_input_roll_reshape_tensor)); - ACL_CHECK(aclDestroyTensor(acl_src0)); + ACL_CHECK(aclDestroyTensor(acl_sin_reshape_tensor)); ACL_CHECK(aclDestroyTensor(acl_dst)); } diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index d96f65936..2ef5b590a 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -1669,12 +1669,14 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, } case GGML_OP_MUL_MAT: { switch (op->src[0]->type) { + case GGML_TYPE_Q8_0: + // Current groupsize should not be greater than k-1 in + // aclnnWeightQuantBatchMatmulV2GetWorkspaceSize + if (op->src[0]->ne[0] <= QK8_0) { + return false; + } case GGML_TYPE_F16: case GGML_TYPE_F32: - case GGML_TYPE_Q8_0: - // TODO: fix me - // Current groupsize should not be greater than k-1 in - // aclnnWeightQuantBatchMatmulV2GetWorkspaceSize(). case GGML_TYPE_Q4_0: return true; default: @@ -1706,9 +1708,61 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, return false; } } + case GGML_OP_CONT: { + // TODO: support GGML_TYPE_BF16 + switch (op->src[0]->type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + return true; + default: + return false; + } + } + case GGML_OP_ROPE: { + // TODO: with ops-test v == 1 + float * freq_scale = (float*)((int32_t*)op->op_params + 6); + float * ext_factor = (float*)((int32_t*)op->op_params + 7); + float * attn_factor = (float*)((int32_t*)op->op_params + 8); + // TODO: with freq_factors + if (op->src[2] != NULL) { + return false; + } + // TODO: n_dims <= ne0 + if (op->src[0]->ne[0] != op->op_params[1]) { + return false; + } + // TODO: ext_factor != 0 + if (*ext_factor != 0) { + return false; + } + // TODO: freq_scale != 1 + if (*freq_scale != 1) { + return false; + } + // TODO: attn_factor != 1 + if (*attn_factor != 1) { + return false; + } + //TODO: type == GGML_TYPE_F16 + switch (op->src[0]->type) { + case GGML_TYPE_F32: + return true; + default: + return false; + } + } + case GGML_OP_UPSCALE: { + // aclnnUpsampleNearest2dGetWorkspaceSize not support + // selfDimN[2]/outDimN[2] or selfDimC[3]/outDimC[3] not equal + if (op->src[0]->ne[2] * op->ne[3] != op->src[0]->ne[3] * op->ne[2]) { + return false; + } + return true; + } + case GGML_OP_IM2COL: + case GGML_OP_CONCAT: case GGML_OP_DUP: case GGML_OP_REPEAT: - case GGML_OP_CONCAT: case GGML_OP_NONE: case GGML_OP_RESHAPE: case GGML_OP_VIEW: @@ -1722,17 +1776,13 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, case GGML_OP_SCALE: case GGML_OP_SQR: case GGML_OP_CLAMP: - case GGML_OP_CONT: case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: - case GGML_OP_ROPE: - case GGML_OP_IM2COL: case GGML_OP_POOL_2D: case GGML_OP_SUM_ROWS: case GGML_OP_ARGSORT: case GGML_OP_ACC: case GGML_OP_GROUP_NORM: - case GGML_OP_UPSCALE: case GGML_OP_PAD: case GGML_OP_ARANGE: case GGML_OP_TIMESTEP_EMBEDDING: