Skip to content
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
55f8dd9
fix: annotate script()
justinchuby Dec 9, 2022
8a3a587
feat(atenlib): clamp, lt, gt
justinchuby Dec 9, 2022
f821b6a
Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 9, 2022
aecc148
Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 10, 2022
6555a55
Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 10, 2022
f8385b0
Update base for Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 10, 2022
468f86f
Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 10, 2022
47b8380
Update base for Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 10, 2022
00f1760
Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 10, 2022
060f9db
Update base for Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 10, 2022
497cb16
Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 10, 2022
9bb4038
Update base for Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 12, 2022
d24110a
Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 12, 2022
cbfb867
Update base for Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 12, 2022
875f235
Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 12, 2022
27008e1
Update base for Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 12, 2022
3a8737d
Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 12, 2022
c5871c8
Update base for Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 12, 2022
012905c
Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 12, 2022
49be5ec
Update base for Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 13, 2022
3a9c5f6
Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 13, 2022
d4f09e8
Update base for Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 13, 2022
ee3143e
Update on "feat(atenlib): implement aten functions 1/n"
justinchuby Dec 13, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion onnxscript/backend/onnx_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def _read_proto_from_file(full):
seq = onnx.SequenceProto()
try:
seq.ParseFromString(serialized)
loaded = to_list(seq)
loaded = to_list(seq) # type: ignore[assignment]
except Exception: # pylint: disable=W0703
try:
loaded = onnx.load_model_from_string(serialized)
Expand Down
5 changes: 4 additions & 1 deletion onnxscript/backend/onnx_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,10 @@ def _python_make_node_graph(self, graph, opsets, indent=0, output_names=None):
if hasattr(graph, "initializer"):
for init in graph.initializer:
node = make_node(
"Constant", [], [self._rename_variable(init.name)], value=init
"Constant",
[],
[self._rename_variable(init.name)], # type: ignore[list-item]
value=init,
)
code.append(self._python_make_node(node, opsets, indent=indent))
if hasattr(graph, "sparse_initializer") and len(graph.sparse_initializer) > 0:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def format_arg_name(arg: torchgen.model.Argument) -> str:
"""Returns the python compatible name of the given argument."""
if arg.name == "from":
return f"{arg.name}_"
return arg.name # type: ignore[no-any-return]
return arg.name


def create_signature(func: torchgen.model.NativeFunction) -> cg.FunctionDef:
Expand Down
68 changes: 54 additions & 14 deletions onnxscript/function_libs/torch_aten/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

from typing import Any, Optional, Sequence

import onnx
Comment thread Fixed
Comment thread Fixed

from onnxscript import BOOL, INT64
from onnxscript.onnx_opset import opset18 as op
from onnxscript.onnx_types import TensorType
Expand Down Expand Up @@ -747,16 +749,31 @@ def aten_clamp(
raise NotImplementedError()


def aten_clamp_max(self: TensorType, max: float) -> TensorType:
def aten_clamp_max_scalar(self, max_):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to validate the shape of max_ to know it is a scalar or tensor instead of having such information in the function name?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't thought of a good way. Any suggestions?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, for validation, we can specify it in the signature / via some data structure. We can discuss today.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They still need to be two functions though.

# clamp_max(Tensor self, Scalar max) -> Tensor

raise NotImplementedError()
max_ = op.CastLike(max_, self)
return op.Clip(self, None, max_)


def aten_clamp_max_tensor(self, max_):
# clamp_max(Tensor self, Scalar max) -> Tensor

return op.Min(self, max_)


def aten_clamp_min(self: TensorType, min: float) -> TensorType:
def aten_clamp_min_scalar(self, min_):
# clamp_min(Tensor self, Scalar min) -> Tensor
# NOTE: min_ is a rank 0 tensor.
# TODO(justinchuby): Specify the type constraints.
min_ = op.CastLike(min_, self)
return op.Clip(self, min_, None)

raise NotImplementedError()

def aten_clamp_min_tensor(self, min_):
# clamp_min(Tensor self, Tensor min) -> Tensor
# TODO(justinchuby): Specify the type constraints.
return op.Max(self, min_)


def aten_clip(
Expand Down Expand Up @@ -1958,10 +1975,12 @@ def aten_gru_cell(
raise NotImplementedError()


def aten_gt(self: TensorType, other: TensorType) -> TensorType:
def aten_gt(self, other):
# gt.Tensor(Tensor self, Tensor other) -> Tensor

raise NotImplementedError()
# TODO(justinchuby): Input spec: non bool tensor
# Boolean inputs can be pre-casted by policy
return op.Greater(self, other)


def aten_hamming_window(window_length: int) -> TensorType:
Expand Down Expand Up @@ -2572,10 +2591,12 @@ def aten_lstm_mps_backward(
raise NotImplementedError()


def aten_lt(self: TensorType, other: TensorType) -> TensorType:
def aten_lt(self, other):
# lt.Tensor(Tensor self, Tensor other) -> Tensor

raise NotImplementedError()
# TODO(justinchuby): Input spec: non bool tensor
# Boolean inputs can be pre-casted by policy
return op.Less(self, other)


def aten_lu_solve(self: TensorType, LU_data: TensorType, LU_pivots: TensorType) -> TensorType:
Expand Down Expand Up @@ -3440,10 +3461,20 @@ def aten_ones(size: INT64) -> TensorType:
raise NotImplementedError()


def aten_ones_like(self: TensorType, memory_format: Optional[str] = None) -> TensorType:
def aten_ones_like(self, dtype: int = -1):
"""ones_like.

Note: dtype is an onnx enum. Users should convert torch dtype to onnx dtype
before calling this function.
"""
# ones_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor

raise NotImplementedError()
shape = op.Shape(self)
if dtype == -1:
one = op.CastLike(1, self) # type: ignore[arg-type]
else:
one = op.Cast(1, to=dtype) # type: ignore[arg-type]
return op.Expand(one, shape)


def aten_or(self: TensorType, other: TensorType) -> TensorType:
Expand Down Expand Up @@ -3916,10 +3947,19 @@ def aten_renorm(self: TensorType, p: float, dim: int, maxnorm: float) -> TensorT
raise NotImplementedError()


def aten_repeat(self: TensorType, repeats: INT64) -> TensorType:
def aten_repeat(self, repeats: INT64):
# repeat(Tensor self, SymInt[] repeats) -> Tensor

raise NotImplementedError()
# FIXME(justinchuby): When repeats.shape == [0]
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an example where we need shape dependent logic.


# TODO(justinchuby): Make ones_like a function when onnxscript supports it
# shape = ones_like(repeats) := {
one = op.Constant(value_int=1)
repeats_shape = op.Shape(repeats)
shape = op.Expand(one, repeats_shape)
# }
self_expanded = op.Expand(self, shape) # type: ignore[arg-type]
return op.Tile(self_expanded, repeats)


def aten_repeat_interleave(
Expand Down Expand Up @@ -4012,10 +4052,10 @@ def aten_rot90(self: TensorType, k: int = 1, dims: Sequence[int] = (0, 1)) -> Te
raise NotImplementedError()


def aten_round(self: TensorType) -> TensorType:
def aten_round(self):
# round(Tensor self) -> Tensor

raise NotImplementedError()
return op.Round(self)


def aten_row_indices(self: TensorType) -> TensorType:
Expand Down
14 changes: 8 additions & 6 deletions onnxscript/irbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import io
import logging
import warnings
from typing import Any, Dict, Optional, Sequence, Tuple, Union
from typing import Any, Optional, Sequence

import onnx
from onnx import ValueInfoProto, helper
Expand Down Expand Up @@ -185,7 +185,7 @@ def __init__(self, name: str, domain: str = "") -> None:
self.stmts: list[IRStmt] = []
self.attrs: list[str] = [] # attribute parameters
self.attr_protos: list[
onnx.AttributeProto
IRAttributeValue
] = [] # attribute parameters with default value
self.called_functions: dict[str, onnx.FunctionProto] = {}
self.docstring: str = ""
Expand Down Expand Up @@ -218,7 +218,7 @@ def append_input(self, name: IRVar) -> None:
def append_output(self, name: IRVar) -> None:
self.outputs.append(name)

def add_attr_parameter(self, attr: Union[str, IRAttributeValue]) -> None:
def add_attr_parameter(self, attr: str | IRAttributeValue) -> None:
if isinstance(attr, IRAttributeValue):
self.attr_protos.append(attr)
else:
Expand Down Expand Up @@ -324,7 +324,7 @@ def to_proto(f):

def to_graph_and_functions(
self, use_default_type: bool = True
) -> Tuple[onnx.GraphProto, dict[str, onnx.FunctionProto]]:
) -> tuple[onnx.GraphProto, dict[str, onnx.FunctionProto]]:
"""Converts this instance into a `onnx.GraphProto` and a map from
function-name to `onnx.FunctionProto`.

Expand Down Expand Up @@ -360,7 +360,7 @@ def to_graph_proto(self, use_default_type: bool = True) -> onnx.GraphProto:
graph, _ = self.to_graph_and_functions(use_default_type=use_default_type)
return graph

def get_opset_import(self) -> Dict[str, int]:
def get_opset_import(self) -> dict[str, int]:
func_opset_imports = {}
for s in self.stmts:
if s.callee.opset.domain not in func_opset_imports:
Expand Down Expand Up @@ -472,5 +472,7 @@ def make_attr_ref(self, attrname: str, refname: str, pytype: type) -> IRAttribut
a = onnx.AttributeProto()
a.name = attrname
a.ref_attr_name = refname
a.type = ta.pytype_to_attrtype(pytype)
type_ = ta.pytype_to_attrtype(pytype)
assert type_ is not None
a.type = type_
return IRAttributeValue(a)
7 changes: 6 additions & 1 deletion onnxscript/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import inspect
import sys
import textwrap
from typing import Any, Callable, Optional

import onnx.helper

Expand Down Expand Up @@ -52,7 +53,11 @@ def script_check(f: ast.FunctionDef, opset, global_names, source, default_opset=
return convert.top_level_stmt(f)


def script(opset=None, default_opset=None, **kwargs):
def script(
opset: Optional[values.Opset] = None,
default_opset: Optional[values.Opset] = None,
**kwargs: Any,
) -> Callable[[Callable[..., Any]], onnxscript.OnnxFunction]:
"""Main decorator. Declares a function as an onnx function.

Args:
Expand Down
2 changes: 1 addition & 1 deletion onnxscript/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __float__(self) -> float:
return float(self.value)

def __index__(self) -> int:
return self.value.__index__() # type: ignore[no-any-return]
return self.value.__index__()

def __getitem__(self, index):
op = self._opset
Expand Down
4 changes: 2 additions & 2 deletions onnxscript/test/common/onnx_script_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,11 @@ def setUpClass(cls):
try:
# experimental version
# pylint: disable=no-value-for-parameter
cls.all_test_cases = node_test.collect_testcases() # type: ignore[attr-defined]
cls.all_test_cases = node_test.collect_testcases() # type: ignore[attr-defined,call-arg]
# pylint: enable=no-value-for-parameter
except TypeError:
# official version
cls.all_test_cases = node_test.collect_testcases(None) # type: ignore[attr-defined]
cls.all_test_cases = node_test.collect_testcases(None) # type: ignore[attr-defined,arg-type]

def _create_model_from_param(
self, param: FunctionTestParams, onnx_case_model: onnx.ModelProto
Expand Down
Loading