Skip to content

Commit be15018

Browse files
committed
Support aten._log_softmax dynamo converter
1 parent 0f8f23d commit be15018

File tree

3 files changed

+48
-1
lines changed

3 files changed

+48
-1
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -693,6 +693,20 @@ def aten_ops_softmax(
693693
)
694694

695695

696+
@dynamo_tensorrt_converter(
697+
torch.ops.aten._log_softmax.default, supports_dynamic_shapes=True
698+
)
699+
def aten_ops_log_softmax(
700+
ctx: ConversionContext,
701+
target: Target,
702+
args: Tuple[Argument, ...],
703+
kwargs: Dict[str, Argument],
704+
name: str,
705+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
706+
softmax = aten_ops_softmax(ctx, target, args, kwargs, name)
707+
return impl.unary.log(ctx, target, SourceIR.ATEN, name, softmax)
708+
709+
696710
@dynamo_tensorrt_converter(
697711
torch.ops.aten.split.Tensor,
698712
capability_validator=has_static_shapes_in_args([1]),

py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,6 @@
7777
aten.logit_backward,
7878
aten.log_sigmoid_backward,
7979
aten.log_sigmoid_forward,
80-
aten._log_softmax,
8180
aten._log_softmax_backward_data,
8281
aten.logspace,
8382
aten.logsumexp.default,
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import torch
2+
from torch.testing._internal.common_utils import run_tests
3+
from torch_tensorrt import Input
4+
5+
from .harness import DispatchTestCase
6+
7+
8+
class TestLogSoftmaxConverter(DispatchTestCase):
9+
def test_log_softmax(self):
10+
class TestModule(torch.nn.Module):
11+
def forward(self, x):
12+
return torch.ops.aten._log_softmax.default(x, 1, False)
13+
14+
inputs = [torch.randn(1, 3, 5, 7)]
15+
self.run_test(TestModule(), inputs)
16+
17+
def test_log_softmax_with_dynamic_shape(self):
18+
class TestModule(torch.nn.Module):
19+
def forward(self, x):
20+
return torch.ops.aten._log_softmax.default(x, 2, False)
21+
22+
input_specs = [
23+
Input(
24+
min_shape=(1, 1, 1, 1),
25+
opt_shape=(2, 4, 6, 8),
26+
max_shape=(8, 8, 8, 8),
27+
dtype=torch.float32,
28+
),
29+
]
30+
self.run_test_with_dynamic_shape(TestModule(), input_specs)
31+
32+
33+
if __name__ == "__main__":
34+
run_tests()

0 commit comments

Comments
 (0)