From b29ea1c869adf273b6f59eefdc23eae2647f6833 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Mon, 14 Apr 2025 11:17:25 -0700 Subject: [PATCH] Refactor internal switch cases (#9802) Summary: There were multiple ET_INTERNAL_SWITCH_CASE_*_TYPES macros that repeated lists of types, for example: `ET_INTERNAL_SWITCH_CASE_INT_TYPES` `ET_INTERNAL_SWITCH_CASE_REAL_TYPES` `ET_INTERNAL_SWITCH_CASE_FLOAT_TYPES` `ET_INTERNAL_SWITCH_CASE_ALL_TYPES` In this PR, we refactor them building from the bottom up, for example reusing INT and FLOAT to build REAL and reusing REAL to build ALL Reviewed By: swolchok Differential Revision: D72248082 --- .../core/exec_aten/util/scalar_type_util.h | 186 +++++++----------- 1 file changed, 72 insertions(+), 114 deletions(-) diff --git a/runtime/core/exec_aten/util/scalar_type_util.h b/runtime/core/exec_aten/util/scalar_type_util.h index d07052c2ec2..6f81146e925 100644 --- a/runtime/core/exec_aten/util/scalar_type_util.h +++ b/runtime/core/exec_aten/util/scalar_type_util.h @@ -921,55 +921,7 @@ struct promote_types { } \ }() -#define ET_INTERNAL_SWITCH_CASE_ALL_TYPES(CTYPE_ALIAS, ...) \ - ET_INTERNAL_SWITCH_CASE( \ - ::executorch::aten::ScalarType::Byte, CTYPE_ALIAS, __VA_ARGS__) \ - ET_INTERNAL_SWITCH_CASE( \ - ::executorch::aten::ScalarType::Char, CTYPE_ALIAS, __VA_ARGS__) \ - ET_INTERNAL_SWITCH_CASE( \ - ::executorch::aten::ScalarType::Short, CTYPE_ALIAS, __VA_ARGS__) \ - ET_INTERNAL_SWITCH_CASE( \ - ::executorch::aten::ScalarType::Int, CTYPE_ALIAS, __VA_ARGS__) \ - ET_INTERNAL_SWITCH_CASE( \ - ::executorch::aten::ScalarType::Long, CTYPE_ALIAS, __VA_ARGS__) \ - ET_INTERNAL_SWITCH_CASE( \ - ::executorch::aten::ScalarType::Half, CTYPE_ALIAS, __VA_ARGS__) \ - ET_INTERNAL_SWITCH_CASE( \ - ::executorch::aten::ScalarType::Float, CTYPE_ALIAS, __VA_ARGS__) \ - ET_INTERNAL_SWITCH_CASE( \ - ::executorch::aten::ScalarType::Double, CTYPE_ALIAS, __VA_ARGS__) \ - ET_INTERNAL_SWITCH_CASE( \ - ::executorch::aten::ScalarType::ComplexHalf, CTYPE_ALIAS, __VA_ARGS__) \ - ET_INTERNAL_SWITCH_CASE( \ - ::executorch::aten::ScalarType::ComplexFloat, CTYPE_ALIAS, __VA_ARGS__) \ - ET_INTERNAL_SWITCH_CASE( \ - ::executorch::aten::ScalarType::ComplexDouble, CTYPE_ALIAS, __VA_ARGS__) \ - ET_INTERNAL_SWITCH_CASE( \ - ::executorch::aten::ScalarType::Bool, CTYPE_ALIAS, __VA_ARGS__) \ - ET_INTERNAL_SWITCH_CASE( \ - ::executorch::aten::ScalarType::QInt8, CTYPE_ALIAS, __VA_ARGS__) \ - ET_INTERNAL_SWITCH_CASE( \ - ::executorch::aten::ScalarType::QUInt8, CTYPE_ALIAS, __VA_ARGS__) \ - ET_INTERNAL_SWITCH_CASE( \ - ::executorch::aten::ScalarType::QInt32, CTYPE_ALIAS, __VA_ARGS__) \ - ET_INTERNAL_SWITCH_CASE( \ - ::executorch::aten::ScalarType::BFloat16, CTYPE_ALIAS, __VA_ARGS__) \ - ET_INTERNAL_SWITCH_CASE( \ - ::executorch::aten::ScalarType::QUInt4x2, CTYPE_ALIAS, __VA_ARGS__) \ - ET_INTERNAL_SWITCH_CASE( \ - ::executorch::aten::ScalarType::QUInt2x4, CTYPE_ALIAS, __VA_ARGS__) \ - ET_INTERNAL_SWITCH_CASE( \ - ::executorch::aten::ScalarType::Bits1x8, CTYPE_ALIAS, __VA_ARGS__) \ - ET_INTERNAL_SWITCH_CASE( \ - ::executorch::aten::ScalarType::Bits2x4, CTYPE_ALIAS, __VA_ARGS__) \ - ET_INTERNAL_SWITCH_CASE( \ - ::executorch::aten::ScalarType::Bits4x2, CTYPE_ALIAS, __VA_ARGS__) \ - ET_INTERNAL_SWITCH_CASE( \ - ::executorch::aten::ScalarType::Bits8, CTYPE_ALIAS, __VA_ARGS__) \ - ET_INTERNAL_SWITCH_CASE( \ - ::executorch::aten::ScalarType::Bits16, CTYPE_ALIAS, __VA_ARGS__) - -#define ET_INTERNAL_SWITCH_CASE_REAL_TYPES(CTYPE_ALIAS, ...) \ +#define ET_INTERNAL_SWITCH_CASE_INT_TYPES(CTYPE_ALIAS, ...) \ ET_INTERNAL_SWITCH_CASE( \ ::executorch::aten::ScalarType::Byte, CTYPE_ALIAS, __VA_ARGS__) \ ET_INTERNAL_SWITCH_CASE( \ @@ -979,12 +931,73 @@ struct promote_types { ET_INTERNAL_SWITCH_CASE( \ ::executorch::aten::ScalarType::Int, CTYPE_ALIAS, __VA_ARGS__) \ ET_INTERNAL_SWITCH_CASE( \ - ::executorch::aten::ScalarType::Long, CTYPE_ALIAS, __VA_ARGS__) \ + ::executorch::aten::ScalarType::Long, CTYPE_ALIAS, __VA_ARGS__) + +#define ET_INTERNAL_SWITCH_CASE_UINT_TYPES(CTYPE_ALIAS, ...) \ + ET_INTERNAL_SWITCH_CASE( \ + ::executorch::aten::ScalarType::UInt16, CTYPE_ALIAS, __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE( \ + ::executorch::aten::ScalarType::UInt32, CTYPE_ALIAS, __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE( \ + ::executorch::aten::ScalarType::UInt64, CTYPE_ALIAS, __VA_ARGS__) + +#define ET_INTERNAL_SWITCH_CASE_FLOAT_TYPES(CTYPE_ALIAS, ...) \ ET_INTERNAL_SWITCH_CASE( \ ::executorch::aten::ScalarType::Float, CTYPE_ALIAS, __VA_ARGS__) \ ET_INTERNAL_SWITCH_CASE( \ ::executorch::aten::ScalarType::Double, CTYPE_ALIAS, __VA_ARGS__) +#define ET_INTERNAL_SWITCH_CASE_REAL_TYPES(CTYPE_ALIAS, ...) \ + ET_INTERNAL_SWITCH_CASE_INT_TYPES(CTYPE_ALIAS, __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE_FLOAT_TYPES(CTYPE_ALIAS, __VA_ARGS__) + +#define ET_INTERNAL_SWITCH_CASE_COMPLEX_TYPES(CTYPE_ALIAS, ...) \ + ET_INTERNAL_SWITCH_CASE( \ + ::executorch::aten::ScalarType::ComplexFloat, CTYPE_ALIAS, __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE( \ + ::executorch::aten::ScalarType::ComplexDouble, CTYPE_ALIAS, __VA_ARGS__) + +#define ET_INTERNAL_SWITCH_CASE_COMPLEXH_TYPES(CTYPE_ALIAS, ...) \ + ET_INTERNAL_SWITCH_CASE_COMPLEX_TYPES(CTYPE_ALIAS, __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE( \ + ::executorch::aten::ScalarType::ComplexHalf, CTYPE_ALIAS, __VA_ARGS__) + +#define ET_INTERNAL_SWITCH_CASE_QINT_TYPES(CTYPE_ALIAS, ...) \ + ET_INTERNAL_SWITCH_CASE( \ + ::executorch::aten::ScalarType::QInt8, CTYPE_ALIAS, __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE( \ + ::executorch::aten::ScalarType::QUInt8, CTYPE_ALIAS, __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE( \ + ::executorch::aten::ScalarType::QInt32, CTYPE_ALIAS, __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE( \ + ::executorch::aten::ScalarType::QUInt4x2, CTYPE_ALIAS, __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE( \ + ::executorch::aten::ScalarType::QUInt2x4, CTYPE_ALIAS, __VA_ARGS__) + +#define ET_INTERNAL_SWITCH_CASE_BITS_TYPES(CTYPE_ALIAS, ...) \ + ET_INTERNAL_SWITCH_CASE( \ + ::executorch::aten::ScalarType::Bits1x8, CTYPE_ALIAS, __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE( \ + ::executorch::aten::ScalarType::Bits2x4, CTYPE_ALIAS, __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE( \ + ::executorch::aten::ScalarType::Bits4x2, CTYPE_ALIAS, __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE( \ + ::executorch::aten::ScalarType::Bits8, CTYPE_ALIAS, __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE( \ + ::executorch::aten::ScalarType::Bits16, CTYPE_ALIAS, __VA_ARGS__) + +#define ET_INTERNAL_SWITCH_CASE_ALL_TYPES(CTYPE_ALIAS, ...) \ + ET_INTERNAL_SWITCH_CASE_REAL_TYPES(CTYPE_ALIAS, __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE( \ + ::executorch::aten::ScalarType::Half, CTYPE_ALIAS, __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE( \ + ::executorch::aten::ScalarType::BFloat16, CTYPE_ALIAS, __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE( \ + ::executorch::aten::ScalarType::Bool, CTYPE_ALIAS, __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE_COMPLEXH_TYPES(CTYPE_ALIAS, __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE_QINT_TYPES(CTYPE_ALIAS, __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE_BITS_TYPES(CTYPE_ALIAS, __VA_ARGS__) + #define ET_INTERNAL_SWITCH_CASE_REAL_TYPES_AND(ADDITIONAL, CTYPE_ALIAS, ...) \ ET_INTERNAL_SWITCH_CASE_REAL_TYPES(CTYPE_ALIAS, __VA_ARGS__) \ ET_INTERNAL_SWITCH_CASE( \ @@ -1008,29 +1021,11 @@ struct promote_types { ET_INTERNAL_SWITCH_CASE( \ ::executorch::aten::ScalarType::ADDITIONAL3, CTYPE_ALIAS, __VA_ARGS__) -#define ET_INTERNAL_SWITCH_CASE_INT_TYPES(CTYPE_ALIAS, ...) \ - ET_INTERNAL_SWITCH_CASE( \ - ::executorch::aten::ScalarType::Byte, CTYPE_ALIAS, __VA_ARGS__) \ - ET_INTERNAL_SWITCH_CASE( \ - ::executorch::aten::ScalarType::Char, CTYPE_ALIAS, __VA_ARGS__) \ - ET_INTERNAL_SWITCH_CASE( \ - ::executorch::aten::ScalarType::Short, CTYPE_ALIAS, __VA_ARGS__) \ - ET_INTERNAL_SWITCH_CASE( \ - ::executorch::aten::ScalarType::Int, CTYPE_ALIAS, __VA_ARGS__) \ - ET_INTERNAL_SWITCH_CASE( \ - ::executorch::aten::ScalarType::Long, CTYPE_ALIAS, __VA_ARGS__) - #define ET_INTERNAL_SWITCH_CASE_INT_TYPES_AND(ADDITIONAL, CTYPE_ALIAS, ...) \ ET_INTERNAL_SWITCH_CASE_INT_TYPES(CTYPE_ALIAS, __VA_ARGS__) \ ET_INTERNAL_SWITCH_CASE( \ ::executorch::aten::ScalarType::ADDITIONAL, CTYPE_ALIAS, __VA_ARGS__) -#define ET_INTERNAL_SWITCH_CASE_FLOAT_TYPES(CTYPE_ALIAS, ...) \ - ET_INTERNAL_SWITCH_CASE( \ - ::executorch::aten::ScalarType::Double, CTYPE_ALIAS, __VA_ARGS__) \ - ET_INTERNAL_SWITCH_CASE( \ - ::executorch::aten::ScalarType::Float, CTYPE_ALIAS, __VA_ARGS__) - #define ET_INTERNAL_SWITCH_CASE_FLOAT_TYPES_AND(ADDITIONAL, CTYPE_ALIAS, ...) \ ET_INTERNAL_SWITCH_CASE_FLOAT_TYPES(CTYPE_ALIAS, __VA_ARGS__) \ ET_INTERNAL_SWITCH_CASE( \ @@ -1050,32 +1045,6 @@ struct promote_types { ET_INTERNAL_SWITCH_CASE( \ ::executorch::aten::ScalarType::ADDITIONAL3, CTYPE_ALIAS, __VA_ARGS__) -#define ET_INTERNAL_SWITCH_CASE_QINT_TYPES(CTYPE_ALIAS, ...) \ - ET_INTERNAL_SWITCH_CASE( \ - ::executorch::aten::ScalarType::QInt8, CTYPE_ALIAS, __VA_ARGS__) \ - ET_INTERNAL_SWITCH_CASE( \ - ::executorch::aten::ScalarType::QUInt8, CTYPE_ALIAS, __VA_ARGS__) \ - ET_INTERNAL_SWITCH_CASE( \ - ::executorch::aten::ScalarType::QInt32, CTYPE_ALIAS, __VA_ARGS__) \ - ET_INTERNAL_SWITCH_CASE( \ - ::executorch::aten::ScalarType::QUInt4x2, CTYPE_ALIAS, __VA_ARGS__) \ - ET_INTERNAL_SWITCH_CASE( \ - ::executorch::aten::ScalarType::QUInt2x4, CTYPE_ALIAS, __VA_ARGS__) - -#define ET_INTERNAL_SWITCH_CASE_COMPLEX_TYPES(CTYPE_ALIAS, ...) \ - ET_INTERNAL_SWITCH_CASE( \ - ::executorch::aten::ScalarType::ComplexFloat, CTYPE_ALIAS, __VA_ARGS__) \ - ET_INTERNAL_SWITCH_CASE( \ - ::executorch::aten::ScalarType::ComplexDouble, CTYPE_ALIAS, __VA_ARGS__) - -#define ET_INTERNAL_SWITCH_CASE_COMPLEXH_TYPES(CTYPE_ALIAS, ...) \ - ET_INTERNAL_SWITCH_CASE( \ - ::executorch::aten::ScalarType::ComplexHalf, CTYPE_ALIAS, __VA_ARGS__) \ - ET_INTERNAL_SWITCH_CASE( \ - ::executorch::aten::ScalarType::ComplexFloat, CTYPE_ALIAS, __VA_ARGS__) \ - ET_INTERNAL_SWITCH_CASE( \ - ::executorch::aten::ScalarType::ComplexDouble, CTYPE_ALIAS, __VA_ARGS__) - #define ET_INTERNAL_SWITCH_CASE_SCALAR_OBJ_TYPES(CTYPE_ALIAS, ...) \ ET_INTERNAL_SWITCH_CASE( \ ::executorch::aten::ScalarType::Bool, CTYPE_ALIAS, __VA_ARGS__) \ @@ -1204,26 +1173,15 @@ struct promote_types { ET_SWITCH_REAL_TYPES_AND3( \ Half, Bool, BFloat16, TYPE, CONTEXT, NAME, CTYPE_ALIAS, __VA_ARGS__) -#define ET_SWITCH_REALHBBF16_AND_UINT_TYPES( \ - TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \ - ET_INTERNAL_SWITCH( \ - TYPE, \ - CONTEXT, \ - NAME, \ - ET_INTERNAL_SWITCH_CASE_REAL_TYPES_AND3( \ - Half, Bool, BFloat16, CTYPE_ALIAS, __VA_ARGS__) \ - ET_INTERNAL_SWITCH_CASE( \ - ::executorch::aten::ScalarType::UInt16, \ - CTYPE_ALIAS, \ - __VA_ARGS__) \ - ET_INTERNAL_SWITCH_CASE( \ - ::executorch::aten::ScalarType::UInt32, \ - CTYPE_ALIAS, \ - __VA_ARGS__) \ - ET_INTERNAL_SWITCH_CASE( \ - ::executorch::aten::ScalarType::UInt64, \ - CTYPE_ALIAS, \ - __VA_ARGS__)) +#define ET_SWITCH_REALHBBF16_AND_UINT_TYPES( \ + TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \ + ET_INTERNAL_SWITCH( \ + TYPE, \ + CONTEXT, \ + NAME, \ + ET_INTERNAL_SWITCH_CASE_REAL_TYPES_AND3( \ + Half, Bool, BFloat16, CTYPE_ALIAS, __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE_UINT_TYPES(CTYPE_ALIAS, __VA_ARGS__)) #define ET_SWITCH_INT_TYPES(TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \ ET_INTERNAL_SWITCH( \