88
99#include < cstring>
1010
11- #include < executorch/runtime/core/exec_aten/exec_aten.h>
12- #include < executorch/runtime/core/exec_aten/util/scalar_type_util.h>
13- #include < executorch/runtime/core/exec_aten/util/tensor_util.h>
14- #include < executorch/runtime/platform/assert.h>
11+ #include < executorch/kernels/portable/cpu/util/copy_ops_util.h>
1512
1613namespace torch {
1714namespace executor {
1815
1916using Tensor = exec_aten::Tensor;
2017
21- void check_cat_args (
18+ bool check_cat_args (
2219 exec_aten::ArrayRef<Tensor> tensors,
2320 int64_t dim,
2421 Tensor& out) {
2522 // Ensure the input tensors list is non-empty
26- ET_CHECK (tensors.size () > 0 );
23+ ET_LOG_AND_RETURN_IF_FALSE (tensors.size () > 0 );
2724
2825 // Find the first non-empty tensor in the list to use as a reference
2926 size_t ref_i = 0 ;
@@ -39,25 +36,30 @@ void check_cat_args(
3936 // https://pytorch.org/docs/stable/generated/torch.cat.html
4037 for (size_t i = 0 ; i < tensors.size (); ++i) {
4138 // All input dtypes must be castable to the output dtype.
42- ET_CHECK (canCast (tensors[i].scalar_type (), out.scalar_type ()));
39+ ET_LOG_AND_RETURN_IF_FALSE (
40+ canCast (tensors[i].scalar_type (), out.scalar_type ()));
4341
4442 // Empty tensors have no shape constraints.
4543 if (tensors[i].numel () == 0 ) {
4644 continue ;
4745 }
4846
4947 // All input tensors must have the same number of dimensions.
50- ET_CHECK (tensors[i].dim () == tensors[ref_i].dim ());
48+ ET_LOG_AND_RETURN_IF_FALSE (
49+ tensor_is_rank (tensors[ref_i], tensors[i].dim ()));
5150
5251 for (size_t d = 0 ; d < tensors[i].dim (); ++d) {
5352 if (d != dim) {
54- ET_CHECK (tensors[i].size (d) == tensors[ref_i].size (d));
53+ ET_LOG_AND_RETURN_IF_FALSE (
54+ tensors_have_same_size_at_dims (tensors[i], d, tensors[ref_i], d));
5555 }
5656 }
5757 }
5858
5959 // Ensure dim is in range.
60- ET_CHECK (dim >= 0 && dim < tensors[ref_i].dim ());
60+ ET_LOG_AND_RETURN_IF_FALSE (tensor_has_dim (tensors[ref_i], dim));
61+
62+ return true ;
6163}
6264
6365void get_cat_out_target_size (
@@ -86,9 +88,9 @@ void get_cat_out_target_size(
8688 }
8789}
8890
89- void check_permute_copy_args (const Tensor& in, IntArrayRef dims, Tensor& out) {
90- ET_CHECK (in. dim () == dims.size ());
91- ET_CHECK_SAME_DTYPE2 ( in, out);
91+ bool check_permute_copy_args (const Tensor& in, IntArrayRef dims, Tensor& out) {
92+ ET_LOG_AND_RETURN_IF_FALSE ( tensor_is_rank (in, dims.size () ));
93+ ET_LOG_AND_RETURN_IF_FALSE ( tensors_have_same_dtype ( in, out) );
9294
9395 // Make sure no dimensions are duplicated and all in the range [-in.dim(),
9496 // in.dim() - 1]. Use gaussian sum to check this.
@@ -98,13 +100,15 @@ void check_permute_copy_args(const Tensor& in, IntArrayRef dims, Tensor& out) {
98100 // Convert dimension to a non-negative number. dim_base is in the range
99101 // [0 .. in.dim() - 1].
100102 size_t dim = dims[i] > -1 ? dims[i] : in.dim () + dims[i];
101- ET_CHECK (dim >= 0 && dim < in. dim ( ));
103+ ET_LOG_AND_RETURN_IF_FALSE ( tensor_has_dim (in, dim));
102104 gauss_sum += dim + 1 ;
103105 }
104106
105- ET_CHECK_MSG (
107+ ET_LOG_MSG_AND_RETURN_IF_FALSE (
106108 gauss_sum == expected_sum,
107109 " The dims passed to permute_copy must contain one of each dim!" );
110+
111+ return true ;
108112}
109113
110114void get_permute_copy_out_target_size (
@@ -119,28 +123,32 @@ void get_permute_copy_out_target_size(
119123 }
120124}
121125
122- void check_stack_args (
126+ bool check_stack_args (
123127 exec_aten::ArrayRef<Tensor> tensors,
124128 int64_t dim,
125129 Tensor& out) {
126130 // Ensure the input tensors list is non-empty
127- ET_CHECK (tensors.size () > 0 );
131+ ET_LOG_AND_RETURN_IF_FALSE (tensors.size () > 0 );
128132
129133 // All input tensors need to be of the same size
130134 // https://pytorch.org/docs/stable/generated/torch.stack.html
131135 for (size_t i = 0 ; i < tensors.size (); i++) {
132136 // All input dtypes must be castable to the output dtype.
133- ET_CHECK (canCast (tensors[i].scalar_type (), out.scalar_type ()));
137+ ET_LOG_AND_RETURN_IF_FALSE (
138+ canCast (tensors[i].scalar_type (), out.scalar_type ()));
134139
135- ET_CHECK ( tensors[i]. dim () == tensors[0 ].dim ());
140+ ET_LOG_AND_RETURN_IF_FALSE ( tensor_is_rank ( tensors[i], tensors[0 ].dim () ));
136141 for (size_t d = 0 ; d < tensors[i].dim (); d++) {
137- ET_CHECK (tensors[i].size (d) == tensors[0 ].size (d));
142+ ET_LOG_AND_RETURN_IF_FALSE (
143+ tensors_have_same_size_at_dims (tensors[i], d, tensors[0 ], d));
138144 }
139145 }
140146
141147 // The output tensor will have a dimension inserted, so dim should be between
142148 // 0 and ndim_of_inputs + 1
143- ET_CHECK (dim >= 0 && dim < tensors[0 ].dim () + 1 );
149+ ET_LOG_AND_RETURN_IF_FALSE (dim >= 0 && dim < tensors[0 ].dim () + 1 );
150+
151+ return true ;
144152}
145153
146154void get_stack_out_target_size (
0 commit comments