Skip to content

Commit 8785231

Browse files
Add op(upsample linear1d) | feat(torchlib) (#1245)
Co-authored-by: Justin Chu <[email protected]>
1 parent 4a85d3f commit 8785231

File tree

3 files changed

+70
-11
lines changed

3 files changed

+70
-11
lines changed

onnxscript/function_libs/torch_lib/ops/nn.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2368,12 +2368,20 @@ def aten_upsample_bilinear2d_backward(
23682368
raise NotImplementedError()
23692369

23702370

2371+
@torch_op("aten::upsample_linear1d", trace_only=True)
23712372
def aten_upsample_linear1d(
2372-
self: TensorType, output_size: INT64, align_corners: bool, scales: Optional[float] = None
2373-
) -> TensorType:
2373+
self: TReal, output_size: INT64, align_corners: bool, scales: Optional[float] = None
2374+
) -> TReal:
23742375
"""upsample_linear1d(Tensor self, SymInt[1] output_size, bool align_corners, float? scales=None) -> Tensor"""
2375-
2376-
raise NotImplementedError()
2376+
# FIXME(justinchuby): Support when scales is provided and align_corners is False
2377+
del scales
2378+
coordinate_transformation_mode = _get_upsample_align_corners_mode(align_corners)
2379+
return _aten_upsample_output_size(
2380+
self,
2381+
output_size,
2382+
mode="linear",
2383+
coordinate_transformation_mode=coordinate_transformation_mode,
2384+
)
23772385

23782386

23792387
def aten_upsample_linear1d_backward(

onnxscript/tests/function_libs/torch_lib/extra_opinfo.py

Lines changed: 49 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1464,9 +1464,7 @@ def shape(size, rank, with_batch_channel=True):
14641464
make_arg(shape(D, rank)), shape(S, rank, False), align_corners
14651465
)
14661466
yield opinfo_core.SampleInput(
1467-
make_arg(shape(D, rank)),
1468-
shape(L, rank, False),
1469-
align_corners,
1467+
make_arg(shape(D, rank)), shape(L, rank, False), align_corners
14701468
)
14711469
yield opinfo_core.SampleInput(
14721470
make_arg(shape(D, rank)),
@@ -1513,10 +1511,7 @@ def shape(size, rank, with_batch_channel=True):
15131511
make_arg(shape(D, rank)), shape(S, rank, False), align_corners, None
15141512
)
15151513
yield opinfo_core.SampleInput(
1516-
make_arg(shape(D, rank)),
1517-
shape(L, rank, False),
1518-
align_corners,
1519-
None,
1514+
make_arg(shape(D, rank)), shape(L, rank, False), align_corners, None
15201515
)
15211516
yield opinfo_core.SampleInput(
15221517
make_arg(shape(D, rank)),
@@ -1544,6 +1539,46 @@ def shape(size, rank, with_batch_channel=True):
15441539
)
15451540

15461541

1542+
def sample_inputs_upsample_linear1d(op_info, device, dtype, requires_grad, **kwargs):
1543+
del op_info
1544+
del kwargs
1545+
1546+
N, C = 2, 3
1547+
D = 4
1548+
SS = 3
1549+
L = 5
1550+
1551+
align_corners_options = (True, False)
1552+
rank = 1
1553+
1554+
def shape(size, rank, with_batch_channel=True):
1555+
if with_batch_channel:
1556+
return tuple([N, C] + ([size] * rank))
1557+
return tuple([size] * rank)
1558+
1559+
make_arg = functools.partial(
1560+
torch_testing.make_tensor,
1561+
device=device,
1562+
dtype=dtype,
1563+
requires_grad=requires_grad,
1564+
low=-1,
1565+
high=1,
1566+
)
1567+
1568+
yield opinfo_core.SampleInput(make_arg(shape(D, rank)), shape(SS, rank, False), True)
1569+
1570+
for align_corners in align_corners_options:
1571+
yield opinfo_core.SampleInput(
1572+
make_arg(shape(D, rank)), shape(S, rank, False), align_corners
1573+
)
1574+
yield opinfo_core.SampleInput(
1575+
make_arg(shape(D, rank)), shape(L, rank, False), align_corners
1576+
)
1577+
yield opinfo_core.SampleInput(
1578+
make_arg(shape(D, rank)), shape(L, rank, False), align_corners, scales=4.2
1579+
)
1580+
1581+
15471582
class _TestParamsMaxPoolEmptyStrideBase:
15481583
# Adapted from https://github.com/pytorch/pytorch/blob/d6d55f8590eab05d2536756fb4efcfb2d07eb81a/torch/testing/_internal/common_methods_invocations.py#L3203
15491584
def __init__(self):
@@ -2037,6 +2072,13 @@ def __init__(self):
20372072
sample_inputs_func=sample_inputs_upsample_2d_vec,
20382073
supports_out=False,
20392074
),
2075+
opinfo_core.OpInfo(
2076+
"ops.aten.upsample_linear1d",
2077+
aten_name="upsample_linear1d",
2078+
dtypes=common_dtype.floating_types_and(torch.bfloat16),
2079+
sample_inputs_func=sample_inputs_upsample_linear1d,
2080+
supports_out=False,
2081+
),
20402082
opinfo_core.OpInfo(
20412083
"nn.functional.max_pool1d_with_indices",
20422084
aten_name="max_pool1d_with_indices",

onnxscript/tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2132,6 +2132,15 @@ def _where_input_wrangler(
21322132
nn_ops.aten_upsample_bicubic2d_vec,
21332133
trace_only=True,
21342134
),
2135+
TorchLibOpInfo(
2136+
"ops.aten.upsample_linear1d",
2137+
nn_ops.aten_upsample_linear1d,
2138+
trace_only=True,
2139+
).xfail(
2140+
matcher=lambda sample: sample.args[1] is False
2141+
and sample.kwargs.get("scales") is not None,
2142+
reason="fixme: align_corners=False output mismatch when scales are provided",
2143+
),
21352144
TorchLibOpInfo(
21362145
"nn.functional.upsample_nearest2d",
21372146
nn_ops.aten_upsample_nearest2d,

0 commit comments

Comments
 (0)