Skip to content

Commit 382fbcc

Browse files
vkuzopytorchmergebot
authored andcommitted
add the torch.float8_e8m0fnu dtype to PyTorch (pytorch#147466)
Summary: Continuing the work from pytorch#146427 Adds the `torch.float8_e8m0fnu` dtype to PyTorch, as detailed in pytorch#146414 . Please see the issue for a detailed definition of the format. Example of basic functionality: ```python import torch # round trip x0 = torch.randn(4, 4, dtype=torch.float32) x1 = x0.to(torch.float8_e8m0fnu) # RNE rounding x2 = x1.to(torch.float32) # 2 ** exponent # creation with empty x0 = torch.empty(4, 4, dtype=torch.float8_e8m0fnu) # printing print(x0) ``` Done in this PR: * numerical correctness * op coverage (except for `torch._scaled_mm`): create tensor, cast to/from float32 * printing a tensor works For future PRs: * performance optimizations for casting * torch._scaled_mm * PT2 * various cleanups (detailed in comments with issue numbers) Test Plan: ``` pytest test/quantization/core/experimental/test_float8.py -s ``` Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: pytorch#147466 Approved by: https://github.com/drisspg
1 parent 574371d commit 382fbcc

25 files changed

+535
-44
lines changed

aten/src/ATen/DLConvertor.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,12 @@ DLDataType getDLDataType(const Tensor& t) {
6363
case ScalarType::BFloat16:
6464
dtype.code = DLDataTypeCode::kDLBfloat;
6565
break;
66+
// TODO(#146647): use macro here instead of spelling out each shell dtype
6667
case ScalarType::Float8_e5m2:
6768
case ScalarType::Float8_e5m2fnuz:
6869
case ScalarType::Float8_e4m3fn:
6970
case ScalarType::Float8_e4m3fnuz:
71+
case ScalarType::Float8_e8m0fnu:
7072
TORCH_CHECK(false, "float8 types are not supported by dlpack");
7173
break;
7274
case ScalarType::QInt8:

aten/src/ATen/Dispatch_v2.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@
8787

8888
#define AT_FLOAT8_TYPES \
8989
c10::kFloat8_e5m2, c10::kFloat8_e5m2fnuz, c10::kFloat8_e4m3fn, \
90-
c10::kFloat8_e4m3fnuz
90+
c10::kFloat8_e4m3fnuz, c10::kFloat8_e8m0fnu
9191

9292
#define AT_INTEGRAL_TYPES \
9393
c10::kByte, c10::kChar, c10::kInt, c10::kLong, c10::kShort

aten/src/ATen/native/Copy.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ bool copy_transpose_valid(const Tensor& self, const Tensor& src) {
5959
#if !defined(C10_MOBILE)
6060
#define _AT_DISPATCH_CP_TYPES(TYPE, NAME, ...) \
6161
AT_DISPATCH_V2( \
62-
TYPE, NAME, AT_WRAP(__VA_ARGS__), kComplexHalf, kHalf, kBool, kBFloat16, kFloat8_e5m2, \
63-
kFloat8_e4m3fn, kFloat8_e5m2fnuz, kFloat8_e4m3fnuz, AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES))
62+
TYPE, NAME, AT_WRAP(__VA_ARGS__), kComplexHalf, kHalf, kBool, kBFloat16, \
63+
AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES))
6464
#else
6565
#define _AT_DISPATCH_CP_TYPES(TYPE, NAME, ...) \
6666
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \

aten/src/ATen/native/TensorCompare.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,8 @@ Tensor isinf(const Tensor& self) {
460460

461461
Tensor isfinite(const Tensor& self) {
462462
// Note: Integral tensor values are always finite
463-
if (c10::isIntegralType(self.scalar_type(), /*includeBool=*/true)) {
463+
if (c10::isIntegralType(self.scalar_type(), /*includeBool=*/true) ||
464+
self.scalar_type() == kFloat8_e8m0fnu) {
464465
return at::ones_like(self, at::kBool, at::MemoryFormat::Preserve);
465466
}
466467

aten/src/ATen/native/cpu/CopyKernel.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -204,12 +204,12 @@ static void reduced_float_copy_kernel(TensorIteratorBase &iter, bool requires_ne
204204
#define _AT_DISPATCH_ALL_TYPES(TYPE, NAME, ...) \
205205
AT_DISPATCH_V2(TYPE, NAME, AT_WRAP(__VA_ARGS__), \
206206
kComplexHalf, kHalf, kBool, \
207-
kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, \
208-
kFloat8_e5m2fnuz, kFloat8_e4m3fnuz, AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES))
207+
kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), \
208+
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES))
209209
#define _AT_DISPATCH_ALL_TYPES_NO_CF(TYPE, NAME, ...) \
210210
AT_DISPATCH_V2(TYPE, NAME, AT_WRAP(__VA_ARGS__), \
211-
kBool, kHalf, kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, \
212-
kFloat8_e5m2fnuz, kFloat8_e4m3fnuz, AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES))
211+
kBool, kHalf, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), \
212+
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES))
213213
#else
214214
#define _AT_DISPATCH_ALL_TYPES(TYPE, NAME, ...) \
215215
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \

aten/src/ATen/native/cpu/FillKernel.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ void fill_kernel(TensorIterator& iter, const Scalar& value_scalar) {
5151
fill_non_native_type<at::Float8_e4m3fnuz>(iter, value_scalar);
5252
} else if (iter.dtype() == ScalarType::Float8_e5m2fnuz) {
5353
fill_non_native_type<at::Float8_e5m2fnuz>(iter, value_scalar);
54+
} else if (iter.dtype() == ScalarType::Float8_e8m0fnu) {
55+
// TODO(#146647): use macro here instead of spelling out each float8 dtype
56+
fill_non_native_type<at::Float8_e8m0fnu>(iter, value_scalar);
5457
} else {
5558
AT_DISPATCH_V2(
5659
iter.dtype(), "fill_cpu", AT_WRAP([&]() {

aten/src/ATen/native/cpu/IndexKernel.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,13 @@ void index_put_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef
184184
}
185185
}),
186186
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
187-
AT_EXPAND(AT_FLOAT8_TYPES),
187+
// AT_EXPAND(AT_FLOAT8_TYPES),
188+
// TODO(#113663): clean up accumulation behavior in float8 dtypes, accumulate=True
189+
// should not be supported here, then reenable AT_FLOAT8_DTYPES
190+
kFloat8_e4m3fn,
191+
kFloat8_e5m2,
192+
kFloat8_e4m3fnuz,
193+
kFloat8_e5m2fnuz,
188194
kComplexHalf,
189195
kHalf,
190196
kBool,

aten/src/ATen/native/cuda/Copy.cu

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,28 @@ void float8_copy_kernel_cuda(TensorIteratorBase &iter) {
144144
gpu_kernel(iter, [] GPU_LAMBDA(Float8_e5m2fnuz x) { return x; });
145145
break;
146146
}
147+
} else if (dtype == kFloat8_e8m0fnu) {
148+
// TODO(#146647): clean this up, too much copy-pasta
149+
switch (other_dtype) {
150+
case kFloat:
151+
gpu_kernel_nocast(iter, [] GPU_LAMBDA(float value) {
152+
return Float8_e8m0fnu(value);
153+
});
154+
break;
155+
case kHalf:
156+
gpu_kernel_nocast(iter, [] GPU_LAMBDA(Half value) {
157+
return Float8_e8m0fnu(value);
158+
});
159+
break;
160+
case kBFloat16:
161+
gpu_kernel_nocast(iter, [] GPU_LAMBDA(BFloat16 value) {
162+
return Float8_e8m0fnu(value);
163+
});
164+
break;
165+
default:
166+
gpu_kernel(iter, [] GPU_LAMBDA(Float8_e8m0fnu x) { return x; });
167+
break;
168+
}
147169
} else {
148170
TORCH_CHECK(false, "This supposed ot be called only for Float8 types");
149171
}
@@ -157,7 +179,7 @@ void direct_copy_kernel_cuda(TensorIteratorBase &iter) {
157179
AT_DISPATCH_QINT_TYPES(dtype, "copy_", [&] {
158180
gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) { return x; });
159181
});
160-
} else if (dtype == kFloat8_e5m2 || dtype == kFloat8_e4m3fn || dtype == kFloat8_e5m2fnuz || dtype == kFloat8_e4m3fnuz) {
182+
} else if (isFloat8Type(dtype)) {
161183
float8_copy_kernel_cuda(iter);
162184
} else if (iter.dtype(1) == kFloat && (dtype == kBFloat16 || dtype == kHalf)) {
163185
if (dtype == kBFloat16) {

aten/src/ATen/native/cuda/Indexing.cu

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -582,7 +582,13 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
582582
C10_CUDA_KERNEL_LAUNCH_CHECK();
583583
}),
584584
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
585-
AT_EXPAND(AT_FLOAT8_TYPES),
585+
// AT_EXPAND(AT_FLOAT8_TYPES),
586+
// TODO(#113663): clean up accumulation behavior in float8 dtypes, accumulate=True
587+
// should not be supported here, then reenable AT_FLOAT8_DTYPES
588+
kFloat8_e4m3fn,
589+
kFloat8_e5m2,
590+
kFloat8_e4m3fnuz,
591+
kFloat8_e5m2fnuz,
586592
kComplexHalf,
587593
kHalf,
588594
kBool,
@@ -606,7 +612,13 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
606612
C10_CUDA_KERNEL_LAUNCH_CHECK();
607613
}),
608614
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
609-
AT_EXPAND(AT_FLOAT8_TYPES),
615+
// AT_EXPAND(AT_FLOAT8_TYPES),
616+
// TODO(#113663): clean up accumulation behavior in float8 dtypes, accumulate=True
617+
// should not be supported here, then reenable AT_FLOAT8_DTYPES
618+
kFloat8_e4m3fn,
619+
kFloat8_e5m2,
620+
kFloat8_e4m3fnuz,
621+
kFloat8_e5m2fnuz,
610622
kComplexHalf,
611623
kHalf,
612624
kBool,
@@ -630,7 +642,13 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
630642
C10_CUDA_KERNEL_LAUNCH_CHECK();
631643
}),
632644
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
633-
AT_EXPAND(AT_FLOAT8_TYPES),
645+
// AT_EXPAND(AT_FLOAT8_TYPES),
646+
// TODO(#113663): clean up accumulation behavior in float8 dtypes, accumulate=True
647+
// should not be supported here, then reenable AT_FLOAT8_DTYPES
648+
kFloat8_e4m3fn,
649+
kFloat8_e5m2,
650+
kFloat8_e4m3fnuz,
651+
kFloat8_e5m2fnuz,
634652
kComplexHalf,
635653
kHalf,
636654
kBool,
@@ -652,7 +670,13 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
652670
C10_CUDA_KERNEL_LAUNCH_CHECK();
653671
}),
654672
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
655-
AT_EXPAND(AT_FLOAT8_TYPES),
673+
// AT_EXPAND(AT_FLOAT8_TYPES),
674+
// TODO(#113663): clean up accumulation behavior in float8 dtypes, accumulate=True
675+
// should not be supported here, then reenable AT_FLOAT8_DTYPES
676+
kFloat8_e4m3fn,
677+
kFloat8_e5m2,
678+
kFloat8_e4m3fnuz,
679+
kFloat8_e5m2fnuz,
656680
kComplexHalf,
657681
kHalf,
658682
kBool,
@@ -677,7 +701,13 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Ten
677701
C10_CUDA_KERNEL_LAUNCH_CHECK();
678702
}),
679703
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
680-
AT_EXPAND(AT_FLOAT8_TYPES),
704+
// AT_EXPAND(AT_FLOAT8_TYPES),
705+
// TODO(#113663): clean up accumulation behavior in float8 dtypes, accumulate=True
706+
// should not be supported here, then reenable AT_FLOAT8_DTYPES
707+
kFloat8_e4m3fn,
708+
kFloat8_e5m2,
709+
kFloat8_e4m3fnuz,
710+
kFloat8_e5m2fnuz,
681711
kComplexHalf,
682712
kHalf,
683713
kBool,

aten/src/ATen/native/cuda/jit_utils.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,10 @@ template <> inline std::string typeName<at::Float8_e5m2fnuz>() {
228228
template <> inline std::string typeName<at::Float8_e4m3fnuz>() {
229229
return "at::Float8_e4m3fnuz";
230230
}
231+
template <> inline std::string typeName<at::Float8_e8m0fnu>() {
232+
// TODO(#146647): Can the code here be made generic for any scalartype?
233+
return "at::Float8_e8m0fnu";
234+
}
231235

232236
#define TYPE_NAME_CASE(ctype, scalartype) \
233237
case ScalarType::scalartype: return typeName<ctype>();

0 commit comments

Comments
 (0)