-
Notifications
You must be signed in to change notification settings - Fork 103
Implement schema-based input/attribute partitioning in GraphBuilder #2837
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,16 +12,16 @@ | |
| import onnxscript._internal.builder as builder | ||
| import onnxscript.testing | ||
| from onnxscript import script | ||
| from onnxscript.onnx_types import DOUBLE, FLOAT | ||
| from onnxscript.onnx_types import DOUBLE, FLOAT, INT64 | ||
|
|
||
| _default_opset_version = 23 | ||
|
|
||
|
|
||
| def _build( | ||
| trace_function, | ||
| input_types: Sequence[ir.TypeAndShape], | ||
| output_types: Sequence[ir.TypeAndShape], | ||
| ) -> ir.Model: | ||
| input_types: Sequence[builder.TypeSpec], | ||
| trace_function=None, | ||
| output_types: Sequence[builder.TypeSpec] | None = None, | ||
| ) -> ir.Graph: | ||
| graph = ir.Graph( | ||
| name="test_model", | ||
| inputs=[], | ||
|
|
@@ -30,25 +30,29 @@ | |
| opset_imports={"": _default_opset_version}, | ||
| ) | ||
|
|
||
| onnx_model = ir.Model(graph=graph, ir_version=10) | ||
| resolved_inputs = [builder._resolve_type_spec(t) for t in input_types] | ||
Check warningCode scanning / lintrunner PYLINT/W0212 Warning
Access to a protected member _resolve_type_spec of a client class (protected-access)
See protected-access. To disable, use # pylint: disable=protected-access |
||
| for i, ts in enumerate(resolved_inputs): | ||
| graph.inputs.append(ir.Value(name=f"input_{i}", type=ts.type, shape=ts.shape)) | ||
|
|
||
| for i, input_type in enumerate(input_types): | ||
| input_name = f"input_{i}" | ||
| graph.inputs.append(ir.Value(name=input_name, type=input_type)) | ||
| if trace_function is not None: | ||
| graph_builder = builder.GraphBuilder(graph) | ||
| outputs = trace_function(graph_builder.op, *graph.inputs) | ||
| if not isinstance(outputs, Sequence): | ||
| outputs = [outputs] | ||
|
|
||
| graph_builder = builder.GraphBuilder(graph) | ||
| outputs = trace_function(graph_builder.op, *graph.inputs) | ||
| if not isinstance(outputs, Sequence): | ||
| outputs = [outputs] | ||
| if len(outputs) != len(output_types): | ||
| raise ValueError(f"Expected {len(output_types)} outputs, but got {len(outputs)}.") | ||
| for output, output_type in zip(outputs, output_types): | ||
| output.type = output_type.type # TODO: need merge_type method in ir.Value | ||
| output.merge_shapes(output_type.shape) | ||
| if output_types is not None: | ||
| resolved_outputs = [builder._resolve_type_spec(t) for t in output_types] | ||
Check warningCode scanning / lintrunner PYLINT/W0212 Warning
Access to a protected member _resolve_type_spec of a client class (protected-access)
See protected-access. To disable, use # pylint: disable=protected-access |
||
| if len(outputs) != len(resolved_outputs): | ||
| raise ValueError( | ||
| f"Expected {len(resolved_outputs)} outputs, but got {len(outputs)}." | ||
| ) | ||
| for output, ts in zip(outputs, resolved_outputs): | ||
| output.type = ts.type | ||
| output.merge_shapes(ts.shape) | ||
|
|
||
| graph.outputs.extend(outputs) | ||
| graph.outputs.extend(outputs) | ||
|
|
||
| return onnx_model | ||
| return graph | ||
|
|
||
|
|
||
| def _create_builder_with_inputs() -> tuple[builder.OpBuilder, ir.Value, ir.Value]: | ||
|
|
@@ -57,24 +61,7 @@ | |
| Returns: | ||
| A tuple of (op_builder, input_x, input_y). | ||
| """ | ||
| graph = ir.Graph( | ||
| name="test_model", | ||
| inputs=[], | ||
| outputs=[], | ||
| nodes=[], | ||
| opset_imports={"": 23}, | ||
| ) | ||
|
|
||
| for i in range(2): | ||
| input_name = f"input_{i}" | ||
| graph.inputs.append( | ||
| ir.Value( | ||
| name=input_name, | ||
| type=ir.TensorType(ir.DataType.FLOAT), | ||
| shape=ir.Shape([2, 3, 4]), | ||
| ) | ||
| ) | ||
|
|
||
| graph = _build(input_types=[FLOAT[2, 3, 4], FLOAT[2, 3, 4]]) | ||
| graph_builder = builder.GraphBuilder(graph) | ||
| x, y = graph.inputs | ||
| return graph_builder.op, x, y | ||
|
|
@@ -89,12 +76,11 @@ | |
| return z | ||
|
|
||
| float_2d = ir.TypeAndShape(ir.TensorType(ir.DataType.FLOAT), ir.Shape([3, 4])) | ||
| model = _build( | ||
| _add_mul_add, | ||
| graph = _build( | ||
| input_types=[float_2d, float_2d], | ||
| trace_function=_add_mul_add, | ||
| output_types=[float_2d], | ||
| ) | ||
| graph = model.graph | ||
| # Expect exactly 3 nodes: Add, Mul, Add | ||
| op_types = [node.op_type for node in graph] | ||
| self.assertEqual(op_types, ["Add", "Mul", "Add"]) | ||
|
|
@@ -121,12 +107,11 @@ | |
| return z | ||
|
|
||
| float_2d = ir.TypeAndShape(ir.TensorType(ir.DataType.FLOAT), ir.Shape([3, 4])) | ||
| model = _build( | ||
| _add_with_custom_names, | ||
| graph = _build( | ||
| input_types=[float_2d, float_2d], | ||
| trace_function=_add_with_custom_names, | ||
| output_types=[float_2d], | ||
| ) | ||
| graph = model.graph | ||
|
|
||
| # Verify that the nodes have outputs with the specified names | ||
| nodes = list(graph) | ||
|
|
@@ -207,12 +192,11 @@ | |
| return z | ||
|
|
||
| float_2d = ir.TypeAndShape(ir.TensorType(ir.DataType.FLOAT), ir.Shape([3, 4])) | ||
| model = _build( | ||
| _ops_with_default_names, | ||
| graph = _build( | ||
| input_types=[float_2d, float_2d], | ||
| trace_function=_ops_with_default_names, | ||
| output_types=[float_2d], | ||
| ) | ||
| graph = model.graph | ||
|
|
||
| # Verify the nodes use the new naming strategy | ||
| nodes = list(graph) | ||
|
|
@@ -964,5 +948,146 @@ | |
| ) | ||
|
|
||
|
|
||
| class PartitionInputsAttributesTest(unittest.TestCase): | ||
| """Tests for GraphBuilder._partition_inputs_attributes.""" | ||
|
|
||
| def test_unknown_op_passes_inputs_and_kwargs_through(self): | ||
| """An unknown op has no schema, so inputs and kwargs pass through unchanged.""" | ||
|
|
||
| def _dummy(op, x, y): | ||
| return op.DummyOp(x, y, alpha=1.0) | ||
|
|
||
| graph = _build( | ||
| input_types=[FLOAT[3, 4], FLOAT[3, 4]], | ||
| trace_function=_dummy, | ||
| ) | ||
| x, y = graph.inputs | ||
| node = list(graph)[0] | ||
Check warningCode scanning / lintrunner RUFF/RUF015 Warning
Prefer next(iter(graph)) over single element slice.
See https://docs.astral.sh/ruff/rules/unnecessary-iterable-allocation-for-first-element |
||
| self.assertEqual(node.op_type, "DummyOp") | ||
| self.assertEqual(list(node.inputs), [x, y]) | ||
| self.assertEqual(node.attributes["alpha"].as_float(), 1.0) | ||
|
|
||
| def test_op_with_only_inputs(self): | ||
| """Add has two inputs and no attributes.""" | ||
|
|
||
| def _add(op, x, y): | ||
| return op.Add(x, y) | ||
|
|
||
| graph = _build( | ||
| input_types=[FLOAT[3, 4], FLOAT[3, 4]], | ||
| trace_function=_add, | ||
| ) | ||
| x, y = graph.inputs | ||
| node = list(graph)[0] | ||
Check warningCode scanning / lintrunner RUFF/RUF015 Warning
Prefer next(iter(graph)) over single element slice.
See https://docs.astral.sh/ruff/rules/unnecessary-iterable-allocation-for-first-element |
||
| self.assertEqual(node.op_type, "Add") | ||
| self.assertEqual(list(node.inputs), [x, y]) | ||
| self.assertEqual(len(node.attributes), 0) | ||
|
|
||
| def test_op_with_inputs_and_attributes_in_kwargs(self): | ||
| """Gemm has 3 inputs (A, B, C) and attributes (alpha, beta, transA, transB).""" | ||
|
|
||
| def _gemm(op, a, b, c): | ||
| return op.Gemm(a, b, c, alpha=2.0, transB=1) | ||
|
|
||
| graph = _build( | ||
| input_types=[FLOAT[3, 4], FLOAT[4, 5], FLOAT[3, 5]], | ||
| trace_function=_gemm, | ||
| ) | ||
| a, b, c = graph.inputs | ||
| node = list(graph)[0] | ||
Check warningCode scanning / lintrunner RUFF/RUF015 Warning
Prefer next(iter(graph)) over single element slice.
See https://docs.astral.sh/ruff/rules/unnecessary-iterable-allocation-for-first-element |
||
| self.assertEqual(node.op_type, "Gemm") | ||
| self.assertEqual(list(node.inputs), [a, b, c]) | ||
| self.assertEqual(node.attributes["alpha"].as_float(), 2.0) | ||
| self.assertEqual(node.attributes["transB"].as_int(), 1) | ||
|
|
||
| def test_op_with_optional_input_omitted(self): | ||
| """Gemm's third input (C) is optional. Omitting it should work.""" | ||
|
|
||
| def _gemm_no_c(op, a, b): | ||
| return op.Gemm(a, b, alpha=2.0) | ||
|
|
||
| graph = _build( | ||
| input_types=[FLOAT[3, 4], FLOAT[4, 5]], | ||
| trace_function=_gemm_no_c, | ||
| ) | ||
| a, b = graph.inputs | ||
| node = list(graph)[0] | ||
Check warningCode scanning / lintrunner RUFF/RUF015 Warning
Prefer next(iter(graph)) over single element slice.
See https://docs.astral.sh/ruff/rules/unnecessary-iterable-allocation-for-first-element |
||
| self.assertEqual(node.op_type, "Gemm") | ||
| self.assertEqual(list(node.inputs), [a, b]) | ||
| self.assertEqual(node.attributes["alpha"].as_float(), 2.0) | ||
|
|
||
| def test_does_not_fill_attribute_defaults(self): | ||
| """Attribute defaults should not be filled in (fill_defaults=False).""" | ||
|
|
||
| def _gemm_no_attrs(op, a, b): | ||
| return op.Gemm(a, b) | ||
|
|
||
| graph = _build( | ||
| input_types=[FLOAT[3, 4], FLOAT[4, 5]], | ||
| trace_function=_gemm_no_attrs, | ||
| ) | ||
| node = list(graph)[0] | ||
Check warningCode scanning / lintrunner RUFF/RUF015 Warning
Prefer next(iter(graph)) over single element slice.
See https://docs.astral.sh/ruff/rules/unnecessary-iterable-allocation-for-first-element |
||
| # alpha, beta, transA, transB all have defaults but should NOT appear | ||
| self.assertFalse(node.attributes) | ||
|
|
||
| def test_variadic_inputs_with_attribute(self): | ||
| """Concat has variadic inputs and an axis attribute.""" | ||
|
|
||
| def _concat(op, x, y, z): | ||
| return op.Concat(x, y, z, axis=0) | ||
|
|
||
| graph = _build( | ||
| input_types=[FLOAT[3, 4], FLOAT[3, 4], FLOAT[3, 4]], | ||
| trace_function=_concat, | ||
| ) | ||
| x, y, z = graph.inputs | ||
| node = list(graph)[0] | ||
Check warningCode scanning / lintrunner RUFF/RUF015 Warning
Prefer next(iter(graph)) over single element slice.
See https://docs.astral.sh/ruff/rules/unnecessary-iterable-allocation-for-first-element |
||
| self.assertEqual(node.op_type, "Concat") | ||
| self.assertEqual(list(node.inputs), [x, y, z]) | ||
| self.assertEqual(node.attributes["axis"].as_int(), 0) | ||
|
|
||
| def test_slice_kwargs_are_correctly_ordered_as_inputs(self): | ||
| """Calling op.Slice with keyword arguments should place them in schema order.""" | ||
|
|
||
| def _slice(op, data, starts, ends, axes, steps): | ||
| # Pass optional inputs as kwargs in non-schema order | ||
| return op.Slice(data, ends=ends, steps=steps, starts=starts, axes=axes) | ||
|
|
||
| graph = _build( | ||
| input_types=[FLOAT[20, 10], INT64[2], INT64[2], INT64[2], INT64[2]], | ||
| trace_function=_slice, | ||
| ) | ||
| data, starts, ends, axes, steps = graph.inputs | ||
|
|
||
| slice_node = list(graph)[0] | ||
Check warningCode scanning / lintrunner RUFF/RUF015 Warning
Prefer next(iter(graph)) over single element slice.
See https://docs.astral.sh/ruff/rules/unnecessary-iterable-allocation-for-first-element |
||
| self.assertEqual(slice_node.op_type, "Slice") | ||
| # Schema order: data, starts, ends, axes, steps | ||
| self.assertEqual(list(slice_node.inputs), [data, starts, ends, axes, steps]) | ||
|
|
||
| def test_omitting_required_input_raises(self): | ||
| """Omitting a required input should raise TypeError.""" | ||
|
|
||
| def _add_missing_input(op, x): | ||
| return op.Add(x) | ||
|
|
||
| with self.assertRaises(TypeError): | ||
| _build( | ||
| input_types=[FLOAT[3, 4]], | ||
| trace_function=_add_missing_input, | ||
| ) | ||
|
|
||
| def test_extra_inputs_raises(self): | ||
| """Extra positional inputs beyond the schema should raise TypeError.""" | ||
|
|
||
| def _add_extra_input(op, x, y, z): | ||
| return op.Add(x, y, z) | ||
|
|
||
| with self.assertRaises(TypeError): | ||
| _build( | ||
| input_types=[FLOAT[3, 4], FLOAT[3, 4], FLOAT[3, 4]], | ||
| trace_function=_add_extra_input, | ||
| ) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The Op object has the signature cached. Would it be possible to use that?