Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions onnxscript/_internal/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,13 +871,13 @@ def _translate_call_expr(
) -> tuple[values.Op, list[ir.Value | None], list[ir.Attr]]:
"""Translates a call-expression."""
callee = self._translate_callee_expr(node.func)
param_schemas = callee.param_schemas()
op_signature = callee.op_signature
# If the callee's schema is available, we use it to determine the inputs and attributes.
# Otherwise, we map named arguments to attributes and positional arguments to inputs.
if param_schemas:
if op_signature:
kwargs = {x.arg: x.value for x in node.keywords}
args, attrs = param_manipulation.separate_input_attributes_from_arguments(
param_schemas, node.args, kwargs, fill_defaults=False
op_signature, node.args, kwargs, fill_defaults=False
)
args = [self._translate_opt_expr(x) for x in args]
attrs = [
Expand Down
15 changes: 8 additions & 7 deletions onnxscript/_internal/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from onnxscript import onnx_opset, tensor
from onnxscript._internal import autocast, param_manipulation, utils, values
from onnxscript.ir import _schemas

UserModeValue: TypeAlias = Union[Optional[np.ndarray], Sequence["UserModeValue"]]

Expand Down Expand Up @@ -273,11 +274,11 @@ def eval_function(
args: The positional arguments to the function.
kwargs: The keyword arguments to the function.
"""
param_schemas = function.param_schemas()
op_signature = function.op_signature
# Split happens in the evaluator instead of the OnnxFunction __call__ method
# so that evaluators can control behaviors like whether to fill in default values for attributes.
tagged_args, tagged_kwargs = param_manipulation.tag_arguments_with_param_schemas(
param_schemas,
tagged_args, tagged_kwargs = param_manipulation.tag_arguments_with_signature(
op_signature,
args,
kwargs,
fill_defaults=False,
Expand All @@ -287,16 +288,16 @@ def eval_function(
adapted_args: list[ExtendedModeValue] = []
adapted_kwargs: dict[str, ExtendedModeValue] = {}
has_array = False
for arg, param_schema in tagged_args:
if param_schema.is_input:
for arg, param in tagged_args:
if isinstance(param, _schemas.Parameter):
Comment thread
titaiwangms marked this conversation as resolved.
adapted_arg, has_array_ = _adapt_to_eager_mode(arg)
has_array = has_array or has_array_
adapted_args.append(adapted_arg)
else:
adapted_args.append(arg)

for key, (arg, param_schema) in tagged_kwargs.items():
if param_schema.is_input:
for key, (arg, param) in tagged_kwargs.items():
if isinstance(param, _schemas.Parameter):
adapted_arg, has_array_ = _adapt_to_eager_mode(arg)
has_array = has_array or has_array_
adapted_kwargs[key] = adapted_arg
Expand Down
76 changes: 44 additions & 32 deletions onnxscript/_internal/param_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
from __future__ import annotations

import collections
from typing import Any, OrderedDict, Sequence
from typing import Any, OrderedDict

from onnxscript._internal import values
from onnxscript.ir import _schemas


def separate_input_attributes_from_arguments(
param_schemas: Sequence[values.ParamSchema],
op_signature: _schemas.OpSignature,
args,
kwargs,
fill_defaults: bool = True,
Expand All @@ -20,7 +20,7 @@ def separate_input_attributes_from_arguments(
"""Separate Python args and kwargs into ONNX inputs and attributes.

Args:
param_schemas: The parameter schemas of an Op or a OnnxFunction.
op_signature: The operator signature containing parameter information.
args: The Python positional arguments supplied by the caller.
kwargs: The Python keyword arguments supplied by the caller.
fill_defaults: Whether to fill the default values for attributes.
Expand All @@ -36,56 +36,61 @@ def separate_input_attributes_from_arguments(
TypeError: When allow_extra_kwargs is False and there are unknown kwargs.
TypeError: When a required input is not provided.
"""
# args, kwargs and param_schemas should be all in order
# args, kwargs and op_signature.params should be all in order
# user may not specify all inputs or attributes

all_param_names = {param.name for param in param_schemas}
all_param_names = {param.name for param in op_signature.params}
extra_kwargs = set(kwargs).difference(all_param_names)
if extra_kwargs and not allow_extra_kwargs:
raise TypeError(f"Unexpected keyword arguments '{extra_kwargs}'")

onnx_inputs = []
onnx_attributes = collections.OrderedDict()

for i, param in enumerate(param_schemas):
if param.is_variadic_input:
for i, param in enumerate(op_signature.params):
is_input = isinstance(param, _schemas.Parameter)
is_variadic = isinstance(param, _schemas.Parameter) and param.variadic

if is_variadic:
# Exhaust all remaining args
onnx_inputs.extend(args[i:])
args = []
continue
if i < len(args):
if param.is_input:
if is_input:
onnx_inputs.append(args[i])
else:
onnx_attributes[param.name] = args[i]
elif param.name in kwargs:
if param.is_input:
if is_input:
onnx_inputs.append(kwargs[param.name])
else:
onnx_attributes[param.name] = kwargs[param.name]
elif (
param.is_attribute and param.default is not values._EmptyDefault # pylint: disable=protected-access
):
elif isinstance(param, _schemas.AttributeParameter) and param.has_default():
# User did not provide the attribute
if fill_defaults:
onnx_attributes[param.name] = param.default
# Extract the value from the Attr object
onnx_attributes[param.name] = param.default.value
elif param.required:
raise TypeError(f"Required input '{param}' was not provided")

return onnx_inputs, onnx_attributes


def tag_arguments_with_param_schemas(
param_schemas: Sequence[values.ParamSchema],
def tag_arguments_with_signature(
op_signature: _schemas.OpSignature,
args,
kwargs,
fill_defaults: bool = True,
allow_extra_kwargs: bool = False,
) -> tuple[list[tuple[Any, values.ParamSchema]], dict[str, tuple[Any, values.ParamSchema]]]:
"""Tag Python args and kwargs with matching ONNX ParamSchema.
) -> tuple[
list[tuple[Any, _schemas.Parameter | _schemas.AttributeParameter]],
dict[str, tuple[Any, _schemas.Parameter | _schemas.AttributeParameter]],
]:
"""Tag Python args and kwargs with matching ONNX Parameter/AttributeParameter.

Args:
param_schemas: The parameter schemas of an Op or a OnnxFunction.
op_signature: The operator signature containing parameter information.
args: The Python positional arguments supplied by the caller.
kwargs: The Python keyword arguments supplied by the caller.
fill_defaults: Whether to fill the default values for attributes.
Expand All @@ -94,27 +99,29 @@ def tag_arguments_with_param_schemas(

Returns:
A tuple of two elements:
- A list of tuple of Python positional argument and ParamSchema.
- A list of tuple of Python positional argument and Parameter/AttributeParameter.
- An ordered dictionary of Python keyword argument names and tuple of argument
value and ParamSchema.
value and Parameter/AttributeParameter.

Raises:
TypeError: When allow_extra_kwargs is False and there are unknown kwargs.
TypeError: When a required input is not provided.
"""
# args, kwargs and param_schemas should be all in order
# args, kwargs and op_signature.params should be all in order
# user may not specify all inputs or attributes

all_param_names = {param.name for param in param_schemas}
all_param_names = {param.name for param in op_signature.params}
extra_kwargs = set(kwargs).difference(all_param_names)
if extra_kwargs and not allow_extra_kwargs:
raise TypeError(f"Unexpected keyword arguments '{extra_kwargs}'")

tagged_args: list[tuple[Any, values.ParamSchema]] = []
tagged_kwargs: dict[str, tuple[Any, values.ParamSchema]] = {}
tagged_args: list[tuple[Any, _schemas.Parameter | _schemas.AttributeParameter]] = []
tagged_kwargs: dict[str, tuple[Any, _schemas.Parameter | _schemas.AttributeParameter]] = {}

for i, param in enumerate(op_signature.params):
is_variadic = isinstance(param, _schemas.Parameter) and param.variadic

for i, param in enumerate(param_schemas):
if param.is_variadic_input:
if is_variadic:
# Exhaust all remaining args
tagged_args.extend((arg, param) for arg in args[i:])
args = []
Expand All @@ -123,25 +130,30 @@ def tag_arguments_with_param_schemas(
tagged_args.append((args[i], param))
elif param.name in kwargs:
tagged_kwargs[param.name] = (kwargs[param.name], param)
elif param.default is not values._EmptyDefault: # pylint: disable=protected-access
elif param.has_default():
# User did not provide the input/attribute
if fill_defaults:
tagged_kwargs[param.name] = (param.default, param)
default_value = param.default
# Extract value from Attr object if it's an AttributeParameter
if isinstance(param, _schemas.AttributeParameter):
default_value = param.default.value
tagged_kwargs[param.name] = (default_value, param)
elif param.required:
raise TypeError(f"Required input/attribute '{param}' was not provided")

return tagged_args, tagged_kwargs


def turn_to_kwargs_to_avoid_ordering(
param_schemas: Sequence[values.ParamSchema],
op_signature: _schemas.OpSignature,
inputs: list[Any],
attributes: dict[str, Any],
) -> dict[str, Any]:
"""Return the inputs and attributes to the order of the function signature."""
for idx, param in enumerate(param_schemas):
for idx, param in enumerate(op_signature.params):
if param.name not in attributes:
if param.is_variadic_input:
is_variadic = isinstance(param, _schemas.Parameter) and param.variadic
if is_variadic:
attributes[param.name] = inputs[idx:]
elif inputs:
attributes[param.name] = inputs.pop(0)
Expand Down
Loading
Loading