diff --git a/onnxscript/_internal/autocast.py b/onnxscript/_internal/autocast.py index 00fab2432d..836bafff97 100644 --- a/onnxscript/_internal/autocast.py +++ b/onnxscript/_internal/autocast.py @@ -7,10 +7,9 @@ import numpy as np import onnx -from onnx import helper, numpy_helper from onnx.defs import OpSchema -from onnxscript import tensor +from onnxscript import ir, tensor if TYPE_CHECKING: from onnxscript import converter @@ -24,42 +23,8 @@ # Utilities to convert a python value to TensorProto (for use by the script converter) -def _py_type_to_onnx_type(pytype: type): - if pytype is bool: - return onnx.TensorProto.BOOL - if pytype is int: - return onnx.TensorProto.INT64 - if pytype is float: - return onnx.TensorProto.FLOAT - if pytype is str: - return onnx.TensorProto.STRING - raise ValueError(f"Tensor element of type {pytype} not supported") - - def pyvalue_to_onnx_tensor(tensor_name: str, pyvalue): - if isinstance(pyvalue, np.ndarray): - return numpy_helper.from_array(pyvalue, tensor_name) - if isinstance(pyvalue, list): - if len(pyvalue) == 0: - raise ValueError("Cannot convert an empty list to tensor") - pytype = type(pyvalue[0]) - if not all(isinstance(e, pytype) for e in pyvalue): - raise ValueError( - "Cannot convert an list with elements of different types to tensor" - ) - return helper.make_tensor( - tensor_name, - _py_type_to_onnx_type(pytype), - [len(pyvalue)], - pyvalue, - ) - onnx_type = _py_type_to_onnx_type(type(pyvalue)) - if onnx_type is onnx.TensorProto.BOOL: - return helper.make_tensor(tensor_name, onnx_type, [], [int(pyvalue)]) - if onnx_type is onnx.TensorProto.STRING: - return helper.make_tensor(tensor_name, onnx_type, [], vals=[pyvalue.encode("utf-8")]) - - return helper.make_tensor(tensor_name, onnx_type, [], [pyvalue]) + return ir.serde.serialize_tensor(ir.tensor(pyvalue, name=tensor_name)) _REPEATED_ATTRIBUTE_TYPES = frozenset( @@ -103,7 +68,13 @@ def pyvalue_to_onnx_attribute( name=key, type=attr_type, t=pyvalue_to_onnx_tensor(name_generator(), value) ) else: - return onnx.helper.make_attribute(key, value) + attr = ir.convenience.convert_attribute( + key, + value, + attr_type=ir.AttributeType(attr_type) if attr_type is not None else None, + ) + assert isinstance(attr, ir.Attr) + return ir.serde.serialize_attribute(attr) # Utilities to convert python values into onnxscript tensors. diff --git a/onnxscript/_internal/utils.py b/onnxscript/_internal/utils.py index e081bb34a2..ce2b657cfd 100644 --- a/onnxscript/_internal/utils.py +++ b/onnxscript/_internal/utils.py @@ -7,7 +7,6 @@ import numpy as np import onnx -import onnx.helper from onnxscript import tensor @@ -65,26 +64,26 @@ def add(k, v): def value_to_type_proto(val): """Return the ONNX type of a python-value.""" if isinstance(val, (np.ndarray, tensor.Tensor)): - elem_type = onnx.helper.np_dtype_to_tensor_dtype(val.dtype) + elem_type = onnx.helper.np_dtype_to_tensor_dtype(val.dtype) # noqa: TID251 shape = val.shape - return onnx.helper.make_tensor_type_proto(elem_type, shape) + return onnx.helper.make_tensor_type_proto(elem_type, shape) # noqa: TID251 if isinstance(val, int): - return onnx.helper.make_tensor_type_proto(onnx.TensorProto.INT32, []) + return onnx.helper.make_tensor_type_proto(onnx.TensorProto.INT32, []) # noqa: TID251 if isinstance(val, (float, np.float32)): - return onnx.helper.make_tensor_type_proto(onnx.TensorProto.FLOAT, []) + return onnx.helper.make_tensor_type_proto(onnx.TensorProto.FLOAT, []) # noqa: TID251 if isinstance(val, list): if len(val) > 0: - return onnx.helper.make_sequence_type_proto(value_to_type_proto(val[0])) + return onnx.helper.make_sequence_type_proto(value_to_type_proto(val[0])) # noqa: TID251 # Edge-case. Cannot determine a suitable ONNX type for an empty list. # Should be using a typed-value instead. # Treated as a sequence of tensors of float-type. - return onnx.helper.make_sequence_type_proto( - onnx.helper.make_tensor_type_proto(onnx.TensorProto.FLOAT, None) + return onnx.helper.make_sequence_type_proto( # noqa: TID251 + onnx.helper.make_tensor_type_proto(onnx.TensorProto.FLOAT, None) # noqa: TID251 ) if isinstance(val, numbers.Number): nparray = np.array(val) - elem_type = onnx.helper.np_dtype_to_tensor_dtype(nparray.dtype) - return onnx.helper.make_tensor_type_proto(elem_type, []) + elem_type = onnx.helper.np_dtype_to_tensor_dtype(nparray.dtype) # noqa: TID251 + return onnx.helper.make_tensor_type_proto(elem_type, []) # noqa: TID251 raise ValueError(f"Value of type {type(val)} is invalid as an ONNX input/output.") @@ -93,7 +92,7 @@ def values_to_value_infos(name_values): skipping any None values. """ return [ - onnx.helper.make_value_info(name, value_to_type_proto(val)) + onnx.helper.make_value_info(name, value_to_type_proto(val)) # noqa: TID251 for (name, val) in name_values if val is not None ] diff --git a/onnxscript/_legacy_ir/__init__.py b/onnxscript/_legacy_ir/__init__.py index 6c4e0c07ec..29bba54586 100644 --- a/onnxscript/_legacy_ir/__init__.py +++ b/onnxscript/_legacy_ir/__init__.py @@ -142,7 +142,7 @@ def value_as_np_array(self) -> np.ndarray | None: if isinstance(self.value, np.ndarray): return self.value if isinstance(self.value, onnx.TensorProto): - return onnx.numpy_helper.to_array(self.value) + return onnx.numpy_helper.to_array(self.value) # noqa: TID251 return None def def_node(self) -> Node | None: diff --git a/onnxscript/_legacy_ir/visitor.py b/onnxscript/_legacy_ir/visitor.py index 8dcc3893ab..6adfeab6d3 100644 --- a/onnxscript/_legacy_ir/visitor.py +++ b/onnxscript/_legacy_ir/visitor.py @@ -1,5 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +# ruff: noqa: TID251 from __future__ import annotations import dataclasses diff --git a/onnxscript/backend/onnx_backend.py b/onnxscript/backend/onnx_backend.py index 78089ebe6a..ef93bb50b7 100644 --- a/onnxscript/backend/onnx_backend.py +++ b/onnxscript/backend/onnx_backend.py @@ -1,6 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. - +# ruff: noqa: TID251 import os import textwrap diff --git a/onnxscript/backend/onnx_export.py b/onnxscript/backend/onnx_export.py index b3f695d700..04c4639ea8 100644 --- a/onnxscript/backend/onnx_export.py +++ b/onnxscript/backend/onnx_export.py @@ -7,7 +7,6 @@ import numpy import onnx from onnx import FunctionProto, GraphProto, ModelProto, TensorProto, ValueInfoProto -from onnx.helper import make_node import onnxscript.onnx_types import onnxscript.type_annotation @@ -68,10 +67,10 @@ def _get_const_repr(const_node): if tensor_proto.data_type in {TensorProto.FLOAT, TensorProto.INT64}: rank = len(tensor_proto.dims) if rank == 0: - array = onnx.numpy_helper.to_array(tensor_proto).reshape(1) + array = onnx.numpy_helper.to_array(tensor_proto).reshape(1) # noqa: TID251 return repr(array[0]) if rank == 1 and tensor_proto.dims[0] < 5: - return repr(list(onnx.numpy_helper.to_array(tensor_proto))) + return repr(list(onnx.numpy_helper.to_array(tensor_proto))) # noqa: TID251 return None @@ -161,7 +160,7 @@ def _attribute_value(attr: onnx.AttributeProto): if onnx.external_data_helper.uses_external_data(tensor_proto): return tensor_proto else: - return onnx.numpy_helper.to_array(tensor_proto) + return onnx.numpy_helper.to_array(tensor_proto) # noqa: TID251 # TODO: # - onnx.AttributeProto.GRAPH # - onnx.AttributeProto.SPARSE_TENSOR @@ -348,7 +347,7 @@ def _translate_graph_body(self, graph, opsets, indent=0): ) self.skipped_initializers[init_py_name] = init continue - node = make_node( + node = onnx.helper.make_node( # noqa: TID251 "Constant", [], [self._translate_onnx_var(init.name)], # type: ignore[list-item] diff --git a/onnxscript/evaluator.py b/onnxscript/evaluator.py index 97551567bb..38784ca7f8 100644 --- a/onnxscript/evaluator.py +++ b/onnxscript/evaluator.py @@ -20,7 +20,6 @@ import numpy as np import onnx import onnx.defs -import onnx.helper import onnx.reference from typing_extensions import TypeAlias @@ -430,21 +429,22 @@ def make_tensor_name() -> str: num_outputs = compute_num_outputs(schema, args, kwargs) outputs = [f"output{i}" for i in range(num_outputs)] - node = onnx.helper.make_node(schema.name, inputs, outputs, domain=schema.domain) + node = onnx.helper.make_node(schema.name, inputs, outputs, domain=schema.domain) # noqa: TID251 node.attribute.extend( make_attr(key, value) for key, value in kwargs.items() if value is not None ) input_value_infos = utils.values_to_value_infos(zip(inputs, args)) implicit_value_infos = utils.values_to_value_infos(implicit_args.items()) output_value_infos = [ - onnx.helper.make_value_info(name, onnx.TypeProto()) for name in outputs + onnx.helper.make_value_info(name, onnx.TypeProto()) # noqa: TID251 + for name in outputs ] - graph = onnx.helper.make_graph( + graph = onnx.helper.make_graph( # noqa: TID251 [node], "node_graph", input_value_infos + implicit_value_infos, output_value_infos ) - opset_id = onnx.helper.make_opsetid(schema.domain, schema.since_version) - model = onnx.helper.make_model( + opset_id = onnx.helper.make_opsetid(schema.domain, schema.since_version) # noqa: TID251 + model = onnx.helper.make_model( # noqa: TID251 graph, opset_imports=[opset_id], ir_version=irbuilder.select_ir_version(schema.since_version, domain=schema.domain), diff --git a/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py b/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py index 8d0aab509e..b5c1456c12 100644 --- a/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py +++ b/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py @@ -1,5 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +# ruff: noqa: TID251 """Graph building functions for torchscript graph backend.""" from __future__ import annotations diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index 34f143b4ee..4a607e75bd 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -17,8 +17,6 @@ import math from typing import Optional, Sequence, Tuple, TypeVar, Union -import onnx - from onnxscript import BFLOAT16, BOOL, DOUBLE, FLOAT, FLOAT16, INT64, ir from onnxscript.function_libs.torch_lib.ops import common as common_ops from onnxscript.function_libs.torch_lib.registration import torch_op @@ -1798,15 +1796,11 @@ def _aten__scaled_dot_product_flash_attention_fillin_empty_outputs( op.Shape(query), op.Constant(value_ints=[0]), op.Constant(value_ints=[3]) ) logsumexp = op.Expand(0.0, query_first_three_dims) - # TODO: Eliminate `make_tensor` usage when ORT supports empty tensor. - empty_tensor_int = op.Cast( - op.ConstantOfShape( - op.Constant(value=onnx.helper.make_tensor("Empty_INTS", INT64.dtype, [0], [])) - ), - to=INT64.dtype, + empty_tensor_int = op.ConstantOfShape( + op.Constant(value=ir.tensor([], dtype=ir.DataType.INT64)) ) empty_tensor_float = op.ConstantOfShape( - op.Constant(value=onnx.helper.make_tensor("Empty_FLOATS", INT64.dtype, [0], [])) + op.Constant(value=ir.tensor([], dtype=ir.DataType.FLOAT)) ) empty_int = op.Constant(value_int=0) @@ -1881,11 +1875,8 @@ def _aten_scaled_dot_product_efficient_attention_fillin_empty_outputs( logsum_exp = op.Expand(0.0, op.Concat(query_first_dims, num_heads, [0], axis=0)) # See Note [Seed and Offset]: - empty_tensor_int = op.Cast( - op.ConstantOfShape( - op.Constant(value=onnx.helper.make_tensor("Empty_INTS", INT64.dtype, [0], [])) - ), - to=INT64.dtype, + empty_tensor_int = op.ConstantOfShape( + op.Constant(value=ir.tensor([], dtype=ir.DataType.INT64)) ) return logsum_exp, empty_tensor_int diff --git a/onnxscript/ir/_convenience/__init__.py b/onnxscript/ir/_convenience/__init__.py index 0addc9da2f..f43685a6f0 100644 --- a/onnxscript/ir/_convenience/__init__.py +++ b/onnxscript/ir/_convenience/__init__.py @@ -35,6 +35,7 @@ _core.RefAttr, _protocols.GraphProtocol, Sequence[_protocols.GraphProtocol], + onnx.GraphProto, _protocols.TypeProtocol, Sequence[_protocols.TypeProtocol], None, @@ -60,10 +61,15 @@ def _infer_attribute_type(attr: SupportedAttrTypes) -> _enums.AttributeType: if isinstance(attr, (_core.TensorBase, onnx.TensorProto, _protocols.TensorProtocol)): # Be sure to check TensorProtocol last because isinstance checking on Protocols can be slower return _enums.AttributeType.TENSOR - if isinstance(attr, (_core.Graph, _protocols.GraphProtocol)): + if isinstance(attr, Sequence) and all( + isinstance(x, (_core.TensorBase, onnx.TensorProto, _protocols.TensorProtocol)) + for x in attr + ): + return _enums.AttributeType.TENSORS + if isinstance(attr, (_core.Graph, onnx.GraphProto, _protocols.GraphProtocol)): return _enums.AttributeType.GRAPH if isinstance(attr, Sequence) and all( - isinstance(x, (_core.Graph, _protocols.GraphProtocol)) for x in attr + isinstance(x, (_core.Graph, onnx.GraphProto, _protocols.GraphProtocol)) for x in attr ): return _enums.AttributeType.GRAPHS if isinstance( @@ -145,11 +151,27 @@ def convert_attribute( if isinstance(attr, (_core.TensorBase, _protocols.TensorProtocol)): return _core.AttrTensor(name, attr) if isinstance(attr, onnx.TensorProto): - return _core.AttrTensor(name, serde.TensorProtoTensor(attr)) + return _core.AttrTensor(name, serde.deserialize_tensor(attr)) + if attr_type == _enums.AttributeType.TENSORS: + tensors = [] + for t in attr: # type: ignore[union-attr] + if isinstance(t, onnx.TensorProto): + tensors.append(_core.AttrTensor(name, serde.deserialize_tensor(t))) + else: + tensors.append(t) # type: ignore[arg-type] + return _core.AttrTensors(name, tensors) # type: ignore[arg-type] if attr_type == _enums.AttributeType.GRAPH: + if isinstance(attr, onnx.GraphProto): + attr = serde.deserialize_graph(attr) return _core.AttrGraph(name, attr) # type: ignore[arg-type] if attr_type == _enums.AttributeType.GRAPHS: - return _core.AttrGraphs(name, attr) # type: ignore[arg-type] + graphs = [] + for graph in attr: # type: ignore[union-attr] + if isinstance(graph, onnx.GraphProto): + graphs.append(serde.deserialize_graph(graph)) + else: + graphs.append(graph) # type: ignore[arg-type] + return _core.AttrGraphs(name, graphs) # type: ignore[arg-type] if attr_type == _enums.AttributeType.TYPE_PROTO: return _core.AttrTypeProto(name, attr) # type: ignore[arg-type] if attr_type == _enums.AttributeType.TYPE_PROTOS: diff --git a/onnxscript/ir/_convenience/_constructors.py b/onnxscript/ir/_convenience/_constructors.py index 3c6137f8cc..86477bcf7a 100644 --- a/onnxscript/ir/_convenience/_constructors.py +++ b/onnxscript/ir/_convenience/_constructors.py @@ -95,9 +95,35 @@ def tensor( # Plain Python object if dtype is not None: numpy_dtype = dtype.numpy() + elif isinstance(value, int) and not isinstance(value, bool): + # Specify int64 for ints because on Windows this may be int32 + numpy_dtype = np.dtype(np.int64) + elif isinstance(value, float): + # If the value is a single float, we use np.float32 as the default dtype + numpy_dtype = np.dtype(np.float32) + elif isinstance(value, Sequence) and all( + (isinstance(elem, int) and not isinstance(value, bool)) for elem in value + ): + numpy_dtype = np.dtype(np.int64) + elif isinstance(value, Sequence) and all(isinstance(elem, float) for elem in value): + # If the value is a sequence of floats, we use np.float32 as the default dtype + numpy_dtype = np.dtype(np.float32) else: numpy_dtype = None array = np.array(value, dtype=numpy_dtype) + + # Handle string tensors by encoding them + if isinstance(value, str) or ( + isinstance(value, Sequence) and value and all(isinstance(elem, str) for elem in value) + ): + array = np.strings.encode(array, encoding="utf-8") + return _core.StringTensor( + array, + shape=_core.Shape(array.shape), + name=name, + doc_string=doc_string, + ) + return _core.Tensor( array, dtype=dtype, diff --git a/onnxscript/irbuilder.py b/onnxscript/irbuilder.py index 407a1ccdb1..a845dcbc53 100644 --- a/onnxscript/irbuilder.py +++ b/onnxscript/irbuilder.py @@ -1,5 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +# ruff: noqa: TID251 from __future__ import annotations import dataclasses diff --git a/onnxscript/main.py b/onnxscript/main.py index 7407baedd1..3ea3e50f90 100644 --- a/onnxscript/main.py +++ b/onnxscript/main.py @@ -8,11 +8,10 @@ import sys from typing import Any, Callable, Optional, Sequence, TypeVar -import onnx.helper from typing_extensions import ParamSpec import onnxscript -from onnxscript import converter, irbuilder, values +from onnxscript import converter, ir, irbuilder, values from onnxscript._internal import ast_utils _R = TypeVar("_R") @@ -161,11 +160,17 @@ def export_onnx_lib(functions: Sequence[values.OnnxFunction], filename: str) -> # Since we don't yet have LibProto defined, we use a ModelProto as a temporary # container for the list of functions exported as a library, with an empty graph # and dummy opset_imports. - model = onnx.helper.make_model( - onnx.GraphProto(), - functions=[f.to_function_proto() for f in functions], + + # TODO(justinchuby): This function is not well supported. We should consider removing it + model = ir.Model( + ir.Graph( + inputs=[], + outputs=[], + nodes=[], + opset_imports={"": 15}, + ), + functions=[ir.serde.deserialize_function(f.to_function_proto()) for f in functions], + ir_version=10, producer_name="p2o", - opset_imports=[onnx.helper.make_opsetid("", 15)], ) - - onnx.save(model, filename) + ir.save(model, filename) diff --git a/onnxscript/onnx_types.py b/onnxscript/onnx_types.py index e83e5ac825..af1d5b4918 100644 --- a/onnxscript/onnx_types.py +++ b/onnxscript/onnx_types.py @@ -7,7 +7,6 @@ from typing import ClassVar, Optional, Tuple, Union import onnx -import onnx.helper import onnxscript.ir @@ -99,7 +98,7 @@ def to_type_proto(cls) -> onnx.TypeProto: shape = cls.shape # example: "FLOAT[10,20]" else: shape = [cls.shape] # example: "FLOAT[10]" - return onnx.helper.make_tensor_type_proto(cls.dtype, shape) + return onnx.helper.make_tensor_type_proto(cls.dtype, shape) # noqa: TID251 @classmethod def to_string(cls) -> str: diff --git a/onnxscript/rewriter/cast_constant_of_shape.py b/onnxscript/rewriter/cast_constant_of_shape.py index 34656ff190..f81cf4820f 100644 --- a/onnxscript/rewriter/cast_constant_of_shape.py +++ b/onnxscript/rewriter/cast_constant_of_shape.py @@ -4,8 +4,6 @@ import logging -import onnx.helper - from onnxscript import ir from onnxscript.rewriter import pattern @@ -20,7 +18,7 @@ def cast_constant_of_shape(op, shape, scalar, dtype): def fused_cast_constant_of_shape(op, shape: ir.Value, scalar: ir.Attr, dtype: ir.Attr, **_): # Cast scalar (a TensorProto attribute) to the specified dtype scalar_value = scalar.value.numpy().item() - cast_value = onnx.helper.make_tensor("value", dtype.value, (1,), [scalar_value]) + cast_value = ir.tensor([scalar_value], dtype=ir.DataType(dtype.as_int())) return op.ConstantOfShape(shape, value=cast_value) @@ -30,7 +28,7 @@ def cast_constant_of_shape_without_value(op, shape, dtype): def fused_cast_constant_of_shape_without_value(op, shape, dtype, **_): - zero = onnx.helper.make_tensor("value", dtype.value, (1,), [0]) + zero = ir.tensor([0], dtype=ir.DataType(dtype.as_int())) return op.ConstantOfShape(shape, value=zero) diff --git a/onnxscript/rewriter/llama_rule_sets.py b/onnxscript/rewriter/llama_rule_sets.py index f721bf5c9e..7342063f30 100644 --- a/onnxscript/rewriter/llama_rule_sets.py +++ b/onnxscript/rewriter/llama_rule_sets.py @@ -4,8 +4,6 @@ from typing import ClassVar -import onnx.numpy_helper - from onnxscript import ir from onnxscript.rewriter import _ir_utils as ir_utils from onnxscript.rewriter import pattern as orp @@ -57,10 +55,10 @@ class CastCast(orp.RewriteRuleAsClass): """Replaces ``Cast(Cast(X, ...), to=to)`` by ``Cast(X, to=to)``.""" _allowed_tensor_types: ClassVar = { - onnx.TensorProto.FLOAT, - onnx.TensorProto.FLOAT16, - onnx.TensorProto.BFLOAT16, - onnx.TensorProto.DOUBLE, + ir.DataType.FLOAT, + ir.DataType.FLOAT16, + ir.DataType.BFLOAT16, + ir.DataType.DOUBLE, } @classmethod @@ -72,7 +70,7 @@ def check(cls, context, x: ir.Value, to: ir.Attr, to_ignored: ir.Attr) -> orp.Ma check_result = orp.MatchResult() if to.value not in cls._allowed_tensor_types: return check_result.fail(f"Output type {to.value} is not allowed") - if to_ignored.value not in cls._allowed_tensor_types: + if to_ignored.as_int() not in cls._allowed_tensor_types: return check_result.fail(f"Ignored type {to_ignored.value} is not allowed") return check_result diff --git a/onnxscript/rewriter/ort_fusions/models/_smollm_1.py b/onnxscript/rewriter/ort_fusions/models/_smollm_1.py index dfff60db5c..c461c2b048 100644 --- a/onnxscript/rewriter/ort_fusions/models/_smollm_1.py +++ b/onnxscript/rewriter/ort_fusions/models/_smollm_1.py @@ -6,8 +6,7 @@ This is an onnxscript version of the model. """ -import numpy -from onnx.helper import make_tensor +import numpy as np import onnxscript.ir as ir from onnxscript import script @@ -73,44 +72,44 @@ def main_graph( unsqueeze_6 = opset18.Unsqueeze(input2, 1) to_copy_1 = opset18.Cast(unsqueeze_6, to=1) view_1 = opset18.Constant( - value=make_tensor( - "value", - 1, - dims=[1, 32, 1], - vals=[ - 1.0, - 0.7498942017555237, - 0.5623413324356079, - 0.4216965138912201, - 0.3162277638912201, - 0.23713736236095428, - 0.17782793939113617, - 0.1333521455526352, - 0.10000000149011612, - 0.07498941570520401, - 0.05623412877321243, - 0.04216964915394783, - 0.03162277862429619, - 0.0237137358635664, - 0.017782794311642647, - 0.01333521492779255, - 0.009999999776482582, - 0.007498942315578461, - 0.005623413249850273, - 0.0042169648222625256, - 0.003162277862429619, - 0.0023713738191872835, - 0.0017782794311642647, - 0.0013335214462131262, - 0.0010000000474974513, - 0.0007498941849917173, - 0.000562341301701963, - 0.00042169648804701865, - 0.0003162277862429619, - 0.0002371373848291114, - 0.00017782794020604342, - 0.0001333521504420787, - ], + value=ir.tensor( + np.array( + [ + 1.0, + 0.7498942017555237, + 0.5623413324356079, + 0.4216965138912201, + 0.3162277638912201, + 0.23713736236095428, + 0.17782793939113617, + 0.1333521455526352, + 0.10000000149011612, + 0.07498941570520401, + 0.05623412877321243, + 0.04216964915394783, + 0.03162277862429619, + 0.0237137358635664, + 0.017782794311642647, + 0.01333521492779255, + 0.009999999776482582, + 0.007498942315578461, + 0.005623413249850273, + 0.0042169648222625256, + 0.003162277862429619, + 0.0023713738191872835, + 0.0017782794311642647, + 0.0013335214462131262, + 0.0010000000474974513, + 0.0007498941849917173, + 0.000562341301701963, + 0.00042169648804701865, + 0.0003162277862429619, + 0.0002371373848291114, + 0.00017782794020604342, + 0.0001333521504420787, + ], + dtype=np.float32, + ).reshape([1, 32, 1]) ) ) view_2 = opset18.Reshape(to_copy_1, [1, 1, 10], allowzero=0) @@ -207,29 +206,29 @@ def main_graph( def make_model_with_random_weights(): - input_layernorm_weight_0 = numpy.random.rand(2048).astype(numpy.float32) - post_attention_layernorm_weight0 = numpy.random.rand(2048).astype(numpy.float32) - norm_weight = numpy.random.rand(2048).astype(numpy.float32) - head_weight = numpy.random.rand(49152, 2048).astype(numpy.float32) - self_attn_q_proj_weight0 = numpy.random.rand(2048, 2048).astype(numpy.float32) - self_attn_k_proj_weight0 = numpy.random.rand(2048, 2048).astype(numpy.float32) - self_attn_v_proj_weight0 = numpy.random.rand(2048, 2048).astype(numpy.float32) - self_attn_o_proj_weight0 = numpy.random.rand(2048, 2048).astype(numpy.float32) - mlp_gate_proj_weight0 = numpy.random.rand(8192, 2048).astype(numpy.float32) - mlp_up_proj_weight0 = numpy.random.rand(8192, 2048).astype(numpy.float32) - mlp_down_proj_weight0 = numpy.random.rand(2048, 8192).astype(numpy.float32) + input_layernorm_weight_0 = np.random.rand(2048).astype(np.float32) + post_attention_layernorm_weight0 = np.random.rand(2048).astype(np.float32) + norm_weight = np.random.rand(2048).astype(np.float32) + head_weight = np.random.rand(49152, 2048).astype(np.float32) + self_attn_q_proj_weight0 = np.random.rand(2048, 2048).astype(np.float32) + self_attn_k_proj_weight0 = np.random.rand(2048, 2048).astype(np.float32) + self_attn_v_proj_weight0 = np.random.rand(2048, 2048).astype(np.float32) + self_attn_o_proj_weight0 = np.random.rand(2048, 2048).astype(np.float32) + mlp_gate_proj_weight0 = np.random.rand(8192, 2048).astype(np.float32) + mlp_up_proj_weight0 = np.random.rand(8192, 2048).astype(np.float32) + mlp_down_proj_weight0 = np.random.rand(2048, 8192).astype(np.float32) model = make_model( - input_layernorm_weight_0, - post_attention_layernorm_weight0, - norm_weight, - head_weight, - self_attn_q_proj_weight0, - self_attn_k_proj_weight0, - self_attn_v_proj_weight0, - self_attn_o_proj_weight0, - mlp_gate_proj_weight0, - mlp_up_proj_weight0, - mlp_down_proj_weight0, + ir.tensor(input_layernorm_weight_0), + ir.tensor(post_attention_layernorm_weight0), + ir.tensor(norm_weight), + ir.tensor(head_weight), + ir.tensor(self_attn_q_proj_weight0), + ir.tensor(self_attn_k_proj_weight0), + ir.tensor(self_attn_v_proj_weight0), + ir.tensor(self_attn_o_proj_weight0), + ir.tensor(mlp_gate_proj_weight0), + ir.tensor(mlp_up_proj_weight0), + ir.tensor(mlp_down_proj_weight0), ) return model @@ -245,9 +244,9 @@ def get_onnx_model(self): def get_ort_inputs(self): if not hasattr(self, "_ort_inputs"): inputs = { - "input0": numpy.random.randint(0, 49152, (1, 10)).astype(numpy.int64), - "input1": numpy.ones((1, 10), dtype=numpy.float32), - "input2": numpy.arange(10, dtype=numpy.int64).reshape(1, 10), + "input0": np.random.randint(0, 49152, (1, 10)).astype(np.int64), + "input1": np.ones((1, 10), dtype=np.float32), + "input2": np.arange(10, dtype=np.int64).reshape(1, 10), } self._ort_inputs = inputs return self._ort_inputs diff --git a/onnxscript/tensor.py b/onnxscript/tensor.py index 21ca3c4a68..f1d781b808 100644 --- a/onnxscript/tensor.py +++ b/onnxscript/tensor.py @@ -6,10 +6,8 @@ from typing import Any, Optional import numpy as np -import onnx.helper -from onnx import TensorProto -from onnxscript import onnx_opset +from onnxscript import ir, onnx_opset from onnxscript._internal import autocast @@ -52,7 +50,7 @@ def dtype(self) -> np.dtype: @property def onnx_dtype(self) -> int: - return onnx.helper.np_dtype_to_tensor_dtype(self.dtype) + return ir.DataType.from_numpy(self.dtype) def __repr__(self) -> str: return f"{self.__class__.__name__}({self.value!r})" @@ -160,10 +158,10 @@ def __getitem__(self, index): def __mod__(self, other): if self.onnx_dtype in { - TensorProto.FLOAT, - TensorProto.DOUBLE, - TensorProto.FLOAT16, - TensorProto.BFLOAT16, + ir.DataType.FLOAT, + ir.DataType.DOUBLE, + ir.DataType.FLOAT16, + ir.DataType.BFLOAT16, }: return self._opset.Mod(self, other, fmod=1) return self._opset.Mod(self, other) diff --git a/onnxscript/testing/__init__.py b/onnxscript/testing/__init__.py index a6e8160063..048b45e7e8 100644 --- a/onnxscript/testing/__init__.py +++ b/onnxscript/testing/__init__.py @@ -14,10 +14,12 @@ from typing import Any, Collection, Sequence import google.protobuf.message +import numpy as np import onnx from onnx import parser import onnxscript +from onnxscript import ir def assert_isomorphic(graph_or_function_1, graph_or_function_2): @@ -66,7 +68,7 @@ def to_map(proto): return to_map(proto1) == to_map(proto2) -def _same_tensor(tp1, tp2): +def _same_tensor(tp1: onnx.TensorProto, tp2: onnx.TensorProto): if tp1.dims != tp2.dims: return False if not _same_optional("data_type", tp1, tp2): @@ -74,18 +76,11 @@ def _same_tensor(tp1, tp2): # Segmented representation not supported yet if tp1.HasField("segment") or tp2.HasField("segment"): return False - if tp1.float_data != tp2.float_data: - return False - if tp1.int32_data != tp2.int32_data: - return False - if tp1.string_data != tp2.string_data: - return False - if tp1.int64_data != tp2.int64_data: - return False - if tp1.uint64_data != tp2.uint64_data: - return False - if tp1.double_data != tp2.double_data: - return False + if tp1.data_location == tp2.data_location == tp1.DataLocation.DEFAULT: + tensor1 = ir.from_proto(tp1) + tensor2 = ir.from_proto(tp2) + if not np.array_equal(tensor1.numpy(), tensor2.numpy(), equal_nan=True): + return False # Ignore name for comparison: # if not _same_optional("name", tp1, tp2): return False if not _same_optional("doc_string", tp1, tp2): diff --git a/onnxscript/values.py b/onnxscript/values.py index d748dc6e64..266f7da571 100644 --- a/onnxscript/values.py +++ b/onnxscript/values.py @@ -176,7 +176,7 @@ def _get_attribute_value(attr_proto: onnx.AttributeProto) -> Any: """Get the default value of an ONNX attribute.""" if attr_proto.type == onnx.AttributeProto.UNDEFINED: return _EmptyDefault - return onnx.helper.get_attribute_value(attr_proto) + return onnx.helper.get_attribute_value(attr_proto) # noqa: TID251 def _param_schemas_from_op_schema( diff --git a/pyproject.toml b/pyproject.toml index ff873319fb..361ba40aa6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -209,6 +209,8 @@ ignore-init-module-imports = true [tool.ruff.lint.flake8-tidy-imports.banned-api] "pathlib".msg = "Using pathlib can impact performance. Use os.path instead" +"onnx.helper".msg = "onnx helpers tend to be protobuf-y and slow. Consider using ir.tensor, ir.DataType and related methods instead" +"onnx.numpy_helper".msg = "onnx numpy helpers tend to be slow. Consider using ir.tensor, ir.DataType and related methods instead" [tool.ruff.lint.per-file-ignores] "__init__.py" = ["TID252"] # Allow relative imports in init files