Skip to content

Commit 4c3dea1

Browse files
committed
Create the OpLike protocol and refactor Op | feat(values)
- Removes `is_single_op` because it is unused. Signed-off-by: Justin Chu <justinchumicrosoft.com> ghstack-source-id: 2df23c2 Pull Request resolved: #692
1 parent 25ea7dd commit 4c3dea1

File tree

3 files changed

+153
-89
lines changed

3 files changed

+153
-89
lines changed

onnxscript/function_libs/torch_lib/tracing.py

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,18 @@
55
import inspect
66
import textwrap
77
import types
8-
from typing import Optional
8+
import typing
9+
from typing import Optional, Tuple
910

1011
import onnx
1112

1213
import onnxscript
1314
from onnxscript import converter as ons_converter
1415
from onnxscript._internal import version_utils
1516

17+
if typing.TYPE_CHECKING:
18+
from onnxscript import irbuilder
19+
1620
_ONNX_OP_SCHEMA_WRITABLE = not version_utils.onnx_older_than("1.14")
1721

1822

@@ -33,7 +37,7 @@ def _get_src_and_ast(f: types.FunctionType) -> tuple[str, ast.FunctionDef]:
3337
return src, f_ast
3438

3539

36-
class TraceOnlyFunction:
40+
class TraceOnlyFunction(onnxscript.values.OpLike):
3741
"""TraceOnlyFunction.
3842
3943
Attributes:
@@ -44,9 +48,11 @@ class TraceOnlyFunction:
4448
def __init__(self, opset: onnxscript.values.Opset, func: types.FunctionType):
4549
self._opset = opset
4650
self._func = func
47-
self._opschema: Optional[onnx.defs.OpSchema] = None
4851
# Set the signature of the class to function's
4952
self.__signature__ = inspect.signature(func)
53+
# Cached computed fields
54+
self._opschema: Optional[onnx.defs.OpSchema] = None
55+
self._param_schemas: Optional[Tuple[onnxscript.values.ParamSchema, ...]] = None
5056

5157
def __call__(self, *args, **kwargs):
5258
return self._func(*args, **kwargs)
@@ -72,13 +78,32 @@ def opset(self) -> onnxscript.values.Opset:
7278
@property
7379
def opschema(self) -> Optional[onnx.defs.OpSchema]:
7480
"""Return the opschema."""
75-
7681
if self._opschema is not None:
7782
return self._opschema
78-
7983
if not _ONNX_OP_SCHEMA_WRITABLE:
8084
return None
8185

86+
# FIXME(justinchuby): outputs are empty. Need to fix.
87+
self._opschema = onnxscript.values.op_schema_from_function_ir(
88+
self._function_ir(), self._opset
89+
)
90+
91+
return self._opschema
92+
93+
def param_schemas(self) -> tuple[onnxscript.values.ParamSchema, ...]:
94+
"""Generate param_schemas for the TraceOnlyFunction."""
95+
if self._param_schemas is None:
96+
self._param_schemas = onnxscript.values.param_schemas_from_function_ir(
97+
self._function_ir()
98+
)
99+
100+
return self._param_schemas
101+
102+
def _function_ir(self) -> irbuilder.IRFunction:
103+
"""Return the IRFunction of the function.
104+
105+
This IRFunction contains only the function signature.
106+
"""
82107
src, func_ast = _get_src_and_ast(self._func)
83108
module = inspect.getmodule(self._func)
84109
closure = inspect.getclosurevars(self._func)
@@ -90,9 +115,4 @@ def opschema(self) -> Optional[onnx.defs.OpSchema]:
90115
source=src,
91116
)
92117

93-
function_ir = converter.translate_function_signature(func_ast)
94-
95-
# FIXME(justinchuby): outputs are empty. Need to fix.
96-
self._opschema = onnxscript.values.op_schema_from_function_ir(function_ir, self._opset)
97-
98-
return self._opschema
118+
return converter.translate_function_signature(func_ast)

onnxscript/irbuilder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def __str__(self):
202202

203203
args = _format(self.args, "(", ", ", ")", _opt_var_to_str)
204204
domain = self.callee.opset.domain
205-
opname = self.callee.opname
205+
opname = self.callee.name
206206
callee = f"{domain}.{opname}" if (domain != "") else opname
207207
return f"{lhs} = {callee} {attrs}{args}"
208208

@@ -212,7 +212,7 @@ def debug_print(self):
212212

213213
def to_node_proto(self, node_name: str) -> onnx.NodeProto:
214214
n = helper.make_node(
215-
self.callee.opname,
215+
self.callee.name,
216216
[_opt_var_to_str(x) for x in self.args],
217217
[str(x) for x in self.result],
218218
domain=self.callee.opset.domain,

onnxscript/values.py

Lines changed: 120 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -163,21 +163,105 @@ def _get_attribute_value(attr_proto: onnx.AttributeProto) -> Any:
163163
return onnx.helper.get_attribute_value(attr_proto)
164164

165165

166-
class Op:
166+
def param_schemas_from_op_schema(
167+
op_schema: onnx.defs.OpSchema,
168+
) -> tuple[ParamSchema, ...]:
169+
"""Get the parameter schemas from an ONNX OpSchema."""
170+
schemas = []
171+
for input_ in op_schema.inputs:
172+
param_schema = ParamSchema(
173+
name=input_.name,
174+
is_input=True,
175+
required=(input_.option != onnx.defs.OpSchema.FormalParameterOption.Optional),
176+
is_variadic_input=(
177+
input_.option == onnx.defs.OpSchema.FormalParameterOption.Variadic
178+
),
179+
)
180+
schemas.append(param_schema)
181+
for attr_name, attribute in op_schema.attributes.items():
182+
default_attr_proto = attribute.default_value
183+
param_schema = ParamSchema(
184+
name=attr_name,
185+
type=_ATTRIBUTE_TYPE_TO_PYTHON_TYPE[attribute.type],
186+
default=_get_attribute_value(default_attr_proto),
187+
is_input=False,
188+
required=attribute.required,
189+
)
190+
schemas.append(param_schema)
191+
192+
return tuple(schemas)
193+
194+
195+
def param_schemas_from_function_ir(
196+
function_ir: irbuilder.IRFunction,
197+
) -> tuple[ParamSchema, ...]:
198+
"""Get the parameter schemas from a FunctionIR."""
199+
# The first len(func_ir.inputs) arguments are onnx inputs
200+
# The rest is onnx attributes
201+
202+
schemas = []
203+
for arg in function_ir.inputs:
204+
if isinstance(arg.typeinfo, onnx.TypeProto.Optional):
205+
required = False
206+
else:
207+
required = True
208+
schemas.append(
209+
ParamSchema(name=arg.name, type=arg.typeinfo, is_input=True, required=required)
210+
)
211+
212+
for attr_parameter in function_ir.attrs:
213+
schemas.append(
214+
ParamSchema(
215+
name=attr_parameter.name,
216+
type=_ATTRIBUTE_TYPE_TO_PYTHON_TYPE.get(
217+
onnx.defs.OpSchema.AttrType(attr_parameter.type) # type: ignore[call-arg]
218+
),
219+
default=_EmptyDefault
220+
if attr_parameter.default_value is None
221+
else attr_parameter.default_value,
222+
is_input=False,
223+
required=not attr_parameter.has_default,
224+
)
225+
)
226+
227+
return tuple(schemas)
228+
229+
230+
@typing.runtime_checkable
231+
class OpLike(Protocol):
232+
"""A protocol for objects that have an ONNX OpSchema."""
233+
234+
@property
235+
def name(self) -> str:
236+
...
237+
238+
@property
239+
def opset(self) -> Opset:
240+
...
241+
242+
@property
243+
def opschema(self) -> onnx.defs.OpSchema:
244+
...
245+
246+
def param_schemas(self) -> tuple[ParamSchema, ...]:
247+
...
248+
249+
250+
class Op(OpLike):
167251
"""Represents an ONNX op instance (for example, the MatMul op from ONNX opset version 13).
168252
It belongs to a particular Opset and has a name.
169253
170254
Attributes:
171255
opset: The Opset that this op belongs to.
172-
opname: The name of the op.
256+
name: The name of the op.
173257
opschema: The ONNX OpSchema for the op.
174258
"""
175259

176260
def __init__(
177-
self, opset, opname: str, opschema: Optional[onnx.defs.OpSchema] = None
261+
self, opset: Opset, opname: str, opschema: Optional[onnx.defs.OpSchema] = None
178262
) -> None:
179-
self.opset = opset
180-
self.opname = opname
263+
self._opset = opset
264+
self._name = opname
181265
self._opschema = opschema
182266
self._param_schemas: Optional[tuple[ParamSchema, ...]] = None
183267

@@ -188,12 +272,17 @@ def __call__(self, *args, **kwargs):
188272
schema = self.get_schema()
189273
if schema is None:
190274
raise RuntimeError(
191-
f"Op '{self.opname}' does not have an OpSchema and cannot be evaluated."
275+
f"Op '{self.name}' does not have an OpSchema and cannot be evaluated."
192276
)
193277
return evaluator.default().eval(schema, args, kwargs)
194278

195-
def is_single_op(self) -> bool:
196-
return isinstance(self.opname, str)
279+
@property
280+
def name(self) -> str:
281+
return self._name
282+
283+
@property
284+
def opset(self) -> Opset:
285+
return self._opset
197286

198287
@property
199288
def opschema(self) -> Optional[onnx.defs.OpSchema]:
@@ -203,7 +292,7 @@ def get_schema(self) -> Optional[onnx.defs.OpSchema]:
203292
"""Returns the ONNX OpSchema for this op."""
204293
if self.opschema is not None:
205294
return self.opschema
206-
return self.opset[self.opname]
295+
return self.opset[self.name]
207296

208297
def has_schema(self) -> bool:
209298
"""Returns True if this op has an OpSchema."""
@@ -217,30 +306,9 @@ def param_schemas(self) -> Optional[tuple[ParamSchema, ...]]:
217306
op_schema = self.get_schema()
218307
if op_schema is None:
219308
return None
220-
schemas = []
221-
for input_ in op_schema.inputs:
222-
param_schema = ParamSchema(
223-
name=input_.name,
224-
is_input=True,
225-
required=(input_.option != onnx.defs.OpSchema.FormalParameterOption.Optional),
226-
is_variadic_input=(
227-
input_.option == onnx.defs.OpSchema.FormalParameterOption.Variadic
228-
),
229-
)
230-
schemas.append(param_schema)
231-
for attr_name, attribute in op_schema.attributes.items():
232-
default_attr_proto = attribute.default_value
233-
param_schema = ParamSchema(
234-
name=attr_name,
235-
type=_ATTRIBUTE_TYPE_TO_PYTHON_TYPE[attribute.type],
236-
default=_get_attribute_value(default_attr_proto),
237-
is_input=False,
238-
required=attribute.required,
239-
)
240-
schemas.append(param_schema)
241309

242-
self._param_schemas = tuple(schemas)
243-
return self._param_schemas # type: ignore[return-value]
310+
self._param_schemas = param_schemas_from_op_schema(op_schema)
311+
return self._param_schemas
244312

245313

246314
@dataclasses.dataclass(repr=False, eq=False)
@@ -355,13 +423,14 @@ def op_schema_from_function_ir(
355423
class OnnxFunction(Op):
356424
"""Represents an ONNX op for which a function-body has been defined in onnxscript.
357425
358-
Args:
359-
opset: opset the function belongs to
360-
pyfun: python function
361-
irfun: python code parsed by class
362-
:class:`onnxscript.converter.Converter`
363-
source: source code used to generate the function
364-
kwargs: additional properties used to construct a ModelProto
426+
Attributes:
427+
opset: Opset the function belongs to.
428+
name: Name of the function.
429+
function: Python function.
430+
function_ir: Python code parsed as an :class:`irbuilder.IRFunction`.
431+
source: Source code used to generate the function.
432+
kwargs: Additional properties used to construct a ModelProto.
433+
opschema: Generated ONNX OpSchema for this op.
365434
"""
366435

367436
def __init__(
@@ -372,6 +441,16 @@ def __init__(
372441
source: str,
373442
kwargs: dict[str, Any],
374443
):
444+
"""Constructs an OnnxFunction.
445+
446+
Args:
447+
opset: opset the function belongs to
448+
pyfun: python function
449+
irfun: python code parsed by class
450+
:class:`onnxscript.converter.Converter`
451+
source: source code used to generate the function
452+
kwargs: additional properties used to construct a ModelProto
453+
"""
375454
opset = opset or Opset(irfun.domain, 1)
376455
super().__init__(opset, irfun.name)
377456
self.function = pyfun
@@ -383,11 +462,6 @@ def __init__(
383462
# Set the signature of the class to function's
384463
self.__signature__ = inspect.signature(pyfun)
385464

386-
@property
387-
def name(self):
388-
"""Returns the function name."""
389-
return self.opname
390-
391465
@property
392466
def opschema(self) -> Optional[onnx.defs.OpSchema]:
393467
"""Construct an OpSchema from function_ir."""
@@ -433,38 +507,8 @@ def param_schemas(self) -> tuple[ParamSchema, ...]:
433507
# NOTE: We generate the parameter schemas from the function_ir instead
434508
# of relying on the auto generated OpSchema because we need to preserve the keyword
435509
# argument order from the Python function definition, which is lost in OpSchema.
436-
function_ir = self.function_ir
437-
# The first len(func_ir.inputs) arguments are onnx inputs
438-
inputs = function_ir.inputs
439-
# The rest is onnx attributes
440-
441-
schemas = []
442-
for arg in inputs:
443-
if isinstance(arg.typeinfo, onnx.TypeProto.Optional):
444-
required = False
445-
else:
446-
required = True
447-
schemas.append(
448-
ParamSchema(name=arg.name, type=arg.typeinfo, is_input=True, required=required)
449-
)
450-
451-
for attr_parameter in function_ir.attrs:
452-
schemas.append(
453-
ParamSchema(
454-
name=attr_parameter.name,
455-
type=_ATTRIBUTE_TYPE_TO_PYTHON_TYPE.get(
456-
onnx.defs.OpSchema.AttrType(attr_parameter.type) # type: ignore[call-arg]
457-
),
458-
default=_EmptyDefault
459-
if attr_parameter.default_value is None
460-
else attr_parameter.default_value,
461-
is_input=False,
462-
required=not attr_parameter.has_default,
463-
)
464-
)
465-
466-
self._param_schemas = tuple(schemas)
467-
return self._param_schemas # type: ignore[return-value]
510+
self._param_schemas = param_schemas_from_function_ir(self.function_ir)
511+
return self._param_schemas
468512

469513
def to_function_proto(self):
470514
"""Converts the function into :class:`onnx.FunctionProto`."""

0 commit comments

Comments
 (0)