Skip to content

Commit d6a14d9

Browse files
bowang007gs-olive
authored andcommitted
support argmax converter
Signed-off-by: Bo Wang <[email protected]>
1 parent 0e4c5d8 commit d6a14d9

File tree

3 files changed

+67
-0
lines changed

3 files changed

+67
-0
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1722,3 +1722,22 @@ def aten_ops_reshape(
17221722
input=args[0],
17231723
shape=args[1],
17241724
)
1725+
1726+
1727+
@dynamo_tensorrt_converter(torch.ops.aten.argmax.default) # type: ignore[misc]
1728+
def aten_ops_argmax(
1729+
ctx: ConversionContext,
1730+
target: Target,
1731+
args: Tuple[Argument, ...],
1732+
kwargs: Dict[str, Argument],
1733+
name: str,
1734+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1735+
return impl.argmax.argmax(
1736+
ctx,
1737+
target,
1738+
SourceIR.ATEN,
1739+
name,
1740+
input=args[0],
1741+
dim=args_bounds_check(args, 1),
1742+
keep_dim=args_bounds_check(args, 2, False),
1743+
)

py/torch_tensorrt/dynamo/conversion/impl/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from . import (
44
activation,
55
attention,
6+
argmax,
67
cast,
78
cat,
89
condition,
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from typing import Optional
2+
3+
import tensorrt as trt
4+
from torch.fx.node import Target
5+
from torch_tensorrt.dynamo._SourceIR import SourceIR
6+
from torch_tensorrt.dynamo.conversion.converter_utils import (
7+
cast_trt_tensor,
8+
get_axes_for_reduce_op,
9+
)
10+
from torch_tensorrt.fx.converters.converter_utils import (
11+
get_positive_dim,
12+
set_layer_name,
13+
)
14+
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
15+
16+
from . import squeeze
17+
18+
19+
def argmax(
20+
network: TRTNetwork,
21+
target: Target,
22+
source_ir: Optional[SourceIR],
23+
name: str,
24+
input: TRTTensor,
25+
dim: int = 0,
26+
keep_dim: bool = False,
27+
) -> TRTTensor:
28+
if not isinstance(input, TRTTensor):
29+
raise RuntimeError(
30+
f"argmax received input {input} that is not part " "of the TensorRT region!"
31+
)
32+
if input.dtype == trt.int32:
33+
input = cast_trt_tensor(network, input, trt.float32, name)
34+
if dim < 0:
35+
dim = len(tuple(input.shape)) + dim
36+
reduce_mask = get_axes_for_reduce_op(get_positive_dim(dim, len(input.shape)))
37+
topk_layer = network.add_topk(input, trt.TopKOperation.MAX, 1, reduce_mask)
38+
set_layer_name(topk_layer, target, name)
39+
40+
out = topk_layer.get_output(1)
41+
42+
if not keep_dim:
43+
out = squeeze.squeeze(
44+
network, target, SourceIR.ATEN, name + "_squeeze", out, dim
45+
)
46+
47+
return out

0 commit comments

Comments
 (0)