Skip to content

Commit 3ea11cf

Browse files
committed
Update
[ghstack-poisoned]
1 parent 948fba6 commit 3ea11cf

File tree

4 files changed

+64
-100
lines changed

4 files changed

+64
-100
lines changed

kernels/portable/cpu/op_argmax.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ Tensor& argmax_out(
4343
ET_KERNEL_CHECK(
4444
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
4545

46-
ET_SWITCH_REAL_TYPES(in.scalar_type(), ctx, "argmax.out", CTYPE, [&] {
46+
ET_SWITCH_REALHBF16_TYPES(in.scalar_type(), ctx, "argmax.out", CTYPE, [&] {
4747
long* out_data = out.mutable_data_ptr<long>();
4848

4949
for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) {

kernels/portable/cpu/op_argmin.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ Tensor& argmin_out(
4343
ET_KERNEL_CHECK(
4444
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
4545

46-
ET_SWITCH_REAL_TYPES(in.scalar_type(), ctx, "argmin.out", CTYPE, [&] {
46+
ET_SWITCH_REALHBF16_TYPES(in.scalar_type(), ctx, "argmin.out", CTYPE, [&] {
4747
long* out_data = out.mutable_data_ptr<long>();
4848

4949
for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) {

kernels/test/op_argmax_test.cpp

+31-49
Original file line numberDiff line numberDiff line change
@@ -31,57 +31,39 @@ class OpArgmaxTest : public OperatorTest {
3131
Tensor& out) {
3232
return torch::executor::aten::argmax_outf(context_, in, dim, keepdim, out);
3333
}
34-
};
35-
36-
TEST_F(OpArgmaxTest, SanityCheckLong) {
37-
TensorFactory<ScalarType::Long> tf;
38-
39-
// clang-format off
40-
Tensor in = tf.make(
41-
{ 2, 3, 4 },
42-
{ 1, 4, 1, 6,
43-
5, 8, 5, 6,
44-
5, 3, 9, 2,
45-
46-
3, 9, 1, 4,
47-
9, 7, 5, 5,
48-
7, 7, 6, 3 });
49-
50-
Tensor out = tf.zeros({2, 4});
51-
Tensor expected = tf.make({2, 4}, {
52-
1, 1, 2, 0,
53-
1, 0, 2, 1 });
54-
Tensor ret = op_argmax_out(in, 1, false, out);
55-
56-
EXPECT_TENSOR_EQ(out, ret);
57-
EXPECT_TENSOR_EQ(out, expected);
58-
// clang-format on
59-
}
60-
61-
TEST_F(OpArgmaxTest, SanityCheckShort) {
62-
TensorFactory<ScalarType::Long> tfl;
63-
TensorFactory<ScalarType::Short> tfs;
6434

65-
// clang-format off
66-
Tensor in = tfs.make(
67-
{ 2, 3, 4 },
68-
{ 1, 4, 1, 6,
69-
5, 8, 5, 6,
70-
5, 3, 9, 2,
71-
72-
3, 9, 1, 4,
73-
9, 7, 5, 5,
74-
7, 7, 6, 3 });
75-
76-
Tensor out = tfl.zeros({2, 4});
77-
Tensor expected = tfl.make({2, 4}, {
78-
1, 1, 2, 0,
79-
1, 0, 2, 1 });
80-
Tensor ret = op_argmax_out(in, 1, false, out);
35+
template <ScalarType dtype>
36+
void test_argmax_dtype() {
37+
TensorFactory<ScalarType::Long> tfl;
38+
TensorFactory<dtype> tf_dtype;
39+
40+
// clang-format off
41+
Tensor in = tf_dtype.make(
42+
{ 2, 3, 4 },
43+
{ 1, 4, 1, 6,
44+
5, 8, 5, 6,
45+
5, 3, 9, 2,
46+
47+
3, 9, 1, 4,
48+
9, 7, 5, 5,
49+
7, 7, 6, 3 });
50+
51+
Tensor out = tfl.zeros({2, 4});
52+
Tensor expected = tfl.make({2, 4}, {
53+
1, 1, 2, 0,
54+
1, 0, 2, 1 });
55+
Tensor ret = op_argmax_out(in, 1, false, out);
56+
57+
EXPECT_TENSOR_EQ(out, ret);
58+
EXPECT_TENSOR_EQ(out, expected);
59+
// clang-format on
60+
}
61+
};
8162

82-
EXPECT_TENSOR_EQ(out, ret);
83-
EXPECT_TENSOR_EQ(out, expected);
84-
// clang-format on
63+
TEST_F(OpArgmaxTest, SanityCheck) {
64+
#define TEST_ENTRY(ctype, dtype) test_argmax_dtype<ScalarType::dtype>();
65+
ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
66+
#undef TEST_ENTRY
8567
}
8668

8769
TEST_F(OpArgmaxTest, SanityCheckNullDim) {

kernels/test/op_argmin_test.cpp

+31-49
Original file line numberDiff line numberDiff line change
@@ -31,57 +31,39 @@ class OpArgminTest : public OperatorTest {
3131
Tensor& out) {
3232
return torch::executor::aten::argmin_outf(context_, in, dim, keepdim, out);
3333
}
34-
};
35-
36-
TEST_F(OpArgminTest, SanityCheckLong) {
37-
TensorFactory<ScalarType::Long> tf;
38-
39-
// clang-format off
40-
Tensor in = tf.make(
41-
{ 2, 3, 4 },
42-
{ 1, 4, 1, 6,
43-
5, 8, 5, 6,
44-
5, 3, 9, 2,
45-
46-
3, 9, 1, 4,
47-
9, 7, 5, 5,
48-
7, 7, 6, 3 });
49-
50-
Tensor out = tf.zeros({2, 4});
51-
Tensor expected = tf.make({2, 4}, {
52-
0, 2, 0, 2,
53-
0, 1, 0, 2 });
54-
Tensor ret = op_argmin_out(in, 1, false, out);
55-
56-
EXPECT_TENSOR_EQ(out, ret);
57-
EXPECT_TENSOR_EQ(out, expected);
58-
// clang-format on
59-
}
60-
61-
TEST_F(OpArgminTest, SanityCheckShort) {
62-
TensorFactory<ScalarType::Long> tfl;
63-
TensorFactory<ScalarType::Short> tfs;
6434

65-
// clang-format off
66-
Tensor in = tfs.make(
67-
{ 2, 3, 4 },
68-
{ 1, 4, 1, 6,
69-
5, 8, 5, 6,
70-
5, 3, 9, 2,
71-
72-
3, 9, 1, 4,
73-
9, 7, 5, 5,
74-
7, 7, 6, 3 });
75-
76-
Tensor out = tfl.zeros({2, 4});
77-
Tensor expected = tfl.make({2, 4}, {
78-
0, 2, 0, 2,
79-
0, 1, 0, 2 });
80-
Tensor ret = op_argmin_out(in, 1, false, out);
35+
template <ScalarType dtype>
36+
void test_argmin_dtype() {
37+
TensorFactory<ScalarType::Long> tfl;
38+
TensorFactory<dtype> tf_dtype;
39+
40+
// clang-format off
41+
Tensor in = tf_dtype.make(
42+
{ 2, 3, 4 },
43+
{ 1, 4, 1, 6,
44+
5, 8, 5, 6,
45+
5, 3, 9, 2,
46+
47+
3, 9, 1, 4,
48+
9, 7, 5, 5,
49+
7, 7, 6, 3 });
50+
51+
Tensor out = tfl.zeros({2, 4});
52+
Tensor expected = tfl.make({2, 4}, {
53+
0, 2, 0, 2,
54+
0, 1, 0, 2 });
55+
Tensor ret = op_argmin_out(in, 1, false, out);
56+
57+
EXPECT_TENSOR_EQ(out, ret);
58+
EXPECT_TENSOR_EQ(out, expected);
59+
// clang-format on
60+
}
61+
};
8162

82-
EXPECT_TENSOR_EQ(out, ret);
83-
EXPECT_TENSOR_EQ(out, expected);
84-
// clang-format on
63+
TEST_F(OpArgminTest, SanityCheck) {
64+
#define TEST_ENTRY(ctype, dtype) test_argmin_dtype<ScalarType::dtype>();
65+
ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
66+
#undef TEST_ENTRY
8567
}
8668

8769
TEST_F(OpArgminTest, SanityCheckNullDim) {

0 commit comments

Comments
 (0)