diff --git a/py/torch_tensorrt/_Input.py b/py/torch_tensorrt/_Input.py index 796b8c6253..5ecaaa2862 100644 --- a/py/torch_tensorrt/_Input.py +++ b/py/torch_tensorrt/_Input.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from enum import Enum from typing import Any, Dict, List, Optional, Sequence, Tuple @@ -32,11 +34,11 @@ class _ShapeMode(Enum): shape: Optional[ Tuple[int, ...] | Dict[str, Tuple[int, ...]] ] = None #: Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }`` - dtype: _enums.dtype = ( # type: ignore[name-defined] + dtype: _enums.dtype = ( _enums.dtype.unknown ) #: The expected data type of the input tensor (default: torch_tensorrt.dtype.float32) _explicit_set_dtype: bool = False - format: _enums.TensorFormat = ( # type: ignore[name-defined] + format: _enums.TensorFormat = ( _enums.TensorFormat.contiguous ) #: The expected format of the input tensor (default: torch_tensorrt.TensorFormat.NCHW) @@ -208,7 +210,7 @@ def _supported_input_size_type(input_size: Any) -> bool: return False @staticmethod - def _parse_dtype(dtype: Any) -> _enums.dtype: # type: ignore[name-defined] + def _parse_dtype(dtype: Any) -> _enums.dtype: if isinstance(dtype, torch.dtype): if dtype == torch.long: return _enums.dtype.long @@ -236,7 +238,7 @@ def _parse_dtype(dtype: Any) -> _enums.dtype: # type: ignore[name-defined] ) @staticmethod - def _to_torch_dtype(dtype: _enums.dtype) -> torch.dtype: # type: ignore[name-defined] + def _to_torch_dtype(dtype: _enums.dtype) -> torch.dtype: if dtype == _enums.dtype.long: return torch.long elif dtype == _enums.dtype.int32: @@ -255,7 +257,7 @@ def is_trt_dtype(self) -> bool: return bool(self.dtype != _enums.dtype.long) @staticmethod - def _parse_format(format: Any) -> _enums.TensorFormat: # type: ignore[name-defined] + def _parse_format(format: Any) -> _enums.TensorFormat: if isinstance(format, torch.memory_format): if format == torch.contiguous_format: return _enums.TensorFormat.contiguous diff --git a/py/torch_tensorrt/_compile.py b/py/torch_tensorrt/_compile.py index 8a95d6eada..fc1c0d30d8 100644 --- a/py/torch_tensorrt/_compile.py +++ b/py/torch_tensorrt/_compile.py @@ -1,5 +1,7 @@ +from __future__ import annotations + from enum import Enum -from typing import Any, Callable, List, Optional, Sequence, Set, TypeGuard +from typing import Any, Callable, List, Optional, Sequence, Set import torch import torch.fx @@ -12,6 +14,7 @@ from torch_tensorrt.fx.lower import compile as fx_compile from torch_tensorrt.fx.utils import LowerPrecision from torch_tensorrt.ts._compiler import compile as torchscript_compile +from typing_extensions import TypeGuard def _non_fx_input_interface( diff --git a/py/torch_tensorrt/dynamo/aten_tracer.py b/py/torch_tensorrt/dynamo/aten_tracer.py index 3ff41bac3d..39e1ea7c9d 100644 --- a/py/torch_tensorrt/dynamo/aten_tracer.py +++ b/py/torch_tensorrt/dynamo/aten_tracer.py @@ -1,7 +1,9 @@ +from __future__ import annotations + import copy import sys from contextlib import contextmanager -from typing import Any, Callable, Dict, Generator, List, Optional, Tuple +from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union import torch import torch._dynamo as torchdynamo @@ -22,7 +24,7 @@ ) from typing_extensions import TypeAlias -Value: TypeAlias = Tuple["Value", ...] | List["Value"] | Dict[str, "Value"] +Value: TypeAlias = Union[Tuple["Value", ...], List["Value"], Dict[str, "Value"]] class DynamoConfig: diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index ca14ad264b..5d31789b92 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging from functools import partial from typing import Any, Callable, Sequence diff --git a/py/torch_tensorrt/dynamo/compile.py b/py/torch_tensorrt/dynamo/compile.py index 0402a6af43..1cc3474883 100644 --- a/py/torch_tensorrt/dynamo/compile.py +++ b/py/torch_tensorrt/dynamo/compile.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import collections.abc import logging from typing import Any, List, Optional, Set, Tuple diff --git a/py/torch_tensorrt/dynamo/conversion/conversion.py b/py/torch_tensorrt/dynamo/conversion/conversion.py index ae3c8b66b2..5c9bbd8c70 100644 --- a/py/torch_tensorrt/dynamo/conversion/conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/conversion.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import io from typing import Sequence diff --git a/py/torch_tensorrt/dynamo/conversion/converter_registry.py b/py/torch_tensorrt/dynamo/conversion/converter_registry.py index 7275844500..c883092166 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_registry.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_registry.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging from dataclasses import dataclass, field from enum import Enum, auto @@ -28,7 +30,7 @@ Dict[str, Argument], str, ], - TRTTensor | Sequence[TRTTensor], + Union[TRTTensor, Sequence[TRTTensor]], ] diff --git a/py/torch_tensorrt/dynamo/conversion/impl/shape.py b/py/torch_tensorrt/dynamo/conversion/impl/shape.py index 86c126552b..8bd137d991 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/shape.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/shape.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import List, Optional, Tuple import numpy as np diff --git a/py/torch_tensorrt/dynamo/conversion/truncate_long_and_double.py b/py/torch_tensorrt/dynamo/conversion/truncate_long_and_double.py index 5cb5d118be..c6c10f475a 100644 --- a/py/torch_tensorrt/dynamo/conversion/truncate_long_and_double.py +++ b/py/torch_tensorrt/dynamo/conversion/truncate_long_and_double.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import Optional, Sequence, Set import torch diff --git a/py/torch_tensorrt/dynamo/lowering/_pre_aot_lowering.py b/py/torch_tensorrt/dynamo/lowering/_pre_aot_lowering.py index 32250607df..e69b9987c7 100644 --- a/py/torch_tensorrt/dynamo/lowering/_pre_aot_lowering.py +++ b/py/torch_tensorrt/dynamo/lowering/_pre_aot_lowering.py @@ -1,10 +1,13 @@ +from __future__ import annotations + import logging from dataclasses import dataclass -from typing import Any, Callable, Dict, Optional, Type, TypeAlias +from typing import Any, Callable, Dict, Optional, Type import torch from torch._ops import OpOverload from torch.fx import GraphModule, Node +from typing_extensions import TypeAlias logger = logging.getLogger(__name__) diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index b5760161a6..e00b66d8f3 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import Any, Dict, List, Optional, Sequence, Tuple import torch diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index bddee9b93b..d5ad5021e2 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging from typing import Any, List, Optional, Tuple diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index bbb1a4354b..cec328e84f 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging from dataclasses import fields, replace from typing import Any, Callable, Dict, Optional, Sequence diff --git a/py/torch_tensorrt/ts/_compile_spec.py b/py/torch_tensorrt/ts/_compile_spec.py index 6803259985..b9a84152e1 100644 --- a/py/torch_tensorrt/ts/_compile_spec.py +++ b/py/torch_tensorrt/ts/_compile_spec.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from copy import deepcopy from typing import Any, Dict, List, Optional, Set @@ -39,7 +41,7 @@ def _supported_input_size_type(input_size: Any) -> bool: ) -def _parse_op_precision(precision: Any) -> _enums.dtype: # type: ignore[name-defined] +def _parse_op_precision(precision: Any) -> _enums.dtype: if isinstance(precision, torch.dtype): if precision == torch.int8: return _enums.dtype.int8 @@ -63,7 +65,7 @@ def _parse_op_precision(precision: Any) -> _enums.dtype: # type: ignore[name-de ) -def _parse_enabled_precisions(precisions: Any) -> Set[_enums.dtype]: # type: ignore[name-defined] +def _parse_enabled_precisions(precisions: Any) -> Set[_enums.dtype]: parsed_precisions = set() if any(isinstance(precisions, type) for type in [list, tuple, set]): for p in precisions: @@ -73,7 +75,7 @@ def _parse_enabled_precisions(precisions: Any) -> Set[_enums.dtype]: # type: ig return parsed_precisions -def _parse_device_type(device: Any) -> _enums.DeviceType: # type: ignore[name-defined] +def _parse_device_type(device: Any) -> _enums.DeviceType: if isinstance(device, torch.device): if device.type == "cuda": return _C.DeviceType.gpu @@ -346,10 +348,10 @@ def TensorRTCompileSpec( device: torch.device | Device = Device._current_device(), disable_tf32: bool = False, sparse_weights: bool = False, - enabled_precisions: Optional[Set[torch.dtype | _enums.dtype]] = None, # type: ignore[name-defined] + enabled_precisions: Optional[Set[torch.dtype | _enums.dtype]] = None, refit: bool = False, debug: bool = False, - capability: _enums.EngineCapability = _enums.EngineCapability.default, # type: ignore[name-defined] + capability: _enums.EngineCapability = _enums.EngineCapability.default, num_avg_timing_iters: int = 1, workspace_size: int = 0, dla_sram_size: int = 1048576, diff --git a/py/torch_tensorrt/ts/_compiler.py b/py/torch_tensorrt/ts/_compiler.py index 30828ce5d8..4a9bb53dc0 100644 --- a/py/torch_tensorrt/ts/_compiler.py +++ b/py/torch_tensorrt/ts/_compiler.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import Any, List, Optional, Sequence, Set, Tuple import torch