Skip to content

Commit 4de365c

Browse files
authored
Merge branch 'main' into xiaowu/AddOp(upsample_linear1d)
2 parents f7d4a1e + 4a85d3f commit 4de365c

File tree

8 files changed

+282
-136
lines changed

8 files changed

+282
-136
lines changed

noxfile.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
ONNX = "onnx==1.14.1"
3030
ONNX_RUNTIME = "onnxruntime==1.16.1"
3131
PYTORCH = "torch==2.1.0"
32+
TORCHVISON = "torchvision==0.16"
3233
ONNX_RUNTIME_NIGHTLY_DEPENDENCIES = (
3334
"flatbuffers",
3435
"coloredlogs",
@@ -52,6 +53,7 @@ def test(session):
5253
session.install(
5354
*COMMON_TEST_DEPENDENCIES,
5455
PYTORCH,
56+
TORCHVISON,
5557
ONNX,
5658
ONNX_RUNTIME,
5759
)
@@ -78,7 +80,7 @@ def test_torch_nightly(session):
7880
@nox.session(tags=["test-onnx-weekly"])
7981
def test_onnx_weekly(session):
8082
"""Test with ONNX weekly (preview) build."""
81-
session.install(*COMMON_TEST_DEPENDENCIES, ONNX_RUNTIME, PYTORCH)
83+
session.install(*COMMON_TEST_DEPENDENCIES, ONNX_RUNTIME, PYTORCH, TORCHVISON)
8284
session.install("-r", "requirements/ci/requirements-onnx-weekly.txt")
8385
session.install(".", "--no-deps")
8486
session.run("pip", "list")
@@ -89,7 +91,11 @@ def test_onnx_weekly(session):
8991
def test_ort_nightly(session):
9092
"""Test with ONNX Runtime nightly builds."""
9193
session.install(
92-
*COMMON_TEST_DEPENDENCIES, PYTORCH, ONNX, *ONNX_RUNTIME_NIGHTLY_DEPENDENCIES
94+
*COMMON_TEST_DEPENDENCIES,
95+
PYTORCH,
96+
TORCHVISON,
97+
ONNX,
98+
*ONNX_RUNTIME_NIGHTLY_DEPENDENCIES,
9399
)
94100
session.install("-r", "requirements/ci/requirements-ort-nightly.txt")
95101
session.install(".", "--no-deps")
@@ -101,7 +107,11 @@ def test_ort_nightly(session):
101107
def test_experimental_torchlib_tracing(session):
102108
"""Test TorchLib with the experimental TORCHLIB_EXPERIMENTAL_PREFER_TRACING flag on."""
103109
session.install(
104-
*COMMON_TEST_DEPENDENCIES, PYTORCH, ONNX, *ONNX_RUNTIME_NIGHTLY_DEPENDENCIES
110+
*COMMON_TEST_DEPENDENCIES,
111+
PYTORCH,
112+
TORCHVISON,
113+
ONNX,
114+
*ONNX_RUNTIME_NIGHTLY_DEPENDENCIES,
105115
)
106116
session.install("-r", "requirements/ci/requirements-ort-nightly.txt")
107117
session.install(".", "--no-deps")

onnxscript/function_libs/torch_lib/ops/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
"prims",
88
"sparse",
99
"special",
10+
"vision",
1011
]
1112

12-
from . import core, fft, linalg, nested, nn, prims, sparse, special
13+
from . import core, fft, linalg, nested, nn, prims, sparse, special, vision

onnxscript/function_libs/torch_lib/ops/nn.py

Lines changed: 99 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -2197,85 +2197,99 @@ def aten_unflatten_dense_tensors(
21972197
raise NotImplementedError()
21982198

21992199

2200-
@torch_op(("aten::upsample_bicubic2d", "aten::upsample_bicubic2d.vec"), trace_only=True)
2201-
def aten_upsample_bicubic2d(
2202-
self: TReal,
2203-
output_size: INT64,
2204-
align_corners: bool,
2205-
scale_factors: Optional[TFloat] = None,
2206-
) -> TReal:
2207-
"""upsample_bicubic2d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor
2208-
upsample_bicubic2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor
2209-
"""
2210-
2211-
if output_size is not None:
2212-
result = _aten_upsample_output_size(self, output_size, align_corners, "cubic")
2213-
else:
2214-
result = _aten_upsample_scales(self, scale_factors, align_corners, "cubic")
2215-
return result
2200+
def _get_upsample_align_corners_mode(align_corners: bool) -> str:
2201+
return "align_corners" if align_corners else "pytorch_half_pixel"
22162202

22172203

2218-
@torch_op("aten::upsample_bicubic2d", private=True)
2204+
@torch_op(("aten::upsample_bicubic2d", "aten::upsample_bilinear2d"), private=True)
22192205
def _aten_upsample_output_size(
22202206
self: TReal,
22212207
output_size: INT64,
2222-
align_corners: bool,
2223-
str_mode: str,
2208+
mode: str,
2209+
coordinate_transformation_mode: str,
22242210
) -> TReal:
22252211
self_shape = op.Shape(self)
22262212
starts = op.Constant(value_ints=[0])
22272213
ends = op.Constant(value_ints=[2])
22282214
batch_channel = op.Slice(self_shape, starts, ends)
22292215
output_size = op.Concat(batch_channel, output_size, axis=0)
2230-
if align_corners:
2231-
result = op.Resize(
2232-
self,
2233-
None,
2234-
None,
2235-
output_size,
2236-
mode=str_mode,
2237-
coordinate_transformation_mode="align_corners",
2238-
)
2239-
else:
2240-
result = op.Resize(
2241-
self,
2242-
None,
2243-
None,
2244-
output_size,
2245-
mode=str_mode,
2246-
coordinate_transformation_mode="pytorch_half_pixel",
2247-
)
2248-
2249-
return result
2216+
return op.Resize(
2217+
self,
2218+
None,
2219+
None,
2220+
output_size,
2221+
mode=mode,
2222+
coordinate_transformation_mode=coordinate_transformation_mode,
2223+
nearest_mode="floor",
2224+
)
22502225

22512226

2252-
@torch_op("aten::upsample_bicubic2d", private=True)
2227+
@torch_op(("aten::upsample_bicubic2d", "aten::upsample_bilinear2d"), private=True)
22532228
def _aten_upsample_scales(
22542229
self: TReal,
22552230
scale_factors: TFloat,
2256-
align_corners: bool,
2257-
str_mode: str,
2231+
mode: str,
2232+
coordinate_transformation_mode: str,
22582233
) -> TReal:
22592234
scale_factors = op.Cast(scale_factors, to=FLOAT.dtype)
22602235
scale_factors = op.Concat(op.Constant(value_floats=[1.0, 1.0]), scale_factors, axis=0)
2261-
if align_corners:
2262-
result = op.Resize(
2236+
return op.Resize(
2237+
self,
2238+
None,
2239+
scale_factors, # format should be: [1.0, 1.0, scale_h, scale_w]
2240+
None,
2241+
mode=mode,
2242+
coordinate_transformation_mode=coordinate_transformation_mode,
2243+
nearest_mode="floor",
2244+
)
2245+
2246+
2247+
@torch_op("aten::upsample_bicubic2d", trace_only=True)
2248+
def aten_upsample_bicubic2d(
2249+
self: TReal,
2250+
output_size: INT64,
2251+
align_corners: bool,
2252+
scales_h: Optional[float] = None,
2253+
scales_w: Optional[float] = None,
2254+
) -> TReal:
2255+
"""upsample_bicubic2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor"""
2256+
2257+
# NOTE: Based on experimentation, scales_h and scales_w are always ignored in PyTorch,
2258+
# unless when align_corners is True, in which case we do not know what is going on.
2259+
coordinate_transformation_mode = _get_upsample_align_corners_mode(align_corners)
2260+
return _aten_upsample_output_size(
2261+
self,
2262+
output_size,
2263+
mode="cubic",
2264+
coordinate_transformation_mode=coordinate_transformation_mode,
2265+
)
2266+
2267+
2268+
@torch_op("aten::upsample_bicubic2d.vec", trace_only=True)
2269+
def aten_upsample_bicubic2d_vec(
2270+
self: TReal,
2271+
output_size: INT64,
2272+
align_corners: bool,
2273+
scale_factors: Optional[Sequence[float]],
2274+
) -> TReal:
2275+
"""upsample_bicubic2d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor"""
2276+
2277+
coordinate_transformation_mode = _get_upsample_align_corners_mode(align_corners)
2278+
if scale_factors is not None:
2279+
result = _aten_upsample_scales(
22632280
self,
2264-
None,
2265-
scale_factors, # format should be: [1.0, 1.0, scale_h, scale_w]
2266-
None,
2267-
mode=str_mode,
2268-
coordinate_transformation_mode="align_corners",
2281+
op.Constant(value_floats=scale_factors),
2282+
mode="cubic",
2283+
coordinate_transformation_mode=coordinate_transformation_mode,
22692284
)
22702285
else:
2271-
result = op.Resize(
2286+
result = _aten_upsample_output_size(
22722287
self,
2273-
None,
2274-
scale_factors, # format should be: [1.0, 1.0, scale_h, scale_w]
2275-
None,
2276-
mode=str_mode,
2277-
coordinate_transformation_mode="pytorch_half_pixel",
2288+
output_size,
2289+
mode="cubic",
2290+
coordinate_transformation_mode=coordinate_transformation_mode,
22782291
)
2292+
22792293
return result
22802294

22812295

@@ -2295,67 +2309,50 @@ def aten_upsample_bicubic2d_backward(
22952309
@torch_op("aten::upsample_bilinear2d", trace_only=True)
22962310
def aten_upsample_bilinear2d(
22972311
self: TReal,
2298-
output_size: Optional[INT64] = None,
2312+
output_size: INT64,
2313+
align_corners: bool,
22992314
scales_h: Optional[float] = None,
23002315
scales_w: Optional[float] = None,
2301-
align_corners: bool = True, # pylint: disable=unused-argument
2302-
) -> TReal:
2303-
"""upsample_bilinear2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor"""
2304-
2305-
if output_size is not None:
2306-
result = _aten_upsample_bilinear2d_output_size(self, output_size)
2307-
else:
2308-
assert scales_h is not None
2309-
assert scales_h == scales_w
2310-
result = _aten_upsample_bilinear2d_scales(self, scales_h, scales_w)
2311-
return result
2312-
2313-
2314-
@torch_op("aten::upsample_bilinear2d", private=True)
2315-
def _aten_upsample_bilinear2d_output_size(
2316-
self: TReal,
2317-
output_size: INT64,
23182316
) -> TReal:
23192317
"""upsample_bilinear2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor"""
23202318

2321-
self_shape = op.Shape(self)
2322-
starts = op.Constant(value_ints=[0])
2323-
ends = op.Constant(value_ints=[2])
2324-
batch_channel = op.Slice(self_shape, starts, ends)
2325-
output_size = op.Concat(batch_channel, output_size, axis=0)
2326-
return op.Resize(
2319+
# NOTE: Based on experimentation, scales_h and scales_w are always ignored in PyTorch,
2320+
# unless when align_corners is True, in which case we do not know what is going on.
2321+
coordinate_transformation_mode = _get_upsample_align_corners_mode(align_corners)
2322+
return _aten_upsample_output_size(
23272323
self,
2328-
None,
2329-
None,
23302324
output_size,
2325+
coordinate_transformation_mode=coordinate_transformation_mode,
23312326
mode="linear",
2332-
coordinate_transformation_mode="align_corners",
23332327
)
23342328

23352329

2336-
@torch_op("aten::upsample_bilinear2d", private=True)
2337-
def _aten_upsample_bilinear2d_scales(
2330+
@torch_op("aten::upsample_bilinear2d.vec", trace_only=True)
2331+
def aten_upsample_bilinear2d_vec(
23382332
self: TReal,
2339-
scales_h: float,
2340-
scales_w: float,
2333+
output_size: Optional[INT64],
2334+
align_corners: bool,
2335+
scale_factors: Optional[Sequence[float]],
23412336
) -> TReal:
2342-
"""upsample_bilinear2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor"""
2337+
"""upsample_bilinear2d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor"""
23432338

2344-
neg_1 = op.Constant(value_ints=[-1])
2345-
scales = op.Concat(
2346-
op.Constant(value_floats=[1.0, 1.0]),
2347-
op.Reshape(op.Constant(value_float=scales_h), neg_1),
2348-
op.Reshape(op.Constant(value_float=scales_w), neg_1),
2349-
axis=0,
2350-
)
2351-
return op.Resize(
2352-
self,
2353-
None,
2354-
scales, # format should be: [1.0, 1.0, scale_h, scale_w]
2355-
None,
2356-
mode="linear",
2357-
coordinate_transformation_mode="align_corners",
2358-
)
2339+
coordinate_transformation_mode = _get_upsample_align_corners_mode(align_corners)
2340+
if scale_factors is not None:
2341+
result = _aten_upsample_scales(
2342+
self,
2343+
op.Constant(value_floats=scale_factors),
2344+
mode="linear",
2345+
coordinate_transformation_mode=coordinate_transformation_mode,
2346+
)
2347+
else:
2348+
result = _aten_upsample_output_size(
2349+
self,
2350+
output_size,
2351+
mode="linear",
2352+
coordinate_transformation_mode=coordinate_transformation_mode,
2353+
)
2354+
2355+
return result
23592356

23602357

23612358
def aten_upsample_bilinear2d_backward(
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# --------------------------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# Licensed under the MIT License.
4+
# --------------------------------------------------------------------------
5+
# mypy: disable-error-code="misc,arg-type,type-arg,valid-type,assignment,return-value"
6+
"""torchvision operators."""
7+
from __future__ import annotations
8+
9+
from onnxscript.function_libs.torch_lib.registration import torch_op
10+
from onnxscript.onnx_opset import opset18 as op
11+
from onnxscript.onnx_types import FLOAT, INT64
12+
13+
_INT64_MAX = 0x7FFFFFFFFFFFFFFF
14+
15+
16+
@torch_op("torchvision::nms")
17+
def torchvision_nms(boxes: FLOAT, scores: FLOAT, iou_threshold: float) -> INT64:
18+
# boxes: [num_batches, spatial_dimension, 4]
19+
boxes = op.Unsqueeze(boxes, [0])
20+
# scores: [num_batches, num_classes, spatial_dimension]
21+
scores = op.Unsqueeze(scores, [0, 1])
22+
# nms_out: [num_selected_indices, 3] where each column is [batch_index, class_index, box_index]
23+
nms_out = op.NonMaxSuppression(boxes, scores, _INT64_MAX, iou_threshold)
24+
return op.Reshape(op.Slice(nms_out, axes=[1], starts=[2], ends=[3]), [-1])

0 commit comments

Comments
 (0)