Skip to content

Commit cb95587

Browse files
committed
"ignore some gradient of specific op"
1 parent bf4da3d commit cb95587

File tree

2 files changed

+16
-6
lines changed

2 files changed

+16
-6
lines changed

paddle/framework/op_proto.proto

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,11 @@ message VarProto {
8484
// "temporary_index": [1]
8585
// }
8686
optional bool temporary = 4 [default=false];
87+
88+
// The gradient of operator can be ignored immediately
89+
// e.g. operator AddOp, y = x1 + x2, the gradient of dy/dx1, dy/dx2
90+
// can be ignored for the future optimized on graph.
91+
optional bool ignore_gradient = 6;
8792
}
8893

8994
// Op protocol message for 3rd-party language binding.
@@ -105,4 +110,5 @@ message OpProto {
105110

106111
// The type of that Op.
107112
required string type = 5;
113+
108114
}

paddle/framework/op_registry.h

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -74,25 +74,29 @@ class OpProtoAndCheckerMaker {
7474

7575
protected:
7676
void AddInput(const std::string& name, const std::string& comment,
77-
bool multiple = false) {
77+
bool multiple = false, bool ignore_gradient = false) {
7878
auto input = proto_->mutable_inputs()->Add();
7979
*input->mutable_name() = name;
8080
*input->mutable_comment() = comment;
81+
*input->set_ignore_gradient(ignore_gradient);
8182
input->set_multiple(multiple);
8283
if (multiple) {
8384
SetHasMultipleInput();
8485
}
8586
}
8687

87-
void AddInputs(const std::string& name, const std::string& comment) {
88-
AddInput(name, comment, true);
88+
void AddInputs(const std::string& name, const std::string& comment,
89+
bool ignore_gradient = false) {
90+
AddInput(name, comment, true, ignore_gradient);
8991
}
9092

9193
void AddOutput(const std::string& name, const std::string& comment,
92-
bool temporary = false, bool multiple = false) {
94+
bool temporary = false, bool multiple = false,
95+
bool ignore_gradient = false) {
9396
auto output = proto_->mutable_outputs()->Add();
9497
*output->mutable_name() = name;
9598
*output->mutable_comment() = comment;
99+
*output->set_ignore_gradient(ignore_gradient);
96100
output->set_multiple(multiple);
97101
if (multiple) {
98102
SetHasMultipleOutput();
@@ -104,8 +108,8 @@ class OpProtoAndCheckerMaker {
104108
}
105109

106110
void AddOutputs(const std::string& name, const std::string& comment,
107-
bool temporary = false) {
108-
AddOutput(name, comment, temporary, true);
111+
bool temporary = false, bool ignore_gradient = false) {
112+
AddOutput(name, comment, temporary, true, ignore_gradient);
109113
}
110114

111115
template <typename T>

0 commit comments

Comments
 (0)