Skip to content

Commit 8bd6dbd

Browse files
Add 2 ops zeros and zeros_like. (#251)
Try to implement 2 ops in aten lib functions: zeros and zeros_like. Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
1 parent 2512120 commit 8bd6dbd

File tree

2 files changed

+17
-5
lines changed

2 files changed

+17
-5
lines changed

onnxscript/function_libs/torch_aten/ops/core.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4858,13 +4858,23 @@ def aten_xor(self: TensorType, other: TensorType) -> TensorType:
48584858
raise NotImplementedError()
48594859

48604860

4861-
def aten_zeros(size: INT64) -> TensorType:
4861+
def aten_zeros(size, dtype: int = -1):
48624862
# zeros(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
48634863

4864-
raise NotImplementedError()
4864+
zero = op.Constant(value_float=0)
4865+
if dtype != -1:
4866+
zero = op.Cast(zero, to=dtype) # type: ignore[arg-type]
48654867

4868+
return op.Expand(zero, size) # type: ignore[arg-type]
48664869

4867-
def aten_zeros_like(self: TensorType, memory_format: Optional[str] = None) -> TensorType:
4870+
4871+
def aten_zeros_like(self, dtype: int = -1):
48684872
# zeros_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
48694873

4870-
raise NotImplementedError()
4874+
shape = op.Shape(self)
4875+
if dtype == -1:
4876+
zero = op.CastLike(0, self) # type: ignore[arg-type]
4877+
else:
4878+
zero = op.Cast(0, to=dtype) # type: ignore[arg-type]
4879+
4880+
return op.Expand(zero, shape)

onnxscript/test/function_libs/torch_aten/ops_correctness_test.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,9 @@ def wrapped(fn):
199199
"t": core_ops.aten_t,
200200
"tan": core_ops.aten_tan,
201201
"tanh": core_ops.aten_tanh,
202-
# "transpose": core_ops.aten_transpose, # TODO(justinchuby): Enable when onnxscript errors are fixed
202+
# "transpose": core_ops.aten_transpose, # TODO(justinchuby): Enable when onnxscript errors are fixed,
203+
"zeros": core_ops.aten_zeros,
204+
"zeros_like": core_ops.aten_zeros_like,
203205
}
204206

205207
TESTED_OPS = frozenset(OPINFO_FUNCTION_MAPPING)

0 commit comments

Comments
 (0)