Skip to content

Create the OpLike protocol and refactor Op | feat(values) #692

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 31 additions & 11 deletions onnxscript/function_libs/torch_lib/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,18 @@
import inspect
import textwrap
import types
from typing import Optional
import typing
from typing import Optional, Tuple

import onnx

import onnxscript
from onnxscript import converter as ons_converter
from onnxscript._internal import version_utils

if typing.TYPE_CHECKING:
from onnxscript import irbuilder

_ONNX_OP_SCHEMA_WRITABLE = not version_utils.onnx_older_than("1.14")


Expand All @@ -33,7 +37,7 @@ def _get_src_and_ast(f: types.FunctionType) -> tuple[str, ast.FunctionDef]:
return src, f_ast


class TraceOnlyFunction:
class TraceOnlyFunction(onnxscript.values.OpLike):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are the differences between a TraceOnlyFunction and an OnnxFunction? Eg., would it make sense for an OnnxFunction to be an extension (or derived class) of a TraceOnlyFunction?

Copy link
Collaborator Author

@justinchuby justinchuby Apr 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A TraceOnlyFunction is simply a python function with added attributes to make it more like an onnx function. To me it is a semantically different thing that happens to implement the same protocol for convenience. But I could be wrong.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me see if a TraceOnlyFunction can be derived from Op

"""TraceOnlyFunction.

Attributes:
Expand All @@ -44,9 +48,11 @@ class TraceOnlyFunction:
def __init__(self, opset: onnxscript.values.Opset, func: types.FunctionType):
self._opset = opset
self._func = func
self._opschema: Optional[onnx.defs.OpSchema] = None
# Set the signature of the class to function's
self.__signature__ = inspect.signature(func)
# Cached computed fields
self._opschema: Optional[onnx.defs.OpSchema] = None
self._param_schemas: Optional[Tuple[onnxscript.values.ParamSchema, ...]] = None

def __call__(self, *args, **kwargs):
return self._func(*args, **kwargs)
Expand All @@ -72,13 +78,32 @@ def opset(self) -> onnxscript.values.Opset:
@property
def opschema(self) -> Optional[onnx.defs.OpSchema]:
"""Return the opschema."""

if self._opschema is not None:
return self._opschema

if not _ONNX_OP_SCHEMA_WRITABLE:
return None

# FIXME(justinchuby): outputs are empty. Need to fix.
self._opschema = onnxscript.values.op_schema_from_function_ir(
self._function_ir(), self._opset
)

return self._opschema

def param_schemas(self) -> tuple[onnxscript.values.ParamSchema, ...]:
"""Generate param_schemas for the TraceOnlyFunction."""
if self._param_schemas is None:
self._param_schemas = onnxscript.values.param_schemas_from_function_ir(
self._function_ir()
)

return self._param_schemas

def _function_ir(self) -> irbuilder.IRFunction:
"""Return the IRFunction of the function.

This IRFunction contains only the function signature.
"""
src, func_ast = _get_src_and_ast(self._func)
module = inspect.getmodule(self._func)
closure = inspect.getclosurevars(self._func)
Expand All @@ -90,9 +115,4 @@ def opschema(self) -> Optional[onnx.defs.OpSchema]:
source=src,
)

function_ir = converter.translate_function_signature(func_ast)

# FIXME(justinchuby): outputs are empty. Need to fix.
self._opschema = onnxscript.values.op_schema_from_function_ir(function_ir, self._opset)

return self._opschema
return converter.translate_function_signature(func_ast)
4 changes: 2 additions & 2 deletions onnxscript/irbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def __str__(self):

args = _format(self.args, "(", ", ", ")", _opt_var_to_str)
domain = self.callee.opset.domain
opname = self.callee.opname
opname = self.callee.name
callee = f"{domain}.{opname}" if (domain != "") else opname
return f"{lhs} = {callee} {attrs}{args}"

Expand All @@ -212,7 +212,7 @@ def debug_print(self):

def to_node_proto(self, node_name: str) -> onnx.NodeProto:
n = helper.make_node(
self.callee.opname,
self.callee.name,
[_opt_var_to_str(x) for x in self.args],
[str(x) for x in self.result],
domain=self.callee.opset.domain,
Expand Down
196 changes: 122 additions & 74 deletions onnxscript/values.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
import inspect
import logging
import types
import typing
from enum import IntFlag
from typing import _GenericAlias # type: ignore[attr-defined]
from typing import Any, Optional, Sequence
from typing import Any, Optional, Protocol, Sequence

import onnx
import onnx.defs
Expand Down Expand Up @@ -162,21 +163,105 @@ def _get_attribute_value(attr_proto: onnx.AttributeProto) -> Any:
return onnx.helper.get_attribute_value(attr_proto)


class Op:
def param_schemas_from_op_schema(
op_schema: onnx.defs.OpSchema,
) -> tuple[ParamSchema, ...]:
"""Get the parameter schemas from an ONNX OpSchema."""
schemas = []
for input_ in op_schema.inputs:
param_schema = ParamSchema(
name=input_.name,
is_input=True,
required=(input_.option != onnx.defs.OpSchema.FormalParameterOption.Optional),
is_variadic_input=(
input_.option == onnx.defs.OpSchema.FormalParameterOption.Variadic
),
)
schemas.append(param_schema)
for attr_name, attribute in op_schema.attributes.items():
default_attr_proto = attribute.default_value
param_schema = ParamSchema(
name=attr_name,
type=_ATTRIBUTE_TYPE_TO_PYTHON_TYPE[attribute.type],
default=_get_attribute_value(default_attr_proto),
is_input=False,
required=attribute.required,
)
schemas.append(param_schema)

return tuple(schemas)


def param_schemas_from_function_ir(
function_ir: irbuilder.IRFunction,
) -> tuple[ParamSchema, ...]:
"""Get the parameter schemas from a FunctionIR."""
# The first len(func_ir.inputs) arguments are onnx inputs
# The rest is onnx attributes

schemas = []
for arg in function_ir.inputs:
if isinstance(arg.typeinfo, onnx.TypeProto.Optional):
required = False
else:
required = True
schemas.append(
ParamSchema(name=arg.name, type=arg.typeinfo, is_input=True, required=required)
)

for attr_parameter in function_ir.attrs:
schemas.append(
ParamSchema(
name=attr_parameter.name,
type=_ATTRIBUTE_TYPE_TO_PYTHON_TYPE.get(
onnx.defs.OpSchema.AttrType(attr_parameter.type) # type: ignore[call-arg]
),
default=_EmptyDefault
if attr_parameter.default_value is None
else attr_parameter.default_value,
is_input=False,
required=not attr_parameter.has_default,
)
)

return tuple(schemas)


@typing.runtime_checkable
class OpLike(Protocol):
"""A protocol for objects that have an ONNX OpSchema."""

@property
def name(self) -> str:
...

@property
def opset(self) -> Opset:
...

@property
def opschema(self) -> Optional[onnx.defs.OpSchema]:
...

def param_schemas(self) -> Optional[tuple[ParamSchema, ...]]:
...


class Op(OpLike):
"""Represents an ONNX op instance (for example, the MatMul op from ONNX opset version 13).
It belongs to a particular Opset and has a name.

Attributes:
opset: The Opset that this op belongs to.
opname: The name of the op.
name: The name of the op.
opschema: The ONNX OpSchema for the op.
"""

def __init__(
self, opset, opname: str, opschema: Optional[onnx.defs.OpSchema] = None
self, opset: Opset, opname: str, opschema: Optional[onnx.defs.OpSchema] = None
) -> None:
self.opset = opset
self.opname = opname
self._opset = opset
self._name = opname
self._opschema = opschema
self._param_schemas: Optional[tuple[ParamSchema, ...]] = None

Expand All @@ -187,10 +272,18 @@ def __call__(self, *args, **kwargs):
schema = self.opschema
if schema is None:
raise RuntimeError(
f"Op '{self.opname}' does not have an OpSchema and cannot be evaluated."
f"Op '{self.name}' does not have an OpSchema and cannot be evaluated."
)
return evaluator.default().eval(schema, args, kwargs)

@property
def name(self) -> str:
return self._name

@property
def opset(self) -> Opset:
return self._opset

@property
def opschema(self) -> Optional[onnx.defs.OpSchema]:
return self._opschema
Expand All @@ -207,30 +300,9 @@ def param_schemas(self) -> Optional[tuple[ParamSchema, ...]]:
op_schema = self.opschema
if op_schema is None:
return None
schemas = []
for input_ in op_schema.inputs:
param_schema = ParamSchema(
name=input_.name,
is_input=True,
required=(input_.option != onnx.defs.OpSchema.FormalParameterOption.Optional),
is_variadic_input=(
input_.option == onnx.defs.OpSchema.FormalParameterOption.Variadic
),
)
schemas.append(param_schema)
for attr_name, attribute in op_schema.attributes.items():
default_attr_proto = attribute.default_value
param_schema = ParamSchema(
name=attr_name,
type=_ATTRIBUTE_TYPE_TO_PYTHON_TYPE[attribute.type],
default=_get_attribute_value(default_attr_proto),
is_input=False,
required=attribute.required,
)
schemas.append(param_schema)

self._param_schemas = tuple(schemas)
return self._param_schemas # type: ignore[return-value]
self._param_schemas = param_schemas_from_op_schema(op_schema)
return self._param_schemas


@dataclasses.dataclass(repr=False, eq=False)
Expand Down Expand Up @@ -345,13 +417,14 @@ def op_schema_from_function_ir(
class OnnxFunction(Op):
"""Represents an ONNX op for which a function-body has been defined in onnxscript.

Args:
opset: opset the function belongs to
pyfun: python function
irfun: python code parsed by class
:class:`onnxscript.converter.Converter`
source: source code used to generate the function
kwargs: additional properties used to construct a ModelProto
Attributes:
opset: Opset the function belongs to.
name: Name of the function.
function: Python function.
function_ir: Python code parsed as an :class:`irbuilder.IRFunction`.
source: Source code used to generate the function.
kwargs: Additional properties used to construct a ModelProto.
opschema: Generated ONNX OpSchema for this op.
"""

def __init__(
Expand All @@ -362,6 +435,16 @@ def __init__(
source: str,
kwargs: dict[str, Any],
):
"""Constructs an OnnxFunction.

Args:
opset: opset the function belongs to
pyfun: python function
irfun: python code parsed by class
:class:`onnxscript.converter.Converter`
source: source code used to generate the function
kwargs: additional properties used to construct a ModelProto
"""
opset = opset or Opset(irfun.domain, 1)
super().__init__(opset, irfun.name)
self.function = pyfun
Expand All @@ -373,11 +456,6 @@ def __init__(
# Set the signature of the class to function's
self.__signature__ = inspect.signature(pyfun)

@property
def name(self):
"""Returns the function name."""
return self.opname

@property
def opschema(self) -> Optional[onnx.defs.OpSchema]:
"""Construct an OpSchema from function_ir."""
Expand Down Expand Up @@ -423,38 +501,8 @@ def param_schemas(self) -> tuple[ParamSchema, ...]:
# NOTE: We generate the parameter schemas from the function_ir instead
# of relying on the auto generated OpSchema because we need to preserve the keyword
# argument order from the Python function definition, which is lost in OpSchema.
function_ir = self.function_ir
# The first len(func_ir.inputs) arguments are onnx inputs
inputs = function_ir.inputs
# The rest is onnx attributes

schemas = []
for arg in inputs:
if isinstance(arg.typeinfo, onnx.TypeProto.Optional):
required = False
else:
required = True
schemas.append(
ParamSchema(name=arg.name, type=arg.typeinfo, is_input=True, required=required)
)

for attr_parameter in function_ir.attrs:
schemas.append(
ParamSchema(
name=attr_parameter.name,
type=_ATTRIBUTE_TYPE_TO_PYTHON_TYPE.get(
onnx.defs.OpSchema.AttrType(attr_parameter.type) # type: ignore[call-arg]
),
default=_EmptyDefault
if attr_parameter.default_value is None
else attr_parameter.default_value,
is_input=False,
required=not attr_parameter.has_default,
)
)

self._param_schemas = tuple(schemas)
return self._param_schemas # type: ignore[return-value]
self._param_schemas = param_schemas_from_function_ir(self.function_ir)
return self._param_schemas

def to_function_proto(self):
"""Converts the function into :class:`onnx.FunctionProto`."""
Expand Down