Skip to content

Commit 677ca1f

Browse files
authored
feat(atenlib): where, full (#286)
Implement where, full, new_full, full_like
1 parent a733ec2 commit 677ca1f

File tree

2 files changed

+26
-11
lines changed

2 files changed

+26
-11
lines changed

onnxscript/function_libs/torch_aten/ops/core.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1841,18 +1841,23 @@ def aten_from_file(
18411841
raise NotImplementedError()
18421842

18431843

1844-
def aten_full(size: INT64, fill_value: float) -> TensorType:
1844+
@torch_op("aten::full")
1845+
def aten_full(size: INT64, fill_value, dtype: int = FLOAT.dtype):
18451846
# full(SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
18461847

1847-
raise NotImplementedError()
1848+
fill_value = op.Cast(fill_value, to=dtype)
18481849

1850+
return op.Expand(fill_value, size)
18491851

1850-
def aten_full_like(
1851-
self: TensorType, fill_value: float, memory_format: Optional[str] = None
1852-
) -> TensorType:
1852+
1853+
@torch_op("aten::full_like")
1854+
def aten_full_like(self, fill_value, dtype: int = FLOAT.dtype):
18531855
# full_like(Tensor self, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
18541856

1855-
raise NotImplementedError()
1857+
fill_value = op.Cast(fill_value, to=dtype)
1858+
self_shape = op.Shape(self)
1859+
1860+
return op.Expand(fill_value, self_shape)
18561861

18571862

18581863
def aten_fused_moving_avg_obs_fake_quant(
@@ -3447,10 +3452,15 @@ def aten_new_empty_strided(self: TensorType, size: INT64, stride: INT64) -> Tens
34473452
raise NotImplementedError()
34483453

34493454

3450-
def aten_new_full(self: TensorType, size: INT64, fill_value: float) -> TensorType:
3455+
@torch_op("aten::new_full")
3456+
def aten_new_full(
3457+
self, size: INT64, fill_value, dtype: int = FLOAT.dtype
3458+
): # pylint: disable=unused-argument
34513459
# new_full(Tensor self, SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
34523460

3453-
raise NotImplementedError()
3461+
fill_value = op.Cast(fill_value, to=dtype)
3462+
3463+
return op.Expand(fill_value, size)
34543464

34553465

34563466
def aten_new_ones(self: TensorType, size: INT64) -> TensorType:
@@ -4928,10 +4938,11 @@ def aten_vstack(tensors: Sequence[TensorType]) -> TensorType:
49284938
raise NotImplementedError()
49294939

49304940

4931-
def aten_where(condition: TensorType) -> TensorType:
4932-
# where(Tensor condition) -> Tensor[]
4941+
@torch_op("aten::where")
4942+
def aten_where(self: TTensor, condition: BOOL, other: TTensor) -> TTensor:
4943+
# where.self(Tensor condition, Tensor self, Tensor other) -> Tensor
49334944

4934-
raise NotImplementedError()
4945+
return op.Where(condition, self, other)
49354946

49364947

49374948
def aten_xlogy(self: TensorType, other: TensorType) -> TensorType:

onnxscript/test/function_libs/torch_aten/ops_correctness_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,8 @@ def wrapped(fn):
169169
"exp": core_ops.aten_exp,
170170
"exp2": core_ops.aten_exp2,
171171
"fmod": core_ops.aten_fmod,
172+
# TODO(justinchuby): Test aten::full
173+
"full_like": core_ops.aten_full_like,
172174
"gt": core_ops.aten_gt,
173175
"isinf": core_ops.aten_isinf,
174176
"lt": core_ops.aten_lt,
@@ -177,6 +179,7 @@ def wrapped(fn):
177179
"mul": core_ops.aten_mul,
178180
"ne": core_ops.aten_ne,
179181
"neg": core_ops.aten_neg,
182+
"new_full": core_ops.aten_new_full,
180183
"nn.functional.elu": nn_ops.aten_elu,
181184
"nn.functional.leaky_relu": nn_ops.aten_leaky_relu,
182185
"nn.functional.linear": nn_ops.aten_linear,
@@ -202,6 +205,7 @@ def wrapped(fn):
202205
"tanh": core_ops.aten_tanh,
203206
# "transpose": core_ops.aten_transpose, # TODO(justinchuby): Enable when onnxscript errors are fixed,
204207
"unsqueeze": core_ops.aten_unsqueeze,
208+
"where": core_ops.aten_where,
205209
"zeros": core_ops.aten_zeros,
206210
"zeros_like": core_ops.aten_zeros_like,
207211
}

0 commit comments

Comments
 (0)