Skip to content

Commit 5b88e6e

Browse files
committed
Revert "add the torch.float8_e8m0fnu dtype to PyTorch (pytorch#147466)"
This reverts commit 382fbcc.
1 parent 4708cfd commit 5b88e6e

25 files changed

+44
-535
lines changed

aten/src/ATen/DLConvertor.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,10 @@ DLDataType getDLDataType(const Tensor& t) {
6363
case ScalarType::BFloat16:
6464
dtype.code = DLDataTypeCode::kDLBfloat;
6565
break;
66-
// TODO(#146647): use macro here instead of spelling out each shell dtype
6766
case ScalarType::Float8_e5m2:
6867
case ScalarType::Float8_e5m2fnuz:
6968
case ScalarType::Float8_e4m3fn:
7069
case ScalarType::Float8_e4m3fnuz:
71-
case ScalarType::Float8_e8m0fnu:
7270
TORCH_CHECK(false, "float8 types are not supported by dlpack");
7371
break;
7472
case ScalarType::QInt8:

aten/src/ATen/Dispatch_v2.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@
8787

8888
#define AT_FLOAT8_TYPES \
8989
c10::kFloat8_e5m2, c10::kFloat8_e5m2fnuz, c10::kFloat8_e4m3fn, \
90-
c10::kFloat8_e4m3fnuz, c10::kFloat8_e8m0fnu
90+
c10::kFloat8_e4m3fnuz
9191

9292
#define AT_INTEGRAL_TYPES \
9393
c10::kByte, c10::kChar, c10::kInt, c10::kLong, c10::kShort

aten/src/ATen/native/Copy.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ bool copy_transpose_valid(const Tensor& self, const Tensor& src) {
5959
#if !defined(C10_MOBILE)
6060
#define _AT_DISPATCH_CP_TYPES(TYPE, NAME, ...) \
6161
AT_DISPATCH_V2( \
62-
TYPE, NAME, AT_WRAP(__VA_ARGS__), kComplexHalf, kHalf, kBool, kBFloat16, \
63-
AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES))
62+
TYPE, NAME, AT_WRAP(__VA_ARGS__), kComplexHalf, kHalf, kBool, kBFloat16, kFloat8_e5m2, \
63+
kFloat8_e4m3fn, kFloat8_e5m2fnuz, kFloat8_e4m3fnuz, AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES))
6464
#else
6565
#define _AT_DISPATCH_CP_TYPES(TYPE, NAME, ...) \
6666
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \

aten/src/ATen/native/TensorCompare.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -460,8 +460,7 @@ Tensor isinf(const Tensor& self) {
460460

461461
Tensor isfinite(const Tensor& self) {
462462
// Note: Integral tensor values are always finite
463-
if (c10::isIntegralType(self.scalar_type(), /*includeBool=*/true) ||
464-
self.scalar_type() == kFloat8_e8m0fnu) {
463+
if (c10::isIntegralType(self.scalar_type(), /*includeBool=*/true)) {
465464
return at::ones_like(self, at::kBool, at::MemoryFormat::Preserve);
466465
}
467466

aten/src/ATen/native/cpu/CopyKernel.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -204,12 +204,12 @@ static void reduced_float_copy_kernel(TensorIteratorBase &iter, bool requires_ne
204204
#define _AT_DISPATCH_ALL_TYPES(TYPE, NAME, ...) \
205205
AT_DISPATCH_V2(TYPE, NAME, AT_WRAP(__VA_ARGS__), \
206206
kComplexHalf, kHalf, kBool, \
207-
kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), \
208-
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES))
207+
kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, \
208+
kFloat8_e5m2fnuz, kFloat8_e4m3fnuz, AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES))
209209
#define _AT_DISPATCH_ALL_TYPES_NO_CF(TYPE, NAME, ...) \
210210
AT_DISPATCH_V2(TYPE, NAME, AT_WRAP(__VA_ARGS__), \
211-
kBool, kHalf, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), \
212-
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES))
211+
kBool, kHalf, kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, \
212+
kFloat8_e5m2fnuz, kFloat8_e4m3fnuz, AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES))
213213
#else
214214
#define _AT_DISPATCH_ALL_TYPES(TYPE, NAME, ...) \
215215
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \

aten/src/ATen/native/cpu/FillKernel.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,6 @@ void fill_kernel(TensorIterator& iter, const Scalar& value_scalar) {
5151
fill_non_native_type<at::Float8_e4m3fnuz>(iter, value_scalar);
5252
} else if (iter.dtype() == ScalarType::Float8_e5m2fnuz) {
5353
fill_non_native_type<at::Float8_e5m2fnuz>(iter, value_scalar);
54-
} else if (iter.dtype() == ScalarType::Float8_e8m0fnu) {
55-
// TODO(#146647): use macro here instead of spelling out each float8 dtype
56-
fill_non_native_type<at::Float8_e8m0fnu>(iter, value_scalar);
5754
} else {
5855
AT_DISPATCH_V2(
5956
iter.dtype(), "fill_cpu", AT_WRAP([&]() {

aten/src/ATen/native/cpu/IndexKernel.cpp

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -184,13 +184,7 @@ void index_put_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef
184184
}
185185
}),
186186
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
187-
// AT_EXPAND(AT_FLOAT8_TYPES),
188-
// TODO(#113663): clean up accumulation behavior in float8 dtypes, accumulate=True
189-
// should not be supported here, then reenable AT_FLOAT8_DTYPES
190-
kFloat8_e4m3fn,
191-
kFloat8_e5m2,
192-
kFloat8_e4m3fnuz,
193-
kFloat8_e5m2fnuz,
187+
AT_EXPAND(AT_FLOAT8_TYPES),
194188
kComplexHalf,
195189
kHalf,
196190
kBool,

aten/src/ATen/native/cuda/Copy.cu

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -144,28 +144,6 @@ void float8_copy_kernel_cuda(TensorIteratorBase &iter) {
144144
gpu_kernel(iter, [] GPU_LAMBDA(Float8_e5m2fnuz x) { return x; });
145145
break;
146146
}
147-
} else if (dtype == kFloat8_e8m0fnu) {
148-
// TODO(#146647): clean this up, too much copy-pasta
149-
switch (other_dtype) {
150-
case kFloat:
151-
gpu_kernel_nocast(iter, [] GPU_LAMBDA(float value) {
152-
return Float8_e8m0fnu(value);
153-
});
154-
break;
155-
case kHalf:
156-
gpu_kernel_nocast(iter, [] GPU_LAMBDA(Half value) {
157-
return Float8_e8m0fnu(value);
158-
});
159-
break;
160-
case kBFloat16:
161-
gpu_kernel_nocast(iter, [] GPU_LAMBDA(BFloat16 value) {
162-
return Float8_e8m0fnu(value);
163-
});
164-
break;
165-
default:
166-
gpu_kernel(iter, [] GPU_LAMBDA(Float8_e8m0fnu x) { return x; });
167-
break;
168-
}
169147
} else {
170148
TORCH_CHECK(false, "This supposed ot be called only for Float8 types");
171149
}
@@ -179,7 +157,7 @@ void direct_copy_kernel_cuda(TensorIteratorBase &iter) {
179157
AT_DISPATCH_QINT_TYPES(dtype, "copy_", [&] {
180158
gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) { return x; });
181159
});
182-
} else if (isFloat8Type(dtype)) {
160+
} else if (dtype == kFloat8_e5m2 || dtype == kFloat8_e4m3fn || dtype == kFloat8_e5m2fnuz || dtype == kFloat8_e4m3fnuz) {
183161
float8_copy_kernel_cuda(iter);
184162
} else if (iter.dtype(1) == kFloat && (dtype == kBFloat16 || dtype == kHalf)) {
185163
if (dtype == kBFloat16) {

aten/src/ATen/native/cuda/Indexing.cu

Lines changed: 5 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -708,13 +708,7 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
708708
C10_CUDA_KERNEL_LAUNCH_CHECK();
709709
}),
710710
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
711-
// AT_EXPAND(AT_FLOAT8_TYPES),
712-
// TODO(#113663): clean up accumulation behavior in float8 dtypes, accumulate=True
713-
// should not be supported here, then reenable AT_FLOAT8_DTYPES
714-
kFloat8_e4m3fn,
715-
kFloat8_e5m2,
716-
kFloat8_e4m3fnuz,
717-
kFloat8_e5m2fnuz,
711+
AT_EXPAND(AT_FLOAT8_TYPES),
718712
kComplexHalf,
719713
kHalf,
720714
kBool,
@@ -740,13 +734,7 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
740734
C10_CUDA_KERNEL_LAUNCH_CHECK();
741735
}),
742736
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
743-
// AT_EXPAND(AT_FLOAT8_TYPES),
744-
// TODO(#113663): clean up accumulation behavior in float8 dtypes, accumulate=True
745-
// should not be supported here, then reenable AT_FLOAT8_DTYPES
746-
kFloat8_e4m3fn,
747-
kFloat8_e5m2,
748-
kFloat8_e4m3fnuz,
749-
kFloat8_e5m2fnuz,
737+
AT_EXPAND(AT_FLOAT8_TYPES),
750738
kComplexHalf,
751739
kHalf,
752740
kBool,
@@ -770,13 +758,7 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
770758
C10_CUDA_KERNEL_LAUNCH_CHECK();
771759
}),
772760
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
773-
// AT_EXPAND(AT_FLOAT8_TYPES),
774-
// TODO(#113663): clean up accumulation behavior in float8 dtypes, accumulate=True
775-
// should not be supported here, then reenable AT_FLOAT8_DTYPES
776-
kFloat8_e4m3fn,
777-
kFloat8_e5m2,
778-
kFloat8_e4m3fnuz,
779-
kFloat8_e5m2fnuz,
761+
AT_EXPAND(AT_FLOAT8_TYPES),
780762
kComplexHalf,
781763
kHalf,
782764
kBool,
@@ -798,13 +780,7 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
798780
C10_CUDA_KERNEL_LAUNCH_CHECK();
799781
}),
800782
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
801-
// AT_EXPAND(AT_FLOAT8_TYPES),
802-
// TODO(#113663): clean up accumulation behavior in float8 dtypes, accumulate=True
803-
// should not be supported here, then reenable AT_FLOAT8_DTYPES
804-
kFloat8_e4m3fn,
805-
kFloat8_e5m2,
806-
kFloat8_e4m3fnuz,
807-
kFloat8_e5m2fnuz,
783+
AT_EXPAND(AT_FLOAT8_TYPES),
808784
kComplexHalf,
809785
kHalf,
810786
kBool,
@@ -829,13 +805,7 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
829805
C10_CUDA_KERNEL_LAUNCH_CHECK();
830806
}),
831807
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
832-
// AT_EXPAND(AT_FLOAT8_TYPES),
833-
// TODO(#113663): clean up accumulation behavior in float8 dtypes, accumulate=True
834-
// should not be supported here, then reenable AT_FLOAT8_DTYPES
835-
kFloat8_e4m3fn,
836-
kFloat8_e5m2,
837-
kFloat8_e4m3fnuz,
838-
kFloat8_e5m2fnuz,
808+
AT_EXPAND(AT_FLOAT8_TYPES),
839809
kComplexHalf,
840810
kHalf,
841811
kBool,

aten/src/ATen/native/cuda/jit_utils.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -228,10 +228,6 @@ template <> inline std::string typeName<at::Float8_e5m2fnuz>() {
228228
template <> inline std::string typeName<at::Float8_e4m3fnuz>() {
229229
return "at::Float8_e4m3fnuz";
230230
}
231-
template <> inline std::string typeName<at::Float8_e8m0fnu>() {
232-
// TODO(#146647): Can the code here be made generic for any scalartype?
233-
return "at::Float8_e8m0fnu";
234-
}
235231

236232
#define TYPE_NAME_CASE(ctype, scalartype) \
237233
case ScalarType::scalartype: return typeName<ctype>();

c10/core/Scalar.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,16 @@ class C10_API Scalar {
4949
#define DEFINE_IMPLICIT_CTOR(type, name) \
5050
Scalar(type vv) : Scalar(vv, true) {}
5151

52-
AT_FORALL_SCALAR_TYPES_AND3(Half, BFloat16, ComplexHalf, DEFINE_IMPLICIT_CTOR)
52+
AT_FORALL_SCALAR_TYPES_AND7(
53+
Half,
54+
BFloat16,
55+
Float8_e5m2,
56+
Float8_e4m3fn,
57+
Float8_e5m2fnuz,
58+
Float8_e4m3fnuz,
59+
ComplexHalf,
60+
DEFINE_IMPLICIT_CTOR)
5361
AT_FORALL_COMPLEX_TYPES(DEFINE_IMPLICIT_CTOR)
54-
AT_FORALL_FLOAT8_TYPES(DEFINE_IMPLICIT_CTOR)
5562

5663
// Helper constructors to allow Scalar creation from long and long long types
5764
// As std::is_same_v<long, long long> is false(except Android), one needs to

c10/core/ScalarType.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -222,9 +222,6 @@ std::pair<std::string, std::string> getDtypeNames(c10::ScalarType scalarType) {
222222
return std::make_pair("float8_e5m2fnuz", "");
223223
case c10::ScalarType::Float8_e4m3fnuz:
224224
return std::make_pair("float8_e4m3fnuz", "");
225-
case c10::ScalarType::Float8_e8m0fnu:
226-
// TODO(#146647): macroify all of this
227-
return std::make_pair("float8_e8m0fnu", "");
228225
default:
229226
throw std::runtime_error("Unimplemented scalar type");
230227
}

c10/core/ScalarType.h

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
#include <c10/util/Float8_e4m3fnuz.h>
88
#include <c10/util/Float8_e5m2.h>
99
#include <c10/util/Float8_e5m2fnuz.h>
10-
#include <c10/util/Float8_e8m0fnu.h>
1110
#include <c10/util/Half.h>
1211
#include <c10/util/bits.h>
1312
#include <c10/util/complex.h>
@@ -103,8 +102,7 @@ struct dummy_int1_7_t {};
103102
_(c10::dummy_int1_7_t<4>, Int4) /* 40 */ \
104103
_(c10::dummy_int1_7_t<5>, Int5) /* 41 */ \
105104
_(c10::dummy_int1_7_t<6>, Int6) /* 42 */ \
106-
_(c10::dummy_int1_7_t<7>, Int7) /* 43 */ \
107-
_(c10::Float8_e8m0fnu, Float8_e8m0fnu) /* 44 */
105+
_(c10::dummy_int1_7_t<7>, Int7) /* 43 */
108106

109107
// If you want to support ComplexHalf for real, add ComplexHalf
110108
// into this macro (and change the name). But beware: convert()
@@ -148,8 +146,7 @@ struct dummy_int1_7_t {};
148146
_(at::Float8_e5m2, Float8_e5m2) \
149147
_(at::Float8_e4m3fn, Float8_e4m3fn) \
150148
_(at::Float8_e5m2fnuz, Float8_e5m2fnuz) \
151-
_(at::Float8_e4m3fnuz, Float8_e4m3fnuz) \
152-
_(at::Float8_e8m0fnu, Float8_e8m0fnu)
149+
_(at::Float8_e4m3fnuz, Float8_e4m3fnuz)
153150

154151
enum class ScalarType : int8_t {
155152
#define DEFINE_ST_ENUM_VAL_(_1, n) n,
@@ -320,13 +317,6 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType)
320317
_(c10::quint4x2, QUInt4x2) \
321318
_(c10::quint2x4, QUInt2x4)
322319

323-
#define AT_FORALL_FLOAT8_TYPES(_) \
324-
_(at::Float8_e5m2, Float8_e5m2) \
325-
_(at::Float8_e4m3fn, Float8_e4m3fn) \
326-
_(at::Float8_e5m2fnuz, Float8_e5m2fnuz) \
327-
_(at::Float8_e4m3fnuz, Float8_e4m3fnuz) \
328-
_(at::Float8_e8m0fnu, Float8_e8m0fnu)
329-
330320
#define AT_FORALL_COMPLEX_TYPES(_) \
331321
_(c10::complex<float>, ComplexFloat) \
332322
_(c10::complex<double>, ComplexDouble)
@@ -382,8 +372,7 @@ inline bool isIntegralType(ScalarType t) {
382372

383373
inline bool isFloat8Type(ScalarType t) {
384374
return t == ScalarType::Float8_e5m2 || t == ScalarType::Float8_e5m2fnuz ||
385-
t == ScalarType::Float8_e4m3fn || t == ScalarType::Float8_e4m3fnuz ||
386-
t == ScalarType::Float8_e8m0fnu;
375+
t == ScalarType::Float8_e4m3fn || t == ScalarType::Float8_e4m3fnuz;
387376
}
388377

389378
inline bool isReducedFloatingType(ScalarType t) {
@@ -457,10 +446,6 @@ inline bool isSignedType(ScalarType t) {
457446
return std::numeric_limits< \
458447
::c10::impl::ScalarTypeToCPPTypeT<ScalarType::name>>::is_signed;
459448

460-
// TODO(#146647): If we expect to have numeric_limits for everything,
461-
// let's just have a big macro for the whole thing.
462-
// If we're hardcoding it, let's just use the macro and a "true"/"false"
463-
// below?
464449
switch (t) {
465450
case ScalarType::QInt8:
466451
case ScalarType::QUInt8:
@@ -482,7 +467,6 @@ inline bool isSignedType(ScalarType t) {
482467
CASE_ISSIGNED(Float8_e5m2fnuz);
483468
CASE_ISSIGNED(Float8_e4m3fn);
484469
CASE_ISSIGNED(Float8_e4m3fnuz);
485-
CASE_ISSIGNED(Float8_e8m0fnu);
486470
CASE_ISSIGNED(Byte);
487471
CASE_ISSIGNED(Char);
488472
CASE_ISSIGNED(Short);

0 commit comments

Comments
 (0)