Skip to content

Commit 77fc946

Browse files
authored
Implement is_floating_point on dtypes (#31)
Migration of changes in microsoft/onnxscript#2335 Signed-off-by: Justin Chu <[email protected]>
1 parent bfa4699 commit 77fc946

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

src/onnx_ir/_enums.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,20 @@ def short_name(self) -> str:
142142
raise TypeError(f"Short name not available for ONNX data type: {self}")
143143
return _DATA_TYPE_TO_SHORT_NAME[self]
144144

145+
def is_floating_point(self) -> bool:
146+
"""Returns True if the data type is a floating point type."""
147+
return self in {
148+
DataType.FLOAT,
149+
DataType.FLOAT16,
150+
DataType.DOUBLE,
151+
DataType.BFLOAT16,
152+
DataType.FLOAT8E4M3FN,
153+
DataType.FLOAT8E4M3FNUZ,
154+
DataType.FLOAT8E5M2,
155+
DataType.FLOAT8E5M2FNUZ,
156+
DataType.FLOAT4E2M1,
157+
}
158+
145159
def __repr__(self) -> str:
146160
return self.name
147161

0 commit comments

Comments
 (0)