Skip to content

[CANN]Support OP MUL_MAT_ID #13042

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 150 additions & 0 deletions ggml/src/ggml-cann/aclnn_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@

#include "../ggml-common.h"

struct mmid_row_mapping {
int32_t i1;
int32_t i2;
};

void bcast_shape(ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst, aclTensor ** acl_src0,
aclTensor ** acl_src1, aclTensor ** acl_dst) {
GGML_ASSERT(ggml_are_same_shape(src0, dst) && ggml_can_repeat(src1, src0));
Expand Down Expand Up @@ -2587,3 +2592,148 @@ void ggml_cann_step(ggml_backend_cann_context& ctx, ggml_tensor* dst){

ggml_cann_release_resources(ctx, acl_src, acl_dst, alpha);
}

void ggml_cann_mul_mat_id(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
//dst [M, K, N, 1]
ggml_tensor * src0 = dst->src[0]; //src0 [D, M, A, 1]
ggml_tensor * src1 = dst->src[1]; //src1 [D, B, N, 1], B = K or B = 1
ggml_tensor * ids = dst->src[2]; //ids [K, N]

GGML_TENSOR_BINARY_OP_LOCALS

int64_t n_as = ne02; // A
int64_t n_ids = ids->ne[0]; // K

std::vector<char> ids_host(ggml_nbytes(ids));
char * ids_dev = (char *) ids->data;
ACL_CHECK(aclrtMemcpyAsync(ids_host.data(), ggml_nbytes(ids), ids_dev, ggml_nbytes(ids),
ACL_MEMCPY_DEVICE_TO_HOST, ctx.stream()));
ACL_CHECK(aclrtSynchronizeStream(ctx.stream()));

char * src0_original = (char *) src0->data;
char * src1_original = (char *) src1->data;
char * dst_original = (char *) dst->data;

ggml_tensor src0_row = *src0;
ggml_tensor src1_row = *src1;
ggml_tensor dst_row = *dst;

// src0_row [D, M, 1, 1]
src0_row.ne[2] = 1;
src0_row.ne[3] = 1;
src0_row.nb[3] = nb02;

// src1_row [D, 1, 1, 1]
src1_row.ne[1] = 1;
src1_row.ne[2] = 1;
src1_row.ne[3] = 1;
src1_row.nb[2] = nb11;
src1_row.nb[3] = nb11;

// dst_row [D, 1, 1, 1]
dst_row.ne[1] = 1;
dst_row.ne[2] = 1;
dst_row.ne[3] = 1;
dst_row.nb[2] = nb1;
dst_row.nb[3] = nb1;

// ne12 == ids->ne[1] == N
if (ne12 == 1) {
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
for (int64_t id = 0; id < n_ids; id++) {
int32_t i02 = *(int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);

GGML_ASSERT(i02 >= 0 && i02 < n_as);

// If B = 1 (broadcast), always use 0; otherwise, use id.
int64_t i11 = (ne11 == 1 ? 0 : id);
int64_t i12 = iid1;

int64_t i1 = id;
int64_t i2 = i12;

src0_row.data = src0_original + i02*nb02;
src1_row.data = src1_original + i11*nb11 + i12*nb12;
dst_row.data = dst_original + i1*nb1 + i2*nb2;
dst_row.src[0] = &src0_row;
dst_row.src[1] = &src1_row;
ggml_cann_mul_mat(ctx, &dst_row);
}
}
} else {
ggml_cann_pool_alloc src1_cont_allocator(
ctx.pool(),sizeof(float) * ggml_nelements(src1));
ggml_cann_pool_alloc dst_cont_allocator(
ctx.pool(), sizeof(float) * ggml_nelements(dst));

void* src1_cont_buf = src1_cont_allocator.get();
void* dst_cont_buf = dst_cont_allocator.get();

src1_row.data = src1_cont_buf;
dst_row.data = dst_cont_buf;

for (int64_t i02 = 0; i02 < n_as; i02++) {
std::vector<mmid_row_mapping> row_mappings;
int64_t num_src1_rows = 0;

for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
for (int64_t id = 0; id < n_ids; id++) {
int32_t row_id_i = *(int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as);

if (row_id_i == i02) {
mmid_row_mapping mapping;
mapping.i1 = static_cast<int32_t>(id);
mapping.i2 = static_cast<int32_t>(iid1);
row_mappings.push_back(mapping);
num_src1_rows++;

int64_t read_b = (ne11 == 1 ? 0 : id);
char* src_ptr = src1_original
+ read_b * nb11
+ mapping.i2 * nb12;
char* dst_ptr = (char*)src1_cont_buf + (num_src1_rows - 1) * nb11;
ACL_CHECK(aclrtMemcpyAsync(dst_ptr, nb11, src_ptr, nb11,
ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream()));
}
}
}

if (num_src1_rows == 0) {
continue;
}

// src0_row [D, M, 1, 1]
src0_row.data = src0_original + i02 * nb02;

// src1_row [D, The number of values ​​in K * N is i02, 1, 1]
src1_row.ne[1] = num_src1_rows;
src1_row.nb[1] = nb11;
src1_row.nb[2] = num_src1_rows * nb11;
src1_row.nb[3] = num_src1_rows * nb11;

// dst_row [D, The number of values ​​in K * N is i02, 1, 1]
dst_row.ne[1] = num_src1_rows;
dst_row.nb[1] = nb1;
dst_row.nb[2] = num_src1_rows * nb1;
dst_row.nb[3] = num_src1_rows * nb1;

dst_row.src[0] = &src0_row;
dst_row.src[1] = &src1_row;

ggml_cann_mul_mat(ctx, &dst_row);

for (int64_t i = 0; i < num_src1_rows; ++i) {
int64_t i1 = row_mappings[i].i1;
int64_t i2 = row_mappings[i].i2;

char* src_ptr = (char*)dst_cont_buf + i * nb1;
char* dst_ptr = dst_original + i1 * nb1 + i2 * nb2;

ACL_CHECK(aclrtMemcpyAsync(dst_ptr, nb1, src_ptr, nb1,
ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream()));
}
}
}
return;
}
27 changes: 27 additions & 0 deletions ggml/src/ggml-cann/aclnn_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -978,6 +978,33 @@ inline void ggml_cann_async_memset(ggml_backend_cann_context & ctx, void * buffe
}
}

/**
* @brief Performs sparse expert-based matrix multiplication using the CANN backend.
*
* @details This function implements a MoE-style batched matrix multiplication, where each input token
* is routed to one or more experts, and each expert corresponds to a specific [D, M] weight matrix
* in the source tensor `src0`. The routing indices are provided via the `ids` tensor.
*
* For each token (from `src1`), the function selects the corresponding expert(s) as specified by `ids`,
* performs the matrix multiplication with the selected expert's weight submatrix (from `src0`),
* and stores the results in `dst`. This operation is optimized and executed on the CANN backend.
*
* Dimensions:
* - src0: [D, M, A, 1], where A is the number of experts
* - src1: [D, B, N, 1], where N is batch size and B is the slot count per sample
* - ids : [K, N], where K is the number of experts each token is routed to
* - dst : [M, K, N, 1], output tensor storing the result of expert × token multiplication
*
* The function handles two main modes:
* - If `ne12 == 1`, a simpler per-token loop is used.
* - If `ne12 > 1`, grouped multiplication and memory copying is used for efficiency.
*
* @param ctx The CANN context used for operations.
* @param dst The destination tensor where the expert-weighted token outputs are stored.
* Expected to be of shape [M, K, N, 1].
*/
void ggml_cann_mul_mat_id(ggml_backend_cann_context& ctx, ggml_tensor* dst);

/**
* @brief Applies a element-wise operation to two input tensors using the CANN
* backend.
Expand Down
11 changes: 9 additions & 2 deletions ggml/src/ggml-cann/ggml-cann.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1672,7 +1672,8 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
ggml_cann_mul_mat(ctx, dst);
break;
case GGML_OP_MUL_MAT_ID:
return false;
ggml_cann_mul_mat_id(ctx, dst);
break;
case GGML_OP_SCALE:
ggml_cann_scale(ctx, dst);
break;
Expand Down Expand Up @@ -2030,7 +2031,13 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
}
}
case GGML_OP_MUL_MAT_ID:
return false;
switch (op->src[0]->type) {
case GGML_TYPE_F16:
case GGML_TYPE_F32:
return true;
default:
return false;
}
// embedding
case GGML_OP_GET_ROWS: {
switch (op->src[0]->type) {
Expand Down
Loading