Skip to content

Commit b635d24

Browse files
authored
add reshape/view/expand/div/clone/eq/equal (#265)
add reshape, view, expand, div, clone, eq, equal
1 parent dd3d747 commit b635d24

File tree

3 files changed

+51
-19
lines changed

3 files changed

+51
-19
lines changed

onnxscript/function_libs/torch_aten/ops/core.py

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -803,10 +803,13 @@ def aten_clamp_min(self: TReal, min_: TReal) -> TReal:
803803
return result
804804

805805

806-
def aten_clone(self: TensorType, memory_format: Optional[str] = None) -> TensorType:
806+
@torch_op("aten::clone")
807+
def aten_clone(
808+
self: TTensor, memory_format: str = "" # pylint: disable=unused-argument
809+
) -> TTensor:
807810
# clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor
808811

809-
raise NotImplementedError()
812+
return op.Identity(self)
810813

811814

812815
def aten_coalesce(self: TensorType) -> TensorType:
@@ -1406,10 +1409,11 @@ def aten_dist(self: TensorType, other: TensorType, p: float = 2) -> TensorType:
14061409
raise NotImplementedError()
14071410

14081411

1409-
def aten_div(self: TensorType, other: TensorType) -> TensorType:
1412+
@torch_op("aten::div")
1413+
def aten_div(self: TReal, other: TReal) -> TReal:
14101414
# div.Tensor(Tensor self, Tensor other) -> Tensor
14111415

1412-
raise NotImplementedError()
1416+
return op.Div(self, other)
14131417

14141418

14151419
def aten_divide(self: TensorType, other: TensorType) -> TensorType:
@@ -1529,16 +1533,21 @@ def aten_empty_strided(size: INT64, stride: INT64) -> TensorType:
15291533
raise NotImplementedError()
15301534

15311535

1532-
def aten_eq(self: TensorType, other: TensorType) -> TensorType:
1536+
@torch_op("aten::eq")
1537+
def aten_eq(self: TTensor, other: TTensor) -> BOOL:
15331538
# eq.Tensor(Tensor self, Tensor other) -> Tensor
15341539

1535-
raise NotImplementedError()
1540+
return op.Equal(self, other)
15361541

15371542

1538-
def aten_equal(self: TensorType, other: TensorType) -> bool:
1543+
@torch_op("aten::equal")
1544+
def aten_equal(self: TTensor, other: TTensor) -> BOOL:
15391545
# equal(Tensor self, Tensor other) -> bool
15401546

1541-
raise NotImplementedError()
1547+
sub_self_other = op.Sub(self, other)
1548+
abs_sub = op.Abs(sub_self_other)
1549+
sum_of_abs = op.ReduceSum(abs_sub, keepdims=0)
1550+
return op.Equal(sum_of_abs, 0)
15421551

15431552

15441553
@torch_op("aten::erf")
@@ -1576,10 +1585,12 @@ def aten_exp2(self: TFloat) -> TFloat:
15761585
return op.Pow(two, self)
15771586

15781587

1579-
def aten_expand(self: TensorType, size: INT64, implicit: bool = False) -> TensorType:
1588+
@torch_op("aten::expand")
1589+
def aten_expand(self: TTensor, size: INT64) -> TTensor:
15801590
# expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a)
15811591

1582-
raise NotImplementedError()
1592+
size = op.Cast(size, to=INT64.dtype) # to INT64
1593+
return op.Expand(self, size)
15831594

15841595

15851596
def aten_expand_as(self: TensorType, other: TensorType) -> TensorType:
@@ -4046,10 +4057,12 @@ def aten_repeat_interleave(
40464057
raise NotImplementedError()
40474058

40484059

4049-
def aten_reshape(self: TensorType, shape: INT64) -> TensorType:
4060+
@torch_op("aten::reshape")
4061+
def aten_reshape(self: TTensor, shape: INT64) -> TTensor:
40504062
# reshape(Tensor(a) self, SymInt[] shape) -> Tensor(a)
40514063

4052-
raise NotImplementedError()
4064+
shape = op.Cast(shape, to=INT64.dtype) # Reshape only support INT64 as 'shape'
4065+
return op.Reshape(self, shape)
40534066

40544067

40554068
def aten_reshape_as(self: TensorType, other: TensorType) -> TensorType:
@@ -4484,7 +4497,9 @@ def aten_subtract(self: TensorType, other: TensorType, alpha: float = 1) -> Tens
44844497
raise NotImplementedError()
44854498

44864499

4487-
def aten_sum(self: TensorType, dtype: Optional[int] = None) -> TensorType:
4500+
def aten_sum(
4501+
self: TensorType, dim: Optional[int] = None, keepdim: bool = False, dtype: int = -1
4502+
) -> TensorType:
44884503
# sum(Tensor self, *, ScalarType? dtype=None) -> Tensor
44894504

44904505
raise NotImplementedError()
@@ -4903,10 +4918,12 @@ def aten_vdot(self: TensorType, other: TensorType) -> TensorType:
49034918
raise NotImplementedError()
49044919

49054920

4906-
def aten_view(self: TensorType, size: INT64) -> TensorType:
4921+
@torch_op("aten::view")
4922+
def aten_view(self: TTensor, size: INT64) -> TTensor:
49074923
# view(Tensor(a) self, SymInt[] size) -> Tensor(a)
49084924

4909-
raise NotImplementedError()
4925+
size = op.Cast(size, to=INT64.dtype) # Reshape only support INT64 as second input
4926+
return op.Reshape(self, size)
49104927

49114928

49124929
def aten_view_as(self: TensorType, other: TensorType) -> TensorType:

onnxscript/function_libs/torch_aten/ops/special.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -202,9 +202,7 @@ def aten_special_log_ndtr(self: TensorType) -> TensorType:
202202
raise NotImplementedError()
203203

204204

205-
def aten_special_log_softmax(
206-
self: TensorType, dim: int, dtype: Optional[int] = None
207-
) -> TensorType:
205+
def aten_special_log_softmax(self: TensorType, dim: int, dtype: int = -1) -> TensorType:
208206
# special_log_softmax(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor
209207

210208
raise NotImplementedError()

onnxscript/test/function_libs/torch_aten/ops_correctness_test.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,12 +204,17 @@ def _upsample_kwargs_wrangler(kwargs: dict[str, Any]) -> dict[str, Any]:
204204
"clamp_max": core_ops.aten_clamp_max,
205205
"clamp_min": core_ops.aten_clamp_min,
206206
"clamp": core_ops.aten_clamp,
207+
"clone": core_ops.aten_clone,
207208
"cos": core_ops.aten_cos,
208209
"cosh": core_ops.aten_cosh,
210+
"div": core_ops.aten_div,
209211
"dot": core_ops.aten_dot,
210-
"erf": core_ops.aten_erf,
212+
"eq": core_ops.aten_eq,
213+
"equal": core_ops.aten_equal,
211214
"exp": core_ops.aten_exp,
212215
"exp2": core_ops.aten_exp2,
216+
"expand": core_ops.aten_expand,
217+
"erf": core_ops.aten_erf,
213218
"fmod": core_ops.aten_fmod,
214219
# TODO(justinchuby): Test aten::full
215220
"full_like": core_ops.aten_full_like,
@@ -242,6 +247,7 @@ def _upsample_kwargs_wrangler(kwargs: dict[str, Any]) -> dict[str, Any]:
242247
"reciprocal": core_ops.aten_reciprocal,
243248
"remainder": core_ops.aten_remainder,
244249
"repeat": core_ops.aten_repeat,
250+
"reshape": core_ops.aten_reshape,
245251
"round": core_ops.aten_round,
246252
"rsqrt": core_ops.aten_rsqrt,
247253
"rsub": core_ops.aten_rsub,
@@ -256,6 +262,7 @@ def _upsample_kwargs_wrangler(kwargs: dict[str, Any]) -> dict[str, Any]:
256262
"tanh": core_ops.aten_tanh,
257263
"transpose": core_ops.aten_transpose,
258264
"unsqueeze": core_ops.aten_unsqueeze,
265+
"view": core_ops.aten_view,
259266
"where": core_ops.aten_where,
260267
"zeros": core_ops.aten_zeros,
261268
"zeros_like": core_ops.aten_zeros_like,
@@ -282,6 +289,16 @@ def _upsample_kwargs_wrangler(kwargs: dict[str, Any]) -> dict[str, Any]:
282289

283290

284291
SKIP_SUBTESTS: tuple[DecorateMeta, ...] = (
292+
skip(
293+
"div",
294+
matcher=lambda sample: sample.kwargs.get("rounding_mode") is not None,
295+
reason="rounding_mode is not yet supported",
296+
),
297+
skip(
298+
"expand",
299+
matcher=lambda sample: (np.array(sample.args[0]) > 0).all() is np.bool_(False),
300+
reason="Negative value is not supported",
301+
),
285302
skip(
286303
"nonzero",
287304
matcher=lambda sample: sample.kwargs.get("as_tuple") is not None,

0 commit comments

Comments
 (0)