Skip to content

Add BuildCinnPass #36345

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged

Conversation

thisjiang
Copy link
Contributor

@thisjiang thisjiang commented Oct 11, 2021

PR types

New features

PR changes

APIs

Describe

如题,添加BuildCinnPass 类,用于从Graph中筛选出所有CINN支持的子图。同时,在原Graph中用一个特殊Op替换替换,其Op名为kCinnLaunchOp,Graph中相连var node相应调整,此外CINN不支持的Op保持不变。

用例

其输入为一个Graph,输出为std::vector<std::unique_ptr<Graph>>,并通过cinn_subgraphs属性返回。

auto pass = paddle::framework::ir::PassRegistry::Instance().Get("build_cinn_pass");
std::vector<std::unique_ptr<Graph>> cinn_subgraphs;
pass->SetNotOwned<std::vector<std::unique_ptr<Graph>>>("cinn_subgraphs", &cinn_subgraphs);
pass->Apply(graph);

上述变量cinn_subgraphs即为我们筛选出的所有CINN支持的子图。

示例

详情示例请见单测文件:build_cinn_pass_test.cc, 具体而言,分为如下几种情况:

Graph中不含任何CINN支持的Op

var1 --
       | --> fake1 --> var3 --> fake2 --> var4
var2 --

Pass前后Graph保持不变,且cinn_subgraphs为空。

Graph由全为CINN支持的OP组成

v1 --
     |
     | --> mul --> v3 --
     |                  |
v2 --                   | --> add --> v5 --> relu --> v6
                        |
                   v4 --

Pass后原Graph变为:

v1 --|
v2 --| --> kCinnLaunchOp --> v6
v4 --|

cinn_subgraphs包含一个子图:

mul --> v3 --> add --> v5 --> relu

Graph含一个CINN子图

fake1 --> v1 --
               |
               | --> mul --> v3 --> relu --> v4 --> fake2
               |
          v2 --

Pass后原Graph变为:

fake1 --> v1 --
               | --> kCinnLaunchOp --> v4 --> fake2
          v2 --

cinn_subgraphs包含一个子图:

mul --> v3 --> relu

Graph包含多个分离的CINN子图

fake1 --> v1 --
               |
               | --> mul --> v3 --> fake2 --> v4 --> relu --> v5 --> fake3
               |
          v2 --

Pass后原Graph变为:

fake1 -> v1 -
             | -> CinnOp -> v3 -> fake2 -> v4 -> CinnOp ->v5 -> fake3
         v2 -

cinn_subgraphs包含2个单op子图:

subgraph1: relu
subgraph2: mul

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

Copy link
Member

@zhhsplendid zhhsplendid left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I created paddle_to_cinn_pass as a pass entry, what's the relation between paddle_to_cinn_pass and cinn_subgraph_search_pass? If they have same functionality, can you remove paddle_to_cinn_pass or merge them together? But you can do it in next PR :-)

// just for local compile
namespace cinn {
namespace frontend {
class OpMapperRegistry {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a singleton? If so, a singleton usually has a pattern:

  1. Replace your Global function by GetInstance
  2. Hide constructor (set it to private in C++)

If it is a singleton, follow the pattern so that the future readers can get more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As notes, the class is just for local compile and test, the complete class definition at https://github.com/PaddlePaddle/CINN/blob/develop/cinn/frontend/op_mapper_registry.h , and the APIs are the same as CINN. After bind paddle and CINN, the code will remove.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add some more explicit notes:

// TODO(jiangcheng05): just for local compile, remove after
// paddle and CINN have been bound.
// The APIs are the same as CINN:
// https://github.com/PaddlePaddle/CINN/blob/develop/cinn/utils/registry.h

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, my mistake, I didn't read the PR description well.


for (auto* op : cluster) {
auto sub_node = sub_graph->CreateOpNode(op->Op());
sub_node->inputs = op->inputs;
Copy link
Member

@zhhsplendid zhhsplendid Oct 12, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

op->inputs/outputs are std::vector<Node*>, will this sub_graph point to nodes of old graphs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

// and add cluster_internals node
for (auto* var_node : *cluster_outputs) {
if (cluster_inputs->count(var_node) > 0) {
// if a input node also existed in output list, remove
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: grammar "existed" -> "exists"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

// Replacing Cinn subgraph to a special op node, whose op_type is
// kCinnSubgraphSearchOpName, and input is cluster_inputs and
// outputs is cluster_outputs.
// Meanwhile, remove all cluster node from cluster_inputs and cluster_outputs.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: "remove all cluster" -> "remove all clusters"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

// kCinnSubgraphSearchOpName, and input is cluster_inputs and
// outputs is cluster_outputs.
// Meanwhile, remove all cluster node from cluster_inputs and cluster_outputs.
void ReplaceSubGraphToSpecialOpNode(Graph* graph, const GraphNodeSet& cluster,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function will change the graph, so the graph is the output parameter and others are input parameters. Usually we would put input parameters before output parameters in coding style.

Reference: https://google.github.io/styleguide/cppguide.html#Inputs_and_Outputs

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

void SearchAllSubgraphs(Graph* graph,
std::vector<std::unique_ptr<Graph>>* cinn_subgraphs) {
auto teller = [](const Node* node) {
return ::cinn::frontend::OpMapperRegistry::Global()->Find(node->Name()) !=
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要考虑支持的设备类型吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

现在应该不用吧?现有的OpMapperRegistry也没有考虑设备

namespace framework {
namespace paddle2cinn {

constexpr char kCinnSubgraphSearchOpName[] = "cinn_subgraph_search_op";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个现在是初步定了名字吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

并没有。。。还想着review的时候定下来


// removing useless link from cluster_inputs to cluster
for (auto* var_node : cluster_inputs) {
auto preserved_node = get_preserved_ops(var_node->outputs);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

preserved_node --> preserved_nodes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


// removing useless link from cluster to cluster_outputs
for (auto* var_node : cluster_outputs) {
auto preserved_node = get_preserved_ops(var_node->inputs);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

preserved_node --> preserved_nodes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Graph* graph) {
auto special_op_node =
AddSpecialOpToGraph(graph, cluster_inputs, cluster_outputs);
RemoveUselessLink(cluster, cluster_inputs, cluster_outputs);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why need this statement? Please add some comments here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RemoveUselessLink这步是为了移除原Graph中连接到subgraph内部node的那些连接,这里我改下函数名再加些注释吧~

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@thisjiang thisjiang changed the title Add CinnSubgraphSearchPass Add BuildCinnPass Oct 13, 2021
@@ -71,11 +71,6 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
"modify_op_lock_and_record_event_pass");
// Note: This pass is used to check whether the multi_device_graph is right.
AppendPass("multi_devices_check_pass");

// Note: This pass is used to enable cinn.
if (FLAGS_use_cinn) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

现在这个pass是默认启用吗,还是只是注册没开启

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

现在build_cinn_pass是没有用flag控制的,也就是默认开启,我用这个FLAGS_use_cinn控制下。这个flag现在是默认false的

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我弄错了,不用加这个判断语句,默认是只注册没开启的,由于Pass需要传入传出cinn_subgraphs属性,所有build_cinn_pass这个我觉得应该在用的时候再调用而非在这儿AppendPass

wzzju
wzzju previously approved these changes Oct 14, 2021
Copy link
Contributor

@wzzju wzzju left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Member

@zhhsplendid zhhsplendid left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@XieYunshen XieYunshen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@thisjiang thisjiang merged commit b3f02c5 into PaddlePaddle:develop Oct 15, 2021
@thisjiang thisjiang deleted the add_cinn_subgraph_search_pass branch October 15, 2021 04:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants