Skip to content

Commit 9e62066

Browse files
committed
resolve reviews
1 parent 0047b3d commit 9e62066

File tree

3 files changed

+10
-35
lines changed

3 files changed

+10
-35
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1395,5 +1395,5 @@ def aten_ops_argmax(
13951395
name,
13961396
input=args[0],
13971397
dim=args_bounds_check(args, 1),
1398-
keep_dim=args_bounds_check(args, 2),
1398+
keep_dim=args_bounds_check(args, 2, False),
13991399
)

py/torch_tensorrt/dynamo/conversion/impl/argmax.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,14 @@
33
import tensorrt as trt
44
from torch.fx.node import Target
55
from torch_tensorrt.dynamo._SourceIR import SourceIR
6-
from torch_tensorrt.dynamo.conversion.converter_utils import cast_trt_tensor
7-
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
6+
from torch_tensorrt.dynamo.conversion.converter_utils import (
7+
cast_trt_tensor,
8+
get_axes_for_reduce_op,
9+
)
10+
from torch_tensorrt.fx.converters.converter_utils import (
11+
get_positive_dim,
12+
set_layer_name,
13+
)
814
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
915

1016
from . import squeeze
@@ -27,7 +33,7 @@ def argmax(
2733
input = cast_trt_tensor(network, input, trt.float32, name)
2834
if dim < 0:
2935
dim = len(tuple(input.shape)) + dim
30-
reduce_mask = 1 << dim
36+
reduce_mask = get_axes_for_reduce_op(get_positive_dim(dim, len(input.shape)))
3137
topk_layer = network.add_topk(input, trt.TopKOperation.MAX, 1, reduce_mask)
3238
set_layer_name(topk_layer, target, name)
3339

tests/py/dynamo/converters/test_argmax.py

Lines changed: 0 additions & 31 deletions
This file was deleted.

0 commit comments

Comments
 (0)