From 04ccd061e9e6006649adf06264c17b7e9dba73ad Mon Sep 17 00:00:00 2001 From: BowenBao Date: Tue, 16 Jan 2024 11:04:45 -0800 Subject: [PATCH 1/6] Fix upsample_bilinear to respect align_corner argument --- onnxscript/function_libs/torch_lib/ops/nn.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index e9ee699aa3..d47ebb2a12 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -2302,12 +2302,17 @@ def aten_upsample_bilinear2d( ) -> TReal: """upsample_bilinear2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor""" + coordinate_transformation_mode = "align_corners" if align_corners else "pytorch_half_pixel" if output_size is not None: - result = _aten_upsample_bilinear2d_output_size(self, output_size) + result = _aten_upsample_bilinear2d_output_size( + self, output_size, coordinate_transformation_mode + ) else: assert scales_h is not None assert scales_h == scales_w, f"scale_h({scales_h}) != scale_w({scales_w})" - result = _aten_upsample_bilinear2d_scales(self, scales_h, scales_w) + result = _aten_upsample_bilinear2d_scales( + self, scales_h, scales_w, coordinate_transformation_mode + ) return result @@ -2327,6 +2332,7 @@ def aten_upsample_bilinear2d_vec( def _aten_upsample_bilinear2d_output_size( self: TReal, output_size: INT64, + coordinate_transformation_mode: str, ) -> TReal: """upsample_bilinear2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor""" @@ -2341,7 +2347,8 @@ def _aten_upsample_bilinear2d_output_size( None, output_size, mode="linear", - coordinate_transformation_mode="align_corners", + coordinate_transformation_mode=coordinate_transformation_mode, + nearest_mode="floor", ) @@ -2350,6 +2357,7 @@ def _aten_upsample_bilinear2d_scales( self: TReal, scales_h: float, scales_w: float, + coordinate_transformation_mode: str, ) -> TReal: """upsample_bilinear2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor""" @@ -2366,7 +2374,8 @@ def _aten_upsample_bilinear2d_scales( scales, # format should be: [1.0, 1.0, scale_h, scale_w] None, mode="linear", - coordinate_transformation_mode="align_corners", + coordinate_transformation_mode=coordinate_transformation_mode, + nearest_mode="floor", ) From 60c9b97ceec30ac45a592eac1185d9f01b9f48b2 Mon Sep 17 00:00:00 2001 From: BowenBao Date: Tue, 16 Jan 2024 11:18:33 -0800 Subject: [PATCH 2/6] Remove 'unused-argument' tag --- onnxscript/function_libs/torch_lib/ops/nn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index d47ebb2a12..37f8020330 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -2298,7 +2298,7 @@ def aten_upsample_bilinear2d( output_size: Optional[INT64] = None, scales_h: Optional[float] = None, scales_w: Optional[float] = None, - align_corners: bool = True, # pylint: disable=unused-argument + align_corners: bool = True, ) -> TReal: """upsample_bilinear2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor""" From 2cb42450627740e215b81ec4ffbdb5a16f89b411 Mon Sep 17 00:00:00 2001 From: BowenBao Date: Tue, 16 Jan 2024 11:28:23 -0800 Subject: [PATCH 3/6] Fix align_corner default value --- onnxscript/function_libs/torch_lib/ops/nn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index 37f8020330..af7c9aff66 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -2298,7 +2298,7 @@ def aten_upsample_bilinear2d( output_size: Optional[INT64] = None, scales_h: Optional[float] = None, scales_w: Optional[float] = None, - align_corners: bool = True, + align_corners: bool = False, ) -> TReal: """upsample_bilinear2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor""" @@ -2320,7 +2320,7 @@ def aten_upsample_bilinear2d( def aten_upsample_bilinear2d_vec( self: TReal, output_size: Optional[INT64] = None, - align_corners: bool = True, + align_corners: bool = False, scale_factors: Optional[Sequence[float]] = None, ) -> TReal: scales_h = scale_factors[0] if scale_factors is not None else None From 7d14542ed331a9d67a727a6b44fcb73de59d8d78 Mon Sep 17 00:00:00 2001 From: BowenBao Date: Tue, 16 Jan 2024 12:24:24 -0800 Subject: [PATCH 4/6] Fix signature to match aten spec; Fix wrangler to bridge from source function torch.nn.functional.upsample_bilinear --- onnxscript/function_libs/torch_lib/ops/nn.py | 13 ++++---- .../function_libs/torch_lib/ops_test_data.py | 30 +++++++++++++++++-- 2 files changed, 35 insertions(+), 8 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index af7c9aff66..59421625a3 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -2295,10 +2295,10 @@ def aten_upsample_bicubic2d_backward( @torch_op("aten::upsample_bilinear2d", trace_only=True) def aten_upsample_bilinear2d( self: TReal, - output_size: Optional[INT64] = None, + output_size: Optional[INT64], + align_corners: bool, scales_h: Optional[float] = None, scales_w: Optional[float] = None, - align_corners: bool = False, ) -> TReal: """upsample_bilinear2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor""" @@ -2319,13 +2319,14 @@ def aten_upsample_bilinear2d( @torch_op("aten::upsample_bilinear2d.vec", trace_only=True) def aten_upsample_bilinear2d_vec( self: TReal, - output_size: Optional[INT64] = None, - align_corners: bool = False, - scale_factors: Optional[Sequence[float]] = None, + output_size: Optional[INT64], + align_corners: bool, + scale_factors: Optional[Sequence[float]], ) -> TReal: + """upsample_bilinear2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor""" scales_h = scale_factors[0] if scale_factors is not None else None scales_w = scale_factors[1] if scale_factors is not None else None - return aten_upsample_bilinear2d(self, output_size, scales_h, scales_w, align_corners) + return aten_upsample_bilinear2d(self, output_size, align_corners, scales_h, scales_w) @torch_op("aten::upsample_bilinear2d", private=True) diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index cabc13268a..affc2997dc 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -417,9 +417,21 @@ def _sum_input_wrangler( def _upsample_bilinear2d_input_wrangler( args: list[Any], kwargs: dict[str, Any] ) -> tuple[list[Any], dict[str, Any]]: + # Wrangler for the signature difference between + # 'nn.functional.upsample_bilinear' + # and + # 'aten::upsample_bilinear2d' + # https://pytorch.org/docs/stable/generated/torch.nn.functional.upsample_bilinear.html if "size" in kwargs: args.append(np.array(kwargs["size"], dtype=np.int64)) del kwargs["size"] # promote tensor type kwargs to args + else: + args.append(None) + if "align_corners" in kwargs: + args.append(kwargs["align_corners"]) + del kwargs["align_corners"] + else: + args.append(True) # Fill in the default value if "scale_factor" in kwargs: kwargs["scales_h"] = kwargs["scale_factor"] kwargs["scales_w"] = kwargs["scale_factor"] @@ -430,12 +442,26 @@ def _upsample_bilinear2d_input_wrangler( def _upsample_bilinear2d_vec_input_wrangler( args: list[Any], kwargs: dict[str, Any] ) -> tuple[list[Any], dict[str, Any]]: + # Wrangler for the signature difference between + # 'nn.functional.upsample_bilinear' + # and + # 'aten::upsample_bilinear2d.vec' + # https://pytorch.org/docs/stable/generated/torch.nn.functional.upsample_bilinear.html if "size" in kwargs: args.append(np.array(kwargs["size"], dtype=np.int64)) del kwargs["size"] # promote tensor type kwargs to args + else: + args.append(None) + if "align_corners" in kwargs: + args.append(kwargs["align_corners"]) + del kwargs["align_corners"] + else: + args.append(True) # Fill in the default value if "scale_factor" in kwargs: - kwargs["scale_factors"] = [kwargs["scale_factor"]] * 2 - del kwargs["scale_factor"] # adapt the function signature + args.append([kwargs["scale_factor"]] * 2) + del kwargs["scale_factor"] # promote tensor type kwargs to args + else: + args.append(None) return args, kwargs From 0ba5f40bc0f9b30fd954e4bfa01d31710c900fa1 Mon Sep 17 00:00:00 2001 From: BowenBao Date: Tue, 16 Jan 2024 12:31:40 -0800 Subject: [PATCH 5/6] minor fix on signature comment --- onnxscript/function_libs/torch_lib/ops/nn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index 59421625a3..efd1fc80d3 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -2323,7 +2323,7 @@ def aten_upsample_bilinear2d_vec( align_corners: bool, scale_factors: Optional[Sequence[float]], ) -> TReal: - """upsample_bilinear2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor""" + """upsample_bilinear2d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor""" scales_h = scale_factors[0] if scale_factors is not None else None scales_w = scale_factors[1] if scale_factors is not None else None return aten_upsample_bilinear2d(self, output_size, align_corners, scales_h, scales_w) From ab8ec04e43c2acff00a80d9d3e777a9e556b2125 Mon Sep 17 00:00:00 2001 From: Bowen Bao Date: Tue, 16 Jan 2024 13:56:09 -0800 Subject: [PATCH 6/6] Update onnxscript/function_libs/torch_lib/ops/nn.py Co-authored-by: Justin Chu --- onnxscript/function_libs/torch_lib/ops/nn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index efd1fc80d3..9a105482c5 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -2295,7 +2295,7 @@ def aten_upsample_bicubic2d_backward( @torch_op("aten::upsample_bilinear2d", trace_only=True) def aten_upsample_bilinear2d( self: TReal, - output_size: Optional[INT64], + output_size: INT64, align_corners: bool, scales_h: Optional[float] = None, scales_w: Optional[float] = None,