Skip to content

Commit 6a511f6

Browse files
committed
Fix incomplete upsample dynamo converter
1 parent 2e6b3a2 commit 6a511f6

File tree

4 files changed

+370
-95
lines changed

4 files changed

+370
-95
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 131 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2818,9 +2818,43 @@ def aten_ops_pad(
28182818
)
28192819

28202820

2821+
@dynamo_tensorrt_converter(torch.ops.aten.upsample_nearest1d.default)
28212822
@dynamo_tensorrt_converter(torch.ops.aten.upsample_nearest2d.default)
2823+
@dynamo_tensorrt_converter(torch.ops.aten.upsample_nearest3d.default)
2824+
@enforce_tensor_types(
2825+
{
2826+
0: (TRTTensor,),
2827+
}
2828+
)
2829+
def aten_ops_upsample_nearest_default(
2830+
ctx: ConversionContext,
2831+
target: Target,
2832+
args: Tuple[Argument, ...],
2833+
kwargs: Dict[str, Argument],
2834+
name: str,
2835+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
2836+
return impl.upsample.upsample(
2837+
ctx,
2838+
target,
2839+
SourceIR.ATEN,
2840+
name,
2841+
args[0],
2842+
size=args[1],
2843+
scale_factor=None,
2844+
mode="nearest",
2845+
align_corners=False,
2846+
)
2847+
2848+
2849+
@dynamo_tensorrt_converter(torch.ops.aten.upsample_nearest1d.vec)
28222850
@dynamo_tensorrt_converter(torch.ops.aten.upsample_nearest2d.vec)
2823-
def upsample_nearest2d(
2851+
@dynamo_tensorrt_converter(torch.ops.aten.upsample_nearest3d.vec)
2852+
@enforce_tensor_types(
2853+
{
2854+
0: (TRTTensor,),
2855+
}
2856+
)
2857+
def aten_ops_upsample_nearest_vec(
28242858
ctx: ConversionContext,
28252859
target: Target,
28262860
args: Tuple[Argument, ...],
@@ -2832,17 +2866,51 @@ def upsample_nearest2d(
28322866
target,
28332867
SourceIR.ATEN,
28342868
name,
2835-
input=args[0],
2836-
out_shape=args_bounds_check(args, 1),
2837-
scale_factors=args_bounds_check(args, 2),
2838-
resize_mode="nearest",
2869+
args[0],
2870+
size=args_bounds_check(args, 1),
2871+
scale_factor=args_bounds_check(args, 2),
2872+
mode="nearest",
28392873
align_corners=False,
28402874
)
28412875

28422876

2877+
@dynamo_tensorrt_converter(torch.ops.aten.upsample_linear1d.default)
28432878
@dynamo_tensorrt_converter(torch.ops.aten.upsample_bilinear2d.default)
2879+
@dynamo_tensorrt_converter(torch.ops.aten.upsample_trilinear3d.default)
2880+
@enforce_tensor_types(
2881+
{
2882+
0: (TRTTensor,),
2883+
}
2884+
)
2885+
def aten_ops_upsample_linear_default(
2886+
ctx: ConversionContext,
2887+
target: Target,
2888+
args: Tuple[Argument, ...],
2889+
kwargs: Dict[str, Argument],
2890+
name: str,
2891+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
2892+
return impl.upsample.upsample(
2893+
ctx,
2894+
target,
2895+
SourceIR.ATEN,
2896+
name,
2897+
args[0],
2898+
size=args[1],
2899+
scale_factor=None,
2900+
mode="linear",
2901+
align_corners=args[2],
2902+
)
2903+
2904+
2905+
@dynamo_tensorrt_converter(torch.ops.aten.upsample_linear1d.vec)
28442906
@dynamo_tensorrt_converter(torch.ops.aten.upsample_bilinear2d.vec)
2845-
def upsample_bilinear2d(
2907+
@dynamo_tensorrt_converter(torch.ops.aten.upsample_trilinear3d.vec)
2908+
@enforce_tensor_types(
2909+
{
2910+
0: (TRTTensor,),
2911+
}
2912+
)
2913+
def aten_ops_upsample_linear_vec(
28462914
ctx: ConversionContext,
28472915
target: Target,
28482916
args: Tuple[Argument, ...],
@@ -2854,11 +2922,63 @@ def upsample_bilinear2d(
28542922
target,
28552923
SourceIR.ATEN,
28562924
name,
2857-
input=args[0],
2858-
out_shape=args_bounds_check(args, 1),
2859-
scale_factors=args_bounds_check(args, 3),
2860-
resize_mode="bilinear",
2861-
align_corners=args_bounds_check(args, 2),
2925+
args[0],
2926+
size=args_bounds_check(args, 1),
2927+
scale_factor=args_bounds_check(args, 3),
2928+
mode="linear",
2929+
align_corners=args[2],
2930+
)
2931+
2932+
2933+
@dynamo_tensorrt_converter(torch.ops.aten.upsample_bicubic2d.default)
2934+
@enforce_tensor_types(
2935+
{
2936+
0: (TRTTensor,),
2937+
}
2938+
)
2939+
def aten_ops_upsample_bicubic_default(
2940+
ctx: ConversionContext,
2941+
target: Target,
2942+
args: Tuple[Argument, ...],
2943+
kwargs: Dict[str, Argument],
2944+
name: str,
2945+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
2946+
return impl.upsample.upsample(
2947+
ctx,
2948+
target,
2949+
SourceIR.ATEN,
2950+
name,
2951+
args[0],
2952+
size=args[1],
2953+
scale_factor=None,
2954+
mode="bicubic",
2955+
align_corners=args[2],
2956+
)
2957+
2958+
2959+
@dynamo_tensorrt_converter(torch.ops.aten.upsample_bicubic2d.vec)
2960+
@enforce_tensor_types(
2961+
{
2962+
0: (TRTTensor,),
2963+
}
2964+
)
2965+
def aten_ops_upsample_bicubic_vec(
2966+
ctx: ConversionContext,
2967+
target: Target,
2968+
args: Tuple[Argument, ...],
2969+
kwargs: Dict[str, Argument],
2970+
name: str,
2971+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
2972+
return impl.upsample.upsample(
2973+
ctx,
2974+
target,
2975+
SourceIR.ATEN,
2976+
name,
2977+
args[0],
2978+
size=args_bounds_check(args, 1),
2979+
scale_factor=args_bounds_check(args, 3),
2980+
mode="bicubic",
2981+
align_corners=args[2],
28622982
)
28632983

28642984

py/torch_tensorrt/dynamo/conversion/impl/upsample.py

Lines changed: 25 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -14,54 +14,35 @@ def upsample(
1414
source_ir: Optional[SourceIR],
1515
name: str,
1616
input: TRTTensor,
17-
out_shape: Optional[Sequence[int]],
18-
scale_factors: Optional[Sequence[float]],
19-
resize_mode: str,
17+
size: Optional[Sequence[int]],
18+
scale_factor: Optional[Sequence[float]],
19+
mode: str,
2020
align_corners: bool,
2121
) -> TRTTensor:
22-
resize_layer = ctx.net.add_resize(input)
23-
# output size calculation
24-
# Pytorch assumes that one of out_shape/scale_factor is None
25-
# Pytorch assumes that dimensions match for out_shape/scale factor
26-
if out_shape is not None:
27-
resize_layer.shape = list(input.shape)[:2] + list(out_shape)
28-
elif scale_factors is not None:
29-
resize_layer.scales = [1.0, 1.0] + list(scale_factors)
30-
else:
31-
raise RuntimeError(
32-
"At least one of out_shape and scale_factors should be specified."
33-
)
22+
layer = ctx.net.add_resize(input)
3423

35-
# interpolate mode
36-
if resize_mode == "nearest" or None:
37-
resize_layer.resize_mode = trt.InterpolationMode.NEAREST
38-
elif resize_mode == "bilinear":
39-
resize_layer.resize_mode = trt.InterpolationMode.LINEAR
40-
if align_corners is None or not align_corners:
41-
raise RuntimeError(
42-
f"Interpolation works differently is align_corners is False for {resize_mode} mode in PyTorch and TensorRT."
43-
)
24+
if size is not None:
25+
layer.shape = list(input.shape)[:2] + list(size)
4426
else:
45-
raise RuntimeError(
46-
f"Interpolation mode is {resize_mode} which is not supported by TensorRT."
47-
)
27+
layer.scales = [1.0, 1.0] + list(scale_factor)
4828

49-
if resize_mode == "nearest":
50-
resize_layer.coordinate_transformation = (
51-
trt.ResizeCoordinateTransformation.ASYMMETRIC
29+
if mode == "nearest":
30+
layer.resize_mode = trt.InterpolationMode.NEAREST
31+
layer.coordinate_transformation = trt.ResizeCoordinateTransformation.ASYMMETRIC
32+
elif mode in ("linear", "bilinear", "trilinear"):
33+
layer.resize_mode = trt.InterpolationMode.LINEAR
34+
layer.coordinate_transformation = (
35+
trt.ResizeCoordinateTransformation.ALIGN_CORNERS
36+
if align_corners
37+
else trt.ResizeCoordinateTransformation.HALF_PIXEL
38+
)
39+
elif mode == "bicubic":
40+
layer.resize_mode = trt.InterpolationMode.CUBIC
41+
layer.coordinate_transformation = (
42+
trt.ResizeCoordinateTransformation.ALIGN_CORNERS
43+
if align_corners
44+
else trt.ResizeCoordinateTransformation.HALF_PIXEL
5245
)
53-
elif resize_mode == "bilinear":
54-
# align corners
55-
if align_corners is not None and align_corners:
56-
resize_layer.coordinate_transformation = (
57-
trt.ResizeCoordinateTransformation.ALIGN_CORNERS
58-
)
59-
else:
60-
resize_layer.coordinate_transformation = (
61-
trt.ResizeCoordinateTransformation.ASYMMETRIC
62-
)
63-
64-
set_layer_name(resize_layer, target, name, source_ir)
6546

66-
out = resize_layer.get_output(0)
67-
return out
47+
set_layer_name(layer, target, name, source_ir)
48+
return layer.get_output(0)

py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,6 @@
152152
aten.unfold_backward,
153153
aten.unfold_copy,
154154
aten._unsafe_index,
155-
aten.upsample_bilinear2d,
156-
aten.upsample_bilinear2d.vec,
157155
aten.upsample_nearest2d_backward,
158156
aten.var,
159157
aten.var_mean,

0 commit comments

Comments
 (0)