Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions custom_ops/xpu_ops/src/ops/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -482,13 +482,14 @@ std::vector<paddle::Tensor> SpeculateGetSeqLensOutput(

void SetDataIpc(const paddle::Tensor& tmp_input, const std::string& shm_name);

void TextImageGatherScatter(paddle::Tensor& input,
paddle::Tensor& text_input,
paddle::Tensor& image_input,
paddle::Tensor& token_type_ids,
paddle::Tensor& text_index,
paddle::Tensor& image_index,
const bool is_scatter);
std::vector<paddle::Tensor> TextImageGatherScatter(
paddle::Tensor& input,
paddle::Tensor& text_input,
paddle::Tensor& image_input,
paddle::Tensor& token_type_ids,
paddle::Tensor& text_index,
paddle::Tensor& image_index,
const bool is_scatter);

void TextImageIndexOut(const paddle::Tensor& token_type_ids,
const paddle::Tensor& text_index,
Expand Down
34 changes: 18 additions & 16 deletions custom_ops/xpu_ops/src/ops/text_image_gather_scatter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,18 @@
#include "paddle/extension.h"
#include "xpu/plugin.h"

void TextImageGatherScatter(paddle::Tensor& input,
paddle::Tensor& text_input,
paddle::Tensor& image_input,
paddle::Tensor& token_type_ids,
paddle::Tensor& text_index,
paddle::Tensor& image_index,
const bool is_scatter) {
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif

std::vector<paddle::Tensor> TextImageGatherScatter(
paddle::Tensor& input,
paddle::Tensor& text_input,
paddle::Tensor& image_input,
paddle::Tensor& token_type_ids,
paddle::Tensor& text_index,
paddle::Tensor& image_index,
const bool is_scatter) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
Expand Down Expand Up @@ -58,22 +63,19 @@ void TextImageGatherScatter(paddle::Tensor& input,
break;
}
}
return {input, text_input, image_input};
}

PD_BUILD_OP(text_image_gather_scatter)
PD_BUILD_STATIC_OP(text_image_gather_scatter)
Comment on lines -63 to +69
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我看xpu的其他自定义算子都用的 PD_BUILD_OP 而不是 PD_BUILD_STATIC_OP,这里是否需要用 PD_BUILD_STATIC_OP 需要再确认下~

.Inputs({"input",
"text_input",
"image_input",
"token_type_ids",
"text_index",
"image_index"})
.Outputs({"text_input_out",
"image_input_out",
"text_index_out",
"image_index_out"})
.Outputs({"output", "text_input_out", "image_input_out"})
.Attrs({"is_scatter:bool"})
.SetInplaceMap({{"text_input", "text_input_out"},
{"image_input", "image_input_out"},
{"text_index", "text_index_out"},
{"image_index", "image_index_out"}})
.SetInplaceMap({{"input", "output"},
{"text_input", "text_input_out"},
{"image_input", "image_input_out"}})
.SetKernelFn(PD_KERNEL(TextImageGatherScatter));
Loading