diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index 67c0223c010a1..a977578647254 100644 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -78,6 +78,11 @@ #include "../ggml-common.h" +struct mmid_row_mapping { + int32_t i1; + int32_t i2; +}; + void bcast_shape(ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst, aclTensor ** acl_src0, aclTensor ** acl_src1, aclTensor ** acl_dst) { GGML_ASSERT(ggml_are_same_shape(src0, dst) && ggml_can_repeat(src1, src0)); @@ -2587,3 +2592,148 @@ void ggml_cann_step(ggml_backend_cann_context& ctx, ggml_tensor* dst){ ggml_cann_release_resources(ctx, acl_src, acl_dst, alpha); } + +void ggml_cann_mul_mat_id(ggml_backend_cann_context& ctx, ggml_tensor* dst) { + //dst [M, K, N, 1] + ggml_tensor * src0 = dst->src[0]; //src0 [D, M, A, 1] + ggml_tensor * src1 = dst->src[1]; //src1 [D, B, N, 1], B = K or B = 1 + ggml_tensor * ids = dst->src[2]; //ids [K, N] + + GGML_TENSOR_BINARY_OP_LOCALS + + int64_t n_as = ne02; // A + int64_t n_ids = ids->ne[0]; // K + + std::vector ids_host(ggml_nbytes(ids)); + char * ids_dev = (char *) ids->data; + ACL_CHECK(aclrtMemcpyAsync(ids_host.data(), ggml_nbytes(ids), ids_dev, ggml_nbytes(ids), + ACL_MEMCPY_DEVICE_TO_HOST, ctx.stream())); + ACL_CHECK(aclrtSynchronizeStream(ctx.stream())); + + char * src0_original = (char *) src0->data; + char * src1_original = (char *) src1->data; + char * dst_original = (char *) dst->data; + + ggml_tensor src0_row = *src0; + ggml_tensor src1_row = *src1; + ggml_tensor dst_row = *dst; + + // src0_row [D, M, 1, 1] + src0_row.ne[2] = 1; + src0_row.ne[3] = 1; + src0_row.nb[3] = nb02; + + // src1_row [D, 1, 1, 1] + src1_row.ne[1] = 1; + src1_row.ne[2] = 1; + src1_row.ne[3] = 1; + src1_row.nb[2] = nb11; + src1_row.nb[3] = nb11; + + // dst_row [D, 1, 1, 1] + dst_row.ne[1] = 1; + dst_row.ne[2] = 1; + dst_row.ne[3] = 1; + dst_row.nb[2] = nb1; + dst_row.nb[3] = nb1; + + // ne12 == ids->ne[1] == N + if (ne12 == 1) { + for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { + for (int64_t id = 0; id < n_ids; id++) { + int32_t i02 = *(int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]); + + GGML_ASSERT(i02 >= 0 && i02 < n_as); + + // If B = 1 (broadcast), always use 0; otherwise, use id. + int64_t i11 = (ne11 == 1 ? 0 : id); + int64_t i12 = iid1; + + int64_t i1 = id; + int64_t i2 = i12; + + src0_row.data = src0_original + i02*nb02; + src1_row.data = src1_original + i11*nb11 + i12*nb12; + dst_row.data = dst_original + i1*nb1 + i2*nb2; + dst_row.src[0] = &src0_row; + dst_row.src[1] = &src1_row; + ggml_cann_mul_mat(ctx, &dst_row); + } + } + } else { + ggml_cann_pool_alloc src1_cont_allocator( + ctx.pool(),sizeof(float) * ggml_nelements(src1)); + ggml_cann_pool_alloc dst_cont_allocator( + ctx.pool(), sizeof(float) * ggml_nelements(dst)); + + void* src1_cont_buf = src1_cont_allocator.get(); + void* dst_cont_buf = dst_cont_allocator.get(); + + src1_row.data = src1_cont_buf; + dst_row.data = dst_cont_buf; + + for (int64_t i02 = 0; i02 < n_as; i02++) { + std::vector row_mappings; + int64_t num_src1_rows = 0; + + for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { + for (int64_t id = 0; id < n_ids; id++) { + int32_t row_id_i = *(int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]); + GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as); + + if (row_id_i == i02) { + mmid_row_mapping mapping; + mapping.i1 = static_cast(id); + mapping.i2 = static_cast(iid1); + row_mappings.push_back(mapping); + num_src1_rows++; + + int64_t read_b = (ne11 == 1 ? 0 : id); + char* src_ptr = src1_original + + read_b * nb11 + + mapping.i2 * nb12; + char* dst_ptr = (char*)src1_cont_buf + (num_src1_rows - 1) * nb11; + ACL_CHECK(aclrtMemcpyAsync(dst_ptr, nb11, src_ptr, nb11, + ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream())); + } + } + } + + if (num_src1_rows == 0) { + continue; + } + + // src0_row [D, M, 1, 1] + src0_row.data = src0_original + i02 * nb02; + + // src1_row [D, The number of values ​​in K * N is i02, 1, 1] + src1_row.ne[1] = num_src1_rows; + src1_row.nb[1] = nb11; + src1_row.nb[2] = num_src1_rows * nb11; + src1_row.nb[3] = num_src1_rows * nb11; + + // dst_row [D, The number of values ​​in K * N is i02, 1, 1] + dst_row.ne[1] = num_src1_rows; + dst_row.nb[1] = nb1; + dst_row.nb[2] = num_src1_rows * nb1; + dst_row.nb[3] = num_src1_rows * nb1; + + dst_row.src[0] = &src0_row; + dst_row.src[1] = &src1_row; + + ggml_cann_mul_mat(ctx, &dst_row); + + for (int64_t i = 0; i < num_src1_rows; ++i) { + int64_t i1 = row_mappings[i].i1; + int64_t i2 = row_mappings[i].i2; + + char* src_ptr = (char*)dst_cont_buf + i * nb1; + char* dst_ptr = dst_original + i1 * nb1 + i2 * nb2; + + ACL_CHECK(aclrtMemcpyAsync(dst_ptr, nb1, src_ptr, nb1, + ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream())); + } + } + } + return; +} diff --git a/ggml/src/ggml-cann/aclnn_ops.h b/ggml/src/ggml-cann/aclnn_ops.h index 462351542e546..a31d341026a8e 100644 --- a/ggml/src/ggml-cann/aclnn_ops.h +++ b/ggml/src/ggml-cann/aclnn_ops.h @@ -978,6 +978,33 @@ inline void ggml_cann_async_memset(ggml_backend_cann_context & ctx, void * buffe } } +/** + * @brief Performs sparse expert-based matrix multiplication using the CANN backend. + * + * @details This function implements a MoE-style batched matrix multiplication, where each input token + * is routed to one or more experts, and each expert corresponds to a specific [D, M] weight matrix + * in the source tensor `src0`. The routing indices are provided via the `ids` tensor. + * + * For each token (from `src1`), the function selects the corresponding expert(s) as specified by `ids`, + * performs the matrix multiplication with the selected expert's weight submatrix (from `src0`), + * and stores the results in `dst`. This operation is optimized and executed on the CANN backend. + * + * Dimensions: + * - src0: [D, M, A, 1], where A is the number of experts + * - src1: [D, B, N, 1], where N is batch size and B is the slot count per sample + * - ids : [K, N], where K is the number of experts each token is routed to + * - dst : [M, K, N, 1], output tensor storing the result of expert × token multiplication + * + * The function handles two main modes: + * - If `ne12 == 1`, a simpler per-token loop is used. + * - If `ne12 > 1`, grouped multiplication and memory copying is used for efficiency. + * + * @param ctx The CANN context used for operations. + * @param dst The destination tensor where the expert-weighted token outputs are stored. + * Expected to be of shape [M, K, N, 1]. + */ +void ggml_cann_mul_mat_id(ggml_backend_cann_context& ctx, ggml_tensor* dst); + /** * @brief Applies a element-wise operation to two input tensors using the CANN * backend. diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index e2617b06e9c39..0cb7bbf17cca5 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -1672,7 +1672,8 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx, ggml_cann_mul_mat(ctx, dst); break; case GGML_OP_MUL_MAT_ID: - return false; + ggml_cann_mul_mat_id(ctx, dst); + break; case GGML_OP_SCALE: ggml_cann_scale(ctx, dst); break; @@ -2030,7 +2031,13 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, } } case GGML_OP_MUL_MAT_ID: - return false; + switch (op->src[0]->type) { + case GGML_TYPE_F16: + case GGML_TYPE_F32: + return true; + default: + return false; + } // embedding case GGML_OP_GET_ROWS: { switch (op->src[0]->type) {