diff --git a/onnxscript/converter.py b/onnxscript/converter.py index 57505b0751..880d53e1de 100644 --- a/onnxscript/converter.py +++ b/onnxscript/converter.py @@ -135,7 +135,8 @@ def __init__(self, name: Optional[Union[str, List[str]]], kind: ConverterExpress def is_const(self): return self.kind == ConverterExpressionKind.CONST - def __str__(self): + def __str__(self) -> str: + assert isinstance(self.name, str), "`name` is not a string. This is likely a bug." return self.name @@ -777,7 +778,7 @@ def translate_call_expr(self, node): else: args = [self.translate_opt_expr(x) for x in node.args] attrs = [self.translate_attr(x.arg, x.value) for x in node.keywords] - args = autocast.static_cast_inputs(self, callee.get_schema(), args) + args = autocast.static_cast_inputs(self, callee.op_schema, args) # In ONNX, there is no way to explicitly specify a None value for an attribute. # Instead, the attribute must be omitted from the attribute list. @@ -786,7 +787,7 @@ def translate_call_expr(self, node): return callee, args, attrs def _cast_like_binary_expression(self, op, left, right): - schema = op.get_schema() + schema = op.op_schema return autocast.static_cast_inputs(self, schema, (left, right)) def translate_bool_op_expr(self, node: ast.BoolOp) -> ConverterExpression: diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test.py b/onnxscript/tests/function_libs/torch_lib/ops_test.py index ee10037484..b68acb4682 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test.py @@ -115,7 +115,7 @@ def test_script_function_passes_checker(self, _, func_with_wrangler): ) def test_script_function_has_op_schema(self, _, func_with_wrangler): func, _ = _split_function_and_wrangler(func_with_wrangler) - schema = func.opschema + schema = func.op_schema self.assertIsNotNone(schema) self.assertEqual(schema.name, func.name) @@ -128,7 +128,7 @@ def test_script_function_has_op_schema(self, _, func_with_wrangler): ) def test_trace_only_function_has_op_schema(self, _, func_with_wrangler): func, _ = _split_function_and_wrangler(func_with_wrangler) - schema = func.opschema + schema = func.op_schema self.assertIsNotNone(schema) self.assertEqual(schema.name, func.name) diff --git a/onnxscript/values.py b/onnxscript/values.py index 7d6ac204a0..fa256703b3 100644 --- a/onnxscript/values.py +++ b/onnxscript/values.py @@ -243,7 +243,7 @@ def opset(self) -> Opset: ... @property - def opschema(self) -> Optional[onnx.defs.OpSchema]: + def op_schema(self) -> Optional[onnx.defs.OpSchema]: ... def param_schemas(self) -> Optional[tuple[ParamSchema, ...]]: @@ -258,31 +258,36 @@ class Op(OpLike): Attributes: opset: The Opset that this op belongs to. name: The name of the op. - opschema: The ONNX OpSchema for the op. + op_schema: The ONNX OpSchema for the op. """ def __init__( - self, opset: Opset, opname: str, opschema: Optional[onnx.defs.OpSchema] = None + self, opset: Opset, opname: str, op_schema: Optional[onnx.defs.OpSchema] = None ) -> None: self._opset = opset self._name = opname - self._opschema = opschema + self._op_schema = op_schema or opset[opname] self._param_schemas: Optional[tuple[ParamSchema, ...]] = None + if self._op_schema is None: + logging.debug( + "An OpSchema was not provided for Op '%s' and " + "there is not one found in opset '%s'.", + opname, + opset, + ) + def __call__(self, *args, **kwargs): # FIXME(after #225): Move import to the top of the file. from onnxscript import evaluator # pylint: disable=import-outside-toplevel - schema = self.get_schema() + schema = self.op_schema if schema is None: raise RuntimeError( f"Op '{self.name}' does not have an OpSchema and cannot be evaluated." ) return evaluator.default().eval(schema, args, kwargs) - def is_single_op(self) -> bool: - return isinstance(self.name, str) - @property def name(self) -> str: return self._name @@ -292,25 +297,19 @@ def opset(self) -> Opset: return self._opset @property - def opschema(self) -> Optional[onnx.defs.OpSchema]: - return self._opschema - - def get_schema(self) -> Optional[onnx.defs.OpSchema]: - """Returns the ONNX OpSchema for this op.""" - if self.opschema is not None: - return self.opschema - return self.opset[self.name] + def op_schema(self) -> Optional[onnx.defs.OpSchema]: + return self._op_schema def has_schema(self) -> bool: """Returns True if this op has an OpSchema.""" - return self.get_schema() is not None + return self.op_schema is not None def param_schemas(self) -> Optional[tuple[ParamSchema, ...]]: """Returns the parameter schemas for this op, if it has one.""" if self._param_schemas is not None: return self._param_schemas - op_schema = self.get_schema() + op_schema = self.op_schema if op_schema is None: return None @@ -437,7 +436,7 @@ class OnnxFunction(Op): function_ir: Python code parsed as an :class:`irbuilder.IRFunction`. source: Source code used to generate the function. kwargs: Additional properties used to construct a ModelProto. - opschema: Generated ONNX OpSchema for this op. + op_schema: Generated ONNX OpSchema for this op. """ def __init__( @@ -465,20 +464,20 @@ def __init__( self.source = source self.kwargs = kwargs self._param_schemas: Optional[tuple[ParamSchema, ...]] = None - self._opschema: Optional[onnx.defs.OpSchema] = None + self._op_schema: Optional[onnx.defs.OpSchema] = None @property - def opschema(self) -> Optional[onnx.defs.OpSchema]: + def op_schema(self) -> Optional[onnx.defs.OpSchema]: """Construct an OpSchema from function_ir.""" - if self._opschema is not None: - return self._opschema + if self._op_schema is not None: + return self._op_schema if not _ONNX_OP_SCHEMA_WRITABLE: return None - self._opschema = op_schema_from_function_ir(self.function_ir, self.opset) + self._op_schema = op_schema_from_function_ir(self.function_ir, self.opset) - return self._opschema + return self._op_schema def __getitem__(self, instance): """Returns a lambda to evaluate function using given evaluator instance. @@ -555,11 +554,6 @@ def __call__(self, *args, **kwargs): def __repr__(self): return f"{self.__class__.__name__}({self.func!r})" - @property - def name(self) -> str: - """Return the name of the op.""" - return self.func.__name__ - @property def function_ir(self) -> irbuilder.IRFunction: """Return the function_ir. @@ -580,19 +574,19 @@ def function_ir(self) -> irbuilder.IRFunction: return converter.translate_function_signature(func_ast) @property - def opschema(self) -> Optional[onnx.defs.OpSchema]: - """Return the opschema.""" + def op_schema(self) -> Optional[onnx.defs.OpSchema]: + """Return the OpSchema.""" - if self._opschema is not None: - return self._opschema + if self._op_schema is not None: + return self._op_schema if not _ONNX_OP_SCHEMA_WRITABLE: return None # FIXME(justinchuby): outputs are empty. Need to fix. - self._opschema = op_schema_from_function_ir(self.function_ir, self._opset) + self._op_schema = op_schema_from_function_ir(self.function_ir, self._opset) - return self._opschema + return self._op_schema def param_schemas(self) -> tuple[ParamSchema, ...]: """Returns the parameter schemas of this function."""