|
1 | 1 | from typing import Optional, Union
|
2 | 2 |
|
3 |
| -import numpy as np |
4 | 3 | import tensorrt as trt
|
5 | 4 | import torch
|
6 | 5 | import torch_tensorrt.dynamo.conversion.impl as impl
|
|
17 | 16 | )
|
18 | 17 | from torch_tensorrt.dynamo.conversion.impl.unary import sign
|
19 | 18 | from torch_tensorrt.dynamo.conversion.impl.unary.base import convert_unary
|
20 |
| -from torch_tensorrt.fx.converters.converter_utils import set_layer_name, squeeze_left |
21 | 19 | from torch_tensorrt.fx.types import TRTTensor
|
22 | 20 | from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter
|
23 | 21 |
|
@@ -186,63 +184,22 @@ def clamp(
|
186 | 184 | source_ir: Optional[SourceIR],
|
187 | 185 | name: str,
|
188 | 186 | input_val: TRTTensor,
|
189 |
| - min_val: Optional[float] = None, |
190 |
| - max_val: Optional[float] = None, |
| 187 | + min_val: Optional[Union[int, float, TRTTensor]] = None, |
| 188 | + max_val: Optional[Union[int, float, TRTTensor]] = None, |
191 | 189 | ) -> TRTTensor:
|
192 |
| - if not isinstance(input_val, TRTTensor): |
193 |
| - raise RuntimeError( |
194 |
| - f"Clamp received input {input_val} that is not part " |
195 |
| - "of the TensorRT region!" |
196 |
| - ) |
197 |
| - |
198 |
| - def _add_layer( |
199 |
| - ctx: ConversionContext, |
200 |
| - input: TRTTensor, |
201 |
| - val: float, |
202 |
| - op: trt.ElementWiseOperation, |
203 |
| - name: str, |
204 |
| - ) -> ( |
205 |
| - trt.ILayer |
206 |
| - ): # TODO: Simplify and merge implementations, should just be max and min stacked |
207 |
| - if not len(input.shape): |
208 |
| - # clamping scalar |
209 |
| - acc_ops_clamp_trt = get_trt_tensor( |
210 |
| - ctx, |
211 |
| - squeeze_left( |
212 |
| - np.array( |
213 |
| - [val], |
214 |
| - dtype=unified_dtype_converter(input.dtype, Frameworks.NUMPY), |
215 |
| - ) |
216 |
| - ), |
217 |
| - f"{name}_clamp_{val}", |
218 |
| - ) |
219 |
| - else: |
220 |
| - acc_ops_clamp_shape = (1,) * len(input.shape) # broadcast all dimensions |
221 |
| - acc_ops_clamp_tensor = np.full( |
222 |
| - acc_ops_clamp_shape, |
223 |
| - val, |
224 |
| - dtype=unified_dtype_converter(input.dtype, Frameworks.NUMPY), |
225 |
| - ) |
226 |
| - acc_ops_clamp_trt = ctx.net.add_constant( |
227 |
| - acc_ops_clamp_shape, acc_ops_clamp_tensor |
228 |
| - ).get_output(0) |
229 |
| - layer = ctx.net.add_elementwise(input, acc_ops_clamp_trt, op) |
230 |
| - return layer |
231 |
| - |
232 |
| - if min_val is not None: |
233 |
| - clamp_min_layer = _add_layer( |
234 |
| - ctx, input_val, min_val, trt.ElementWiseOperation.MAX, name |
235 |
| - ) |
236 |
| - set_layer_name(clamp_min_layer, target, f"{name}_clamp_min") |
237 |
| - input_val = clamp_min_layer.get_output(0) |
238 |
| - if max_val is not None: |
239 |
| - clamp_max_layer = _add_layer( |
240 |
| - ctx, input_val, max_val, trt.ElementWiseOperation.MIN, name |
241 |
| - ) |
242 |
| - set_layer_name(clamp_max_layer, target, f"{name}_clamp_max") |
243 |
| - input_val = clamp_max_layer.get_output(0) |
| 190 | + if min_val is None: |
| 191 | + min_val = float("-inf") |
| 192 | + if max_val is None: |
| 193 | + max_val = float("inf") |
244 | 194 |
|
245 |
| - return input_val |
| 195 | + return impl.elementwise.min( |
| 196 | + ctx, |
| 197 | + target, |
| 198 | + source_ir, |
| 199 | + f"{name}_min", |
| 200 | + impl.elementwise.max(ctx, target, source_ir, f"{name}_max", input_val, min_val), |
| 201 | + max_val, |
| 202 | + ) |
246 | 203 |
|
247 | 204 |
|
248 | 205 | def add(
|
|
0 commit comments