Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions onnxscript/_internal/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import onnxscript._internal._inference as inference
import onnxscript.optimizer
from onnxscript._internal import _inliner
from onnxscript._internal import _inliner, param_manipulation

# A permissible value for an op input, which can be converted to an ir.Value.
VALUE_LIKE = Union[
Expand Down Expand Up @@ -255,9 +255,16 @@ def _partition_inputs_attributes(
inputs: Sequence[ir.Value | ir.TensorProtocol],
kwargs: dict[str, Any],
) -> tuple[Sequence[ir.Value | ir.TensorProtocol], dict[str, Any]]:
# Not implemented yet
del schema
return inputs, kwargs
if schema is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Op object has the signature cached. Would it be possible to use that?

return inputs, kwargs
op_signature = ir.schemas.OpSignature.from_op_schema(schema)
return param_manipulation.separate_input_attributes_from_arguments(
op_signature,
list(inputs),
kwargs,
fill_defaults=False,
allow_extra_args=False,
)

def _cast_inputs(
self,
Expand Down
219 changes: 172 additions & 47 deletions onnxscript/_internal/builder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,16 @@
import onnxscript._internal.builder as builder
import onnxscript.testing
from onnxscript import script
from onnxscript.onnx_types import DOUBLE, FLOAT
from onnxscript.onnx_types import DOUBLE, FLOAT, INT64

_default_opset_version = 23


def _build(
trace_function,
input_types: Sequence[ir.TypeAndShape],
output_types: Sequence[ir.TypeAndShape],
) -> ir.Model:
input_types: Sequence[builder.TypeSpec],
trace_function=None,
output_types: Sequence[builder.TypeSpec] | None = None,
) -> ir.Graph:
graph = ir.Graph(
name="test_model",
inputs=[],
Expand All @@ -30,25 +30,29 @@
opset_imports={"": _default_opset_version},
)

onnx_model = ir.Model(graph=graph, ir_version=10)
resolved_inputs = [builder._resolve_type_spec(t) for t in input_types]

Check warning

Code scanning / lintrunner

PYLINT/W0212 Warning

Access to a protected member _resolve_type_spec of a client class (protected-access)
See protected-access. To disable, use # pylint: disable=protected-access
for i, ts in enumerate(resolved_inputs):
graph.inputs.append(ir.Value(name=f"input_{i}", type=ts.type, shape=ts.shape))

for i, input_type in enumerate(input_types):
input_name = f"input_{i}"
graph.inputs.append(ir.Value(name=input_name, type=input_type))
if trace_function is not None:
graph_builder = builder.GraphBuilder(graph)
outputs = trace_function(graph_builder.op, *graph.inputs)
if not isinstance(outputs, Sequence):
outputs = [outputs]

graph_builder = builder.GraphBuilder(graph)
outputs = trace_function(graph_builder.op, *graph.inputs)
if not isinstance(outputs, Sequence):
outputs = [outputs]
if len(outputs) != len(output_types):
raise ValueError(f"Expected {len(output_types)} outputs, but got {len(outputs)}.")
for output, output_type in zip(outputs, output_types):
output.type = output_type.type # TODO: need merge_type method in ir.Value
output.merge_shapes(output_type.shape)
if output_types is not None:
resolved_outputs = [builder._resolve_type_spec(t) for t in output_types]

Check warning

Code scanning / lintrunner

PYLINT/W0212 Warning

Access to a protected member _resolve_type_spec of a client class (protected-access)
See protected-access. To disable, use # pylint: disable=protected-access
if len(outputs) != len(resolved_outputs):
raise ValueError(
f"Expected {len(resolved_outputs)} outputs, but got {len(outputs)}."
)
for output, ts in zip(outputs, resolved_outputs):
output.type = ts.type
output.merge_shapes(ts.shape)

graph.outputs.extend(outputs)
graph.outputs.extend(outputs)

return onnx_model
return graph


def _create_builder_with_inputs() -> tuple[builder.OpBuilder, ir.Value, ir.Value]:
Expand All @@ -57,24 +61,7 @@
Returns:
A tuple of (op_builder, input_x, input_y).
"""
graph = ir.Graph(
name="test_model",
inputs=[],
outputs=[],
nodes=[],
opset_imports={"": 23},
)

for i in range(2):
input_name = f"input_{i}"
graph.inputs.append(
ir.Value(
name=input_name,
type=ir.TensorType(ir.DataType.FLOAT),
shape=ir.Shape([2, 3, 4]),
)
)

graph = _build(input_types=[FLOAT[2, 3, 4], FLOAT[2, 3, 4]])
graph_builder = builder.GraphBuilder(graph)
x, y = graph.inputs
return graph_builder.op, x, y
Expand All @@ -89,12 +76,11 @@
return z

float_2d = ir.TypeAndShape(ir.TensorType(ir.DataType.FLOAT), ir.Shape([3, 4]))
model = _build(
_add_mul_add,
graph = _build(
input_types=[float_2d, float_2d],
trace_function=_add_mul_add,
output_types=[float_2d],
)
graph = model.graph
# Expect exactly 3 nodes: Add, Mul, Add
op_types = [node.op_type for node in graph]
self.assertEqual(op_types, ["Add", "Mul", "Add"])
Expand All @@ -121,12 +107,11 @@
return z

float_2d = ir.TypeAndShape(ir.TensorType(ir.DataType.FLOAT), ir.Shape([3, 4]))
model = _build(
_add_with_custom_names,
graph = _build(
input_types=[float_2d, float_2d],
trace_function=_add_with_custom_names,
output_types=[float_2d],
)
graph = model.graph

# Verify that the nodes have outputs with the specified names
nodes = list(graph)
Expand Down Expand Up @@ -207,12 +192,11 @@
return z

float_2d = ir.TypeAndShape(ir.TensorType(ir.DataType.FLOAT), ir.Shape([3, 4]))
model = _build(
_ops_with_default_names,
graph = _build(
input_types=[float_2d, float_2d],
trace_function=_ops_with_default_names,
output_types=[float_2d],
)
graph = model.graph

# Verify the nodes use the new naming strategy
nodes = list(graph)
Expand Down Expand Up @@ -964,5 +948,146 @@
)


class PartitionInputsAttributesTest(unittest.TestCase):
"""Tests for GraphBuilder._partition_inputs_attributes."""

def test_unknown_op_passes_inputs_and_kwargs_through(self):
"""An unknown op has no schema, so inputs and kwargs pass through unchanged."""

def _dummy(op, x, y):
return op.DummyOp(x, y, alpha=1.0)

graph = _build(
input_types=[FLOAT[3, 4], FLOAT[3, 4]],
trace_function=_dummy,
)
x, y = graph.inputs
node = list(graph)[0]

Check warning

Code scanning / lintrunner

RUFF/RUF015 Warning

Prefer next(iter(graph)) over single element slice.
See https://docs.astral.sh/ruff/rules/unnecessary-iterable-allocation-for-first-element
self.assertEqual(node.op_type, "DummyOp")
self.assertEqual(list(node.inputs), [x, y])
self.assertEqual(node.attributes["alpha"].as_float(), 1.0)

def test_op_with_only_inputs(self):
"""Add has two inputs and no attributes."""

def _add(op, x, y):
return op.Add(x, y)

graph = _build(
input_types=[FLOAT[3, 4], FLOAT[3, 4]],
trace_function=_add,
)
x, y = graph.inputs
node = list(graph)[0]

Check warning

Code scanning / lintrunner

RUFF/RUF015 Warning

Prefer next(iter(graph)) over single element slice.
See https://docs.astral.sh/ruff/rules/unnecessary-iterable-allocation-for-first-element
self.assertEqual(node.op_type, "Add")
self.assertEqual(list(node.inputs), [x, y])
self.assertEqual(len(node.attributes), 0)

def test_op_with_inputs_and_attributes_in_kwargs(self):
"""Gemm has 3 inputs (A, B, C) and attributes (alpha, beta, transA, transB)."""

def _gemm(op, a, b, c):
return op.Gemm(a, b, c, alpha=2.0, transB=1)

graph = _build(
input_types=[FLOAT[3, 4], FLOAT[4, 5], FLOAT[3, 5]],
trace_function=_gemm,
)
a, b, c = graph.inputs
node = list(graph)[0]

Check warning

Code scanning / lintrunner

RUFF/RUF015 Warning

Prefer next(iter(graph)) over single element slice.
See https://docs.astral.sh/ruff/rules/unnecessary-iterable-allocation-for-first-element
self.assertEqual(node.op_type, "Gemm")
self.assertEqual(list(node.inputs), [a, b, c])
self.assertEqual(node.attributes["alpha"].as_float(), 2.0)
self.assertEqual(node.attributes["transB"].as_int(), 1)

def test_op_with_optional_input_omitted(self):
"""Gemm's third input (C) is optional. Omitting it should work."""

def _gemm_no_c(op, a, b):
return op.Gemm(a, b, alpha=2.0)

graph = _build(
input_types=[FLOAT[3, 4], FLOAT[4, 5]],
trace_function=_gemm_no_c,
)
a, b = graph.inputs
node = list(graph)[0]

Check warning

Code scanning / lintrunner

RUFF/RUF015 Warning

Prefer next(iter(graph)) over single element slice.
See https://docs.astral.sh/ruff/rules/unnecessary-iterable-allocation-for-first-element
self.assertEqual(node.op_type, "Gemm")
self.assertEqual(list(node.inputs), [a, b])
self.assertEqual(node.attributes["alpha"].as_float(), 2.0)

def test_does_not_fill_attribute_defaults(self):
"""Attribute defaults should not be filled in (fill_defaults=False)."""

def _gemm_no_attrs(op, a, b):
return op.Gemm(a, b)

graph = _build(
input_types=[FLOAT[3, 4], FLOAT[4, 5]],
trace_function=_gemm_no_attrs,
)
node = list(graph)[0]

Check warning

Code scanning / lintrunner

RUFF/RUF015 Warning

Prefer next(iter(graph)) over single element slice.
See https://docs.astral.sh/ruff/rules/unnecessary-iterable-allocation-for-first-element
# alpha, beta, transA, transB all have defaults but should NOT appear
self.assertFalse(node.attributes)

def test_variadic_inputs_with_attribute(self):
"""Concat has variadic inputs and an axis attribute."""

def _concat(op, x, y, z):
return op.Concat(x, y, z, axis=0)

graph = _build(
input_types=[FLOAT[3, 4], FLOAT[3, 4], FLOAT[3, 4]],
trace_function=_concat,
)
x, y, z = graph.inputs
node = list(graph)[0]

Check warning

Code scanning / lintrunner

RUFF/RUF015 Warning

Prefer next(iter(graph)) over single element slice.
See https://docs.astral.sh/ruff/rules/unnecessary-iterable-allocation-for-first-element
self.assertEqual(node.op_type, "Concat")
self.assertEqual(list(node.inputs), [x, y, z])
self.assertEqual(node.attributes["axis"].as_int(), 0)

def test_slice_kwargs_are_correctly_ordered_as_inputs(self):
"""Calling op.Slice with keyword arguments should place them in schema order."""

def _slice(op, data, starts, ends, axes, steps):
# Pass optional inputs as kwargs in non-schema order
return op.Slice(data, ends=ends, steps=steps, starts=starts, axes=axes)

graph = _build(
input_types=[FLOAT[20, 10], INT64[2], INT64[2], INT64[2], INT64[2]],
trace_function=_slice,
)
data, starts, ends, axes, steps = graph.inputs

slice_node = list(graph)[0]

Check warning

Code scanning / lintrunner

RUFF/RUF015 Warning

Prefer next(iter(graph)) over single element slice.
See https://docs.astral.sh/ruff/rules/unnecessary-iterable-allocation-for-first-element
self.assertEqual(slice_node.op_type, "Slice")
# Schema order: data, starts, ends, axes, steps
self.assertEqual(list(slice_node.inputs), [data, starts, ends, axes, steps])

def test_omitting_required_input_raises(self):
"""Omitting a required input should raise TypeError."""

def _add_missing_input(op, x):
return op.Add(x)

with self.assertRaises(TypeError):
_build(
input_types=[FLOAT[3, 4]],
trace_function=_add_missing_input,
)

def test_extra_inputs_raises(self):
"""Extra positional inputs beyond the schema should raise TypeError."""

def _add_extra_input(op, x, y, z):
return op.Add(x, y, z)

with self.assertRaises(TypeError):
_build(
input_types=[FLOAT[3, 4], FLOAT[3, 4], FLOAT[3, 4]],
trace_function=_add_extra_input,
)


if __name__ == "__main__":
unittest.main()
13 changes: 13 additions & 0 deletions onnxscript/_internal/param_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def separate_input_attributes_from_arguments(
kwargs,
fill_defaults: bool = True,
allow_extra_kwargs: bool = False,
allow_extra_args: bool = True,
) -> tuple[list[Any], OrderedDict[str, Any]]:
"""Separate Python args and kwargs into ONNX inputs and attributes.

Expand All @@ -26,6 +27,9 @@ def separate_input_attributes_from_arguments(
fill_defaults: Whether to fill the default values for attributes.
allow_extra_kwargs: Whether to allow extra keyword arguments.
When set to True, extra/unknown arguments will be ignored.
allow_extra_args: Whether to allow extra positional arguments beyond
what the schema declares (when no variadic parameter exists).
When set to False, a TypeError is raised for extra args.

Returns:
A tuple of two elements:
Expand All @@ -34,6 +38,7 @@ def separate_input_attributes_from_arguments(

Raises:
TypeError: When allow_extra_kwargs is False and there are unknown kwargs.
TypeError: When allow_extra_args is False and there are extra positional args.
TypeError: When a required input is not provided.
"""
# args, kwargs and op_signature.params should be all in order
Expand All @@ -46,12 +51,14 @@ def separate_input_attributes_from_arguments(

onnx_inputs = []
onnx_attributes = collections.OrderedDict()
has_variadic = False

for i, param in enumerate(op_signature.params):
is_input = param.is_param()
is_variadic = is_input and param.variadic

if is_variadic:
has_variadic = True
# Exhaust all remaining args
onnx_inputs.extend(args[i:])
args = []
Expand All @@ -74,6 +81,12 @@ def separate_input_attributes_from_arguments(
elif param.required:
raise TypeError(f"Required input '{param}' was not provided")

if not allow_extra_args and not has_variadic and len(args) > len(op_signature.params):
raise TypeError(
f"Too many positional arguments: expected {len(op_signature.params)}, "
f"got {len(args)}"
)

return onnx_inputs, onnx_attributes


Expand Down
Loading