Skip to content

Commit 768417a

Browse files
committed
fix: Replay all FX changes in Dynamo
- Add multiple fixes to make FX changes appear in Dynamo directory, using Dynamo registry - All converters with open PRs are linked and shown - Update references, imports, code, merges, rebases accordingly - Add new test cases to Dynamo for converters
1 parent 6f345cf commit 768417a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

58 files changed

+4829
-0
lines changed

.circleci/config.yml

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -797,6 +797,22 @@ commands:
797797
- store_artifacts:
798798
path: /tmp/testlogs
799799

800+
test-dynamo_converters:
801+
description: "Test the Dynamo aten converters"
802+
steps:
803+
- run:
804+
name: Run Dynamo converter tests
805+
command: |
806+
set -e
807+
cd py/torch_tensorrt/dynamo/converters/test
808+
TESTS_TO_RUN=$(circleci tests glob "test_*.py" | circleci tests split --split-by=timings)
809+
pytest --junitxml=/tmp/artifacts/test_results/dynamo/converters/test_results.xml $TESTS_TO_RUN
810+
811+
- store_test_results:
812+
path: /tmp/artifacts
813+
- store_artifacts:
814+
path: /tmp/testlogs
815+
800816
# =================== Dynamo tests end ======================== #
801817

802818
# Define a job to be invoked later in a workflow.
@@ -1056,6 +1072,7 @@ jobs:
10561072
- test-dynamo-compile
10571073
- test-dynamo-compile-core
10581074
- test-dynamo-fx_ts
1075+
- test-dynamo_converters
10591076

10601077
package-x86_64-linux:
10611078
parameters:

py/torch_tensorrt/dynamo/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
DYNAMO_CONVERTERS,
33
dynamo_tensorrt_converter,
44
)
5+
from .converters import *
56

67
from torch_tensorrt.dynamo import fx_ts_compat
78
from .backend import compile

py/torch_tensorrt/dynamo/backend/lowering/_decompositions.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,16 @@ def rsqrt_replacement(*args, **kwargs) -> torch.Tensor:
5151
return torch.reciprocal(torch.sqrt(*args, **kwargs))
5252

5353

54+
@register_decomposition(aten._unsafe_view, registry=DECOMPOSITIONS)
55+
def unsafe_view_replacement(x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
56+
return torch.reshape(x, *args, **kwargs)
57+
58+
59+
@register_decomposition(torch.ops.aten.lift_fresh_copy, registry=DECOMPOSITIONS)
60+
def lift_fresh_copy_replacement(x: torch.Tensor) -> torch.Tensor:
61+
return x
62+
63+
5464
@register_decomposition(aten.alias, registry=DECOMPOSITIONS)
5565
def alias_replacement(x: torch.Tensor) -> torch.Tensor:
5666
return x
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .aten_ops_converters import *
Lines changed: 303 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,303 @@
1+
import logging
2+
from typing import Dict, Sequence, Tuple, Union
3+
import torch
4+
from torch_tensorrt.fx.converters import acc_ops_converters
5+
from torch_tensorrt.dynamo import dynamo_tensorrt_converter
6+
from torch.fx.node import Argument, Target
7+
8+
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
9+
from torch_tensorrt.dynamo.converters.converter_utils import SourceIR
10+
from torch_tensorrt.dynamo.converters.impl.elementwise import (
11+
trunc_div,
12+
rsqrt,
13+
fmod,
14+
rsub,
15+
clamp,
16+
)
17+
from torch_tensorrt.dynamo.converters.impl.normalization import (
18+
batch_norm,
19+
layer_norm,
20+
softmax,
21+
)
22+
from torch_tensorrt.fx.converters.impl import activation
23+
from torch_tensorrt.dynamo.converters.impl.squeeze import squeeze
24+
from torch_tensorrt.dynamo.converters.impl.select import select
25+
from torch_tensorrt.dynamo.converters.impl.slice import slice_op
26+
from torch_tensorrt.dynamo.converters.impl.matmul import matrix_multiply
27+
from torch_tensorrt.dynamo.converters.impl.condition import where
28+
from torch_tensorrt.dynamo.converters.impl.unsqueeze import unsqueeze
29+
30+
_LOGGER: logging.Logger = logging.getLogger(__name__)
31+
32+
33+
def or_none(args, i):
34+
return args[i] if len(args) > i else None
35+
36+
37+
@dynamo_tensorrt_converter(torch.ops.aten.batch_norm)
38+
def aten_ops_batch_norm(
39+
network: TRTNetwork,
40+
target: Target,
41+
args: Tuple[Argument, ...],
42+
kwargs: Dict[str, Argument],
43+
name: str,
44+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
45+
return batch_norm(
46+
network,
47+
target,
48+
SourceIR.ATEN,
49+
name,
50+
args[0],
51+
args[1],
52+
args[2],
53+
args[3],
54+
args[4],
55+
args[5],
56+
args[6],
57+
args[7],
58+
)
59+
60+
61+
@dynamo_tensorrt_converter(torch.ops.aten.div.default)
62+
@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor_mode)
63+
@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor)
64+
def aten_ops_div(
65+
network: TRTNetwork,
66+
target: Target,
67+
args: Tuple[Argument, ...],
68+
kwargs: Dict[str, Argument],
69+
name: str,
70+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
71+
kwargs_new = {
72+
"input": args[0],
73+
"other": args[1],
74+
}
75+
rounding_mode = kwargs.get("rounding_mode")
76+
if rounding_mode is None:
77+
return acc_ops_converters.acc_ops_div(network, target, None, kwargs_new, name)
78+
elif rounding_mode == "floor":
79+
return acc_ops_converters.acc_ops_floor_div(
80+
network, target, None, kwargs_new, name
81+
)
82+
elif rounding_mode == "trunc":
83+
return trunc_div(network, target, SourceIR.ATEN, name, args[0], args[1])
84+
else:
85+
raise RuntimeError(
86+
f"Target {target} does not support rounding mode {rounding_mode}"
87+
)
88+
89+
90+
@dynamo_tensorrt_converter(torch.ops.aten.fmod.Scalar)
91+
@dynamo_tensorrt_converter(torch.ops.aten.fmod.Tensor)
92+
def aten_ops_fmod(
93+
network: TRTNetwork,
94+
target: Target,
95+
args: Tuple[Argument, ...],
96+
kwargs: Dict[str, Argument],
97+
name: str,
98+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
99+
return fmod(network, target, SourceIR.ATEN, name, args[0], args[1])
100+
101+
102+
@dynamo_tensorrt_converter(torch.ops.aten.gelu.default)
103+
def aten_ops_gelu(
104+
network: TRTNetwork,
105+
target: Target,
106+
args: Tuple[Argument, ...],
107+
kwargs: Dict[str, Argument],
108+
name: str,
109+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
110+
return activation.gelu(
111+
network,
112+
target,
113+
SourceIR.ATEN,
114+
name,
115+
args[0],
116+
)
117+
118+
119+
@dynamo_tensorrt_converter(torch.ops.aten.matmul)
120+
@dynamo_tensorrt_converter(torch.ops.aten.mm.default)
121+
def aten_ops_matmul(
122+
network: TRTNetwork,
123+
target: Target,
124+
args: Tuple[Argument, ...],
125+
kwargs: Dict[str, Argument],
126+
name: str,
127+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
128+
return matrix_multiply(network, target, SourceIR.ATEN, name, args[0], args[1])
129+
130+
131+
@dynamo_tensorrt_converter(torch.ops.aten.layer_norm.default)
132+
def aten_ops_layernorm(
133+
network: TRTNetwork,
134+
target: Target,
135+
args: Tuple[Argument, ...],
136+
kwargs: Dict[str, Argument],
137+
name: str,
138+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
139+
return layer_norm(
140+
network,
141+
target,
142+
SourceIR.ATEN,
143+
name,
144+
args[0],
145+
args[1],
146+
args[2],
147+
args[3],
148+
args[4],
149+
)
150+
151+
152+
@dynamo_tensorrt_converter(torch.ops.aten.relu.default)
153+
def aten_ops_relu(
154+
network: TRTNetwork,
155+
target: Target,
156+
args: Tuple[Argument, ...],
157+
kwargs: Dict[str, Argument],
158+
name: str,
159+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
160+
161+
return activation.relu(
162+
network,
163+
target,
164+
SourceIR.ATEN,
165+
name,
166+
args[0],
167+
)
168+
169+
170+
@dynamo_tensorrt_converter(torch.ops.aten.rsqrt.default)
171+
def aten_ops_rsqrt(
172+
network: TRTNetwork,
173+
target: Target,
174+
args: Tuple[Argument, ...],
175+
kwargs: Dict[str, Argument],
176+
name: str,
177+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
178+
179+
return rsqrt(
180+
network,
181+
target,
182+
SourceIR.ATEN,
183+
name,
184+
args[0],
185+
)
186+
187+
188+
@dynamo_tensorrt_converter(torch.ops.aten.squeeze.dim)
189+
@dynamo_tensorrt_converter(torch.ops.aten.squeeze.dims)
190+
def aten_ops_squeeze(
191+
network: TRTNetwork,
192+
target: Target,
193+
args: Tuple[Argument, ...],
194+
kwargs: Dict[str, Argument],
195+
name: str,
196+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
197+
return squeeze(network, target, SourceIR.ATEN, name, args[0], args[1])
198+
199+
200+
@dynamo_tensorrt_converter(torch.ops.aten.unsqueeze.default)
201+
def aten_ops_unsqueeze(
202+
network: TRTNetwork,
203+
target: Target,
204+
args: Tuple[Argument, ...],
205+
kwargs: Dict[str, Argument],
206+
name: str,
207+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
208+
return unsqueeze(network, target, SourceIR.ATEN, name, input_t=args[0], dim=args[1])
209+
210+
211+
@dynamo_tensorrt_converter(torch.ops.aten.rsub.Tensor)
212+
def aten_ops_rsub(
213+
network: TRTNetwork,
214+
target: Target,
215+
args: Tuple[Argument, ...],
216+
kwargs: Dict[str, Argument],
217+
name: str,
218+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
219+
alpha = None
220+
if "alpha" in kwargs:
221+
alpha = kwargs["alpha"]
222+
return rsub(network, target, SourceIR.ATEN, name, args[0], args[1], alpha)
223+
224+
225+
@dynamo_tensorrt_converter(torch.ops.aten._softmax.default)
226+
def aten_ops_softmax(
227+
network: TRTNetwork,
228+
target: Target,
229+
args: Tuple[Argument, ...],
230+
kwargs: Dict[str, Argument],
231+
name: str,
232+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
233+
return softmax(network, target, SourceIR.ATEN, name, args[0], args[1])
234+
235+
236+
@dynamo_tensorrt_converter(torch.ops.aten.where.self)
237+
def aten_ops_where(
238+
network: TRTNetwork,
239+
target: Target,
240+
args: Tuple[Argument, ...],
241+
kwargs: Dict[str, Argument],
242+
name: str,
243+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
244+
return where(
245+
network,
246+
target,
247+
SourceIR.ATEN,
248+
name,
249+
args[1],
250+
args[2],
251+
args[0],
252+
)
253+
254+
255+
@dynamo_tensorrt_converter(torch.ops.aten.clamp.default)
256+
def aten_ops_clamp(
257+
network: TRTNetwork,
258+
target: Target,
259+
args: Tuple[Argument, ...],
260+
kwargs: Dict[str, Argument],
261+
name: str,
262+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
263+
return clamp.clamp(
264+
network,
265+
target,
266+
SourceIR.ACC,
267+
name,
268+
input_val=args[0],
269+
min_val=or_none(args, 1),
270+
max_val=or_none(args, 2),
271+
)
272+
273+
274+
@dynamo_tensorrt_converter(torch.ops.aten.select.int)
275+
def aten_ops_select(
276+
network: TRTNetwork,
277+
target: Target,
278+
args: Tuple[Argument, ...],
279+
kwargs: Dict[str, Argument],
280+
name: str,
281+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
282+
return select(network, target, SourceIR.ATEN, name, args[0], args[1], args[2])
283+
284+
285+
@dynamo_tensorrt_converter(torch.ops.aten.slice.Tensor)
286+
def aten_ops_slice(
287+
network: TRTNetwork,
288+
target: Target,
289+
args: Tuple[Argument, ...],
290+
kwargs: Dict[str, Argument],
291+
name: str,
292+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
293+
return slice_op(
294+
network,
295+
target,
296+
SourceIR.ATEN,
297+
name,
298+
args[0],
299+
args[1],
300+
args[2],
301+
args[3],
302+
args[4],
303+
)

0 commit comments

Comments
 (0)