Skip to content

Commit eba70a6

Browse files
committed
feat(atenlib): ops 2/n
ghstack-source-id: 23e8ab8 Pull Request resolved: #252
1 parent d99e4b8 commit eba70a6

File tree

3 files changed

+43
-14
lines changed

3 files changed

+43
-14
lines changed

onnxscript/function_libs/torch_aten/ops/core.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3463,16 +3463,13 @@ def aten_nuclear_norm(self: TensorType, keepdim: bool = False) -> TensorType:
34633463
raise NotImplementedError()
34643464

34653465

3466-
def aten_numpy_T(self: TensorType) -> TensorType:
3467-
# numpy_T(Tensor(a) self) -> Tensor(a)
3468-
3469-
raise NotImplementedError()
3470-
3471-
3472-
def aten_ones(size: INT64) -> TensorType:
3466+
def aten_ones(size: INT64, dtype: int = -1) -> TensorType:
34733467
# ones(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
34743468

3475-
raise NotImplementedError()
3469+
one = op.Constant(value_float=1)
3470+
if dtype != -1:
3471+
one = op.Cast(one, to=dtype) # type: ignore[arg-type]
3472+
return op.Expand(one, size) # type: ignore[arg-type]
34763473

34773474

34783475
def aten_ones_like(self, dtype: int = -1):
@@ -4461,7 +4458,13 @@ def aten_symeig(
44614458
def aten_t(self: TensorType) -> TensorType:
44624459
# t(Tensor(a) self) -> Tensor(a)
44634460

4464-
raise NotImplementedError()
4461+
# TODO(justinchuby): Make rank a function
4462+
rank = op.Shape(op.Shape(self))
4463+
if rank == 0 or rank == 1:
4464+
result = self
4465+
else:
4466+
result = op.Transpose(self, perm=[1, 0])
4467+
return result
44654468

44664469

44674470
def aten_t_copy(self: TensorType) -> TensorType:
@@ -4606,6 +4609,13 @@ def aten_trace_backward(grad: TensorType, sizes: INT64) -> TensorType:
46064609
raise NotImplementedError()
46074610

46084611

4612+
def aten_transpose(self, dim0: int, dim1: int):
4613+
# transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a)
4614+
4615+
# FIXME(justinchuby): onnxscript raises Unsupported expression type
4616+
return op.Transpose(self, [dim0, dim1])
4617+
4618+
46094619
def aten_triangular_solve(
46104620
self: TensorType,
46114621
A: TensorType,

onnxscript/function_libs/torch_aten/ops/nn.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -407,12 +407,17 @@ def aten_leaky_relu_backward(
407407
raise NotImplementedError()
408408

409409

410-
def aten_linear(
411-
input: TensorType, weight: TensorType, bias: Optional[TensorType] = None
412-
) -> TensorType:
410+
def aten_linear(input, weight, bias=None) -> TensorType:
413411
# linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor
414412

415-
raise NotImplementedError()
413+
result = op.MatMul(input, weight)
414+
if op.OptionalHasElement(bias):
415+
# FIXME(justinchuby): INVALID_GRAPH : This is an invalid model.
416+
# In Node, ("", OptionalHasElement, "", -1) : () -> ("output0",) ,
417+
# Error Node () has input size 0 not in range [min=1, max=1]
418+
bias = op.OptionalGetElement(bias)
419+
result = op.Add(result, bias) # type: ignore[arg-type]
420+
return result
416421

417422

418423
def aten_log_sigmoid(self: TensorType) -> TensorType:

onnxscript/test/function_libs/torch_aten/ops_correctness_test.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,25 +162,30 @@ def wrapped(fn):
162162
# Find the names of the OpInfos in torch/testing/_internal/common_methods_invocations.py
163163
OPINFO_FUNCTION_MAPPING: dict[str, Callable[..., Any]] = {
164164
"add": core_ops.aten_add,
165-
# "clamp": core_ops.aten_clamp, # TODO(justinchuby): Enable
166165
"clamp_max": core_ops.aten_clamp_max_tensor,
167166
"clamp_min": core_ops.aten_clamp_min_tensor,
167+
"clamp": core_ops.aten_clamp,
168168
"gt": core_ops.aten_gt,
169169
"lt": core_ops.aten_lt,
170170
"mul": core_ops.aten_mul,
171171
"nn.functional.elu": nn_ops.aten_elu,
172+
"nn.functional.linear": nn_ops.aten_linear,
172173
"nn.functional.relu6": nn_ops.aten_relu6,
173174
"nn.functional.selu": core_ops.aten_selu,
174175
"ones_like": core_ops.aten_ones_like,
176+
"ones": core_ops.aten_ones,
175177
"repeat": core_ops.aten_repeat,
176178
"round": core_ops.aten_round,
177179
"sub": core_ops.aten_sub,
180+
"t": core_ops.aten_t,
181+
# "transpose": core_ops.aten_transpose, # TODO(justinchuby): Enable when onnxscript errors are fixed
178182
}
179183

180184
TESTED_OPS = frozenset(OPINFO_FUNCTION_MAPPING)
181185

182186
EXPECTED_SKIPS_OR_FAILS = (
183187
xfail("add", dtypes=BOOL_TYPES, reason="Add is not defined on bool tensors"),
188+
skip("clamp", reason="Enable when onnxscript errors are fixed"),
184189
xfail("clamp_max", dtypes=BOOL_TYPES, reason="Min is not defined on bool tensors"),
185190
xfail("clamp_min", dtypes=BOOL_TYPES, reason="Max is not defined on bool tensors"),
186191
xfail("gt", dtypes=BOOL_TYPES, reason="Greater is not defined on bool tensors"),
@@ -191,6 +196,10 @@ def wrapped(fn):
191196
dtypes=dtypes_except(torch.float16, torch.float32),
192197
reason="ONNX Runtime doesn't support float64 for Elu",
193198
),
199+
xfail(
200+
"nn.functional.linear",
201+
reason="ONNX Runtime thinks the graph is invalid",
202+
),
194203
xfail(
195204
"nn.functional.relu6",
196205
dtypes=dtypes_except(torch.float16, torch.float32),
@@ -213,6 +222,7 @@ def wrapped(fn):
213222
"round", variant_name="decimals_neg_3", reason="The ATen op does not support decimals"
214223
),
215224
xfail("sub", dtypes=BOOL_TYPES, reason="Sub is not defined on bool tensors"),
225+
xfail("transpose", reason="Enable when onnxscript errors are fixed"),
216226
)
217227

218228

@@ -240,6 +250,10 @@ def wrapped(fn):
240250

241251
OPS_DB = copy.deepcopy(common_methods_invocations.op_db)
242252

253+
ALL_OPS_IN_DB = frozenset(op_info.name for op_info in OPS_DB)
254+
# Assert all ops in OPINFO_FUNCTION_MAPPING are in the OPS_DB
255+
assert TESTED_OPS.issubset(ALL_OPS_IN_DB), f"{TESTED_OPS - ALL_OPS_IN_DB} not in OPS_DB"
256+
243257

244258
TORCH_TYPE_TO_ONNX = {
245259
torch.bool: onnx.TensorProto.BOOL,

0 commit comments

Comments
 (0)