Skip to content

Commit b2d3d27

Browse files
authored
feat(atenlib): ops 2/n (#252)
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * #256 * #255 * __->__ #252
1 parent d99e4b8 commit b2d3d27

File tree

3 files changed

+75
-18
lines changed

3 files changed

+75
-18
lines changed

onnxscript/function_libs/torch_aten/ops/core.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2680,10 +2680,10 @@ def aten_masked_select_backward(
26802680
raise NotImplementedError()
26812681

26822682

2683-
def aten_matmul(self: TensorType, other: TensorType) -> TensorType:
2683+
def aten_matmul(self, other):
26842684
# matmul(Tensor self, Tensor other) -> Tensor
26852685

2686-
raise NotImplementedError()
2686+
return op.MatMul(self, other)
26872687

26882688

26892689
def aten_matmul_backward(
@@ -3080,10 +3080,11 @@ def aten_mkldnn_max_pool3d_backward(
30803080
raise NotImplementedError()
30813081

30823082

3083-
def aten_mm(self: TensorType, mat2: TensorType) -> TensorType:
3083+
def aten_mm(self, mat2):
30843084
# mm(Tensor self, Tensor mat2) -> Tensor
30853085

3086-
raise NotImplementedError()
3086+
# TODO(justinchuby): Specify type conversion for uint8/int8/int16
3087+
return op.MatMul(self, mat2)
30873088

30883089

30893090
def aten_mode(
@@ -3463,16 +3464,13 @@ def aten_nuclear_norm(self: TensorType, keepdim: bool = False) -> TensorType:
34633464
raise NotImplementedError()
34643465

34653466

3466-
def aten_numpy_T(self: TensorType) -> TensorType:
3467-
# numpy_T(Tensor(a) self) -> Tensor(a)
3468-
3469-
raise NotImplementedError()
3470-
3471-
3472-
def aten_ones(size: INT64) -> TensorType:
3467+
def aten_ones(size: INT64, dtype: int = -1) -> TensorType:
34733468
# ones(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
34743469

3475-
raise NotImplementedError()
3470+
one = op.Constant(value_float=1)
3471+
if dtype != -1:
3472+
one = op.Cast(one, to=dtype) # type: ignore[arg-type]
3473+
return op.Expand(one, size) # type: ignore[arg-type]
34763474

34773475

34783476
def aten_ones_like(self, dtype: int = -1):
@@ -4461,7 +4459,13 @@ def aten_symeig(
44614459
def aten_t(self: TensorType) -> TensorType:
44624460
# t(Tensor(a) self) -> Tensor(a)
44634461

4464-
raise NotImplementedError()
4462+
# TODO(justinchuby): Make rank a function
4463+
rank = op.Size(op.Shape(self)) # type: ignore[arg-type]
4464+
if rank == 0 or rank == 1: # pylint: disable=consider-using-in
4465+
result = self
4466+
else:
4467+
result = op.Transpose(self, perm=[1, 0]) # type: ignore[arg-type]
4468+
return result
44654469

44664470

44674471
def aten_t_copy(self: TensorType) -> TensorType:
@@ -4606,6 +4610,13 @@ def aten_trace_backward(grad: TensorType, sizes: INT64) -> TensorType:
46064610
raise NotImplementedError()
46074611

46084612

4613+
def aten_transpose(self, dim0: int, dim1: int):
4614+
# transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a)
4615+
4616+
# FIXME(justinchuby): onnxscript raises Unsupported expression type
4617+
return op.Transpose(self, [dim0, dim1])
4618+
4619+
46094620
def aten_triangular_solve(
46104621
self: TensorType,
46114622
A: TensorType,

onnxscript/function_libs/torch_aten/ops/nn.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -407,12 +407,21 @@ def aten_leaky_relu_backward(
407407
raise NotImplementedError()
408408

409409

410-
def aten_linear(
411-
input: TensorType, weight: TensorType, bias: Optional[TensorType] = None
412-
) -> TensorType:
410+
def aten_linear(input, weight, bias=None) -> TensorType:
413411
# linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor
414412

415-
raise NotImplementedError()
413+
# FIXME(justinchuby): Enable the test
414+
# INVALID_GRAPH : This is an invalid model.
415+
# In Node, ("", OptionalHasElement, "", -1) : () -> ("output0",) ,
416+
# Error Node () has input size 0 not in range [min=1, max=1]
417+
418+
# NOTE: The symbolic function in torch.onnx also uses Gemm in certain cases
419+
# Optimizers may consider this path and replace it with Gemm
420+
result = op.MatMul(input, weight)
421+
if op.OptionalHasElement(bias):
422+
bias = op.OptionalGetElement(bias)
423+
result = op.Add(result, bias) # type: ignore[arg-type]
424+
return result
416425

417426

418427
def aten_log_sigmoid(self: TensorType) -> TensorType:

onnxscript/test/function_libs/torch_aten/ops_correctness_test.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,35 +162,56 @@ def wrapped(fn):
162162
# Find the names of the OpInfos in torch/testing/_internal/common_methods_invocations.py
163163
OPINFO_FUNCTION_MAPPING: dict[str, Callable[..., Any]] = {
164164
"add": core_ops.aten_add,
165-
# "clamp": core_ops.aten_clamp, # TODO(justinchuby): Enable
166165
"clamp_max": core_ops.aten_clamp_max_tensor,
167166
"clamp_min": core_ops.aten_clamp_min_tensor,
167+
"clamp": core_ops.aten_clamp,
168168
"gt": core_ops.aten_gt,
169169
"lt": core_ops.aten_lt,
170+
"matmul": core_ops.aten_matmul,
171+
"mm": core_ops.aten_mm,
170172
"mul": core_ops.aten_mul,
171173
"nn.functional.elu": nn_ops.aten_elu,
174+
"nn.functional.linear": nn_ops.aten_linear,
172175
"nn.functional.relu6": nn_ops.aten_relu6,
173176
"nn.functional.selu": core_ops.aten_selu,
174177
"ones_like": core_ops.aten_ones_like,
178+
"ones": core_ops.aten_ones,
175179
"repeat": core_ops.aten_repeat,
176180
"round": core_ops.aten_round,
177181
"sub": core_ops.aten_sub,
182+
"t": core_ops.aten_t,
183+
# "transpose": core_ops.aten_transpose, # TODO(justinchuby): Enable when onnxscript errors are fixed
178184
}
179185

180186
TESTED_OPS = frozenset(OPINFO_FUNCTION_MAPPING)
181187

182188
EXPECTED_SKIPS_OR_FAILS = (
183189
xfail("add", dtypes=BOOL_TYPES, reason="Add is not defined on bool tensors"),
190+
skip("clamp", reason="Enable when onnxscript errors are fixed"),
184191
xfail("clamp_max", dtypes=BOOL_TYPES, reason="Min is not defined on bool tensors"),
185192
xfail("clamp_min", dtypes=BOOL_TYPES, reason="Max is not defined on bool tensors"),
186193
xfail("gt", dtypes=BOOL_TYPES, reason="Greater is not defined on bool tensors"),
187194
xfail("lt", dtypes=BOOL_TYPES, reason="Less is not defined on bool tensors"),
195+
xfail(
196+
"matmul",
197+
dtypes=[torch.uint8, torch.int8, torch.int16],
198+
reason="MatMul is not defined on int16/int8/uint8 tensors",
199+
),
200+
xfail(
201+
"mm",
202+
dtypes=[torch.uint8, torch.int8, torch.int16],
203+
reason="MatMul is not defined on int16/int8/uint8 tensors",
204+
),
188205
xfail("mul", dtypes=BOOL_TYPES, reason="Mul is not defined on bool tensors"),
189206
xfail(
190207
"nn.functional.elu",
191208
dtypes=dtypes_except(torch.float16, torch.float32),
192209
reason="ONNX Runtime doesn't support float64 for Elu",
193210
),
211+
xfail(
212+
"nn.functional.linear",
213+
reason="ONNX Runtime thinks the graph is invalid",
214+
),
194215
xfail(
195216
"nn.functional.relu6",
196217
dtypes=dtypes_except(torch.float16, torch.float32),
@@ -213,6 +234,7 @@ def wrapped(fn):
213234
"round", variant_name="decimals_neg_3", reason="The ATen op does not support decimals"
214235
),
215236
xfail("sub", dtypes=BOOL_TYPES, reason="Sub is not defined on bool tensors"),
237+
xfail("transpose", reason="Enable when onnxscript errors are fixed"),
216238
)
217239

218240

@@ -240,6 +262,10 @@ def wrapped(fn):
240262

241263
OPS_DB = copy.deepcopy(common_methods_invocations.op_db)
242264

265+
ALL_OPS_IN_DB = frozenset(op_info.name for op_info in OPS_DB)
266+
# Assert all ops in OPINFO_FUNCTION_MAPPING are in the OPS_DB
267+
assert TESTED_OPS.issubset(ALL_OPS_IN_DB), f"{TESTED_OPS - ALL_OPS_IN_DB} not in OPS_DB"
268+
243269

244270
TORCH_TYPE_TO_ONNX = {
245271
torch.bool: onnx.TensorProto.BOOL,
@@ -369,10 +395,21 @@ def test_output_match(self, device: str, dtype: torch.dtype, op):
369395
)
370396
# pylint: enable=c-extension-no-member
371397

398+
if dtype == torch.float32:
399+
# Relax atol and rtol for float32 based on empirical results
400+
# The current most relaxed values are for aten::matmul
401+
rtol = 3.7e-6
402+
atol = 1.8e-5
403+
else:
404+
rtol = None
405+
atol = None
406+
372407
# Use torch testing to ensure dtypes and shapes match
373408
torch.testing.assert_close(
374409
torch.tensor(function_output),
375410
output_torch,
411+
rtol=rtol,
412+
atol=atol,
376413
)
377414

378415

0 commit comments

Comments
 (0)