-
Notifications
You must be signed in to change notification settings - Fork 5.8k
Add BuildCinnPass #36345
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add BuildCinnPass #36345
Conversation
Thanks for your contribution! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I created paddle_to_cinn_pass
as a pass entry, what's the relation between paddle_to_cinn_pass
and cinn_subgraph_search_pass
? If they have same functionality, can you remove paddle_to_cinn_pass
or merge them together? But you can do it in next PR :-)
// just for local compile | ||
namespace cinn { | ||
namespace frontend { | ||
class OpMapperRegistry { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this a singleton? If so, a singleton usually has a pattern:
- Replace your
Global
function byGetInstance
- Hide constructor (set it to private in C++)
If it is a singleton, follow the pattern so that the future readers can get more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As notes, the class is just for local compile and test, the complete class definition at https://github.com/PaddlePaddle/CINN/blob/develop/cinn/frontend/op_mapper_registry.h , and the APIs are the same as CINN. After bind paddle and CINN, the code will remove.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add some more explicit notes:
// TODO(jiangcheng05): just for local compile, remove after
// paddle and CINN have been bound.
// The APIs are the same as CINN:
// https://github.com/PaddlePaddle/CINN/blob/develop/cinn/utils/registry.h
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, my mistake, I didn't read the PR description well.
|
||
for (auto* op : cluster) { | ||
auto sub_node = sub_graph->CreateOpNode(op->Op()); | ||
sub_node->inputs = op->inputs; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
op->inputs/outputs
are std::vector<Node*>
, will this sub_graph
point to nodes of old graphs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
// and add cluster_internals node | ||
for (auto* var_node : *cluster_outputs) { | ||
if (cluster_inputs->count(var_node) > 0) { | ||
// if a input node also existed in output list, remove |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NIT: grammar "existed" -> "exists"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
// Replacing Cinn subgraph to a special op node, whose op_type is | ||
// kCinnSubgraphSearchOpName, and input is cluster_inputs and | ||
// outputs is cluster_outputs. | ||
// Meanwhile, remove all cluster node from cluster_inputs and cluster_outputs. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NIT: "remove all cluster" -> "remove all clusters"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
// kCinnSubgraphSearchOpName, and input is cluster_inputs and | ||
// outputs is cluster_outputs. | ||
// Meanwhile, remove all cluster node from cluster_inputs and cluster_outputs. | ||
void ReplaceSubGraphToSpecialOpNode(Graph* graph, const GraphNodeSet& cluster, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function will change the graph
, so the graph
is the output parameter and others are input parameters. Usually we would put input parameters before output parameters in coding style.
Reference: https://google.github.io/styleguide/cppguide.html#Inputs_and_Outputs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
… add_cinn_subgraph_search_pass
void SearchAllSubgraphs(Graph* graph, | ||
std::vector<std::unique_ptr<Graph>>* cinn_subgraphs) { | ||
auto teller = [](const Node* node) { | ||
return ::cinn::frontend::OpMapperRegistry::Global()->Find(node->Name()) != |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
需要考虑支持的设备类型吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
现在应该不用吧?现有的OpMapperRegistry
也没有考虑设备
namespace framework { | ||
namespace paddle2cinn { | ||
|
||
constexpr char kCinnSubgraphSearchOpName[] = "cinn_subgraph_search_op"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个现在是初步定了名字吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
并没有。。。还想着review的时候定下来
…ot have link to out-graph
|
||
// removing useless link from cluster_inputs to cluster | ||
for (auto* var_node : cluster_inputs) { | ||
auto preserved_node = get_preserved_ops(var_node->outputs); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
preserved_node
--> preserved_nodes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
|
||
// removing useless link from cluster to cluster_outputs | ||
for (auto* var_node : cluster_outputs) { | ||
auto preserved_node = get_preserved_ops(var_node->inputs); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
preserved_node
--> preserved_nodes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
Graph* graph) { | ||
auto special_op_node = | ||
AddSpecialOpToGraph(graph, cluster_inputs, cluster_outputs); | ||
RemoveUselessLink(cluster, cluster_inputs, cluster_outputs); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why need this statement? Please add some comments here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
RemoveUselessLink
这步是为了移除原Graph中连接到subgraph内部node的那些连接,这里我改下函数名再加些注释吧~
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
… add_cinn_subgraph_search_pass
@@ -71,11 +71,6 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { | |||
"modify_op_lock_and_record_event_pass"); | |||
// Note: This pass is used to check whether the multi_device_graph is right. | |||
AppendPass("multi_devices_check_pass"); | |||
|
|||
// Note: This pass is used to enable cinn. | |||
if (FLAGS_use_cinn) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
现在这个pass是默认启用吗,还是只是注册没开启
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
现在build_cinn_pass
是没有用flag控制的,也就是默认开启,我用这个FLAGS_use_cinn
控制下。这个flag现在是默认false的
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我弄错了,不用加这个判断语句,默认是只注册没开启的,由于Pass需要传入传出cinn_subgraphs
属性,所有build_cinn_pass
这个我觉得应该在用的时候再调用而非在这儿AppendPass
。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
…lose test_run_from_cinn
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
New features
PR changes
APIs
Describe
如题,添加
BuildCinnPass
类,用于从Graph中筛选出所有CINN支持的子图。同时,在原Graph中用一个特殊Op替换替换,其Op名为kCinnLaunchOp
,Graph中相连var node相应调整,此外CINN不支持的Op保持不变。用例
其输入为一个
Graph
,输出为std::vector<std::unique_ptr<Graph>>
,并通过cinn_subgraphs
属性返回。上述变量
cinn_subgraphs
即为我们筛选出的所有CINN支持的子图。示例
详情示例请见单测文件:
build_cinn_pass_test.cc
, 具体而言,分为如下几种情况:Graph中不含任何CINN支持的Op
Pass前后Graph保持不变,且
cinn_subgraphs
为空。Graph由全为CINN支持的OP组成
Pass后原Graph变为:
cinn_subgraphs
包含一个子图:Graph含一个CINN子图
Pass后原Graph变为:
cinn_subgraphs
包含一个子图:Graph包含多个分离的CINN子图
Pass后原Graph变为:
cinn_subgraphs
包含2个单op子图: