Skip to content

Commit 1231cc0

Browse files
AddOp(upsample_bicubic2d) | feat(torchlib) (#1208)
Co-authored-by: Justin Chu <[email protected]>
1 parent 1fa1ed6 commit 1231cc0

File tree

3 files changed

+139
-6
lines changed

3 files changed

+139
-6
lines changed

onnxscript/function_libs/torch_lib/ops/nn.py

Lines changed: 76 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2197,16 +2197,86 @@ def aten_unflatten_dense_tensors(
21972197
raise NotImplementedError()
21982198

21992199

2200+
@torch_op(("aten::upsample_bicubic2d", "aten::upsample_bicubic2d.vec"), trace_only=True)
22002201
def aten_upsample_bicubic2d(
2201-
self: TensorType,
2202+
self: TReal,
22022203
output_size: INT64,
22032204
align_corners: bool,
2204-
scales_h: Optional[float] = None,
2205-
scales_w: Optional[float] = None,
2206-
) -> TensorType:
2207-
"""upsample_bicubic2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor"""
2205+
scale_factors: Optional[TFloat] = None,
2206+
) -> TReal:
2207+
"""upsample_bicubic2d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor
2208+
upsample_bicubic2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor
2209+
"""
22082210

2209-
raise NotImplementedError()
2211+
if output_size is not None:
2212+
result = _aten_upsample_output_size(self, output_size, align_corners, "cubic")
2213+
else:
2214+
result = _aten_upsample_scales(self, scale_factors, align_corners, "cubic")
2215+
return result
2216+
2217+
2218+
@torch_op("aten::upsample_bicubic2d", private=True)
2219+
def _aten_upsample_output_size(
2220+
self: TReal,
2221+
output_size: INT64,
2222+
align_corners: bool,
2223+
str_mode: str,
2224+
) -> TReal:
2225+
self_shape = op.Shape(self)
2226+
starts = op.Constant(value_ints=[0])
2227+
ends = op.Constant(value_ints=[2])
2228+
batch_channel = op.Slice(self_shape, starts, ends)
2229+
output_size = op.Concat(batch_channel, output_size, axis=0)
2230+
if align_corners:
2231+
result = op.Resize(
2232+
self,
2233+
None,
2234+
None,
2235+
output_size,
2236+
mode=str_mode,
2237+
coordinate_transformation_mode="align_corners",
2238+
)
2239+
else:
2240+
result = op.Resize(
2241+
self,
2242+
None,
2243+
None,
2244+
output_size,
2245+
mode=str_mode,
2246+
coordinate_transformation_mode="pytorch_half_pixel",
2247+
)
2248+
2249+
return result
2250+
2251+
2252+
@torch_op("aten::upsample_bicubic2d", private=True)
2253+
def _aten_upsample_scales(
2254+
self: TReal,
2255+
scale_factors: TFloat,
2256+
align_corners: bool,
2257+
str_mode: str,
2258+
) -> TReal:
2259+
scale_factors = op.Cast(scale_factors, to=FLOAT.dtype)
2260+
scale_factors = op.Concat(op.Constant(value_floats=[1.0, 1.0]), scale_factors, axis=0)
2261+
if align_corners:
2262+
result = op.Resize(
2263+
self,
2264+
None,
2265+
scale_factors, # format should be: [1.0, 1.0, scale_h, scale_w]
2266+
None,
2267+
mode=str_mode,
2268+
coordinate_transformation_mode="align_corners",
2269+
)
2270+
else:
2271+
result = op.Resize(
2272+
self,
2273+
None,
2274+
scale_factors, # format should be: [1.0, 1.0, scale_h, scale_w]
2275+
None,
2276+
mode=str_mode,
2277+
coordinate_transformation_mode="pytorch_half_pixel",
2278+
)
2279+
return result
22102280

22112281

22122282
def aten_upsample_bicubic2d_backward(

onnxscript/tests/function_libs/torch_lib/extra_opinfo.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1409,6 +1409,57 @@ def sample_inputs_unfold(op_info, device, dtype, requires_grad, **kwargs):
14091409
yield opinfo_core.SampleInput(t, args=(dimension, size, step))
14101410

14111411

1412+
def sample_inputs_upsample_bicubic2d(op_info, device, dtype, requires_grad, **kwargs):
1413+
del op_info
1414+
del kwargs
1415+
1416+
N, C = 2, 3
1417+
D = 4
1418+
SS = 3
1419+
L = 5
1420+
1421+
align_corners_options = (True, False)
1422+
rank = 2
1423+
1424+
def shape(size, rank, with_batch_channel=True):
1425+
if with_batch_channel:
1426+
return tuple([N, C] + ([size] * rank))
1427+
return tuple([size] * rank)
1428+
1429+
make_arg = functools.partial(
1430+
torch_testing.make_tensor,
1431+
device=device,
1432+
dtype=dtype,
1433+
requires_grad=requires_grad,
1434+
low=-1,
1435+
high=1,
1436+
)
1437+
1438+
yield opinfo_core.SampleInput(make_arg(shape(D, rank)), shape(SS, rank, False), True)
1439+
1440+
for align_corners in align_corners_options:
1441+
yield opinfo_core.SampleInput(
1442+
make_arg(shape(D, rank)), shape(S, rank, False), align_corners
1443+
)
1444+
yield opinfo_core.SampleInput(
1445+
make_arg(shape(D, rank)),
1446+
shape(L, rank, False),
1447+
align_corners,
1448+
)
1449+
yield opinfo_core.SampleInput(
1450+
make_arg(shape(D, rank)),
1451+
None, # output_size
1452+
align_corners,
1453+
(1.7, 1.7), # scaler
1454+
)
1455+
yield opinfo_core.SampleInput(
1456+
make_arg(shape(D, rank)),
1457+
None, # if this is None, the scalar must be list
1458+
align_corners,
1459+
(0.6, 0.6),
1460+
)
1461+
1462+
14121463
class _TestParamsMaxPoolEmptyStrideBase:
14131464
# Adapted from https://github.com/pytorch/pytorch/blob/d6d55f8590eab05d2536756fb4efcfb2d07eb81a/torch/testing/_internal/common_methods_invocations.py#L3203
14141465
def __init__(self):
@@ -1874,6 +1925,13 @@ def __init__(self):
18741925
sample_inputs_func=sample_inputs_unfold,
18751926
supports_out=False,
18761927
),
1928+
opinfo_core.OpInfo(
1929+
"ops.aten.upsample_bicubic2d",
1930+
aten_name="upsample_bicubic2d",
1931+
dtypes=common_dtype.floating_types_and(torch.bfloat16),
1932+
sample_inputs_func=sample_inputs_upsample_bicubic2d,
1933+
supports_out=False,
1934+
),
18771935
opinfo_core.OpInfo(
18781936
"nn.functional.max_pool1d_with_indices",
18791937
aten_name="max_pool1d_with_indices",

onnxscript/tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2122,6 +2122,11 @@ def _where_input_wrangler(
21222122
input_wrangler=_upsample_bilinear2d_input_wrangler,
21232123
trace_only=True,
21242124
),
2125+
TorchLibOpInfo(
2126+
"ops.aten.upsample_bicubic2d",
2127+
nn_ops.aten_upsample_bicubic2d,
2128+
trace_only=True,
2129+
),
21252130
TorchLibOpInfo(
21262131
"nn.functional.upsample_nearest2d",
21272132
nn_ops.aten_upsample_nearest2d,

0 commit comments

Comments
 (0)