Skip to content

Commit de04072

Browse files
committed
feat: support aten.clamp.Tensor and update aten.clamp.default dynamo converters
1 parent f0e6d2d commit de04072

File tree

3 files changed

+40
-58
lines changed

3 files changed

+40
-58
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -683,6 +683,7 @@ def aten_ops_where(
683683

684684

685685
@dynamo_tensorrt_converter(torch.ops.aten.clamp.default)
686+
@dynamo_tensorrt_converter(torch.ops.aten.clamp.Tensor)
686687
def aten_ops_clamp(
687688
ctx: ConversionContext,
688689
target: Target,

py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py

Lines changed: 14 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from typing import Optional, Union
22

3-
import numpy as np
43
import tensorrt as trt
54
import torch
65
import torch_tensorrt.dynamo.conversion.impl as impl
@@ -17,7 +16,6 @@
1716
)
1817
from torch_tensorrt.dynamo.conversion.impl.unary import sign
1918
from torch_tensorrt.dynamo.conversion.impl.unary.base import convert_unary
20-
from torch_tensorrt.fx.converters.converter_utils import set_layer_name, squeeze_left
2119
from torch_tensorrt.fx.types import TRTTensor
2220
from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter
2321

@@ -186,63 +184,22 @@ def clamp(
186184
source_ir: Optional[SourceIR],
187185
name: str,
188186
input_val: TRTTensor,
189-
min_val: Optional[float] = None,
190-
max_val: Optional[float] = None,
187+
min_val: Optional[Union[int, float, TRTTensor]] = None,
188+
max_val: Optional[Union[int, float, TRTTensor]] = None,
191189
) -> TRTTensor:
192-
if not isinstance(input_val, TRTTensor):
193-
raise RuntimeError(
194-
f"Clamp received input {input_val} that is not part "
195-
"of the TensorRT region!"
196-
)
197-
198-
def _add_layer(
199-
ctx: ConversionContext,
200-
input: TRTTensor,
201-
val: float,
202-
op: trt.ElementWiseOperation,
203-
name: str,
204-
) -> (
205-
trt.ILayer
206-
): # TODO: Simplify and merge implementations, should just be max and min stacked
207-
if not len(input.shape):
208-
# clamping scalar
209-
acc_ops_clamp_trt = get_trt_tensor(
210-
ctx,
211-
squeeze_left(
212-
np.array(
213-
[val],
214-
dtype=unified_dtype_converter(input.dtype, Frameworks.NUMPY),
215-
)
216-
),
217-
f"{name}_clamp_{val}",
218-
)
219-
else:
220-
acc_ops_clamp_shape = (1,) * len(input.shape) # broadcast all dimensions
221-
acc_ops_clamp_tensor = np.full(
222-
acc_ops_clamp_shape,
223-
val,
224-
dtype=unified_dtype_converter(input.dtype, Frameworks.NUMPY),
225-
)
226-
acc_ops_clamp_trt = ctx.net.add_constant(
227-
acc_ops_clamp_shape, acc_ops_clamp_tensor
228-
).get_output(0)
229-
layer = ctx.net.add_elementwise(input, acc_ops_clamp_trt, op)
230-
return layer
231-
232-
if min_val is not None:
233-
clamp_min_layer = _add_layer(
234-
ctx, input_val, min_val, trt.ElementWiseOperation.MAX, name
235-
)
236-
set_layer_name(clamp_min_layer, target, f"{name}_clamp_min")
237-
input_val = clamp_min_layer.get_output(0)
238-
if max_val is not None:
239-
clamp_max_layer = _add_layer(
240-
ctx, input_val, max_val, trt.ElementWiseOperation.MIN, name
241-
)
242-
set_layer_name(clamp_max_layer, target, f"{name}_clamp_max")
243-
input_val = clamp_max_layer.get_output(0)
190+
if min_val is None:
191+
min_val = float("-inf")
192+
if max_val is None:
193+
max_val = float("inf")
244194

245-
return input_val
195+
return impl.elementwise.min(
196+
ctx,
197+
target,
198+
source_ir,
199+
f"{name}_min",
200+
impl.elementwise.max(ctx, target, source_ir, f"{name}_max", input_val, min_val),
201+
max_val,
202+
)
246203

247204

248205
def add(

tests/py/dynamo/conversion/test_clamp_aten.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def forward(self, x):
4949

5050
class TestScalarModule(torch.nn.Module):
5151
def forward(self, x):
52-
y = torch.ops.aten.mean.default(x)
52+
y = torch.ops.aten.mean.dim(x, None, True)
5353
return torch.ops.aten.clamp.default(y, min, max)
5454

5555
input_specs = [
@@ -63,6 +63,30 @@ def forward(self, x):
6363
self.run_test_with_dynamic_shape(TestModule(), input_specs)
6464
self.run_test_with_dynamic_shape(TestScalarModule(), input_specs)
6565

66+
@parameterized.expand(
67+
[
68+
param("default", min=-1 * torch.randn(3, 4), max=0 * torch.randn(3, 4)),
69+
param("min", min=0.5 * torch.randn(3, 4)),
70+
param("max", max=0.5 * torch.randn(3, 4)),
71+
param(
72+
"minBiggerThanMax", min=1 * torch.randn(3, 4), max=0 * torch.randn(3, 4)
73+
),
74+
param("float32Boundary", min=-3.4028234663852886e38 * torch.randn(3, 4)),
75+
]
76+
)
77+
def test_clamp_tensor(
78+
self,
79+
test_name,
80+
min=None,
81+
max=None,
82+
):
83+
class TestModule(torch.nn.Module):
84+
def forward(self, x):
85+
return torch.ops.aten.clamp.Tensor(x, min, max)
86+
87+
inputs = [torch.randn(3, 4)]
88+
self.run_test(TestModule(), inputs)
89+
6690

6791
if __name__ == "__main__":
6892
run_tests()

0 commit comments

Comments
 (0)