Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions ggml/src/ggml-cann/aclnn_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
#include <aclnnop/aclnn_mean.h>
#include <aclnnop/aclnn_mm.h>
#include <aclnnop/aclnn_mul.h>
#include <aclnnop/aclnn_mv.h>
#include <aclnnop/aclnn_permute.h>
#include <aclnnop/aclnn_pow.h>
#include <aclnnop/aclnn_pow_tensor_tensor.h>
Expand Down Expand Up @@ -429,6 +430,94 @@ void ggml_cann_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
nullptr);
}

void ggml_cann_gated_linear_attn(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
ggml_tensor * k = dst->src[0];
ggml_tensor * v = dst->src[1];
ggml_tensor * q = dst->src[2];
ggml_tensor * g = dst->src[3];
ggml_tensor * s = dst->src[4];

int64_t B = dst->src[4]->ne[1];
int64_t T = dst->src[0]->ne[2];
int64_t H = dst->src[0]->ne[1];
int64_t C = dst->ne[0];
int64_t D = C / H;
int64_t L = T / B;

int64_t ne_qkg[2] = {1, D};
// int64_t ne_qkg[2] = {D, 1};
int64_t ne_s[2] = {D, D};
int64_t ne_vo[2] = {D, 1};
// int64_t ne_vo[2] = {1, D};
int64_t ne_q[1] = {D};
size_t nb_base = ggml_type_size(k->type);
size_t nb_qkg[2] = {nb_base, nb_base};
size_t nb_s[2] = {nb_base, D * nb_base};
size_t nb_vo[2] = {nb_base, D * nb_base};
size_t nb_q[1] = {nb_base};

float scale;
memcpy(&scale, dst->op_params, sizeof(float));

for (int64_t b = 0; b < B; b++) {
for (int64_t h = 0; h < H; h++) {
size_t s_offset = (b * (H * D * D) + h * (D * D)) * nb_base;
// D * D
aclTensor* acl_s = ggml_cann_create_tensor(s, ne_s, nb_s, 2, ACL_FORMAT_ND, s_offset);
aclTensor* acl_s_new = ggml_cann_create_tensor(dst, ne_s, nb_s, 2, ACL_FORMAT_ND, (B * L * H * D) * nb_base + s_offset);
cann_copy(ctx, acl_s, acl_s_new);
for (int64_t l = 0; l < L; l++) {
size_t qkvgo_offset = (b * (L * H * D) + l * (H * D) + h * (D)) * nb_base;
// D * 1
aclTensor* acl_k = ggml_cann_create_tensor(k, ne_qkg, nb_qkg, 2, ACL_FORMAT_ND, qkvgo_offset);
aclTensor* acl_g = ggml_cann_create_tensor(g, ne_qkg, nb_qkg, 2, ACL_FORMAT_ND, qkvgo_offset);
// D
aclTensor* acl_q = ggml_cann_create_tensor(q, ne_q, nb_q, 1, ACL_FORMAT_ND, qkvgo_offset);
// 1 * D
aclTensor* acl_v = ggml_cann_create_tensor(v, ne_vo, nb_vo, 2, ACL_FORMAT_ND, qkvgo_offset);
// D
aclTensor* acl_o = ggml_cann_create_tensor(dst, ne_q, nb_q, 1, ACL_FORMAT_ND, qkvgo_offset);
// repeat k and v
// buffer for repeated k
size_t buf_size = D * D * sizeof(float);
ggml_cann_pool_alloc state_buf1(ctx.pool(), buf_size);
void* buf1_ptr = state_buf1.get();
aclTensor* acl_buf_k = ggml_cann_create_tensor(buf1_ptr, ggml_cann_type_mapping(k->type), ggml_type_size(k->type), ne_s, nb_s, 2);
// buffer for repeated v
ggml_cann_pool_alloc state_buf2(ctx.pool(), buf_size);
void* buf2_ptr = state_buf2.get();
aclTensor* acl_buf_v = ggml_cann_create_tensor(buf2_ptr, ggml_cann_type_mapping(k->type), ggml_type_size(k->type), ne_s, nb_s, 2);
// repeat
int64_t k_rep[2] = {1, D};
int64_t v_rep[2] = {D, 1};
// int64_t k_rep[2] = {D, 1};
// int64_t v_rep[2] = {1, D};
aclIntArray* acl_k_rep = aclCreateIntArray(k_rep, 2);
aclIntArray* acl_v_rep = aclCreateIntArray(v_rep, 2);
GGML_CANN_CALL_ACLNN_OP(ctx, Repeat, acl_k, acl_k_rep, acl_buf_k);
GGML_CANN_CALL_ACLNN_OP(ctx, Repeat, acl_v, acl_v_rep, acl_buf_v);
// inplace mul, saved in acl_buf_k
aclnn_mul(ctx, acl_buf_k, acl_buf_v, nullptr);
// apply g to s
// reuse acl_buf_v to store repeated g
GGML_CANN_CALL_ACLNN_OP(ctx, Repeat, acl_g, acl_k_rep, acl_buf_v);
aclnn_mul(ctx, acl_s_new, acl_buf_v, nullptr);
// add kv
aclnn_add(ctx, acl_s_new, acl_buf_k, nullptr);
// compute output
// permute state and store in acl_buf k
int64_t newdim[2] = {1, 0};
aclnn_permute(ctx, acl_s_new, acl_buf_k, newdim, 2);
GGML_CANN_CALL_ACLNN_OP(ctx, Mv, acl_buf_k, acl_q, acl_o, 1);
aclnn_muls(ctx, acl_o, scale, nullptr, true);
ggml_cann_release_resources(ctx, acl_q, acl_k, acl_v, acl_o, acl_g, acl_buf_k, acl_buf_v, acl_k_rep, acl_v_rep);
}
ggml_cann_release_resources(ctx, acl_s, acl_s_new);
}
}
}


void ggml_cann_l2_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
ggml_tensor * src = dst->src[0];

Expand Down
16 changes: 16 additions & 0 deletions ggml/src/ggml-cann/aclnn_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,18 @@ void ggml_cann_l2_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst);
*/
void ggml_cann_cross_entropy_loss(ggml_backend_cann_context & ctx, ggml_tensor * dst);

/**
* @brief Computes the Gated Linear Attention for a ggml tensor using the CANN
* backend.
*
* @details ...
*
* @param ctx The CANN context used for operations.
* @param dst The destination tensor where the normalized values will be stored.
* @attention ...
*/
void ggml_cann_gated_linear_attn(ggml_backend_cann_context& ctx, ggml_tensor* dst);

/**
* @brief Computes the Group Normalization for a ggml tensor using the CANN
* backend.
Expand Down Expand Up @@ -674,6 +686,10 @@ void aclnn_cos(ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor *
*/
void aclnn_sin(ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst);

static void cann_copy(ggml_backend_cann_context& ctx, aclTensor* acl_src, aclTensor* acl_dst);
static void aclnn_permute(ggml_backend_cann_context& ctx, aclTensor* acl_src,
aclTensor* acl_dst, int64_t* new_dim, uint64_t dims);

/**
* @brief Prepares broadcast-compatible ACL tensors for two input tensors and one
* output tensor.
Expand Down
4 changes: 4 additions & 0 deletions ggml/src/ggml-cann/ggml-cann.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1889,6 +1889,9 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct gg
case GGML_OP_OUT_PROD:
ggml_cann_out_prod(ctx, dst);
break;
case GGML_OP_GATED_LINEAR_ATTN:
ggml_cann_gated_linear_attn(ctx, dst);
break;
default:
return false;
}
Expand Down Expand Up @@ -2561,6 +2564,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
case GGML_OP_MEAN:
case GGML_OP_PAD_REFLECT_1D:
case GGML_OP_COUNT_EQUAL:
case GGML_OP_GATED_LINEAR_ATTN:
return true;
case GGML_OP_OUT_PROD:
{
Expand Down