|
7 | 7 | import io |
8 | 8 | import logging |
9 | 9 | import warnings |
10 | | -from typing import Any, Dict, Optional, Sequence, Tuple, Union |
| 10 | +from typing import Any, Optional, Sequence |
11 | 11 |
|
12 | 12 | import onnx |
13 | 13 | from onnx import ValueInfoProto, helper |
@@ -185,7 +185,7 @@ def __init__(self, name: str, domain: str = "") -> None: |
185 | 185 | self.stmts: list[IRStmt] = [] |
186 | 186 | self.attrs: list[str] = [] # attribute parameters |
187 | 187 | self.attr_protos: list[ |
188 | | - onnx.AttributeProto |
| 188 | + IRAttributeValue |
189 | 189 | ] = [] # attribute parameters with default value |
190 | 190 | self.called_functions: dict[str, onnx.FunctionProto] = {} |
191 | 191 | self.docstring: str = "" |
@@ -218,7 +218,7 @@ def append_input(self, name: IRVar) -> None: |
218 | 218 | def append_output(self, name: IRVar) -> None: |
219 | 219 | self.outputs.append(name) |
220 | 220 |
|
221 | | - def add_attr_parameter(self, attr: Union[str, IRAttributeValue]) -> None: |
| 221 | + def add_attr_parameter(self, attr: str | IRAttributeValue) -> None: |
222 | 222 | if isinstance(attr, IRAttributeValue): |
223 | 223 | self.attr_protos.append(attr) |
224 | 224 | else: |
@@ -324,7 +324,7 @@ def to_proto(f): |
324 | 324 |
|
325 | 325 | def to_graph_and_functions( |
326 | 326 | self, use_default_type: bool = True |
327 | | - ) -> Tuple[onnx.GraphProto, dict[str, onnx.FunctionProto]]: |
| 327 | + ) -> tuple[onnx.GraphProto, dict[str, onnx.FunctionProto]]: |
328 | 328 | """Converts this instance into a `onnx.GraphProto` and a map from |
329 | 329 | function-name to `onnx.FunctionProto`. |
330 | 330 |
|
@@ -360,7 +360,7 @@ def to_graph_proto(self, use_default_type: bool = True) -> onnx.GraphProto: |
360 | 360 | graph, _ = self.to_graph_and_functions(use_default_type=use_default_type) |
361 | 361 | return graph |
362 | 362 |
|
363 | | - def get_opset_import(self) -> Dict[str, int]: |
| 363 | + def get_opset_import(self) -> dict[str, int]: |
364 | 364 | func_opset_imports = {} |
365 | 365 | for s in self.stmts: |
366 | 366 | if s.callee.opset.domain not in func_opset_imports: |
|
0 commit comments