Skip to content

Use ir methods to replace onnx helper #2091

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
Apr 30, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 7 additions & 25 deletions onnxscript/_internal/autocast.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import numpy as np
import onnx
from onnx import helper, numpy_helper
from onnxscript import ir
from onnx.defs import OpSchema

from onnxscript import tensor
Expand Down Expand Up @@ -37,29 +37,7 @@


def pyvalue_to_onnx_tensor(tensor_name: str, pyvalue):
if isinstance(pyvalue, np.ndarray):
return numpy_helper.from_array(pyvalue, tensor_name)
if isinstance(pyvalue, list):
if len(pyvalue) == 0:
raise ValueError("Cannot convert an empty list to tensor")
pytype = type(pyvalue[0])
if not all(isinstance(e, pytype) for e in pyvalue):
raise ValueError(
"Cannot convert an list with elements of different types to tensor"
)
return helper.make_tensor(
tensor_name,
_py_type_to_onnx_type(pytype),
[len(pyvalue)],
pyvalue,
)
onnx_type = _py_type_to_onnx_type(type(pyvalue))
if onnx_type is onnx.TensorProto.BOOL:
return helper.make_tensor(tensor_name, onnx_type, [], [int(pyvalue)])
if onnx_type is onnx.TensorProto.STRING:
return helper.make_tensor(tensor_name, onnx_type, [], vals=[pyvalue.encode("utf-8")])

return helper.make_tensor(tensor_name, onnx_type, [], [pyvalue])
return ir.serde.serialize_tensor(ir.tensor(pyvalue, name=tensor_name))


_REPEATED_ATTRIBUTE_TYPES = frozenset(
Expand Down Expand Up @@ -103,7 +81,11 @@
name=key, type=attr_type, t=pyvalue_to_onnx_tensor(name_generator(), value)
)
else:
return onnx.helper.make_attribute(key, value)
attr = ir.convenience.convert_attribute(
key, value, attr_type=ir.AttributeType(attr_type)
)
assert isinstance(attr, ir.Attr)
return ir.serde.serialize_attribute(attr)


# Utilities to convert python values into onnxscript tensors.
Expand Down
2 changes: 1 addition & 1 deletion onnxscript/_legacy_ir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def value_as_np_array(self) -> np.ndarray | None:
if isinstance(self.value, np.ndarray):
return self.value
if isinstance(self.value, onnx.TensorProto):
return onnx.numpy_helper.to_array(self.value)
return onnx.numpy_helper.to_array(self.value) # noqa: TID251
return None

def def_node(self) -> Node | None:
Expand Down
13 changes: 7 additions & 6 deletions onnxscript/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import numpy as np
import onnx
import onnx.defs
import onnx.helper
import onnx.helper # noqa: TID251
import onnx.reference
from typing_extensions import TypeAlias

Expand Down Expand Up @@ -430,21 +430,22 @@ def make_tensor_name() -> str:
num_outputs = compute_num_outputs(schema, args, kwargs)
outputs = [f"output{i}" for i in range(num_outputs)]

node = onnx.helper.make_node(schema.name, inputs, outputs, domain=schema.domain)
node = onnx.helper.make_node(schema.name, inputs, outputs, domain=schema.domain) # noqa: TID251
node.attribute.extend(
make_attr(key, value) for key, value in kwargs.items() if value is not None
)
input_value_infos = utils.values_to_value_infos(zip(inputs, args))
implicit_value_infos = utils.values_to_value_infos(implicit_args.items())
output_value_infos = [
onnx.helper.make_value_info(name, onnx.TypeProto()) for name in outputs
onnx.helper.make_value_info(name, onnx.TypeProto()) # noqa: TID251
for name in outputs
]

graph = onnx.helper.make_graph(
graph = onnx.helper.make_graph( # noqa: TID251
[node], "node_graph", input_value_infos + implicit_value_infos, output_value_infos
)
opset_id = onnx.helper.make_opsetid(schema.domain, schema.since_version)
model = onnx.helper.make_model(
opset_id = onnx.helper.make_opsetid(schema.domain, schema.since_version) # noqa: TID251
model = onnx.helper.make_model( # noqa: TID251
graph,
opset_imports=[opset_id],
ir_version=irbuilder.select_ir_version(schema.since_version, domain=schema.domain),
Expand Down
17 changes: 5 additions & 12 deletions onnxscript/function_libs/torch_lib/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1800,15 +1800,11 @@ def _aten__scaled_dot_product_flash_attention_fillin_empty_outputs(
op.Shape(query), op.Constant(value_ints=[0]), op.Constant(value_ints=[3])
)
logsumexp = op.Expand(0.0, query_first_three_dims)
# TODO: Eliminate `make_tensor` usage when ORT supports empty tensor.
empty_tensor_int = op.Cast(
op.ConstantOfShape(
op.Constant(value=onnx.helper.make_tensor("Empty_INTS", INT64.dtype, [0], []))
),
to=INT64.dtype,
empty_tensor_int = op.ConstantOfShape(
op.Constant(value=ir.tensor([], dtype=ir.DataType.INT64))
)
empty_tensor_float = op.ConstantOfShape(
op.Constant(value=onnx.helper.make_tensor("Empty_FLOATS", INT64.dtype, [0], []))
op.Constant(value=ir.tensor([], dtype=ir.DataType.FLOAT))
)
empty_int = op.Constant(value_int=0)

Expand Down Expand Up @@ -1883,11 +1879,8 @@ def _aten_scaled_dot_product_efficient_attention_fillin_empty_outputs(
logsum_exp = op.Expand(0.0, op.Concat(query_first_dims, num_heads, [0], axis=0))

# See Note [Seed and Offset]:
empty_tensor_int = op.Cast(
op.ConstantOfShape(
op.Constant(value=onnx.helper.make_tensor("Empty_INTS", INT64.dtype, [0], []))
),
to=INT64.dtype,
empty_tensor_int = op.ConstantOfShape(
op.Constant(value=ir.tensor([], dtype=ir.DataType.INT64))
)

return logsum_exp, empty_tensor_int
Expand Down
3 changes: 1 addition & 2 deletions onnxscript/onnx_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from typing import ClassVar, Optional, Tuple, Union

import onnx
import onnx.helper

import onnxscript.ir

Expand Down Expand Up @@ -99,7 +98,7 @@ def to_type_proto(cls) -> onnx.TypeProto:
shape = cls.shape # example: "FLOAT[10,20]"
else:
shape = [cls.shape] # example: "FLOAT[10]"
return onnx.helper.make_tensor_type_proto(cls.dtype, shape)
return onnx.helper.make_tensor_type_proto(cls.dtype, shape) # noqa: TID251

@classmethod
def to_string(cls) -> str:
Expand Down
2 changes: 1 addition & 1 deletion onnxscript/optimizer/_legacy/constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def foldable_value(self, name: str, value):
)
return None

return onnx.numpy_helper.from_array(value, name)
return onnx.numpy_helper.from_array(value, name) # noqa: TID251

def new_constant(self, name, value):
if isinstance(value, (int, float, np.ScalarType)):
Expand Down
6 changes: 2 additions & 4 deletions onnxscript/rewriter/cast_constant_of_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@

import logging

import onnx.helper

from onnxscript import ir
from onnxscript.rewriter import pattern

Expand All @@ -20,7 +18,7 @@ def cast_constant_of_shape(op, shape, scalar, dtype):
def fused_cast_constant_of_shape(op, shape: ir.Value, scalar: ir.Attr, dtype: ir.Attr, **_):
# Cast scalar (a TensorProto attribute) to the specified dtype
scalar_value = scalar.value.numpy().item()
cast_value = onnx.helper.make_tensor("value", dtype.value, (1,), [scalar_value])
cast_value = ir.tensor([scalar_value], dtype=ir.DataType(dtype.as_int()))
return op.ConstantOfShape(shape, value=cast_value)


Expand All @@ -30,7 +28,7 @@ def cast_constant_of_shape_without_value(op, shape, dtype):


def fused_cast_constant_of_shape_without_value(op, shape, dtype, **_):
zero = onnx.helper.make_tensor("value", dtype.value, (1,), [0])
zero = ir.tensor([0], dtype=ir.DataType(dtype.as_int()))
return op.ConstantOfShape(shape, value=zero)


Expand Down
2 changes: 0 additions & 2 deletions onnxscript/rewriter/llama_rule_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@

from typing import ClassVar

import onnx.numpy_helper

from onnxscript import ir
from onnxscript.rewriter import _ir_utils as ir_utils
from onnxscript.rewriter import pattern as orp
Expand Down
4 changes: 2 additions & 2 deletions onnxscript/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Any, Optional

import numpy as np
import onnx.helper
import onnx
from onnx import TensorProto

from onnxscript import onnx_opset
Expand Down Expand Up @@ -52,7 +52,7 @@ def dtype(self) -> np.dtype:

@property
def onnx_dtype(self) -> int:
return onnx.helper.np_dtype_to_tensor_dtype(self.dtype)
return onnx.helper.np_dtype_to_tensor_dtype(self.dtype) # noqa: TID251

def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.value!r})"
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,8 @@ ignore-init-module-imports = true

[tool.ruff.lint.flake8-tidy-imports.banned-api]
"pathlib".msg = "Using pathlib can impact performance. Use os.path instead"
"onnx.helper".msg = "onnx helpers tend to be protobuf-y and slow. Consider using ir.tensor, ir.DataType and related methods instead"
"onnx.numpy_helper".msg = "onnx numpy helpers tend to be slow. Consider using ir.tensor, ir.DataType and related methods instead"

[tool.ruff.lint.per-file-ignores]
"__init__.py" = ["TID252"] # Allow relative imports in init files
Expand Down
Loading