Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions onnxscript/_internal/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,54 @@ def initializer(
self._graph.register_initializer(value)
return value

def input(
self,
name: str,
dtype: ir.DataType | None = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking about this too. I think if we could accommodate something like FLOAT['N', 1024] as a way of compactly specifying type and shape, it would help. Like it is done here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was hesitant to pass in generic-looking type objects around. The behavior of generic type classes tend to be a bit unstable across different python versions (where things are stored, how data can be accessed, when something is evaluated, etc.). So I am not preferring it for now.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is unstable? Maybe we can fix it. Or support something like (dtype, ('N', 1024))

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In summary: the suggestions are

  • a combined TypeAndShape would be useful and more compact (in some settings, not necessarily all).
  • we could support the legacy onnxscript notation by adding a method to convert it to an ir.TypeAndShape(), and allowing objects that support a method called "toTypeAndShape" ... I currently have a to_ir method, but that name is too generic to be used safely for this purpose, but a name like "toTypeAndShape" should be reasonable.

shape: ir.Shape | Sequence[int | str | None] | None = None,
*,
type: ir.TypeProtocol | None = None,
const_value: ir.TensorProtocol | None = None,
metadata_props: dict[str, str] | None = None,
) -> ir.Value:
"""Create an input to the graph and return the corresponding ir.Value.

Args:
name: The name of the value.
dtype: The data type of the TensorType of the value. This is used only when type is None.
shape: The shape of the value.
type: The type of the value. Only one of dtype and type can be specified.
const_value: The constant tensor that initializes the value. Supply this argument
when you want to create an initializer. The type and shape can be obtained from the tensor.
metadata_props: The metadata properties that will be serialized to the ONNX proto.

Returns:
A Value object.
"""
value = ir.val(
name=name,
dtype=dtype,
shape=shape,
type=type,
const_value=const_value,
metadata_props=metadata_props,
)
self._graph.inputs.append(value)
if const_value is not None:
self._graph.register_initializer(value)
return value

def add_output(self, value: ir.Value, name: str | None) -> None:
"""Add an output to the graph.

Args:
value: The ir.Value to add as an output.
name: The name to assign to the output value. If None, no renaming is done.
"""
if name:
value.name = name
self._graph.outputs.append(value)

def _input_to_ir_value(
self, value: VALUE_LIKE, like_type: ir.Value | None = None
) -> ir.Value:
Expand Down
96 changes: 96 additions & 0 deletions onnxscript/_internal/builder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,102 @@ def test_output_names_are_unique_for_same_op_type(self):
names = [t1.name, t2.name, t3.name]
self.assertEqual(len(set(names)), 3)

def test_input_creates_and_registers_graph_input(self):
"""Test that GraphBuilder.input creates and appends a graph input value."""
graph = ir.Graph(
name="test_model",
inputs=[],
outputs=[],
nodes=[],
opset_imports={"": _default_opset_version},
)
graph_builder = builder.GraphBuilder(graph)

value = graph_builder.input("data", dtype=ir.DataType.FLOAT, shape=[2, 3])

self.assertEqual(value.name, "data")
self.assertEqual(value.type.dtype, ir.DataType.FLOAT)
self.assertEqual(list(value.shape), [2, 3])
self.assertEqual(len(graph.inputs), 1)
self.assertIs(graph.inputs[0], value)

def test_input_with_const_value_registers_initializer(self):
"""Test that GraphBuilder.input registers initializer when const_value is provided."""
graph = ir.Graph(
name="test_model",
inputs=[],
outputs=[],
nodes=[],
opset_imports={"": _default_opset_version},
)
graph_builder = builder.GraphBuilder(graph)

const_tensor = ir.tensor([1.0, 2.0], dtype=ir.DataType.FLOAT, name="const_data")
value = graph_builder.input("const_input", const_value=const_tensor)

self.assertEqual(len(graph.inputs), 1)
self.assertIs(graph.inputs[0], value)
self.assertIn("const_input", graph.initializers)
self.assertIs(graph.initializers["const_input"], value)
self.assertIs(value.const_value, const_tensor)

def test_input_without_const_value_does_not_register_initializer(self):
"""Test that GraphBuilder.input does not register initializer without const_value."""
graph = ir.Graph(
name="test_model",
inputs=[],
outputs=[],
nodes=[],
opset_imports={"": _default_opset_version},
)
graph_builder = builder.GraphBuilder(graph)

value = graph_builder.input("regular_input", dtype=ir.DataType.FLOAT, shape=[2])

self.assertEqual(len(graph.inputs), 1)
self.assertIs(graph.inputs[0], value)
self.assertNotIn("regular_input", graph.initializers)

def test_add_output_renames_and_registers_output(self):
"""Test that GraphBuilder.add_output renames (optionally) and appends outputs."""
graph = ir.Graph(
name="test_model",
inputs=[],
outputs=[],
nodes=[],
opset_imports={"": _default_opset_version},
)
graph_builder = builder.GraphBuilder(graph)

output = ir.Value(name="old_name")
graph_builder.add_output(output, "new_name")

self.assertEqual(output.name, "new_name")
self.assertEqual(len(graph.outputs), 1)
self.assertIs(graph.outputs[0], output)
Comment on lines +600 to +616
Copy link

Copilot AI Feb 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test only verifies the case where a name is provided. Consider adding a test case where name=None is passed to ensure the method correctly handles the case where no renaming is needed, as documented in the method's docstring.

Copilot uses AI. Check for mistakes.

def test_initializer_qualification_behavior(self):
"""Test that GraphBuilder.initializer qualifies names unless explicitly disabled."""
graph = ir.Graph(
name="test_model",
inputs=[],
outputs=[],
nodes=[],
opset_imports={"": _default_opset_version},
)
graph_builder = builder.GraphBuilder(graph)

graph_builder.push_module("layer1")
qualified = graph_builder.initializer(ir.tensor([1.0], name="w"), name="weight")
unqualified = graph_builder.initializer(
ir.tensor([2.0], name="b"), name="bias", qualify=False
)

self.assertEqual(qualified.name, "layer1.weight")
self.assertEqual(unqualified.name, "bias")
self.assertIn("layer1.weight", graph.initializers)
self.assertIn("bias", graph.initializers)

def test_multi_output_names_are_unique(self):
"""Test that multi-output ops produce unique names with counter suffix."""
op, x, y = _create_builder_with_inputs()
Expand Down
Loading