diff --git a/onnxscript/_legacy_ir/__init__.py b/onnxscript/_legacy_ir/__init__.py new file mode 100644 index 0000000000..bb4bd65a71 --- /dev/null +++ b/onnxscript/_legacy_ir/__init__.py @@ -0,0 +1,297 @@ +from __future__ import annotations + + +import dataclasses +from collections import deque +from typing import List, Tuple, Union + +import numpy as np +import onnx + + +class Unknown: + """A special value used to indicate that a value is not a statically known constant. + + We use this instead of None because None is a valid constant value (since ONNX + supports the Optional type). + """ + + instance = None + + def __init__(self) -> None: + if Unknown.instance is not None: + raise ValueError("Unknown.instance is already set") + Unknown.instance = self + + +# Singleton instance of Unknown +unknown = Unknown() +NotConstant = unknown + +# ConcreteValue: This type represents constant values that an ONNX variable can take. +# TODO: Extend this to a recursive type to handle lists of tensors, etc., support optionals, +# maps, etc. +# TODO (rama): The value is sometimes stored as a numpy array, and sometimes as an ONNX TensorProto. +# A uniform representation would be helpful, but we should avoid unnecessary conversions for +# large tensors. Should be cleaned up in the new IR. +ConcreteValue = Union[onnx.TensorProto, np.ndarray, Unknown, None] + +# SymbolicValue: This information is used to enable partial-evaluation and specialization +# of sequence operations, as well as elimination of redundant Identity ops. +# The symbolic value of a variable X can be: +# - a string with the value "Y", indicating that "X" is a copy of "Y" +# - a list of strings, indicating that "X" is a list of tensors, with their symbolic values +# Eg., the symbolic value ["A", "B", "C"] indicates that the value of X is equal to +# "SequenceConstruct(A, B, C)". +# TODO: Technically, SymbolicValue should be a recursive type to handle lists of lists of +# tensors, etc. However, we currently only handle lists of tensors. + +SymbolicValue = Union[str, List[str]] + +FunctionId = Tuple[str, str, str] + + +def get_function_id(function: onnx.FunctionProto) -> FunctionId: + return (function.domain, function.name, getattr(function, "overload", "")) + + +def get_function_id_from_node(node: onnx.NodeProto) -> FunctionId: + return (node.domain, node.op_type, getattr(node, "overload", "")) + + +@dataclasses.dataclass +class StaticValueInfo: + name: str + value: ConcreteValue = NotConstant + type: onnx.TypeProto | None = None + symbolic_value: SymbolicValue | None = None + + def is_copy(self) -> bool: + return isinstance(self.symbolic_value, str) + + def tensor_shape_proto(self) -> onnx.TensorShapeProto | None: + """Returns the shape of a tensor or None. + + A return value of None could mean that the type is unknown or that the type is not a tensor + or that the tensor shape (that is, even the rank) is unknown. + """ + type = self.type + if type and type.HasField("tensor_type") and type.tensor_type.HasField("shape"): + return type.tensor_type.shape + return None + + @property + def shape(self) -> list[str | int | None] | None: + """Returns the shape in a list. + + Str means that the shape is dynamic. + """ + type = self.type + if type and type.HasField("tensor_type") and type.tensor_type.HasField("shape"): + dims = [] + for dim in type.tensor_type.shape.dim: + if dim.HasField("dim_param"): + dims.append(dim.dim_param) + elif dim.HasField("dim_value"): + dims.append(dim.dim_value) + else: + dims.append(None) + return dims + if self.value_as_np_array is not None: + return list(self.value_as_np_array.shape) + return None + + @property + def element_type(self) -> int | None: + """Returns the element type of a tensor, or None if type is not known or is not a tensor.""" + type = self.type + if type and type.HasField("tensor_type"): + return type.tensor_type.elem_type + return None + + def identity_merge_from(self, other: StaticValueInfo) -> None: + """Merge the value of other into self. + + This models the effect of an identity (copy) operation. + This will update static-analysis information based on incoming value. + """ + if not isinstance(other, StaticValueInfo): + raise TypeError(f"Cannot merge {other} into {self}.") + if other.value is not NotConstant: + self.value = other.value + # TODO: merge and combine best shape information from both types. + if other.tensor_shape_proto() is not None and other.element_type is not None: + self.type = other.type + # We cannot copy symbolic value across different scopes. + + # WIP: Extensions towards new IR: Note that the default construction of StaticValueInfo + # does not fill in the following fields. These fields are filled in by the IRBuilder + # which constructs the IR from the ONNX model. + node: Node | None = None + uses: list[Node] = dataclasses.field(default_factory=list) + output_index: int | None = None + is_output: bool = False + + @property + def const_value(self) -> ConcreteValue: + return self.value + + @property + def value_as_np_array(self) -> np.ndarray | None: + if isinstance(self.value, np.ndarray): + return self.value + if isinstance(self.value, onnx.TensorProto): + return onnx.numpy_helper.to_array(self.value) + return None + + def def_node(self) -> Node | None: + return self.node + + def def_index(self) -> int: + return self.output_index + + def is_same_as(self, other: StaticValueInfo) -> bool: + """Returns true if this value represents the same IR object as the other value. + + This is *not* value-equality, but rather object-equality. + """ + return self is other + + def __str__(self) -> str: + shape = self.shape + if shape is not None: + shape = [str(dim) for dim in shape] + shape_str = f"[{', '.join(shape)}]" + else: + shape_str = "None" + return ( + f"StaticValueInfo({self.name}, shape:{shape_str}, dtype:{self.element_type}, " + f"{'has const value' if self.value is not unknown else 'no const value'}.)" + ) + + +Value = StaticValueInfo + + +class Model: + def __init__(self) -> None: + self.gen_var_counter: int = 0 + + def set( + self, + model_proto: onnx.ModelProto, + graph: Graph, + functions: list[Function], + version_map: dict[str, int], + ) -> None: + """TODO. This is a temporary patch.""" + self.original_model_proto = model_proto + self.graph = graph + self.functions = functions + self.version_map = version_map + + def make_new_name(self): + # Temporary hack. + self.gen_var_counter += 1 + return f"_gen_{self.gen_var_counter}" + + def __str__(self) -> str: + # TODO: Naive string representation for debugging. Need to improve this. + return "\n".join( + [ + f"ModelGraph: {self.graph}", + f"Functions: {self.functions}", + f"VersionMap: {self.version_map}", + ] + ) + + +class Graph: + def __init__(self, graph_proto: onnx.GraphProto): + self.original_graph_proto = graph_proto + self.nodes: deque[Node] = deque() + self.values: dict[str, Value] = {} + + @property + def name(self) -> str: + return self.original_graph_proto.name + + def __str__(self) -> str: + return "\n".join( + [ + "Graph", + f"Nodes: {[str(n) for n in self.nodes]}", + f"Values: {[str(v) for v in self.values]}", + ] + ) + + +class Function: + def __init__(self, function_proto: onnx.FunctionProto): + self.original_function_proto = function_proto + self.nodes = deque() + self.values = {} + + @property + def id(self) -> FunctionId: + return (self.domain, self.name, self.overload) + + @property + def domain(self) -> str: + return self.original_function_proto.domain + + @property + def name(self) -> str: + return self.original_function_proto.name + + @property + def overload(self) -> str: + return getattr(self.original_function_proto, "overload", "") + + def __str__(self) -> str: + return "\n".join( + [ + "Function", + f"Nodes: {[str(n) for n in self.nodes]}", + f"Values: {[str(v) for v in self.values]}", + ] + ) + + +class RefAttr: + def __init__(self, name: str, ref_attr_name: str, type) -> None: + self.name = name + self.ref_attr_name = ref_attr_name + self.type = type + + def to_proto(self) -> onnx.AttributeProto: + attr_proto = onnx.AttributeProto() + attr_proto.name = self.name + attr_proto.ref_attr_name = self.ref_attr_name + attr_proto.type = self.type + return attr_proto + + +class Node: + def __init__(self, node_proto: onnx.NodeProto) -> None: + self.original_node_proto = node_proto + self.domain: str = node_proto.domain + self.version: int | None = None + self.op_type: str = node_proto.op_type + self.inputs: list[Value | None] = [] + self.outputs: list[Value | None] = [] + self.attributes: dict[str, int | float | RefAttr | Graph | list[Graph]] = {} + + def get_attribute(self, name: str) -> int | float | None: + return self.attributes.get(name, None) + + def __str__(self) -> str: + return "\n".join( + [ + "Node", + f"OpType: {self.op_type}", + f"Inputs: {self.inputs}", + f"Outputs: {self.outputs}", + f"Attributes: {self.attributes}", + ] + ) diff --git a/onnxscript/ir/irbuilder.py b/onnxscript/_legacy_ir/irbuilder.py similarity index 99% rename from onnxscript/ir/irbuilder.py rename to onnxscript/_legacy_ir/irbuilder.py index 6a7f681977..aa440faadd 100644 --- a/onnxscript/ir/irbuilder.py +++ b/onnxscript/_legacy_ir/irbuilder.py @@ -5,8 +5,8 @@ import onnx -from onnxscript import ir -from onnxscript.ir import visitor +import onnxscript._legacy_ir as ir +from onnxscript._legacy_ir import visitor from onnxscript.utils import utils """ NOTE: IRBuilder and function visiting diff --git a/onnxscript/ir/irbuilder_test.py b/onnxscript/_legacy_ir/irbuilder_test.py similarity index 99% rename from onnxscript/ir/irbuilder_test.py rename to onnxscript/_legacy_ir/irbuilder_test.py index 671dc35f6f..5312152582 100644 --- a/onnxscript/ir/irbuilder_test.py +++ b/onnxscript/_legacy_ir/irbuilder_test.py @@ -2,7 +2,7 @@ import onnx.parser -from onnxscript.ir import irbuilder +from onnxscript._legacy_ir import irbuilder class IRBuilderTest(unittest.TestCase): diff --git a/onnxscript/ir/protobuilder.py b/onnxscript/_legacy_ir/protobuilder.py similarity index 99% rename from onnxscript/ir/protobuilder.py rename to onnxscript/_legacy_ir/protobuilder.py index 8d76c27fb9..87e4de667b 100644 --- a/onnxscript/ir/protobuilder.py +++ b/onnxscript/_legacy_ir/protobuilder.py @@ -4,7 +4,7 @@ import onnx.helper from onnx.helper import make_attribute -from onnxscript import ir +import onnxscript._legacy_ir as ir class ModelProtoBuilder: diff --git a/onnxscript/ir/protobuilder_test.py b/onnxscript/_legacy_ir/protobuilder_test.py similarity index 99% rename from onnxscript/ir/protobuilder_test.py rename to onnxscript/_legacy_ir/protobuilder_test.py index 3d9f798447..f20fbaac4e 100644 --- a/onnxscript/ir/protobuilder_test.py +++ b/onnxscript/_legacy_ir/protobuilder_test.py @@ -4,7 +4,7 @@ import onnx.checker import onnx.parser -from onnxscript.ir import irbuilder, protobuilder +from onnxscript._legacy_ir import irbuilder, protobuilder from onnxscript.rewriter import pattern from onnxscript.rewriter.onnxruntime import instance_to_group_normalization diff --git a/onnxscript/ir/visitor.py b/onnxscript/_legacy_ir/visitor.py similarity index 99% rename from onnxscript/ir/visitor.py rename to onnxscript/_legacy_ir/visitor.py index 973001f0e7..03a0835890 100644 --- a/onnxscript/ir/visitor.py +++ b/onnxscript/_legacy_ir/visitor.py @@ -7,7 +7,7 @@ import numpy as np import onnx -from onnxscript import ir +import onnxscript._legacy_ir as ir from onnxscript.utils.utils import ( get_initializer_type, is_control_flow_op, diff --git a/onnxscript/ir/visitor_test.py b/onnxscript/_legacy_ir/visitor_test.py similarity index 96% rename from onnxscript/ir/visitor_test.py rename to onnxscript/_legacy_ir/visitor_test.py index 814b103458..e4559472e3 100644 --- a/onnxscript/ir/visitor_test.py +++ b/onnxscript/_legacy_ir/visitor_test.py @@ -2,7 +2,7 @@ import onnx -from onnxscript.ir import visitor +from onnxscript._legacy_ir import visitor class FunctionCallsiteProtoTransformerTest(unittest.TestCase): diff --git a/onnxscript/ir/__init__.py b/onnxscript/ir/__init__.py index 69794b1c09..3494e5c335 100644 --- a/onnxscript/ir/__init__.py +++ b/onnxscript/ir/__init__.py @@ -1,6 +1,4 @@ """In-memory intermediate representation for ONNX graphs.""" -from __future__ import annotations - __all__ = [ # Modules "serde", @@ -108,298 +106,3 @@ TypeProtocol, ValueProtocol, ) - -import dataclasses -from collections import deque -from typing import List, Tuple, Union - -import numpy as np -import onnx - - -class Unknown: - """A special value used to indicate that a value is not a statically known constant. - - We use this instead of None because None is a valid constant value (since ONNX - supports the Optional type). - """ - - instance = None - - def __init__(self) -> None: - if Unknown.instance is not None: - raise ValueError("Unknown.instance is already set") - Unknown.instance = self - - -# Singleton instance of Unknown -unknown = Unknown() -NotConstant = unknown - -# ConcreteValue: This type represents constant values that an ONNX variable can take. -# TODO: Extend this to a recursive type to handle lists of tensors, etc., support optionals, -# maps, etc. -# TODO (rama): The value is sometimes stored as a numpy array, and sometimes as an ONNX TensorProto. -# A uniform representation would be helpful, but we should avoid unnecessary conversions for -# large tensors. Should be cleaned up in the new IR. -ConcreteValue = Union[onnx.TensorProto, np.ndarray, Unknown, None] - -# SymbolicValue: This information is used to enable partial-evaluation and specialization -# of sequence operations, as well as elimination of redundant Identity ops. -# The symbolic value of a variable X can be: -# - a string with the value "Y", indicating that "X" is a copy of "Y" -# - a list of strings, indicating that "X" is a list of tensors, with their symbolic values -# Eg., the symbolic value ["A", "B", "C"] indicates that the value of X is equal to -# "SequenceConstruct(A, B, C)". -# TODO: Technically, SymbolicValue should be a recursive type to handle lists of lists of -# tensors, etc. However, we currently only handle lists of tensors. - -SymbolicValue = Union[str, List[str]] - -FunctionId = Tuple[str, str, str] - - -def get_function_id(function: onnx.FunctionProto) -> FunctionId: - return (function.domain, function.name, getattr(function, "overload", "")) - - -def get_function_id_from_node(node: onnx.NodeProto) -> FunctionId: - return (node.domain, node.op_type, getattr(node, "overload", "")) - - -@dataclasses.dataclass -class StaticValueInfo: - name: str - value: ConcreteValue = NotConstant - type: onnx.TypeProto | None = None - symbolic_value: SymbolicValue | None = None - - def is_copy(self) -> bool: - return isinstance(self.symbolic_value, str) - - def tensor_shape_proto(self) -> onnx.TensorShapeProto | None: - """Returns the shape of a tensor or None. - - A return value of None could mean that the type is unknown or that the type is not a tensor - or that the tensor shape (that is, even the rank) is unknown. - """ - type = self.type - if type and type.HasField("tensor_type") and type.tensor_type.HasField("shape"): - return type.tensor_type.shape - return None - - @property - def shape(self) -> list[str | int | None] | None: - """Returns the shape in a list. - - Str means that the shape is dynamic. - """ - type = self.type - if type and type.HasField("tensor_type") and type.tensor_type.HasField("shape"): - dims = [] - for dim in type.tensor_type.shape.dim: - if dim.HasField("dim_param"): - dims.append(dim.dim_param) - elif dim.HasField("dim_value"): - dims.append(dim.dim_value) - else: - dims.append(None) - return dims - if self.value_as_np_array is not None: - return list(self.value_as_np_array.shape) - return None - - @property - def element_type(self) -> int | None: - """Returns the element type of a tensor, or None if type is not known or is not a tensor.""" - type = self.type - if type and type.HasField("tensor_type"): - return type.tensor_type.elem_type - return None - - def identity_merge_from(self, other: StaticValueInfo) -> None: - """Merge the value of other into self. - - This models the effect of an identity (copy) operation. - This will update static-analysis information based on incoming value. - """ - if not isinstance(other, StaticValueInfo): - raise TypeError(f"Cannot merge {other} into {self}.") - if other.value is not NotConstant: - self.value = other.value - # TODO: merge and combine best shape information from both types. - if other.tensor_shape_proto() is not None and other.element_type is not None: - self.type = other.type - # We cannot copy symbolic value across different scopes. - - # WIP: Extensions towards new IR: Note that the default construction of StaticValueInfo - # does not fill in the following fields. These fields are filled in by the IRBuilder - # which constructs the IR from the ONNX model. - node: Node | None = None - uses: list[Node] = dataclasses.field(default_factory=list) - output_index: int | None = None - is_output: bool = False - - @property - def const_value(self) -> ConcreteValue: - return self.value - - @property - def value_as_np_array(self) -> np.ndarray | None: - if isinstance(self.value, np.ndarray): - return self.value - if isinstance(self.value, onnx.TensorProto): - return onnx.numpy_helper.to_array(self.value) - return None - - def def_node(self) -> Node | None: - return self.node - - def def_index(self) -> int: - return self.output_index - - def is_same_as(self, other: StaticValueInfo) -> bool: - """Returns true if this value represents the same IR object as the other value. - - This is *not* value-equality, but rather object-equality. - """ - return self is other - - def __str__(self) -> str: - shape = self.shape - if shape is not None: - shape = [str(dim) for dim in shape] - shape_str = f"[{', '.join(shape)}]" - else: - shape_str = "None" - return ( - f"StaticValueInfo({self.name}, shape:{shape_str}, dtype:{self.element_type}, " - f"{'has const value' if self.value is not unknown else 'no const value'}.)" - ) - - -Value = StaticValueInfo - - -class Model: - def __init__(self) -> None: - self.gen_var_counter: int = 0 - - def set( - self, - model_proto: onnx.ModelProto, - graph: Graph, - functions: list[Function], - version_map: dict[str, int], - ) -> None: - """TODO. This is a temporary patch.""" - self.original_model_proto = model_proto - self.graph = graph - self.functions = functions - self.version_map = version_map - - def make_new_name(self): - # Temporary hack. - self.gen_var_counter += 1 - return f"_gen_{self.gen_var_counter}" - - def __str__(self) -> str: - # TODO: Naive string representation for debugging. Need to improve this. - return "\n".join( - [ - f"ModelGraph: {self.graph}", - f"Functions: {self.functions}", - f"VersionMap: {self.version_map}", - ] - ) - - -class Graph: - def __init__(self, graph_proto: onnx.GraphProto): - self.original_graph_proto = graph_proto - self.nodes: deque[Node] = deque() - self.values: dict[str, Value] = {} - - @property - def name(self) -> str: - return self.original_graph_proto.name - - def __str__(self) -> str: - return "\n".join( - [ - "Graph", - f"Nodes: {[str(n) for n in self.nodes]}", - f"Values: {[str(v) for v in self.values]}", - ] - ) - - -class Function: - def __init__(self, function_proto: onnx.FunctionProto): - self.original_function_proto = function_proto - self.nodes = deque() - self.values = {} - - @property - def id(self) -> FunctionId: - return (self.domain, self.name, self.overload) - - @property - def domain(self) -> str: - return self.original_function_proto.domain - - @property - def name(self) -> str: - return self.original_function_proto.name - - @property - def overload(self) -> str: - return getattr(self.original_function_proto, "overload", "") - - def __str__(self) -> str: - return "\n".join( - [ - "Function", - f"Nodes: {[str(n) for n in self.nodes]}", - f"Values: {[str(v) for v in self.values]}", - ] - ) - - -class RefAttr: - def __init__(self, name: str, ref_attr_name: str, type) -> None: - self.name = name - self.ref_attr_name = ref_attr_name - self.type = type - - def to_proto(self) -> onnx.AttributeProto: - attr_proto = onnx.AttributeProto() - attr_proto.name = self.name - attr_proto.ref_attr_name = self.ref_attr_name - attr_proto.type = self.type - return attr_proto - - -class Node: - def __init__(self, node_proto: onnx.NodeProto) -> None: - self.original_node_proto = node_proto - self.domain: str = node_proto.domain - self.version: int | None = None - self.op_type: str = node_proto.op_type - self.inputs: list[Value | None] = [] - self.outputs: list[Value | None] = [] - self.attributes: dict[str, int | float | RefAttr | Graph | list[Graph]] = {} - - def get_attribute(self, name: str) -> int | float | None: - return self.attributes.get(name, None) - - def __str__(self) -> str: - return "\n".join( - [ - "Node", - f"OpType: {self.op_type}", - f"Inputs: {self.inputs}", - f"Outputs: {self.outputs}", - f"Attributes: {self.attributes}", - ] - ) diff --git a/onnxscript/optimizer/constant_folding.py b/onnxscript/optimizer/constant_folding.py index 914f6efaa2..61214a4ba1 100644 --- a/onnxscript/optimizer/constant_folding.py +++ b/onnxscript/optimizer/constant_folding.py @@ -7,8 +7,8 @@ import onnx import onnx.reference.ops -from onnxscript import ir -from onnxscript.ir import visitor +import onnxscript._legacy_ir as ir +from onnxscript._legacy_ir import visitor from onnxscript.optimizer import evaluator from onnxscript.utils.utils import ( is_control_flow_op, diff --git a/onnxscript/optimizer/copy_propagation.py b/onnxscript/optimizer/copy_propagation.py index dc6613e702..bb19fbdcec 100644 --- a/onnxscript/optimizer/copy_propagation.py +++ b/onnxscript/optimizer/copy_propagation.py @@ -5,7 +5,7 @@ import onnx import onnxscript.optimizer.remove_unused -from onnxscript.ir import visitor +from onnxscript._legacy_ir import visitor from onnxscript.utils.utils import is_onnx_op diff --git a/onnxscript/optimizer/evaluator.py b/onnxscript/optimizer/evaluator.py index e8415ae245..9bd2b62841 100644 --- a/onnxscript/optimizer/evaluator.py +++ b/onnxscript/optimizer/evaluator.py @@ -14,7 +14,7 @@ import onnx import onnx.reference.ops -from onnxscript import ir +import onnxscript._legacy_ir as ir from onnxscript.utils.utils import ( get_node_attr_value, ) diff --git a/onnxscript/optimizer/simple_function_folding.py b/onnxscript/optimizer/simple_function_folding.py index 03b86f9a5a..f965edef6e 100644 --- a/onnxscript/optimizer/simple_function_folding.py +++ b/onnxscript/optimizer/simple_function_folding.py @@ -7,8 +7,8 @@ import onnx -from onnxscript import ir -from onnxscript.ir import visitor +import onnxscript._legacy_ir as ir +from onnxscript._legacy_ir import visitor from onnxscript.optimizer import remove_unused logger = logging.getLogger(__name__) diff --git a/onnxscript/rewriter/__init__.py b/onnxscript/rewriter/__init__.py index fa80841110..3fe036f228 100644 --- a/onnxscript/rewriter/__init__.py +++ b/onnxscript/rewriter/__init__.py @@ -14,7 +14,7 @@ import onnx -from onnxscript.ir import irbuilder, protobuilder +from onnxscript._legacy_ir import irbuilder, protobuilder from onnxscript.rewriter import function_rule, pattern PatternRewriteRule = pattern.RewriteRule diff --git a/onnxscript/rewriter/broadcast_to_matmul.py b/onnxscript/rewriter/broadcast_to_matmul.py index 36cb58603c..3b8d59cb50 100644 --- a/onnxscript/rewriter/broadcast_to_matmul.py +++ b/onnxscript/rewriter/broadcast_to_matmul.py @@ -5,7 +5,7 @@ import numpy as np -from onnxscript import ir +import onnxscript._legacy_ir as ir from onnxscript.rewriter import pattern op = pattern.onnxop diff --git a/onnxscript/rewriter/broadcast_to_matmul_test.py b/onnxscript/rewriter/broadcast_to_matmul_test.py index 83fc3cd4e1..73cb59e635 100644 --- a/onnxscript/rewriter/broadcast_to_matmul_test.py +++ b/onnxscript/rewriter/broadcast_to_matmul_test.py @@ -2,7 +2,7 @@ import onnx.parser -from onnxscript.ir import irbuilder +from onnxscript._legacy_ir import irbuilder from onnxscript.rewriter import broadcast_to_matmul diff --git a/onnxscript/rewriter/cast_constant_of_shape.py b/onnxscript/rewriter/cast_constant_of_shape.py index a4b90df42c..b8a64f6d3f 100644 --- a/onnxscript/rewriter/cast_constant_of_shape.py +++ b/onnxscript/rewriter/cast_constant_of_shape.py @@ -6,7 +6,7 @@ import numpy as np import onnx -from onnxscript import ir +import onnxscript._legacy_ir as ir from onnxscript.rewriter import pattern op = pattern.onnxop diff --git a/onnxscript/rewriter/cast_constant_of_shape_test.py b/onnxscript/rewriter/cast_constant_of_shape_test.py index e7aa5f515f..c459a40c4c 100644 --- a/onnxscript/rewriter/cast_constant_of_shape_test.py +++ b/onnxscript/rewriter/cast_constant_of_shape_test.py @@ -2,7 +2,7 @@ import onnx.parser -from onnxscript.ir import irbuilder +from onnxscript._legacy_ir import irbuilder from onnxscript.rewriter import cast_constant_of_shape diff --git a/onnxscript/rewriter/function_rule.py b/onnxscript/rewriter/function_rule.py index 1432ee2c2d..f36f790227 100644 --- a/onnxscript/rewriter/function_rule.py +++ b/onnxscript/rewriter/function_rule.py @@ -7,8 +7,8 @@ import onnxscript from packaging import version -from onnxscript import ir -from onnxscript.ir import visitor +import onnxscript._legacy_ir as ir +from onnxscript._legacy_ir import visitor from onnxscript.rewriter import pattern logger = logging.getLogger(__name__) diff --git a/onnxscript/rewriter/gemm_to_matmul_add_test.py b/onnxscript/rewriter/gemm_to_matmul_add_test.py index a2cd339927..615d6311a0 100644 --- a/onnxscript/rewriter/gemm_to_matmul_add_test.py +++ b/onnxscript/rewriter/gemm_to_matmul_add_test.py @@ -2,7 +2,7 @@ import onnx.parser -from onnxscript.ir import irbuilder +from onnxscript._legacy_ir import irbuilder from onnxscript.rewriter import gemm_to_matmul_add diff --git a/onnxscript/rewriter/no_op_test.py b/onnxscript/rewriter/no_op_test.py index ece94b9ec0..e38d8b7c6a 100644 --- a/onnxscript/rewriter/no_op_test.py +++ b/onnxscript/rewriter/no_op_test.py @@ -3,7 +3,7 @@ import onnx.parser import parameterized -from onnxscript.ir import irbuilder +from onnxscript._legacy_ir import irbuilder from onnxscript.rewriter import no_op diff --git a/onnxscript/rewriter/onnxruntime/__init__.py b/onnxscript/rewriter/onnxruntime/__init__.py index b7357a1d0e..5a99d03749 100644 --- a/onnxscript/rewriter/onnxruntime/__init__.py +++ b/onnxscript/rewriter/onnxruntime/__init__.py @@ -2,7 +2,7 @@ import onnx -from onnxscript.ir import irbuilder, protobuilder +from onnxscript._legacy_ir import irbuilder, protobuilder from onnxscript.optimizer import remove_unused from onnxscript.rewriter import function_rule, pattern from onnxscript.rewriter.onnxruntime import ( diff --git a/onnxscript/rewriter/onnxruntime/instance_to_group_normalization.py b/onnxscript/rewriter/onnxruntime/instance_to_group_normalization.py index 3d974bea92..2c7b742719 100644 --- a/onnxscript/rewriter/onnxruntime/instance_to_group_normalization.py +++ b/onnxscript/rewriter/onnxruntime/instance_to_group_normalization.py @@ -6,7 +6,7 @@ import numpy as np import onnx -from onnxscript import ir +import onnxscript._legacy_ir as ir from onnxscript.rewriter import pattern op = pattern.onnxop diff --git a/onnxscript/rewriter/onnxruntime/instance_to_group_normalization_test.py b/onnxscript/rewriter/onnxruntime/instance_to_group_normalization_test.py index dadca0546b..67ae0554f7 100644 --- a/onnxscript/rewriter/onnxruntime/instance_to_group_normalization_test.py +++ b/onnxscript/rewriter/onnxruntime/instance_to_group_normalization_test.py @@ -3,7 +3,7 @@ import numpy as np import onnx.parser -from onnxscript.ir import irbuilder +from onnxscript._legacy_ir import irbuilder from onnxscript.rewriter.onnxruntime import instance_to_group_normalization diff --git a/onnxscript/rewriter/onnxruntime/softmax.py b/onnxscript/rewriter/onnxruntime/softmax.py index 9ea48cf3e0..5b5a93fe64 100644 --- a/onnxscript/rewriter/onnxruntime/softmax.py +++ b/onnxscript/rewriter/onnxruntime/softmax.py @@ -5,7 +5,7 @@ import onnx -from onnxscript import ir +import onnxscript._legacy_ir as ir from onnxscript.rewriter import pattern op = pattern.onnxop diff --git a/onnxscript/rewriter/onnxruntime/softmax_test.py b/onnxscript/rewriter/onnxruntime/softmax_test.py index 2b15a8b755..507c38c149 100644 --- a/onnxscript/rewriter/onnxruntime/softmax_test.py +++ b/onnxscript/rewriter/onnxruntime/softmax_test.py @@ -3,7 +3,7 @@ import onnx.parser import parameterized -from onnxscript.ir import irbuilder +from onnxscript._legacy_ir import irbuilder from onnxscript.rewriter.onnxruntime import softmax diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index 43582db400..8dce188322 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -10,8 +10,8 @@ import onnx.numpy_helper import onnx.printer -from onnxscript import ir -from onnxscript.ir import irbuilder +import onnxscript._legacy_ir as ir +from onnxscript._legacy_ir import irbuilder # Overview of the pattern module: The classes below are used to define both # patterns (that we search for) and replacements for rewrite rules. diff --git a/onnxscript/rewriter/pattern_test.py b/onnxscript/rewriter/pattern_test.py index e888391a88..7720d51ffa 100644 --- a/onnxscript/rewriter/pattern_test.py +++ b/onnxscript/rewriter/pattern_test.py @@ -4,7 +4,7 @@ import numpy as np import onnx.parser -from onnxscript.ir import irbuilder, protobuilder +from onnxscript._legacy_ir import irbuilder, protobuilder from onnxscript.rewriter import cast_constant_of_shape, pattern logger = logging.getLogger(__name__) diff --git a/tests/common/testutils.py b/tests/common/testutils.py index 119147992b..a54d9fcfd4 100644 --- a/tests/common/testutils.py +++ b/tests/common/testutils.py @@ -14,7 +14,7 @@ import onnxruntime from onnxscript import optimizer -from onnxscript.ir import visitor +from onnxscript._legacy_ir import visitor from onnxscript.rewriter import onnxruntime as ort_rewriter from onnxscript.utils import evaluation_utils