Skip to content

Use ir methods to replace onnx helper #2091

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
Apr 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 9 additions & 38 deletions onnxscript/_internal/autocast.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@

import numpy as np
import onnx
from onnx import helper, numpy_helper
from onnx.defs import OpSchema

from onnxscript import tensor
from onnxscript import ir, tensor

if TYPE_CHECKING:
from onnxscript import converter
Expand All @@ -24,42 +23,8 @@
# Utilities to convert a python value to TensorProto (for use by the script converter)


def _py_type_to_onnx_type(pytype: type):
if pytype is bool:
return onnx.TensorProto.BOOL
if pytype is int:
return onnx.TensorProto.INT64
if pytype is float:
return onnx.TensorProto.FLOAT
if pytype is str:
return onnx.TensorProto.STRING
raise ValueError(f"Tensor element of type {pytype} not supported")


def pyvalue_to_onnx_tensor(tensor_name: str, pyvalue):
if isinstance(pyvalue, np.ndarray):
return numpy_helper.from_array(pyvalue, tensor_name)
if isinstance(pyvalue, list):
if len(pyvalue) == 0:
raise ValueError("Cannot convert an empty list to tensor")
pytype = type(pyvalue[0])
if not all(isinstance(e, pytype) for e in pyvalue):
raise ValueError(
"Cannot convert an list with elements of different types to tensor"
)
return helper.make_tensor(
tensor_name,
_py_type_to_onnx_type(pytype),
[len(pyvalue)],
pyvalue,
)
onnx_type = _py_type_to_onnx_type(type(pyvalue))
if onnx_type is onnx.TensorProto.BOOL:
return helper.make_tensor(tensor_name, onnx_type, [], [int(pyvalue)])
if onnx_type is onnx.TensorProto.STRING:
return helper.make_tensor(tensor_name, onnx_type, [], vals=[pyvalue.encode("utf-8")])

return helper.make_tensor(tensor_name, onnx_type, [], [pyvalue])
return ir.serde.serialize_tensor(ir.tensor(pyvalue, name=tensor_name))


_REPEATED_ATTRIBUTE_TYPES = frozenset(
Expand Down Expand Up @@ -103,7 +68,13 @@ def pyvalue_to_onnx_attribute(
name=key, type=attr_type, t=pyvalue_to_onnx_tensor(name_generator(), value)
)
else:
return onnx.helper.make_attribute(key, value)
attr = ir.convenience.convert_attribute(
key,
value,
attr_type=ir.AttributeType(attr_type) if attr_type is not None else None,
)
assert isinstance(attr, ir.Attr)
return ir.serde.serialize_attribute(attr)


# Utilities to convert python values into onnxscript tensors.
Expand Down
21 changes: 10 additions & 11 deletions onnxscript/_internal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import numpy as np
import onnx
import onnx.helper

from onnxscript import tensor

Expand Down Expand Up @@ -65,26 +64,26 @@ def add(k, v):
def value_to_type_proto(val):
"""Return the ONNX type of a python-value."""
if isinstance(val, (np.ndarray, tensor.Tensor)):
elem_type = onnx.helper.np_dtype_to_tensor_dtype(val.dtype)
elem_type = onnx.helper.np_dtype_to_tensor_dtype(val.dtype) # noqa: TID251
shape = val.shape
return onnx.helper.make_tensor_type_proto(elem_type, shape)
return onnx.helper.make_tensor_type_proto(elem_type, shape) # noqa: TID251
if isinstance(val, int):
return onnx.helper.make_tensor_type_proto(onnx.TensorProto.INT32, [])
return onnx.helper.make_tensor_type_proto(onnx.TensorProto.INT32, []) # noqa: TID251
if isinstance(val, (float, np.float32)):
return onnx.helper.make_tensor_type_proto(onnx.TensorProto.FLOAT, [])
return onnx.helper.make_tensor_type_proto(onnx.TensorProto.FLOAT, []) # noqa: TID251
if isinstance(val, list):
if len(val) > 0:
return onnx.helper.make_sequence_type_proto(value_to_type_proto(val[0]))
return onnx.helper.make_sequence_type_proto(value_to_type_proto(val[0])) # noqa: TID251
# Edge-case. Cannot determine a suitable ONNX type for an empty list.
# Should be using a typed-value instead.
# Treated as a sequence of tensors of float-type.
return onnx.helper.make_sequence_type_proto(
onnx.helper.make_tensor_type_proto(onnx.TensorProto.FLOAT, None)
return onnx.helper.make_sequence_type_proto( # noqa: TID251
onnx.helper.make_tensor_type_proto(onnx.TensorProto.FLOAT, None) # noqa: TID251
)
if isinstance(val, numbers.Number):
nparray = np.array(val)
elem_type = onnx.helper.np_dtype_to_tensor_dtype(nparray.dtype)
return onnx.helper.make_tensor_type_proto(elem_type, [])
elem_type = onnx.helper.np_dtype_to_tensor_dtype(nparray.dtype) # noqa: TID251
return onnx.helper.make_tensor_type_proto(elem_type, []) # noqa: TID251
raise ValueError(f"Value of type {type(val)} is invalid as an ONNX input/output.")


Expand All @@ -93,7 +92,7 @@ def values_to_value_infos(name_values):
skipping any None values.
"""
return [
onnx.helper.make_value_info(name, value_to_type_proto(val))
onnx.helper.make_value_info(name, value_to_type_proto(val)) # noqa: TID251
for (name, val) in name_values
if val is not None
]
2 changes: 1 addition & 1 deletion onnxscript/_legacy_ir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def value_as_np_array(self) -> np.ndarray | None:
if isinstance(self.value, np.ndarray):
return self.value
if isinstance(self.value, onnx.TensorProto):
return onnx.numpy_helper.to_array(self.value)
return onnx.numpy_helper.to_array(self.value) # noqa: TID251
return None

def def_node(self) -> Node | None:
Expand Down
1 change: 1 addition & 0 deletions onnxscript/_legacy_ir/visitor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ruff: noqa: TID251
from __future__ import annotations

import dataclasses
Expand Down
2 changes: 1 addition & 1 deletion onnxscript/backend/onnx_backend.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

# ruff: noqa: TID251

import os
import textwrap
Expand Down
9 changes: 4 additions & 5 deletions onnxscript/backend/onnx_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import numpy
import onnx
from onnx import FunctionProto, GraphProto, ModelProto, TensorProto, ValueInfoProto
from onnx.helper import make_node

import onnxscript.onnx_types
import onnxscript.type_annotation
Expand Down Expand Up @@ -68,10 +67,10 @@ def _get_const_repr(const_node):
if tensor_proto.data_type in {TensorProto.FLOAT, TensorProto.INT64}:
rank = len(tensor_proto.dims)
if rank == 0:
array = onnx.numpy_helper.to_array(tensor_proto).reshape(1)
array = onnx.numpy_helper.to_array(tensor_proto).reshape(1) # noqa: TID251
return repr(array[0])
if rank == 1 and tensor_proto.dims[0] < 5:
return repr(list(onnx.numpy_helper.to_array(tensor_proto)))
return repr(list(onnx.numpy_helper.to_array(tensor_proto))) # noqa: TID251
return None


Expand Down Expand Up @@ -161,7 +160,7 @@ def _attribute_value(attr: onnx.AttributeProto):
if onnx.external_data_helper.uses_external_data(tensor_proto):
return tensor_proto
else:
return onnx.numpy_helper.to_array(tensor_proto)
return onnx.numpy_helper.to_array(tensor_proto) # noqa: TID251
# TODO:
# - onnx.AttributeProto.GRAPH
# - onnx.AttributeProto.SPARSE_TENSOR
Expand Down Expand Up @@ -348,7 +347,7 @@ def _translate_graph_body(self, graph, opsets, indent=0):
)
self.skipped_initializers[init_py_name] = init
continue
node = make_node(
node = onnx.helper.make_node( # noqa: TID251
"Constant",
[],
[self._translate_onnx_var(init.name)], # type: ignore[list-item]
Expand Down
12 changes: 6 additions & 6 deletions onnxscript/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import numpy as np
import onnx
import onnx.defs
import onnx.helper
import onnx.reference
from typing_extensions import TypeAlias

Expand Down Expand Up @@ -430,21 +429,22 @@ def make_tensor_name() -> str:
num_outputs = compute_num_outputs(schema, args, kwargs)
outputs = [f"output{i}" for i in range(num_outputs)]

node = onnx.helper.make_node(schema.name, inputs, outputs, domain=schema.domain)
node = onnx.helper.make_node(schema.name, inputs, outputs, domain=schema.domain) # noqa: TID251
node.attribute.extend(
make_attr(key, value) for key, value in kwargs.items() if value is not None
)
input_value_infos = utils.values_to_value_infos(zip(inputs, args))
implicit_value_infos = utils.values_to_value_infos(implicit_args.items())
output_value_infos = [
onnx.helper.make_value_info(name, onnx.TypeProto()) for name in outputs
onnx.helper.make_value_info(name, onnx.TypeProto()) # noqa: TID251
for name in outputs
]

graph = onnx.helper.make_graph(
graph = onnx.helper.make_graph( # noqa: TID251
[node], "node_graph", input_value_infos + implicit_value_infos, output_value_infos
)
opset_id = onnx.helper.make_opsetid(schema.domain, schema.since_version)
model = onnx.helper.make_model(
opset_id = onnx.helper.make_opsetid(schema.domain, schema.since_version) # noqa: TID251
model = onnx.helper.make_model( # noqa: TID251
graph,
opset_imports=[opset_id],
ir_version=irbuilder.select_ir_version(schema.since_version, domain=schema.domain),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ruff: noqa: TID251
"""Graph building functions for torchscript graph backend."""

from __future__ import annotations
Expand Down
19 changes: 5 additions & 14 deletions onnxscript/function_libs/torch_lib/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
import math
from typing import Optional, Sequence, Tuple, TypeVar, Union

import onnx

from onnxscript import BFLOAT16, BOOL, DOUBLE, FLOAT, FLOAT16, INT64, ir
from onnxscript.function_libs.torch_lib.ops import common as common_ops
from onnxscript.function_libs.torch_lib.registration import torch_op
Expand Down Expand Up @@ -1798,15 +1796,11 @@ def _aten__scaled_dot_product_flash_attention_fillin_empty_outputs(
op.Shape(query), op.Constant(value_ints=[0]), op.Constant(value_ints=[3])
)
logsumexp = op.Expand(0.0, query_first_three_dims)
# TODO: Eliminate `make_tensor` usage when ORT supports empty tensor.
empty_tensor_int = op.Cast(
op.ConstantOfShape(
op.Constant(value=onnx.helper.make_tensor("Empty_INTS", INT64.dtype, [0], []))
),
to=INT64.dtype,
empty_tensor_int = op.ConstantOfShape(
op.Constant(value=ir.tensor([], dtype=ir.DataType.INT64))
)
empty_tensor_float = op.ConstantOfShape(
op.Constant(value=onnx.helper.make_tensor("Empty_FLOATS", INT64.dtype, [0], []))
op.Constant(value=ir.tensor([], dtype=ir.DataType.FLOAT))
)
empty_int = op.Constant(value_int=0)

Expand Down Expand Up @@ -1881,11 +1875,8 @@ def _aten_scaled_dot_product_efficient_attention_fillin_empty_outputs(
logsum_exp = op.Expand(0.0, op.Concat(query_first_dims, num_heads, [0], axis=0))

# See Note [Seed and Offset]:
empty_tensor_int = op.Cast(
op.ConstantOfShape(
op.Constant(value=onnx.helper.make_tensor("Empty_INTS", INT64.dtype, [0], []))
),
to=INT64.dtype,
empty_tensor_int = op.ConstantOfShape(
op.Constant(value=ir.tensor([], dtype=ir.DataType.INT64))
)

return logsum_exp, empty_tensor_int
Expand Down
30 changes: 26 additions & 4 deletions onnxscript/ir/_convenience/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
_core.RefAttr,
_protocols.GraphProtocol,
Sequence[_protocols.GraphProtocol],
onnx.GraphProto,
_protocols.TypeProtocol,
Sequence[_protocols.TypeProtocol],
None,
Expand All @@ -60,10 +61,15 @@ def _infer_attribute_type(attr: SupportedAttrTypes) -> _enums.AttributeType:
if isinstance(attr, (_core.TensorBase, onnx.TensorProto, _protocols.TensorProtocol)):
# Be sure to check TensorProtocol last because isinstance checking on Protocols can be slower
return _enums.AttributeType.TENSOR
if isinstance(attr, (_core.Graph, _protocols.GraphProtocol)):
if isinstance(attr, Sequence) and all(
isinstance(x, (_core.TensorBase, onnx.TensorProto, _protocols.TensorProtocol))
for x in attr
):
return _enums.AttributeType.TENSORS
if isinstance(attr, (_core.Graph, onnx.GraphProto, _protocols.GraphProtocol)):
return _enums.AttributeType.GRAPH
if isinstance(attr, Sequence) and all(
isinstance(x, (_core.Graph, _protocols.GraphProtocol)) for x in attr
isinstance(x, (_core.Graph, onnx.GraphProto, _protocols.GraphProtocol)) for x in attr
):
return _enums.AttributeType.GRAPHS
if isinstance(
Expand Down Expand Up @@ -145,11 +151,27 @@ def convert_attribute(
if isinstance(attr, (_core.TensorBase, _protocols.TensorProtocol)):
return _core.AttrTensor(name, attr)
if isinstance(attr, onnx.TensorProto):
return _core.AttrTensor(name, serde.TensorProtoTensor(attr))
return _core.AttrTensor(name, serde.deserialize_tensor(attr))
if attr_type == _enums.AttributeType.TENSORS:
tensors = []
for t in attr: # type: ignore[union-attr]
if isinstance(t, onnx.TensorProto):
tensors.append(_core.AttrTensor(name, serde.deserialize_tensor(t)))
else:
tensors.append(t) # type: ignore[arg-type]
return _core.AttrTensors(name, tensors) # type: ignore[arg-type]
if attr_type == _enums.AttributeType.GRAPH:
if isinstance(attr, onnx.GraphProto):
attr = serde.deserialize_graph(attr)
return _core.AttrGraph(name, attr) # type: ignore[arg-type]
if attr_type == _enums.AttributeType.GRAPHS:
return _core.AttrGraphs(name, attr) # type: ignore[arg-type]
graphs = []
for graph in attr: # type: ignore[union-attr]
if isinstance(graph, onnx.GraphProto):
graphs.append(serde.deserialize_graph(graph))
else:
graphs.append(graph) # type: ignore[arg-type]
return _core.AttrGraphs(name, graphs) # type: ignore[arg-type]
if attr_type == _enums.AttributeType.TYPE_PROTO:
return _core.AttrTypeProto(name, attr) # type: ignore[arg-type]
if attr_type == _enums.AttributeType.TYPE_PROTOS:
Expand Down
26 changes: 26 additions & 0 deletions onnxscript/ir/_convenience/_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,35 @@ def tensor(
# Plain Python object
if dtype is not None:
numpy_dtype = dtype.numpy()
elif isinstance(value, int) and not isinstance(value, bool):
# Specify int64 for ints because on Windows this may be int32
numpy_dtype = np.dtype(np.int64)
elif isinstance(value, float):
# If the value is a single float, we use np.float32 as the default dtype
numpy_dtype = np.dtype(np.float32)
elif isinstance(value, Sequence) and all(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check that sequence is not empty and ask for a dtype if it is

(isinstance(elem, int) and not isinstance(value, bool)) for elem in value
):
numpy_dtype = np.dtype(np.int64)
elif isinstance(value, Sequence) and all(isinstance(elem, float) for elem in value):
# If the value is a sequence of floats, we use np.float32 as the default dtype
numpy_dtype = np.dtype(np.float32)
else:
numpy_dtype = None
array = np.array(value, dtype=numpy_dtype)

# Handle string tensors by encoding them
if isinstance(value, str) or (
isinstance(value, Sequence) and value and all(isinstance(elem, str) for elem in value)
):
array = np.strings.encode(array, encoding="utf-8")
return _core.StringTensor(
array,
shape=_core.Shape(array.shape),
name=name,
doc_string=doc_string,
)

return _core.Tensor(
array,
dtype=dtype,
Expand Down
1 change: 1 addition & 0 deletions onnxscript/irbuilder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ruff: noqa: TID251
from __future__ import annotations

import dataclasses
Expand Down
Loading
Loading