Skip to content

Commit efa0cc2

Browse files
authored
implement isinf20 and isnan20 (#17874)
1 parent abb3291 commit efa0cc2

8 files changed

Lines changed: 389 additions & 101 deletions

File tree

docs/OperatorKernels.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,10 @@ Do not modify directly.*
156156
|||[1, 10]|**B** = tensor(bool)<br/> **V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
157157
|ImageScaler|*in* input:**T**<br> *out* output:**T**|1+|**T** = tensor(float)|
158158
|InstanceNormalization|*in* input:**T**<br> *in* scale:**T**<br> *in* B:**T**<br> *out* output:**T**|6+|**T** = tensor(float)|
159-
|IsInf|*in* X:**T1**<br> *out* Y:**T2**|10+|**T1** = tensor(double), tensor(float)<br/> **T2** = tensor(bool)|
160-
|IsNaN|*in* X:**T1**<br> *out* Y:**T2**|13+|**T1** = tensor(double), tensor(float), tensor(float16)<br/> **T2** = tensor(bool)|
159+
|IsInf|*in* X:**T1**<br> *out* Y:**T2**|20+|**T1** = tensor(double), tensor(float), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)<br/> **T2** = tensor(bool)|
160+
|||[10, 19]|**T1** = tensor(double), tensor(float)<br/> **T2** = tensor(bool)|
161+
|IsNaN|*in* X:**T1**<br> *out* Y:**T2**|20+|**T1** = tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)<br/> **T2** = tensor(bool)|
162+
|||[13, 19]|**T1** = tensor(double), tensor(float), tensor(float16)<br/> **T2** = tensor(bool)|
161163
|||[9, 12]|**T1** = tensor(double), tensor(float), tensor(float16)<br/> **T2** = tensor(bool)|
162164
|LRN|*in* X:**T**<br> *out* Y:**T**|13+|**T** = tensor(float)|
163165
|||[1, 12]|**T** = tensor(float)|

include/onnxruntime/core/framework/float8.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,9 +208,10 @@ struct Float8E4M3FNUZ {
208208
val = static_cast<uint8_t>((b & 0x80000000) >> 24); // sign
209209
if ((b & 0x7fffffff) == 0x7f800000) { // infinity
210210
if (saturate) {
211+
// the highest available value
211212
val |= 0x7F;
212213
} else {
213-
// infinity
214+
// NaN
214215
val = 0x80;
215216
}
216217
} else if ((b & 0x7F800000) == 0x7F800000) { // NaN
@@ -362,8 +363,10 @@ struct Float8E5M2 {
362363
val = (b & 0x80000000) >> 24; // sign
363364
if ((b & 0x7FFFFFFF) == 0x7F800000) { // inf
364365
if (saturate) {
366+
// the highest available value
365367
val |= 0x7B;
366368
} else {
369+
// the infinity
367370
val |= 0x7C;
368371
}
369372
} else if ((b & 0x7F800000) == 0x7F800000) { // NaN

onnxruntime/core/providers/cpu/cpu_execution_provider.cc

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
365365
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10, Slice);
366366
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 11, Dropout);
367367
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10, NonMaxSuppression);
368-
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, IsInf);
368+
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 19, IsInf);
369369
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 15, float, RoiAlign);
370370
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 15, double, RoiAlign);
371371
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, ReverseSequence);
@@ -682,9 +682,9 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Ga
682682
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 15, ScatterND);
683683
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 15, ScatterElements);
684684
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 13, Identity);
685-
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, IsNaN);
686-
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, IsNaN);
687-
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, MLFloat16, IsNaN);
685+
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 19, float, IsNaN);
686+
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 19, double, IsNaN);
687+
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 19, MLFloat16, IsNaN);
688688
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, bool, NonZero);
689689
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, NonZero);
690690
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int32_t, NonZero);
@@ -960,6 +960,16 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, Sh
960960

961961
// Opset 20
962962
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, ConstantOfShape);
963+
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, float, IsNaN);
964+
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, double, IsNaN);
965+
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, MLFloat16, IsNaN);
966+
#if !defined(DISABLE_FLOAT8_TYPES)
967+
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E4M3FN, IsNaN);
968+
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E4M3FNUZ, IsNaN);
969+
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E5M2, IsNaN);
970+
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E5M2FNUZ, IsNaN);
971+
#endif
972+
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, IsInf);
963973

964974
// !!PLEASE READ BELOW!! Following that, add new entries above this comment
965975

@@ -1492,7 +1502,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
14921502
Dropout)>,
14931503
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10,
14941504
NonMaxSuppression)>,
1495-
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, IsInf)>,
1505+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 19, IsInf)>,
14961506
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 15, float,
14971507
RoiAlign)>,
14981508
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 15, double,
@@ -1981,12 +1991,12 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
19811991
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 15, ScatterElements)>,
19821992
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 15, ScatterND)>,
19831993
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 13, Identity)>,
1984-
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float,
1985-
IsNaN)>,
1986-
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double,
1987-
IsNaN)>,
1988-
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, MLFloat16,
1989-
IsNaN)>,
1994+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 19, float,
1995+
IsNaN)>,
1996+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 19, double,
1997+
IsNaN)>,
1998+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 19, MLFloat16,
1999+
IsNaN)>,
19902000
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, bool,
19912001
NonZero)>,
19922002
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float,
@@ -2389,6 +2399,16 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
23892399

23902400
// Opset 20
23912401
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, ConstantOfShape)>,
2402+
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, float, IsNaN)>,
2403+
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, double, IsNaN)>,
2404+
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, MLFloat16, IsNaN)>,
2405+
#if !defined(DISABLE_FLOAT8_TYPES)
2406+
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E4M3FN, IsNaN)>,
2407+
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E4M3FNUZ, IsNaN)>,
2408+
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E5M2, IsNaN)>,
2409+
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E5M2FNUZ, IsNaN)>,
2410+
#endif
2411+
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, IsInf)>,
23922412
};
23932413

23942414
for (auto& function_table_entry : function_table) {

onnxruntime/core/providers/cpu/tensor/isinf.cc

Lines changed: 92 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,30 +14,64 @@ namespace onnxruntime {
1414
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#IsInf
1515

1616
namespace op_kernel_type_control {
17-
ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPES_ALL_OPSETS(
18-
kCpuExecutionProvider, kOnnxDomain, IsInf, Input, 0,
19-
float, double);
17+
using IsInfTypesOpset10 = TypeList<float, double>;
18+
19+
ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPE_LIST(
20+
kCpuExecutionProvider, kOnnxDomain, IsInf, 10, Input, 0,
21+
IsInfTypesOpset10);
22+
23+
using IsInfTypesOpset20 =
24+
TypeList<
25+
float,
26+
double
27+
#if !defined(DISABLE_FLOAT8_TYPES)
28+
,
29+
Float8E4M3FN, Float8E4M3FNUZ, Float8E5M2, Float8E5M2FNUZ
30+
#endif
31+
>;
32+
33+
ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPE_LIST(
34+
kCpuExecutionProvider,
35+
kOnnxDomain,
36+
IsInf,
37+
20,
38+
Input,
39+
0,
40+
IsInfTypesOpset20);
2041
} // namespace op_kernel_type_control
2142

2243
class IsInf final : public OpKernel {
2344
public:
24-
using EnabledDataTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS(kCpuExecutionProvider, kOnnxDomain,
25-
IsInf, Input, 0);
45+
using EnabledDataTypes10 = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain,
46+
IsInf, 10, Input, 0);
47+
using EnabledDataTypes20 = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain,
48+
IsInf, 20, Input, 0);
2649

2750
explicit IsInf(const OpKernelInfo& info);
2851
Status Compute(OpKernelContext* context) const override;
2952

3053
private:
3154
int64_t detect_positive_{1};
3255
int64_t detect_negative_{1};
56+
int opset_;
3357
};
3458

35-
ONNX_CPU_OPERATOR_KERNEL(
59+
ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
3660
IsInf,
3761
10,
62+
19,
3863
KernelDefBuilder()
3964
.TypeConstraint("T1",
40-
BuildKernelDefConstraintsFromTypeList<IsInf::EnabledDataTypes>())
65+
BuildKernelDefConstraintsFromTypeList<IsInf::EnabledDataTypes10>())
66+
.TypeConstraint("T2", DataTypeImpl::GetTensorType<bool>()),
67+
IsInf);
68+
69+
ONNX_CPU_OPERATOR_KERNEL(
70+
IsInf,
71+
20,
72+
KernelDefBuilder()
73+
.TypeConstraint("T1",
74+
BuildKernelDefConstraintsFromTypeList<IsInf::EnabledDataTypes20>())
4175
.TypeConstraint("T2", DataTypeImpl::GetTensorType<bool>()),
4276
IsInf);
4377

@@ -46,6 +80,7 @@ IsInf::IsInf(const OpKernelInfo& info) : OpKernel(info) {
4680
ORT_ENFORCE(status.IsOK(), "Failed to obtain detect_positive");
4781
status = info.GetAttr("detect_negative", &detect_negative_);
4882
ORT_ENFORCE(status.IsOK(), "Failed to obtain detect_negative");
83+
opset_ = info.node().SinceVersion();
4984
}
5085

5186
namespace isinf_internal {
@@ -78,6 +113,49 @@ struct ComputeDispatchTarget {
78113
}
79114
}
80115
};
116+
117+
#if !defined(DISABLE_FLOAT8_TYPES)
118+
template <>
119+
struct ComputeDispatchTarget<Float8E4M3FN> {
120+
void operator()(const Tensor&, Tensor& Y, bool, bool) const {
121+
EigenMap<bool>(Y).array() = false;
122+
}
123+
};
124+
125+
template <>
126+
struct ComputeDispatchTarget<Float8E4M3FNUZ> {
127+
void operator()(const Tensor&, Tensor& Y, bool, bool) const {
128+
EigenMap<bool>(Y).array() = false;
129+
}
130+
};
131+
132+
template <>
133+
struct ComputeDispatchTarget<Float8E5M2> {
134+
void operator()(const Tensor& X, Tensor& Y, bool detect_positive, bool detect_negative) const {
135+
auto& dims = X.Shape();
136+
auto input = ConstEigenVectorMap<uint8_t>(static_cast<const uint8_t*>(static_cast<const void*>(X.Data<Float8E5M2>())), onnxruntime::narrow<size_t>(dims.Size()));
137+
auto output = EigenMap<bool>(Y);
138+
139+
// S.11111.00
140+
if (detect_positive && detect_negative) {
141+
output.array() = input.array() == 0b01111100 || input.array() == 0b11111100;
142+
} else if (detect_positive) {
143+
output.array() = input.array() == 0b01111100;
144+
} else if (detect_negative) {
145+
output.array() = input.array() == 0b11111100;
146+
} else {
147+
output.array() = false;
148+
}
149+
}
150+
};
151+
152+
template <>
153+
struct ComputeDispatchTarget<Float8E5M2FNUZ> {
154+
void operator()(const Tensor&, Tensor& Y, bool, bool) const {
155+
EigenMap<bool>(Y).array() = false;
156+
}
157+
};
158+
#endif
81159
} // namespace isinf_internal
82160

83161
Status IsInf::Compute(OpKernelContext* context) const {
@@ -88,8 +166,13 @@ Status IsInf::Compute(OpKernelContext* context) const {
88166

89167
using namespace isinf_internal;
90168

91-
utils::MLTypeCallDispatcherFromTypeList<EnabledDataTypes> dispatcher{X.GetElementType()};
92-
dispatcher.Invoke<ComputeDispatchTarget>(X, Y, detect_positive_ != 0, detect_negative_ != 0);
169+
if (opset_ < 20) {
170+
utils::MLTypeCallDispatcherFromTypeList<EnabledDataTypes10> dispatcher{X.GetElementType()};
171+
dispatcher.Invoke<ComputeDispatchTarget>(X, Y, detect_positive_ != 0, detect_negative_ != 0);
172+
} else {
173+
utils::MLTypeCallDispatcherFromTypeList<EnabledDataTypes20> dispatcher{X.GetElementType()};
174+
dispatcher.Invoke<ComputeDispatchTarget>(X, Y, detect_positive_ != 0, detect_negative_ != 0);
175+
}
93176

94177
return Status::OK();
95178
}

onnxruntime/core/providers/cpu/tensor/isnan.cc

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,20 @@ namespace onnxruntime {
2020
.TypeConstraint("T2", DataTypeImpl::GetTensorType<bool>()), \
2121
IsNaN<data_type>);
2222

23+
#define ADD_TYPED_ISNAN_OP_13(data_type) \
24+
ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( \
25+
IsNaN, \
26+
13, 19, \
27+
data_type, \
28+
KernelDefBuilder() \
29+
.TypeConstraint("T1", DataTypeImpl::GetTensorType<data_type>()) \
30+
.TypeConstraint("T2", DataTypeImpl::GetTensorType<bool>()), \
31+
IsNaN<data_type>);
32+
2333
#define ADD_TYPED_ISNAN_OP(data_type) \
2434
ONNX_CPU_OPERATOR_TYPED_KERNEL( \
2535
IsNaN, \
26-
13, \
36+
20, \
2737
data_type, \
2838
KernelDefBuilder() \
2939
.TypeConstraint("T1", DataTypeImpl::GetTensorType<data_type>()) \
@@ -33,10 +43,20 @@ namespace onnxruntime {
3343
ADD_TYPED_ISNAN_OP_9(float);
3444
ADD_TYPED_ISNAN_OP_9(double);
3545
ADD_TYPED_ISNAN_OP_9(MLFloat16);
46+
ADD_TYPED_ISNAN_OP_13(float);
47+
ADD_TYPED_ISNAN_OP_13(double);
48+
ADD_TYPED_ISNAN_OP_13(MLFloat16);
3649
ADD_TYPED_ISNAN_OP(float);
3750
ADD_TYPED_ISNAN_OP(double);
3851
ADD_TYPED_ISNAN_OP(MLFloat16);
3952

53+
#if !defined(DISABLE_FLOAT8_TYPES)
54+
ADD_TYPED_ISNAN_OP(Float8E4M3FN);
55+
ADD_TYPED_ISNAN_OP(Float8E4M3FNUZ);
56+
ADD_TYPED_ISNAN_OP(Float8E5M2);
57+
ADD_TYPED_ISNAN_OP(Float8E5M2FNUZ);
58+
#endif
59+
4060
template <typename T>
4161
Status IsNaN<T>::Compute(OpKernelContext* context) const {
4262
const auto* X_ptr = context->Input<Tensor>(0);
@@ -70,4 +90,63 @@ Status IsNaN<MLFloat16>::Compute(OpKernelContext* context) const {
7090

7191
return Status::OK();
7292
}
93+
94+
#if !defined(DISABLE_FLOAT8_TYPES)
95+
template <>
96+
Status IsNaN<Float8E4M3FN>::Compute(OpKernelContext* context) const {
97+
const auto* X = context->Input<Tensor>(0);
98+
auto& dims = X->Shape();
99+
auto& Y = *context->Output(0, dims);
100+
101+
auto input = ConstEigenVectorMap<uint8_t>(static_cast<const uint8_t*>(static_cast<const void*>(X->Data<Float8E4M3FN>())), onnxruntime::narrow<size_t>(dims.Size()));
102+
auto output = EigenMap<bool>(Y);
103+
104+
// S.1111.111
105+
std::transform(input.begin(), input.end(), output.begin(), [](uint8_t c) { return (c & 0x7f) == 0x7f; });
106+
return Status::OK();
107+
}
108+
109+
template <>
110+
Status IsNaN<Float8E4M3FNUZ>::Compute(OpKernelContext* context) const {
111+
const auto* X = context->Input<Tensor>(0);
112+
auto X_data = X->Data<Float8E4M3FNUZ>();
113+
auto& dims = X->Shape();
114+
auto shape_size = dims.Size();
115+
auto& Y = *context->Output(0, dims);
116+
117+
// 1.0000.000
118+
EigenMap<bool>(Y) =
119+
ConstEigenVectorMap<uint8_t>(static_cast<const uint8_t*>(static_cast<const void*>(X_data)), onnxruntime::narrow<size_t>(shape_size)).array() == 0x80;
120+
121+
return Status::OK();
122+
}
123+
124+
template <>
125+
Status IsNaN<Float8E5M2>::Compute(OpKernelContext* context) const {
126+
const auto* X = context->Input<Tensor>(0);
127+
auto& dims = X->Shape();
128+
auto& Y = *context->Output(0, dims);
129+
130+
auto input = ConstEigenVectorMap<uint8_t>(static_cast<const uint8_t*>(static_cast<const void*>(X->Data<Float8E5M2>())), onnxruntime::narrow<size_t>(dims.Size()));
131+
auto output = EigenMap<bool>(Y);
132+
133+
// S.11111.{01, 10, 11}
134+
std::transform(input.begin(), input.end(), output.begin(), [](uint8_t c) { return ((c & 0x7c) == 0x7c) && ((c & 0x03) != 0x00); });
135+
return Status::OK();
136+
}
137+
138+
template <>
139+
Status IsNaN<Float8E5M2FNUZ>::Compute(OpKernelContext* context) const {
140+
const auto* X = context->Input<Tensor>(0);
141+
auto X_data = X->Data<Float8E5M2FNUZ>();
142+
auto& dims = X->Shape();
143+
auto shape_size = dims.Size();
144+
auto& Y = *context->Output(0, dims);
145+
146+
// 1.0000.000
147+
EigenMap<bool>(Y) = ConstEigenVectorMap<uint8_t>(static_cast<const uint8_t*>(static_cast<const void*>(X_data)), onnxruntime::narrow<size_t>(shape_size)).array() == 0x80;
148+
149+
return Status::OK();
150+
}
151+
#endif
73152
} // namespace onnxruntime

0 commit comments

Comments
 (0)