Skip to content

Commit e331363

Browse files
committed
feat: Add _to_copy, operator.get and clone
- Add ATen converters for key operators in the pipeline of multiple models - Add robust testing and patch issues in interpreter - Add evaluator and casting utilities to the converter utils
1 parent ce06f6e commit e331363

File tree

10 files changed

+335
-14
lines changed

10 files changed

+335
-14
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 81 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
import operator
23
from typing import Dict, Sequence, Tuple, Union
34
import torch
45
import tensorrt as trt
@@ -9,8 +10,10 @@
910
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
1011
from torch_tensorrt.dynamo._SourceIR import SourceIR
1112
from torch_tensorrt.dynamo.conversion import impl
12-
from torch_tensorrt.dynamo.conversion.converter_utils import cast_trt_tensor
13-
from torch_tensorrt.dynamo.conversion.converter_utils import cast_int_int_div_trt_tensor
13+
from torch_tensorrt.dynamo.conversion.converter_utils import (
14+
cast_trt_tensor,
15+
cast_int_int_div_trt_tensor,
16+
)
1417

1518
_LOGGER: logging.Logger = logging.getLogger(__name__)
1619

@@ -70,13 +73,13 @@ def aten_ops_div(
7073
kwargs_new["input"].dtype == trt.int8 or kwargs_new["input"].dtype == trt.int32
7174
):
7275
kwargs_new["input"] = cast_trt_tensor(
73-
network, kwargs_new["input"], trt.float32, name
76+
network, kwargs_new["input"], trt.float32, name, target
7477
)
7578
elif isinstance(args[1], TRTTensor) and (
7679
kwargs_new["other"].dtype == trt.int8 or kwargs_new["other"].dtype == trt.int32
7780
):
7881
kwargs_new["other"] = cast_trt_tensor(
79-
network, kwargs_new["other"], trt.float32, name
82+
network, kwargs_new["other"], trt.float32, name, target
8083
)
8184
rounding_mode = kwargs.get("rounding_mode")
8285
if rounding_mode is None:
@@ -377,3 +380,77 @@ def aten_ops_permute(
377380
args[0],
378381
args[1],
379382
)
383+
384+
385+
def to_copy_dtype_validator(to_copy_node: Node):
386+
allowed_casts = {torch.float, torch.int32, torch.bool, torch.int8, torch.float16}
387+
388+
# Validate input node has convertible kwargs
389+
if "dtype" in to_copy_node.kwargs:
390+
if to_copy_node.kwargs["dtype"] in allowed_casts:
391+
return True
392+
else:
393+
_LOGGER.debug(
394+
f"_to_copy converter rejected node {to_copy_node} with dtype {to_copy_node.kwargs['dtype']}"
395+
)
396+
return False
397+
else:
398+
_LOGGER.debug(
399+
f"_to_copy converter rejected node {to_copy_node} with kwargs {to_copy_node.kwargs}"
400+
)
401+
return False
402+
403+
404+
@dynamo_tensorrt_converter(
405+
torch.ops.aten._to_copy.default, capability_validator=to_copy_dtype_validator
406+
)
407+
def aten_ops_to_copy_dtype(
408+
network: TRTNetwork,
409+
target: Target,
410+
args: Tuple[Argument, ...],
411+
kwargs: Dict[str, Argument],
412+
name: str,
413+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
414+
return impl.cast.to_copy(
415+
network,
416+
target,
417+
SourceIR.ATEN,
418+
name,
419+
args[0],
420+
kwargs["dtype"],
421+
)
422+
423+
424+
@dynamo_tensorrt_converter(operator.getitem)
425+
def operator_getitem(
426+
network: TRTNetwork,
427+
target: Target,
428+
args: Tuple[Argument, ...],
429+
kwargs: Dict[str, Argument],
430+
name: str,
431+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
432+
return impl.evaluators.getitem(
433+
network,
434+
target,
435+
SourceIR.ATEN,
436+
name,
437+
args[0],
438+
args[1],
439+
)
440+
441+
442+
@dynamo_tensorrt_converter(torch.ops.aten.clone.default)
443+
def aten_ops_clone(
444+
network: TRTNetwork,
445+
target: Target,
446+
args: Tuple[Argument, ...],
447+
kwargs: Dict[str, Argument],
448+
name: str,
449+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
450+
return impl.evaluators.clone(
451+
network,
452+
target,
453+
SourceIR.ATEN,
454+
name,
455+
args[0],
456+
)

py/torch_tensorrt/dynamo/conversion/converter_utils.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import torch
22

3+
from torch.fx.node import _get_qualified_name, Target
4+
35
from torch_tensorrt.fx.types import (
46
TRTDataType,
57
TRTNetwork,
@@ -12,7 +14,9 @@
1214
)
1315

1416
import tensorrt as trt
15-
from typing import List
17+
from typing import List, Optional
18+
19+
from .._SourceIR import SourceIR
1620

1721

1822
def dynamic_unsupported(node: torch.fx.Node) -> bool:
@@ -49,24 +53,35 @@ def cast_trt_tensor(
4953
input_val: TRTTensor,
5054
dtype: TRTDataType,
5155
name: str,
56+
target: Target = "",
57+
source_ir: Optional[SourceIR] = None,
5258
) -> TRTTensor:
5359
"""
5460
Given a TRT Tensor, convert that Tensor to the specified dtype
5561
Adds an Identity layer to the network which performs the conversion
5662
Args:
5763
network (TRTNetwork): A TensorRT network
5864
input_val (TRTTensor): A TRT Tensor to cast to a new data type
59-
dtype (TRTDataType): The TRTDataType to cast the input Tensor to
65+
dtype (TRTDataType, torch.dtype, np.dtype): The data type to cast the input Tensor to
6066
name (str): Name of the calling layer
67+
target (Target): Target of calling node
68+
source_ir (SourceIR): SourceIR of calling converter
6169
Returns:
6270
A TensorRT ITensor which has been casted to the specified dtype
6371
"""
6472
trt_dtype = unified_dtype_converter(dtype, Frameworks.TRT)
6573

6674
if input_val.dtype != trt_dtype:
75+
source_ir = source_ir if source_ir is not None else SourceIR.UNKNOWN
76+
target_name = (
77+
f"{source_ir}_ops{'.' + target if target else ''}"
78+
if (isinstance(target, str))
79+
else f"{source_ir}_ops.{_get_qualified_name(target)}"
80+
)
81+
6782
identity_layer = network.add_identity(input_val)
6883
identity_layer.set_output_type(0, trt_dtype)
69-
identity_layer.name = f"Cast ITensor {input_val.name} from {input_val.dtype} to {trt_dtype} - {name}"
84+
identity_layer.name = f"Cast ITensor {input_val.name} from {input_val.dtype} to {trt_dtype} -{name}-[{target_name}]-[{name}]"
7085
return identity_layer.get_output(0)
7186
else:
7287
return input_val

py/torch_tensorrt/dynamo/conversion/impl/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,5 @@
1212
from . import squeeze
1313
from . import unsqueeze
1414
from . import permutation
15+
from . import cast
16+
from . import evaluators
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from typing import Optional
2+
from torch.fx.node import Target
3+
4+
from torch_tensorrt.dynamo._SourceIR import SourceIR
5+
from torch_tensorrt.dynamo.conversion.converter_utils import cast_trt_tensor
6+
7+
from torch_tensorrt.fx.types import (
8+
TRTNetwork,
9+
TRTTensor,
10+
TRTDataType,
11+
)
12+
13+
14+
def to_copy(
15+
network: TRTNetwork,
16+
target: Target,
17+
source_ir: Optional[SourceIR],
18+
name: str,
19+
input: TRTTensor,
20+
dtype: TRTDataType,
21+
) -> TRTTensor:
22+
if not isinstance(input, TRTTensor):
23+
raise RuntimeError(
24+
f"to_copy received input {input} that is not a TensorRT ITensor"
25+
)
26+
27+
casted_tensor = cast_trt_tensor(network, input, dtype, name, target, source_ir)
28+
return casted_tensor

py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,9 +137,13 @@ def convert_binary_elementwise(
137137
trt_promoted_type = unified_dtype_converter(promoted_type, Frameworks.TRT)
138138

139139
if trt_promoted_type != lhs_val.dtype:
140-
lhs_val = cast_trt_tensor(network, lhs_val, trt_promoted_type, name)
140+
lhs_val = cast_trt_tensor(
141+
network, lhs_val, trt_promoted_type, name, target, source_ir
142+
)
141143
if trt_promoted_type != rhs_val.dtype:
142-
rhs_val = cast_trt_tensor(network, rhs_val, trt_promoted_type, name)
144+
rhs_val = cast_trt_tensor(
145+
network, rhs_val, trt_promoted_type, name, target, source_ir
146+
)
143147

144148
# Check the limitation in the doc string.
145149
if network.has_implicit_batch_dimension:
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import operator
2+
import logging
3+
from typing import Optional, Sequence
4+
from torch.fx.node import Target
5+
6+
from torch_tensorrt.dynamo._SourceIR import SourceIR
7+
8+
from torch_tensorrt.fx.types import (
9+
TRTNetwork,
10+
TRTTensor,
11+
)
12+
13+
14+
LOGGER: logging.Logger = logging.getLogger(__name__)
15+
16+
17+
def getitem(
18+
network: TRTNetwork,
19+
target: Target,
20+
source_ir: Optional[SourceIR],
21+
name: str,
22+
input: Sequence[TRTTensor],
23+
index: int,
24+
) -> TRTTensor:
25+
LOGGER.debug(f"Evaluating getitem on object with name: {name}")
26+
27+
# Directly index the input sequence and return the value
28+
return operator.getitem(input, index)
29+
30+
31+
def clone(
32+
network: TRTNetwork,
33+
target: Target,
34+
source_ir: Optional[SourceIR],
35+
name: str,
36+
input: TRTTensor,
37+
) -> TRTTensor:
38+
if not isinstance(input, TRTTensor):
39+
raise RuntimeError(
40+
f"clone received input {input} that is not a TensorRT ITensor"
41+
)
42+
43+
LOGGER.debug(f"Evaluating clone on object with name: {name}")
44+
45+
return input

py/torch_tensorrt/dynamo/conversion/trt_interpreter.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@
2929
] = Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
3030

3131

32+
class UnsupportedOperatorException(RuntimeError):
33+
pass
34+
35+
3236
class TRTInterpreterResult(NamedTuple):
3337
engine: Any
3438
input_names: Sequence[str]
@@ -288,7 +292,7 @@ def call_module(self, target, args, kwargs):
288292
converter = CONVERTERS.get(self._cur_node)
289293

290294
if not converter:
291-
raise RuntimeError(
295+
raise UnsupportedOperatorException(
292296
f"Conversion of module of type {submod_type} not currently supported!"
293297
)
294298

@@ -298,7 +302,7 @@ def call_module(self, target, args, kwargs):
298302
def call_function(self, target, args, kwargs):
299303
converter = CONVERTERS.get(self._cur_node)
300304
if not converter:
301-
raise RuntimeError(
305+
raise UnsupportedOperatorException(
302306
f"Conversion of function {torch.typename(target)} not currently supported!"
303307
)
304308

@@ -310,7 +314,7 @@ def call_method(self, target, args, kwargs):
310314
converter = CONVERTERS.get(self._cur_node)
311315

312316
if not converter:
313-
raise RuntimeError(
317+
raise UnsupportedOperatorException(
314318
f"Conversion of method {target} not currently supported!"
315319
)
316320

py/torch_tensorrt/dynamo/test_utils.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@ def generate_graph(
217217
expected_ops: Set[Callable],
218218
unexpected_ops: Optional[Set[Callable]] = None,
219219
customized_passes: List[Callable] = None,
220+
disable_passes: bool = False,
220221
):
221222
# Torchdynamo+aot proxytensor tracer
222223
# Below are common passes
@@ -234,6 +235,10 @@ def generate_graph(
234235
# Combine with customized passes specific to any model
235236
if customized_passes:
236237
passes_list.extend(customized_passes)
238+
239+
if disable_passes:
240+
passes_list = []
241+
237242
fx_module, _ = aten_tracer.trace(mod, original_inputs)
238243
for passes in passes_list:
239244
pr: PassResult = passes(fx_module)
@@ -261,9 +266,17 @@ def run_test(
261266
atol=1e-03,
262267
precision=torch.float,
263268
check_dtype=True,
269+
disable_passes=False,
264270
):
265271
mod.eval()
266-
mod = self.generate_graph(mod, inputs, expected_ops, unexpected_ops, None)
272+
mod = self.generate_graph(
273+
mod,
274+
inputs,
275+
expected_ops,
276+
unexpected_ops,
277+
None,
278+
disable_passes=disable_passes,
279+
)
267280

268281
if apply_passes is not None:
269282
pass_tracer = chain_passes(*apply_passes)
@@ -293,10 +306,18 @@ def run_test_with_dynamic_shape(
293306
unexpected_ops=None,
294307
rtol=1e-03,
295308
atol=1e-03,
309+
disable_passes=False,
296310
):
297311
mod.eval()
298312
inputs = [spec.example_tensor("opt_shape") for spec in input_specs]
299-
mod = self.generate_graph(mod, inputs, expected_ops, unexpected_ops, None)
313+
mod = self.generate_graph(
314+
mod,
315+
inputs,
316+
expected_ops,
317+
unexpected_ops,
318+
None,
319+
disable_passes=disable_passes,
320+
)
300321

301322
interp = TRTInterpreter(
302323
mod,

0 commit comments

Comments
 (0)