Skip to content

Fix upsample_bilinear to respect align_corner argument #1254

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jan 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 20 additions & 10 deletions onnxscript/function_libs/torch_lib/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2295,38 +2295,45 @@ def aten_upsample_bicubic2d_backward(
@torch_op("aten::upsample_bilinear2d", trace_only=True)
def aten_upsample_bilinear2d(
self: TReal,
output_size: Optional[INT64] = None,
output_size: INT64,
align_corners: bool,
scales_h: Optional[float] = None,
scales_w: Optional[float] = None,
align_corners: bool = True, # pylint: disable=unused-argument
) -> TReal:
"""upsample_bilinear2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor"""

coordinate_transformation_mode = "align_corners" if align_corners else "pytorch_half_pixel"
if output_size is not None:
result = _aten_upsample_bilinear2d_output_size(self, output_size)
result = _aten_upsample_bilinear2d_output_size(
self, output_size, coordinate_transformation_mode
)
else:
assert scales_h is not None
assert scales_h == scales_w, f"scale_h({scales_h}) != scale_w({scales_w})"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this still valid?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unclear to me, looks like different scale values are not covered in tests.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines 2306 to 2312
Copy link
Collaborator

@justinchuby justinchuby Jan 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we reverse the check - omit output_size when the scales are not None?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have link to spec whether which one takes precedence over the other?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But testing suggested otherwise

result = _aten_upsample_bilinear2d_scales(self, scales_h, scales_w)
result = _aten_upsample_bilinear2d_scales(
self, scales_h, scales_w, coordinate_transformation_mode
)
return result


@torch_op("aten::upsample_bilinear2d.vec", trace_only=True)
def aten_upsample_bilinear2d_vec(
self: TReal,
output_size: Optional[INT64] = None,
align_corners: bool = True,
scale_factors: Optional[Sequence[float]] = None,
output_size: Optional[INT64],
align_corners: bool,
scale_factors: Optional[Sequence[float]],
) -> TReal:
"""upsample_bilinear2d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor"""
scales_h = scale_factors[0] if scale_factors is not None else None
scales_w = scale_factors[1] if scale_factors is not None else None
return aten_upsample_bilinear2d(self, output_size, scales_h, scales_w, align_corners)
return aten_upsample_bilinear2d(self, output_size, align_corners, scales_h, scales_w)


@torch_op("aten::upsample_bilinear2d", private=True)
def _aten_upsample_bilinear2d_output_size(
self: TReal,
output_size: INT64,
coordinate_transformation_mode: str,
) -> TReal:
"""upsample_bilinear2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor"""

Expand All @@ -2341,7 +2348,8 @@ def _aten_upsample_bilinear2d_output_size(
None,
output_size,
mode="linear",
coordinate_transformation_mode="align_corners",
coordinate_transformation_mode=coordinate_transformation_mode,
nearest_mode="floor",
)


Expand All @@ -2350,6 +2358,7 @@ def _aten_upsample_bilinear2d_scales(
self: TReal,
scales_h: float,
scales_w: float,
coordinate_transformation_mode: str,
) -> TReal:
"""upsample_bilinear2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor"""

Expand All @@ -2366,7 +2375,8 @@ def _aten_upsample_bilinear2d_scales(
scales, # format should be: [1.0, 1.0, scale_h, scale_w]
None,
mode="linear",
coordinate_transformation_mode="align_corners",
coordinate_transformation_mode=coordinate_transformation_mode,
nearest_mode="floor",
)


Expand Down
30 changes: 28 additions & 2 deletions onnxscript/tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,9 +418,21 @@ def _sum_input_wrangler(
def _upsample_bilinear2d_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
# Wrangler for the signature difference between
# 'nn.functional.upsample_bilinear'
# and
# 'aten::upsample_bilinear2d'
# https://pytorch.org/docs/stable/generated/torch.nn.functional.upsample_bilinear.html
if "size" in kwargs:
args.append(np.array(kwargs["size"], dtype=np.int64))
del kwargs["size"] # promote tensor type kwargs to args
else:
args.append(None)
if "align_corners" in kwargs:
args.append(kwargs["align_corners"])
del kwargs["align_corners"]
else:
args.append(True) # Fill in the default value
if "scale_factor" in kwargs:
kwargs["scales_h"] = kwargs["scale_factor"]
kwargs["scales_w"] = kwargs["scale_factor"]
Expand All @@ -431,12 +443,26 @@ def _upsample_bilinear2d_input_wrangler(
def _upsample_bilinear2d_vec_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
# Wrangler for the signature difference between
# 'nn.functional.upsample_bilinear'
# and
# 'aten::upsample_bilinear2d.vec'
# https://pytorch.org/docs/stable/generated/torch.nn.functional.upsample_bilinear.html
if "size" in kwargs:
args.append(np.array(kwargs["size"], dtype=np.int64))
del kwargs["size"] # promote tensor type kwargs to args
else:
args.append(None)
if "align_corners" in kwargs:
args.append(kwargs["align_corners"])
del kwargs["align_corners"]
else:
args.append(True) # Fill in the default value
if "scale_factor" in kwargs:
kwargs["scale_factors"] = [kwargs["scale_factor"]] * 2
del kwargs["scale_factor"] # adapt the function signature
args.append([kwargs["scale_factor"]] * 2)
del kwargs["scale_factor"] # promote tensor type kwargs to args
else:
args.append(None)
return args, kwargs


Expand Down