Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions paddle/fluid/eager/auto_code_generator/generator/eager_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
"put_along_axis_double_grad",
"masked_fill_double_grad",
"index_elementwise_put_with_tensor_double_grad",
"view_shape_double_grad",
]

# white ops list whose kernel can automatically do type promotion.
Expand Down Expand Up @@ -3170,12 +3171,7 @@ def GenerateNodeDefinition(
)

grad_api_args[grad_api_position] = name
if (
not is_invoke_forward_api
or name in self.grad_api_contents['invoke']
):
# NOTE: attr 'dims' is not necessary for 'invoke: view_shape(out_grad, input.shape())'
get_grad_in_args_list.append(get_attr_str)
get_grad_in_args_list.append(get_attr_str)

get_grad_in_args_str = "\n".join(get_grad_in_args_list)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,5 @@
'masked_select_grad',
'index_elementwise_get_grad',
'index_elementwise_put_with_tensor_grad',
'view_shape_grad',
]
Original file line number Diff line number Diff line change
Expand Up @@ -1296,5 +1296,16 @@ void index_elementwise_put_with_tensor_double_grad(
}
}

template <typename T>
void view_shape_double_grad(const Tensor& grad_input_grad,
const std::vector<int64_t> dims,
Tensor* grad_out_grad) {
if (grad_out_grad) {
Tensor grad_out_grad_tmp;
grad_out_grad_tmp = reshape<T>(grad_input_grad, dims);
set_output<T>(grad_out_grad_tmp, grad_out_grad);
}
}

} // namespace prim
} // namespace paddle
20 changes: 20 additions & 0 deletions paddle/phi/ops/yaml/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4025,6 +4025,26 @@
data_type : out_grad
no_need_buffer: input

- backward_op : view_shape_double_grad
forward : view_shape_grad (Tensor input, Tensor grad_out, int64_t[] dims) -> Tensor(grad_input)
args : (Tensor grad_input_grad, int64_t[] dims)
output : Tensor(grad_out_grad)
infer_meta :
func : StridedUnChangedInferMeta
param : [grad_input_grad]
composite: view_shape_double_grad(grad_input_grad, dims, grad_out_grad)

- backward_op : view_shape_grad
forward : view_shape (Tensor input, int64_t[] dims = {}) -> Tensor(out)
args : (Tensor input, Tensor out_grad, int64_t[] dims = {})
output : Tensor(input_grad)
infer_meta :
func : StridedUnChangedInferMeta
param : [input]
kernel :
func : view_shape_grad
backward : view_shape_double_grad

- backward_op : warpctc_grad
forward : warpctc (Tensor logits, Tensor label, Tensor logits_length, Tensor labels_length, int blank = 0, bool norm_by_times = false) -> Tensor(loss), Tensor(warpctcgrad)
args : (Tensor logits, Tensor logits_length, Tensor warpctcgrad, Tensor loss_grad, int blank, bool norm_by_times)
Expand Down
7 changes: 0 additions & 7 deletions paddle/phi/ops/yaml/inconsistent/dygraph_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -361,13 +361,6 @@
composite : tile_grad(x, out_grad, repeat_times, x_grad)
backward : tile_double_grad

- backward_op : view_shape_grad
forward : view_shape (Tensor input, int64_t[] dims = {}) -> Tensor(out)
args : (Tensor input, Tensor out_grad, int64_t[] dims = {})
output : Tensor(input_grad)
invoke: view_shape(out_grad, input.shape())
no_need_buffer: input

- backward_op: fused_gemm_epilogue_grad
forward : fused_gemm_epilogue(Tensor x, Tensor y, Tensor bias, bool trans_x, bool trans_y, str activation) -> Tensor(out), Tensor(reserve_space)
args : (Tensor x, Tensor y, Tensor reserve_space, Tensor out_grad, bool trans_x, bool trans_y, str activation)
Expand Down
9 changes: 0 additions & 9 deletions paddle/phi/ops/yaml/inconsistent/dygraph_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -404,12 +404,3 @@
data_type : x
optional : indices, inverse, counts
traits : paddle::dialect::ForwardOnlyTrait

- op : view_shape
args : (Tensor input, int64_t[] dims = {})
output : Tensor(out)
infer_meta :
func : ViewShapeInferMeta
kernel :
func : view_shape
backward : view_shape_grad
3 changes: 3 additions & 0 deletions paddle/phi/ops/yaml/op_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4044,6 +4044,9 @@
get_expected_kernel_type :
update_loss_scaling_ : GetUpdateLossScalingExpectedKernelType

- op : view_shape
backward : view_shape_grad, view_shape_double_grad

- op : viterbi_decode
inputs :
{potentials : Input, transition_params : Transition, lengths : Length}
Expand Down
10 changes: 10 additions & 0 deletions paddle/phi/ops/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5817,6 +5817,16 @@
no_need_buffer : input
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : view_shape
args : (Tensor input, int64_t[] dims = {})
output : Tensor(out)
infer_meta :
func : ViewShapeInferMeta
kernel :
func : view_shape
backward : view_shape_grad
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : view_slice
args : (Tensor input, int64_t begin_idx, int64_t end_idx)
output : Tensor
Expand Down
1 change: 1 addition & 0 deletions test/legacy_test/test_stride.py
Original file line number Diff line number Diff line change
Expand Up @@ -971,6 +971,7 @@ def call_stride(self):
self.call_view5()
self.call_view6()
self.call_view7()
self.call_view8()
self.call_view9()
self.call_view10()
self.call_view11()
Expand Down
Loading