@@ -112,8 +112,66 @@ TEST(BroadcastIndexesRangeTest, OneAndTwoDExhaustive) {
112
112
EXPECT_EQ (expected, actual);
113
113
}
114
114
115
- // Here we assume that the previous tests established that padding
116
- // with leading 1s is working, and test:
115
+ // Make sure nothing is thrown off by a size-1 dim in the output:
116
+ // [] -> [1, W]
117
+ // [] -> [H, 1]
118
+ // [1] -> [1, W]
119
+ // [1] -> [H, 1]
120
+ // [W] -> [1, W]
121
+ // [1, 1] -> [1, W]
122
+ // [1, 1] -> [H, 1]
123
+ // [1, W] -> [1, W]
124
+ // [H, 1] -> [H, 1]
125
+ TEST (BroadcastIndexesRangeTest, OneAndTwoDWith1InOutputShapeExhaustive) {
126
+ TensorFactory<ScalarType::Int> tf;
127
+ constexpr auto H = 2 ;
128
+ constexpr auto W = 3 ;
129
+ Tensor out_row = tf.zeros ({1 , W});
130
+ Tensor out_col = tf.zeros ({H, 1 });
131
+ Tensor in_0d_scalar = tf.zeros ({});
132
+ Tensor in_1d_scalar = tf.zeros ({1 });
133
+ Tensor in_2d_scalar = tf.zeros ({1 , 1 });
134
+
135
+ Tensor in_row = tf.zeros ({W});
136
+ Tensor in_leading_one_row = tf.zeros ({1 , W});
137
+
138
+ Tensor in_col = tf.zeros ({H, 1 });
139
+
140
+ size_t idx = 0 ;
141
+ for (const auto
142
+ [out_idx,
143
+ in_0d_idx,
144
+ in_1d_idx,
145
+ in_2d_idx,
146
+ in_row_idx,
147
+ in_leading_one_row_idx] :
148
+ BroadcastIndexesRange<5 >(
149
+ out_row,
150
+ in_0d_scalar,
151
+ in_1d_scalar,
152
+ in_2d_scalar,
153
+ in_row,
154
+ in_leading_one_row)) {
155
+ EXPECT_EQ (out_idx, idx++);
156
+ EXPECT_EQ (in_0d_idx, 0 );
157
+ EXPECT_EQ (in_1d_idx, 0 );
158
+ EXPECT_EQ (in_2d_idx, 0 );
159
+ EXPECT_EQ (in_row_idx, out_idx);
160
+ EXPECT_EQ (in_leading_one_row_idx, out_idx);
161
+ }
162
+
163
+ idx = 0 ;
164
+ for (const auto [out_idx, in_0d_idx, in_1d_idx, in_2d_idx, in_col_idx] :
165
+ BroadcastIndexesRange<4 >(
166
+ out_col, in_0d_scalar, in_1d_scalar, in_2d_scalar, in_col)) {
167
+ EXPECT_EQ (out_idx, idx++);
168
+ EXPECT_EQ (in_0d_idx, 0 );
169
+ EXPECT_EQ (in_1d_idx, 0 );
170
+ EXPECT_EQ (in_2d_idx, 0 );
171
+ EXPECT_EQ (in_col_idx, out_idx);
172
+ }
173
+ }
174
+
117
175
// [1, 1, 1] -> [C, H, W]
118
176
// [C, H, 1] -> [C, H, W]
119
177
// [C, 1, W] -> [C, H, W]
@@ -166,11 +224,12 @@ TEST(BroadcastIndexesRangeTest, ThreeDBroadcasting) {
166
224
// 4-D should generalize, but we will go ahead and test:
167
225
// [N, 1, H, 1] -> [N, C, H, W]
168
226
// [1, C, 1, W] -> [N, C, H, W]
169
- TEST (BroadcastIndexesRangeTest, FourDBroadcasting) {
227
+ template <size_t N, size_t C, size_t H, size_t W>
228
+ void four_d_broadcasting_test () {
170
229
TensorFactory<ScalarType::Int> tf;
171
- Tensor out = tf.zeros ({2 , 3 , 4 , 5 });
172
- Tensor in_broadcast_cw = tf.zeros ({2 , 1 , 4 , 1 });
173
- Tensor in_broadcast_nh = tf.zeros ({1 , 3 , 1 , 5 });
230
+ Tensor out = tf.zeros ({N, C, H, W });
231
+ Tensor in_broadcast_cw = tf.zeros ({N , 1 , H , 1 });
232
+ Tensor in_broadcast_nh = tf.zeros ({1 , C , 1 , W });
174
233
175
234
// Writing out all the indexes would be too cumbersome, so here we
176
235
// take the opportunity to mutation test against delinearize_index
@@ -190,3 +249,12 @@ TEST(BroadcastIndexesRangeTest, FourDBroadcasting) {
190
249
linearize_access_indexes (out_indexes, out.dim (), in_broadcast_nh));
191
250
}
192
251
}
252
+
253
+ TEST (BroadcastIndexesRangeTest, FourDBroadcasting) {
254
+ four_d_broadcasting_test<2 , 3 , 4 , 5 >();
255
+ }
256
+
257
+ TEST (BroadcastIndexesRangeTest, FourDBroadcastingWithOneDimsInOutput) {
258
+ four_d_broadcasting_test<2 , 3 , 1 , 5 >();
259
+ four_d_broadcasting_test<2 , 1 , 3 , 1 >();
260
+ }
0 commit comments