|
4 | 4 | import torch |
5 | 5 |
|
6 | 6 | from torch_tensorrt import _enums |
7 | | -from torch_tensorrt import _C |
8 | 7 |
|
9 | 8 |
|
10 | 9 | class Input(object): |
@@ -41,6 +40,7 @@ class _ShapeMode(Enum): |
41 | 40 | DOMAIN_OFFSET = 2.0 |
42 | 41 | low_tensor_domain_incl = 0.0 |
43 | 42 | high_tensor_domain_excl = low_tensor_domain_incl + DOMAIN_OFFSET |
| 43 | + torch_dtype = None |
44 | 44 |
|
45 | 45 | def __init__(self, *args, **kwargs): |
46 | 46 | """__init__ Method for torch_tensorrt.Input |
@@ -138,6 +138,9 @@ def __init__(self, *args, **kwargs): |
138 | 138 | ) |
139 | 139 |
|
140 | 140 | if "dtype" in kwargs: |
| 141 | + if isinstance(kwargs["dtype"], torch.dtype): |
| 142 | + self.torch_dtype = kwargs["dtype"] |
| 143 | + |
141 | 144 | self.dtype = Input._parse_dtype(kwargs["dtype"]) |
142 | 145 | self._explicit_set_dtype = True |
143 | 146 |
|
@@ -173,59 +176,6 @@ def __str__(self) -> str: |
173 | 176 | else: |
174 | 177 | raise RuntimeError("Unknown input shape mode") |
175 | 178 |
|
176 | | - def _to_internal(self) -> _C.Input: |
177 | | - internal_in = _C.Input() |
178 | | - if self.shape_mode == Input._ShapeMode.DYNAMIC: |
179 | | - if not Input._supported_input_size_type(self.shape["min_shape"]): |
180 | | - raise TypeError( |
181 | | - "Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: " |
182 | | - + str(type(self.shape["min_shape"])) |
183 | | - + " for min_shape" |
184 | | - ) |
185 | | - else: |
186 | | - internal_in.min = self.shape["min_shape"] |
187 | | - |
188 | | - if not Input._supported_input_size_type(self.shape["opt_shape"]): |
189 | | - raise TypeError( |
190 | | - "Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: " |
191 | | - + str(type(self.shape["opt_shape"])) |
192 | | - + " for opt_shape" |
193 | | - ) |
194 | | - else: |
195 | | - internal_in.opt = self.shape["opt_shape"] |
196 | | - |
197 | | - if not Input._supported_input_size_type(self.shape["max_shape"]): |
198 | | - raise TypeError( |
199 | | - "Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: " |
200 | | - + str(type(self.shape["max_shape"])) |
201 | | - + " for max_shape" |
202 | | - ) |
203 | | - else: |
204 | | - internal_in.max = self.shape["max_shape"] |
205 | | - internal_in.input_is_dynamic = True |
206 | | - else: |
207 | | - if not Input._supported_input_size_type(self.shape): |
208 | | - raise TypeError( |
209 | | - "Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: " |
210 | | - + str(type(self.shape)) |
211 | | - + " for shape" |
212 | | - ) |
213 | | - else: |
214 | | - internal_in.opt = self.shape |
215 | | - internal_in.input_is_dynamic = False |
216 | | - |
217 | | - if self.dtype != _enums.dtype.unknown: |
218 | | - self._explicit_set_dtype = True |
219 | | - else: |
220 | | - self._explicit_set_dtype = False |
221 | | - |
222 | | - internal_in.dtype = Input._parse_dtype(self.dtype) |
223 | | - internal_in._explicit_set_dtype = self._explicit_set_dtype |
224 | | - internal_in.format = Input._parse_format(self.format) |
225 | | - |
226 | | - internal_in.tensor_domain = Input._parse_tensor_domain(self.tensor_domain) |
227 | | - return internal_in |
228 | | - |
229 | 179 | @staticmethod |
230 | 180 | def _supported_input_size_type(input_size: Any) -> bool: |
231 | 181 | if isinstance(input_size, torch.Size): |
@@ -304,6 +254,7 @@ def _parse_tensor_domain(domain: Optional[Tuple[float, float]]) -> Tuple: |
304 | 254 | Input.low_tensor_domain_incl, |
305 | 255 | Input.high_tensor_domain_excl, |
306 | 256 | ) |
| 257 | + |
307 | 258 | elif len(domain) == 2: |
308 | 259 | domain_lo, domain_hi = domain |
309 | 260 |
|
@@ -416,8 +367,10 @@ def example_tensor(self, optimization_profile_field: str = None) -> torch.Tensor |
416 | 367 | ) |
417 | 368 |
|
418 | 369 | if self.shape_mode == Input._ShapeMode.STATIC: |
419 | | - return torch.randn(self.shape).to(dtype=self.dtype) |
| 370 | + return torch.randn(self.shape).to( |
| 371 | + dtype=self.dtype if not self.torch_dtype else self.torch_dtype |
| 372 | + ) |
420 | 373 | else: |
421 | 374 | return torch.randn(self.shape[optimization_profile_field]).to( |
422 | | - dtype=self.dtype |
| 375 | + dtype=self.dtype if not self.torch_dtype else self.torch_dtype |
423 | 376 | ) |
0 commit comments