Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions onnxscript/_internal/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,13 +871,13 @@ def _translate_call_expr(
) -> tuple[values.Op, list[ir.Value | None], list[ir.Attr]]:
"""Translates a call-expression."""
callee = self._translate_callee_expr(node.func)
param_schemas = callee.param_schemas()
op_signature = callee.op_signature
# If the callee's schema is available, we use it to determine the inputs and attributes.
# Otherwise, we map named arguments to attributes and positional arguments to inputs.
if param_schemas:
if op_signature:
kwargs = {x.arg: x.value for x in node.keywords}
args, attrs = param_manipulation.separate_input_attributes_from_arguments(
param_schemas, node.args, kwargs, fill_defaults=False
op_signature, node.args, kwargs, fill_defaults=False
)
args = [self._translate_opt_expr(x) for x in args]
attrs = [
Expand Down
15 changes: 8 additions & 7 deletions onnxscript/_internal/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from onnxscript import onnx_opset, tensor
from onnxscript._internal import autocast, param_manipulation, utils, values
from onnxscript.ir import _schemas

UserModeValue: TypeAlias = Union[Optional[np.ndarray], Sequence["UserModeValue"]]

Expand Down Expand Up @@ -273,11 +274,11 @@ def eval_function(
args: The positional arguments to the function.
kwargs: The keyword arguments to the function.
"""
param_schemas = function.param_schemas()
op_signature = function.op_signature
# Split happens in the evaluator instead of the OnnxFunction __call__ method
# so that evaluators can control behaviors like whether to fill in default values for attributes.
tagged_args, tagged_kwargs = param_manipulation.tag_arguments_with_param_schemas(
param_schemas,
tagged_args, tagged_kwargs = param_manipulation.tag_arguments_with_signature(
op_signature,
args,
kwargs,
fill_defaults=False,
Expand All @@ -287,16 +288,16 @@ def eval_function(
adapted_args: list[ExtendedModeValue] = []
adapted_kwargs: dict[str, ExtendedModeValue] = {}
has_array = False
for arg, param_schema in tagged_args:
if param_schema.is_input:
for arg, param in tagged_args:
if isinstance(param, _schemas.Parameter):
Comment thread
titaiwangms marked this conversation as resolved.
adapted_arg, has_array_ = _adapt_to_eager_mode(arg)
has_array = has_array or has_array_
adapted_args.append(adapted_arg)
else:
adapted_args.append(arg)

for key, (arg, param_schema) in tagged_kwargs.items():
if param_schema.is_input:
for key, (arg, param) in tagged_kwargs.items():
if isinstance(param, _schemas.Parameter):
adapted_arg, has_array_ = _adapt_to_eager_mode(arg)
has_array = has_array or has_array_
adapted_kwargs[key] = adapted_arg
Expand Down
76 changes: 44 additions & 32 deletions onnxscript/_internal/param_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
from __future__ import annotations

import collections
from typing import Any, OrderedDict, Sequence
from typing import Any, OrderedDict

from onnxscript._internal import values
from onnxscript.ir import _schemas


def separate_input_attributes_from_arguments(
param_schemas: Sequence[values.ParamSchema],
op_signature: _schemas.OpSignature,
args,
kwargs,
fill_defaults: bool = True,
Expand All @@ -20,7 +20,7 @@ def separate_input_attributes_from_arguments(
"""Separate Python args and kwargs into ONNX inputs and attributes.

Args:
param_schemas: The parameter schemas of an Op or a OnnxFunction.
op_signature: The operator signature containing parameter information.
args: The Python positional arguments supplied by the caller.
kwargs: The Python keyword arguments supplied by the caller.
fill_defaults: Whether to fill the default values for attributes.
Expand All @@ -36,56 +36,61 @@ def separate_input_attributes_from_arguments(
TypeError: When allow_extra_kwargs is False and there are unknown kwargs.
TypeError: When a required input is not provided.
"""
# args, kwargs and param_schemas should be all in order
# args, kwargs and op_signature.params should be all in order
# user may not specify all inputs or attributes

all_param_names = {param.name for param in param_schemas}
all_param_names = {param.name for param in op_signature.params}
extra_kwargs = set(kwargs).difference(all_param_names)
if extra_kwargs and not allow_extra_kwargs:
raise TypeError(f"Unexpected keyword arguments '{extra_kwargs}'")

onnx_inputs = []
onnx_attributes = collections.OrderedDict()

for i, param in enumerate(param_schemas):
if param.is_variadic_input:
for i, param in enumerate(op_signature.params):
is_input = isinstance(param, _schemas.Parameter)
is_variadic = isinstance(param, _schemas.Parameter) and param.variadic

if is_variadic:
# Exhaust all remaining args
onnx_inputs.extend(args[i:])
args = []
continue
if i < len(args):
if param.is_input:
if is_input:
onnx_inputs.append(args[i])
else:
onnx_attributes[param.name] = args[i]
elif param.name in kwargs:
if param.is_input:
if is_input:
onnx_inputs.append(kwargs[param.name])
else:
onnx_attributes[param.name] = kwargs[param.name]
elif (
param.is_attribute and param.default is not values._EmptyDefault # pylint: disable=protected-access
):
elif isinstance(param, _schemas.AttributeParameter) and param.has_default():
# User did not provide the attribute
if fill_defaults:
onnx_attributes[param.name] = param.default
# Extract the value from the Attr object
onnx_attributes[param.name] = param.default.value if param.default else None
Comment thread
justinchuby marked this conversation as resolved.
Outdated
elif param.required:
raise TypeError(f"Required input '{param}' was not provided")

return onnx_inputs, onnx_attributes


def tag_arguments_with_param_schemas(
param_schemas: Sequence[values.ParamSchema],
def tag_arguments_with_signature(
op_signature: _schemas.OpSignature,
args,
kwargs,
fill_defaults: bool = True,
allow_extra_kwargs: bool = False,
) -> tuple[list[tuple[Any, values.ParamSchema]], dict[str, tuple[Any, values.ParamSchema]]]:
"""Tag Python args and kwargs with matching ONNX ParamSchema.
) -> tuple[
list[tuple[Any, _schemas.Parameter | _schemas.AttributeParameter]],
dict[str, tuple[Any, _schemas.Parameter | _schemas.AttributeParameter]],
]:
"""Tag Python args and kwargs with matching ONNX Parameter/AttributeParameter.

Args:
param_schemas: The parameter schemas of an Op or a OnnxFunction.
op_signature: The operator signature containing parameter information.
args: The Python positional arguments supplied by the caller.
kwargs: The Python keyword arguments supplied by the caller.
fill_defaults: Whether to fill the default values for attributes.
Expand All @@ -94,27 +99,29 @@ def tag_arguments_with_param_schemas(

Returns:
A tuple of two elements:
- A list of tuple of Python positional argument and ParamSchema.
- A list of tuple of Python positional argument and Parameter/AttributeParameter.
- An ordered dictionary of Python keyword argument names and tuple of argument
value and ParamSchema.
value and Parameter/AttributeParameter.

Raises:
TypeError: When allow_extra_kwargs is False and there are unknown kwargs.
TypeError: When a required input is not provided.
"""
# args, kwargs and param_schemas should be all in order
# args, kwargs and op_signature.params should be all in order
# user may not specify all inputs or attributes

all_param_names = {param.name for param in param_schemas}
all_param_names = {param.name for param in op_signature.params}
extra_kwargs = set(kwargs).difference(all_param_names)
if extra_kwargs and not allow_extra_kwargs:
raise TypeError(f"Unexpected keyword arguments '{extra_kwargs}'")

tagged_args: list[tuple[Any, values.ParamSchema]] = []
tagged_kwargs: dict[str, tuple[Any, values.ParamSchema]] = {}
tagged_args: list[tuple[Any, _schemas.Parameter | _schemas.AttributeParameter]] = []
tagged_kwargs: dict[str, tuple[Any, _schemas.Parameter | _schemas.AttributeParameter]] = {}

for i, param in enumerate(op_signature.params):
is_variadic = isinstance(param, _schemas.Parameter) and param.variadic

for i, param in enumerate(param_schemas):
if param.is_variadic_input:
if is_variadic:
# Exhaust all remaining args
tagged_args.extend((arg, param) for arg in args[i:])
args = []
Expand All @@ -123,25 +130,30 @@ def tag_arguments_with_param_schemas(
tagged_args.append((args[i], param))
elif param.name in kwargs:
tagged_kwargs[param.name] = (kwargs[param.name], param)
elif param.default is not values._EmptyDefault: # pylint: disable=protected-access
elif param.has_default():
# User did not provide the input/attribute
if fill_defaults:
tagged_kwargs[param.name] = (param.default, param)
default_value = param.default
# Extract value from Attr object if it's an AttributeParameter
if isinstance(param, _schemas.AttributeParameter) and param.default:
Comment thread
justinchuby marked this conversation as resolved.
Outdated
default_value = param.default.value
tagged_kwargs[param.name] = (default_value, param)
elif param.required:
raise TypeError(f"Required input/attribute '{param}' was not provided")

return tagged_args, tagged_kwargs


def turn_to_kwargs_to_avoid_ordering(
param_schemas: Sequence[values.ParamSchema],
op_signature: _schemas.OpSignature,
inputs: list[Any],
attributes: dict[str, Any],
) -> dict[str, Any]:
"""Return the inputs and attributes to the order of the function signature."""
for idx, param in enumerate(param_schemas):
for idx, param in enumerate(op_signature.params):
if param.name not in attributes:
if param.is_variadic_input:
is_variadic = isinstance(param, _schemas.Parameter) and param.variadic
if is_variadic:
attributes[param.name] = inputs[idx:]
elif inputs:
attributes[param.name] = inputs.pop(0)
Expand Down
Loading
Loading