diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index 48f4b7db691..708417feadd 100644 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -57,6 +57,7 @@ #include #include #include +#include #include #include #include @@ -429,6 +430,94 @@ void ggml_cann_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst) { nullptr); } +void ggml_cann_gated_linear_attn(ggml_backend_cann_context& ctx, ggml_tensor* dst) { + ggml_tensor * k = dst->src[0]; + ggml_tensor * v = dst->src[1]; + ggml_tensor * q = dst->src[2]; + ggml_tensor * g = dst->src[3]; + ggml_tensor * s = dst->src[4]; + + int64_t B = dst->src[4]->ne[1]; + int64_t T = dst->src[0]->ne[2]; + int64_t H = dst->src[0]->ne[1]; + int64_t C = dst->ne[0]; + int64_t D = C / H; + int64_t L = T / B; + + int64_t ne_qkg[2] = {1, D}; + // int64_t ne_qkg[2] = {D, 1}; + int64_t ne_s[2] = {D, D}; + int64_t ne_vo[2] = {D, 1}; + // int64_t ne_vo[2] = {1, D}; + int64_t ne_q[1] = {D}; + size_t nb_base = ggml_type_size(k->type); + size_t nb_qkg[2] = {nb_base, nb_base}; + size_t nb_s[2] = {nb_base, D * nb_base}; + size_t nb_vo[2] = {nb_base, D * nb_base}; + size_t nb_q[1] = {nb_base}; + + float scale; + memcpy(&scale, dst->op_params, sizeof(float)); + + for (int64_t b = 0; b < B; b++) { + for (int64_t h = 0; h < H; h++) { + size_t s_offset = (b * (H * D * D) + h * (D * D)) * nb_base; + // D * D + aclTensor* acl_s = ggml_cann_create_tensor(s, ne_s, nb_s, 2, ACL_FORMAT_ND, s_offset); + aclTensor* acl_s_new = ggml_cann_create_tensor(dst, ne_s, nb_s, 2, ACL_FORMAT_ND, (B * L * H * D) * nb_base + s_offset); + cann_copy(ctx, acl_s, acl_s_new); + for (int64_t l = 0; l < L; l++) { + size_t qkvgo_offset = (b * (L * H * D) + l * (H * D) + h * (D)) * nb_base; + // D * 1 + aclTensor* acl_k = ggml_cann_create_tensor(k, ne_qkg, nb_qkg, 2, ACL_FORMAT_ND, qkvgo_offset); + aclTensor* acl_g = ggml_cann_create_tensor(g, ne_qkg, nb_qkg, 2, ACL_FORMAT_ND, qkvgo_offset); + // D + aclTensor* acl_q = ggml_cann_create_tensor(q, ne_q, nb_q, 1, ACL_FORMAT_ND, qkvgo_offset); + // 1 * D + aclTensor* acl_v = ggml_cann_create_tensor(v, ne_vo, nb_vo, 2, ACL_FORMAT_ND, qkvgo_offset); + // D + aclTensor* acl_o = ggml_cann_create_tensor(dst, ne_q, nb_q, 1, ACL_FORMAT_ND, qkvgo_offset); + // repeat k and v + // buffer for repeated k + size_t buf_size = D * D * sizeof(float); + ggml_cann_pool_alloc state_buf1(ctx.pool(), buf_size); + void* buf1_ptr = state_buf1.get(); + aclTensor* acl_buf_k = ggml_cann_create_tensor(buf1_ptr, ggml_cann_type_mapping(k->type), ggml_type_size(k->type), ne_s, nb_s, 2); + // buffer for repeated v + ggml_cann_pool_alloc state_buf2(ctx.pool(), buf_size); + void* buf2_ptr = state_buf2.get(); + aclTensor* acl_buf_v = ggml_cann_create_tensor(buf2_ptr, ggml_cann_type_mapping(k->type), ggml_type_size(k->type), ne_s, nb_s, 2); + // repeat + int64_t k_rep[2] = {1, D}; + int64_t v_rep[2] = {D, 1}; + // int64_t k_rep[2] = {D, 1}; + // int64_t v_rep[2] = {1, D}; + aclIntArray* acl_k_rep = aclCreateIntArray(k_rep, 2); + aclIntArray* acl_v_rep = aclCreateIntArray(v_rep, 2); + GGML_CANN_CALL_ACLNN_OP(ctx, Repeat, acl_k, acl_k_rep, acl_buf_k); + GGML_CANN_CALL_ACLNN_OP(ctx, Repeat, acl_v, acl_v_rep, acl_buf_v); + // inplace mul, saved in acl_buf_k + aclnn_mul(ctx, acl_buf_k, acl_buf_v, nullptr); + // apply g to s + // reuse acl_buf_v to store repeated g + GGML_CANN_CALL_ACLNN_OP(ctx, Repeat, acl_g, acl_k_rep, acl_buf_v); + aclnn_mul(ctx, acl_s_new, acl_buf_v, nullptr); + // add kv + aclnn_add(ctx, acl_s_new, acl_buf_k, nullptr); + // compute output + // permute state and store in acl_buf k + int64_t newdim[2] = {1, 0}; + aclnn_permute(ctx, acl_s_new, acl_buf_k, newdim, 2); + GGML_CANN_CALL_ACLNN_OP(ctx, Mv, acl_buf_k, acl_q, acl_o, 1); + aclnn_muls(ctx, acl_o, scale, nullptr, true); + ggml_cann_release_resources(ctx, acl_q, acl_k, acl_v, acl_o, acl_g, acl_buf_k, acl_buf_v, acl_k_rep, acl_v_rep); + } + ggml_cann_release_resources(ctx, acl_s, acl_s_new); + } + } +} + + void ggml_cann_l2_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst) { ggml_tensor * src = dst->src[0]; diff --git a/ggml/src/ggml-cann/aclnn_ops.h b/ggml/src/ggml-cann/aclnn_ops.h index 1ebbc769c71..3e1794912d9 100644 --- a/ggml/src/ggml-cann/aclnn_ops.h +++ b/ggml/src/ggml-cann/aclnn_ops.h @@ -251,6 +251,18 @@ void ggml_cann_l2_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst); */ void ggml_cann_cross_entropy_loss(ggml_backend_cann_context & ctx, ggml_tensor * dst); +/** + * @brief Computes the Gated Linear Attention for a ggml tensor using the CANN + * backend. + * + * @details ... + * + * @param ctx The CANN context used for operations. + * @param dst The destination tensor where the normalized values will be stored. + * @attention ... + */ +void ggml_cann_gated_linear_attn(ggml_backend_cann_context& ctx, ggml_tensor* dst); + /** * @brief Computes the Group Normalization for a ggml tensor using the CANN * backend. @@ -674,6 +686,10 @@ void aclnn_cos(ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * */ void aclnn_sin(ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst); +static void cann_copy(ggml_backend_cann_context& ctx, aclTensor* acl_src, aclTensor* acl_dst); +static void aclnn_permute(ggml_backend_cann_context& ctx, aclTensor* acl_src, + aclTensor* acl_dst, int64_t* new_dim, uint64_t dims); + /** * @brief Prepares broadcast-compatible ACL tensors for two input tensors and one * output tensor. diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index 544c1e2a501..e7e48aa091d 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -1889,6 +1889,9 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct gg case GGML_OP_OUT_PROD: ggml_cann_out_prod(ctx, dst); break; + case GGML_OP_GATED_LINEAR_ATTN: + ggml_cann_gated_linear_attn(ctx, dst); + break; default: return false; } @@ -2561,6 +2564,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten case GGML_OP_MEAN: case GGML_OP_PAD_REFLECT_1D: case GGML_OP_COUNT_EQUAL: + case GGML_OP_GATED_LINEAR_ATTN: return true; case GGML_OP_OUT_PROD: {