Skip to content

Commit 92d5d71

Browse files
authored
Retire get_schema in Op | chore!(api) (#698)
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * __->__ #698 Remove `get_schema` in Op because (1) if an op does not have schema, it will not be in its opset either. (2) If an op does have schema, we don't need to consult with the opset. This change ensures a single correct way of accessing OpSchema. Rename `opschema` to `op_schema` for naming consistency.
1 parent dd5a5cb commit 92d5d71

File tree

3 files changed

+36
-41
lines changed

3 files changed

+36
-41
lines changed

onnxscript/converter.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,8 @@ def __init__(self, name: Optional[Union[str, List[str]]], kind: ConverterExpress
135135
def is_const(self):
136136
return self.kind == ConverterExpressionKind.CONST
137137

138-
def __str__(self):
138+
def __str__(self) -> str:
139+
assert isinstance(self.name, str), "`name` is not a string. This is likely a bug."
139140
return self.name
140141

141142

@@ -777,7 +778,7 @@ def translate_call_expr(self, node):
777778
else:
778779
args = [self.translate_opt_expr(x) for x in node.args]
779780
attrs = [self.translate_attr(x.arg, x.value) for x in node.keywords]
780-
args = autocast.static_cast_inputs(self, callee.get_schema(), args)
781+
args = autocast.static_cast_inputs(self, callee.op_schema, args)
781782

782783
# In ONNX, there is no way to explicitly specify a None value for an attribute.
783784
# Instead, the attribute must be omitted from the attribute list.
@@ -786,7 +787,7 @@ def translate_call_expr(self, node):
786787
return callee, args, attrs
787788

788789
def _cast_like_binary_expression(self, op, left, right):
789-
schema = op.get_schema()
790+
schema = op.op_schema
790791
return autocast.static_cast_inputs(self, schema, (left, right))
791792

792793
def translate_bool_op_expr(self, node: ast.BoolOp) -> ConverterExpression:

onnxscript/tests/function_libs/torch_lib/ops_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def test_script_function_passes_checker(self, _, func_with_wrangler):
115115
)
116116
def test_script_function_has_op_schema(self, _, func_with_wrangler):
117117
func, _ = _split_function_and_wrangler(func_with_wrangler)
118-
schema = func.opschema
118+
schema = func.op_schema
119119
self.assertIsNotNone(schema)
120120
self.assertEqual(schema.name, func.name)
121121

@@ -128,7 +128,7 @@ def test_script_function_has_op_schema(self, _, func_with_wrangler):
128128
)
129129
def test_trace_only_function_has_op_schema(self, _, func_with_wrangler):
130130
func, _ = _split_function_and_wrangler(func_with_wrangler)
131-
schema = func.opschema
131+
schema = func.op_schema
132132
self.assertIsNotNone(schema)
133133
self.assertEqual(schema.name, func.name)
134134

onnxscript/values.py

Lines changed: 30 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ def opset(self) -> Opset:
243243
...
244244

245245
@property
246-
def opschema(self) -> Optional[onnx.defs.OpSchema]:
246+
def op_schema(self) -> Optional[onnx.defs.OpSchema]:
247247
...
248248

249249
def param_schemas(self) -> Optional[tuple[ParamSchema, ...]]:
@@ -258,31 +258,36 @@ class Op(OpLike):
258258
Attributes:
259259
opset: The Opset that this op belongs to.
260260
name: The name of the op.
261-
opschema: The ONNX OpSchema for the op.
261+
op_schema: The ONNX OpSchema for the op.
262262
"""
263263

264264
def __init__(
265-
self, opset: Opset, opname: str, opschema: Optional[onnx.defs.OpSchema] = None
265+
self, opset: Opset, opname: str, op_schema: Optional[onnx.defs.OpSchema] = None
266266
) -> None:
267267
self._opset = opset
268268
self._name = opname
269-
self._opschema = opschema
269+
self._op_schema = op_schema or opset[opname]
270270
self._param_schemas: Optional[tuple[ParamSchema, ...]] = None
271271

272+
if self._op_schema is None:
273+
logging.debug(
274+
"An OpSchema was not provided for Op '%s' and "
275+
"there is not one found in opset '%s'.",
276+
opname,
277+
opset,
278+
)
279+
272280
def __call__(self, *args, **kwargs):
273281
# FIXME(after #225): Move import to the top of the file.
274282
from onnxscript import evaluator # pylint: disable=import-outside-toplevel
275283

276-
schema = self.get_schema()
284+
schema = self.op_schema
277285
if schema is None:
278286
raise RuntimeError(
279287
f"Op '{self.name}' does not have an OpSchema and cannot be evaluated."
280288
)
281289
return evaluator.default().eval(schema, args, kwargs)
282290

283-
def is_single_op(self) -> bool:
284-
return isinstance(self.name, str)
285-
286291
@property
287292
def name(self) -> str:
288293
return self._name
@@ -292,25 +297,19 @@ def opset(self) -> Opset:
292297
return self._opset
293298

294299
@property
295-
def opschema(self) -> Optional[onnx.defs.OpSchema]:
296-
return self._opschema
297-
298-
def get_schema(self) -> Optional[onnx.defs.OpSchema]:
299-
"""Returns the ONNX OpSchema for this op."""
300-
if self.opschema is not None:
301-
return self.opschema
302-
return self.opset[self.name]
300+
def op_schema(self) -> Optional[onnx.defs.OpSchema]:
301+
return self._op_schema
303302

304303
def has_schema(self) -> bool:
305304
"""Returns True if this op has an OpSchema."""
306-
return self.get_schema() is not None
305+
return self.op_schema is not None
307306

308307
def param_schemas(self) -> Optional[tuple[ParamSchema, ...]]:
309308
"""Returns the parameter schemas for this op, if it has one."""
310309
if self._param_schemas is not None:
311310
return self._param_schemas
312311

313-
op_schema = self.get_schema()
312+
op_schema = self.op_schema
314313
if op_schema is None:
315314
return None
316315

@@ -437,7 +436,7 @@ class OnnxFunction(Op):
437436
function_ir: Python code parsed as an :class:`irbuilder.IRFunction`.
438437
source: Source code used to generate the function.
439438
kwargs: Additional properties used to construct a ModelProto.
440-
opschema: Generated ONNX OpSchema for this op.
439+
op_schema: Generated ONNX OpSchema for this op.
441440
"""
442441

443442
def __init__(
@@ -465,20 +464,20 @@ def __init__(
465464
self.source = source
466465
self.kwargs = kwargs
467466
self._param_schemas: Optional[tuple[ParamSchema, ...]] = None
468-
self._opschema: Optional[onnx.defs.OpSchema] = None
467+
self._op_schema: Optional[onnx.defs.OpSchema] = None
469468

470469
@property
471-
def opschema(self) -> Optional[onnx.defs.OpSchema]:
470+
def op_schema(self) -> Optional[onnx.defs.OpSchema]:
472471
"""Construct an OpSchema from function_ir."""
473-
if self._opschema is not None:
474-
return self._opschema
472+
if self._op_schema is not None:
473+
return self._op_schema
475474

476475
if not _ONNX_OP_SCHEMA_WRITABLE:
477476
return None
478477

479-
self._opschema = op_schema_from_function_ir(self.function_ir, self.opset)
478+
self._op_schema = op_schema_from_function_ir(self.function_ir, self.opset)
480479

481-
return self._opschema
480+
return self._op_schema
482481

483482
def __getitem__(self, instance):
484483
"""Returns a lambda to evaluate function using given evaluator instance.
@@ -555,11 +554,6 @@ def __call__(self, *args, **kwargs):
555554
def __repr__(self):
556555
return f"{self.__class__.__name__}({self.func!r})"
557556

558-
@property
559-
def name(self) -> str:
560-
"""Return the name of the op."""
561-
return self.func.__name__
562-
563557
@property
564558
def function_ir(self) -> irbuilder.IRFunction:
565559
"""Return the function_ir.
@@ -580,19 +574,19 @@ def function_ir(self) -> irbuilder.IRFunction:
580574
return converter.translate_function_signature(func_ast)
581575

582576
@property
583-
def opschema(self) -> Optional[onnx.defs.OpSchema]:
584-
"""Return the opschema."""
577+
def op_schema(self) -> Optional[onnx.defs.OpSchema]:
578+
"""Return the OpSchema."""
585579

586-
if self._opschema is not None:
587-
return self._opschema
580+
if self._op_schema is not None:
581+
return self._op_schema
588582

589583
if not _ONNX_OP_SCHEMA_WRITABLE:
590584
return None
591585

592586
# FIXME(justinchuby): outputs are empty. Need to fix.
593-
self._opschema = op_schema_from_function_ir(self.function_ir, self._opset)
587+
self._op_schema = op_schema_from_function_ir(self.function_ir, self._opset)
594588

595-
return self._opschema
589+
return self._op_schema
596590

597591
def param_schemas(self) -> tuple[ParamSchema, ...]:
598592
"""Returns the parameter schemas of this function."""

0 commit comments

Comments
 (0)