1414namespace torch {
1515namespace executor {
1616namespace native {
17+ namespace internal {
18+
19+ template <
20+ bool can_cast,
21+ template <typename >
22+ typename OpFunc,
23+ typename CTYPE_A,
24+ typename CTYPE_B,
25+ typename CTYPE_IN,
26+ typename CTYPE_OUT>
27+ struct BitwiseOpInner ;
28+
29+ template <
30+ template <typename >
31+ typename OpFunc,
32+ typename CTYPE_A,
33+ typename CTYPE_B,
34+ typename CTYPE_IN,
35+ typename CTYPE_OUT>
36+ struct BitwiseOpInner <true , OpFunc, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> {
37+ static void run (const Tensor& a, const Tensor& b, Tensor& out) {
38+ apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
39+ // NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue)
40+ [](const CTYPE_A val_a, const CTYPE_B val_b) {
41+ CTYPE_IN a_casted = static_cast <CTYPE_IN>(val_a);
42+ CTYPE_IN b_casted = static_cast <CTYPE_IN>(val_b);
43+ CTYPE_IN value = OpFunc<CTYPE_IN>()(a_casted, b_casted);
44+
45+ return static_cast <CTYPE_OUT>(value);
46+ },
47+ a,
48+ b,
49+ out);
50+ }
51+ };
52+
53+ struct ReportCanCastBug {
54+ static void run (const Tensor&, const Tensor&, Tensor&) {
55+ ET_DCHECK_MSG (false , " BUG: canCast should have been checked above" );
56+ }
57+ };
58+
59+ template <
60+ template <typename >
61+ typename OpFunc,
62+ typename CTYPE_A,
63+ typename CTYPE_B,
64+ typename CTYPE_IN,
65+ typename CTYPE_OUT>
66+ struct BitwiseOpInner <false , OpFunc, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
67+ : public ReportCanCastBug {};
68+
69+ } // namespace internal
70+
1771template <template <typename > typename OpFunc>
1872Tensor& bitwise_op_out (
1973 RuntimeContext& ctx,
@@ -36,21 +90,17 @@ Tensor& bitwise_op_out(
3690
3791 ET_SWITCH_INT_TYPES_AND (Bool, a_type, ctx, op_name, CTYPE_A, [&]() {
3892 ET_SWITCH_INT_TYPES_AND (Bool, b_type, ctx, op_name, CTYPE_B, [&]() {
39- ET_SWITCH_INT_TYPES_AND (Bool, common_type, ctx, op_name, CTYPE_IN, [&]() {
40- ET_SWITCH_REAL_TYPES_AND (
41- Bool, out_type, ctx, op_name, CTYPE_OUT, [&]() {
42- apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
43- [](const CTYPE_A val_a, const CTYPE_B val_b) {
44- CTYPE_IN a_casted = static_cast <CTYPE_IN>(val_a);
45- CTYPE_IN b_casted = static_cast <CTYPE_IN>(val_b);
46- CTYPE_IN value = OpFunc<CTYPE_IN>()(a_casted, b_casted);
47-
48- return static_cast <CTYPE_OUT>(value);
49- },
50- a,
51- b,
52- out);
53- });
93+ using CTYPE_IN =
94+ typename torch::executor::promote_types<CTYPE_A, CTYPE_B>::type;
95+ ET_DCHECK (CppTypeToScalarType<CTYPE_IN>::value == common_type);
96+ ET_SWITCH_REAL_TYPES_AND (Bool, out_type, ctx, op_name, CTYPE_OUT, [&]() {
97+ internal::BitwiseOpInner<
98+ can_cast<CTYPE_IN, CTYPE_OUT>::value,
99+ OpFunc,
100+ CTYPE_A,
101+ CTYPE_B,
102+ CTYPE_IN,
103+ CTYPE_OUT>::run (a, b, out);
54104 });
55105 });
56106 });
0 commit comments