Skip to content

[Migration][DO NOT MERGE] Separate old ir into _legacy_ir folder #1332

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
297 changes: 297 additions & 0 deletions onnxscript/_legacy_ir/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,297 @@
from __future__ import annotations

Check warning

Code scanning / lintrunner

RUFF/I001

Import block is un-sorted or un-formatted. See https://docs.astral.sh/ruff/rules/unsorted-imports

Check warning

Code scanning / lintrunner

RUFF/format

Run `lintrunner -a` to apply this patch.


import dataclasses
from collections import deque
from typing import List, Tuple, Union

import numpy as np
import onnx


class Unknown:
"""A special value used to indicate that a value is not a statically known constant.

We use this instead of None because None is a valid constant value (since ONNX
supports the Optional type).
"""

instance = None

def __init__(self) -> None:
if Unknown.instance is not None:
raise ValueError("Unknown.instance is already set")
Unknown.instance = self


# Singleton instance of Unknown
unknown = Unknown()
NotConstant = unknown

# ConcreteValue: This type represents constant values that an ONNX variable can take.
# TODO: Extend this to a recursive type to handle lists of tensors, etc., support optionals,
# maps, etc.
# TODO (rama): The value is sometimes stored as a numpy array, and sometimes as an ONNX TensorProto.
# A uniform representation would be helpful, but we should avoid unnecessary conversions for
# large tensors. Should be cleaned up in the new IR.
ConcreteValue = Union[onnx.TensorProto, np.ndarray, Unknown, None]

# SymbolicValue: This information is used to enable partial-evaluation and specialization
# of sequence operations, as well as elimination of redundant Identity ops.
# The symbolic value of a variable X can be:
# - a string with the value "Y", indicating that "X" is a copy of "Y"
# - a list of strings, indicating that "X" is a list of tensors, with their symbolic values
# Eg., the symbolic value ["A", "B", "C"] indicates that the value of X is equal to
# "SequenceConstruct(A, B, C)".
# TODO: Technically, SymbolicValue should be a recursive type to handle lists of lists of
# tensors, etc. However, we currently only handle lists of tensors.

SymbolicValue = Union[str, List[str]]

FunctionId = Tuple[str, str, str]


def get_function_id(function: onnx.FunctionProto) -> FunctionId:
return (function.domain, function.name, getattr(function, "overload", ""))


def get_function_id_from_node(node: onnx.NodeProto) -> FunctionId:
return (node.domain, node.op_type, getattr(node, "overload", ""))


@dataclasses.dataclass
class StaticValueInfo:
name: str
value: ConcreteValue = NotConstant
type: onnx.TypeProto | None = None
symbolic_value: SymbolicValue | None = None

def is_copy(self) -> bool:
return isinstance(self.symbolic_value, str)

def tensor_shape_proto(self) -> onnx.TensorShapeProto | None:
"""Returns the shape of a tensor or None.

A return value of None could mean that the type is unknown or that the type is not a tensor
or that the tensor shape (that is, even the rank) is unknown.
"""
type = self.type
if type and type.HasField("tensor_type") and type.tensor_type.HasField("shape"):
return type.tensor_type.shape
return None

@property
def shape(self) -> list[str | int | None] | None:
"""Returns the shape in a list.

Str means that the shape is dynamic.
"""
type = self.type
if type and type.HasField("tensor_type") and type.tensor_type.HasField("shape"):
dims = []
for dim in type.tensor_type.shape.dim:
if dim.HasField("dim_param"):
dims.append(dim.dim_param)
elif dim.HasField("dim_value"):
dims.append(dim.dim_value)
else:
dims.append(None)
return dims
if self.value_as_np_array is not None:
return list(self.value_as_np_array.shape)
return None

@property
def element_type(self) -> int | None:
"""Returns the element type of a tensor, or None if type is not known or is not a tensor."""
type = self.type
if type and type.HasField("tensor_type"):
return type.tensor_type.elem_type
return None

def identity_merge_from(self, other: StaticValueInfo) -> None:
"""Merge the value of other into self.

This models the effect of an identity (copy) operation.
This will update static-analysis information based on incoming value.
"""
if not isinstance(other, StaticValueInfo):
raise TypeError(f"Cannot merge {other} into {self}.")
if other.value is not NotConstant:
self.value = other.value
# TODO: merge and combine best shape information from both types.
if other.tensor_shape_proto() is not None and other.element_type is not None:
self.type = other.type
# We cannot copy symbolic value across different scopes.

# WIP: Extensions towards new IR: Note that the default construction of StaticValueInfo
# does not fill in the following fields. These fields are filled in by the IRBuilder
# which constructs the IR from the ONNX model.
node: Node | None = None
uses: list[Node] = dataclasses.field(default_factory=list)
output_index: int | None = None
is_output: bool = False

@property
def const_value(self) -> ConcreteValue:
return self.value

@property
def value_as_np_array(self) -> np.ndarray | None:
if isinstance(self.value, np.ndarray):
return self.value
if isinstance(self.value, onnx.TensorProto):
return onnx.numpy_helper.to_array(self.value)
return None

def def_node(self) -> Node | None:
return self.node

def def_index(self) -> int:
return self.output_index

Check failure

Code scanning / lintrunner

MYPY/return-value

Incompatible return value type (got "int | None", expected "int") To disable, use ` # type: ignore[return-value]`

def is_same_as(self, other: StaticValueInfo) -> bool:
"""Returns true if this value represents the same IR object as the other value.

This is *not* value-equality, but rather object-equality.
"""
return self is other

def __str__(self) -> str:
shape = self.shape
if shape is not None:
shape = [str(dim) for dim in shape]
shape_str = f"[{', '.join(shape)}]"

Check failure

Code scanning / lintrunner

MYPY/arg-type

Argument 1 to "join" of "str" has incompatible type "list[str | int | None]"; expected "Iterable[str]" To disable, use ` # type: ignore[arg-type]`
else:
shape_str = "None"
return (
f"StaticValueInfo({self.name}, shape:{shape_str}, dtype:{self.element_type}, "
f"{'has const value' if self.value is not unknown else 'no const value'}.)"
)


Value = StaticValueInfo


class Model:
def __init__(self) -> None:
self.gen_var_counter: int = 0

def set(
self,
model_proto: onnx.ModelProto,
graph: Graph,
functions: list[Function],
version_map: dict[str, int],
) -> None:
"""TODO. This is a temporary patch."""
self.original_model_proto = model_proto
self.graph = graph
self.functions = functions
self.version_map = version_map

def make_new_name(self):
# Temporary hack.
self.gen_var_counter += 1
return f"_gen_{self.gen_var_counter}"

def __str__(self) -> str:
# TODO: Naive string representation for debugging. Need to improve this.
return "\n".join(
[
f"ModelGraph: {self.graph}",
f"Functions: {self.functions}",
f"VersionMap: {self.version_map}",
]
)


class Graph:
def __init__(self, graph_proto: onnx.GraphProto):
self.original_graph_proto = graph_proto
self.nodes: deque[Node] = deque()
self.values: dict[str, Value] = {}

@property
def name(self) -> str:
return self.original_graph_proto.name

def __str__(self) -> str:
return "\n".join(
[
"Graph",
f"Nodes: {[str(n) for n in self.nodes]}",
f"Values: {[str(v) for v in self.values]}",
]
)


class Function:
def __init__(self, function_proto: onnx.FunctionProto):
self.original_function_proto = function_proto
self.nodes = deque()

Check failure

Code scanning / lintrunner

MYPY/var-annotated

Need type annotation for "nodes" To disable, use ` # type: ignore[var-annotated]`
self.values = {}

Check failure

Code scanning / lintrunner

MYPY/var-annotated

Need type annotation for "values" (hint: "values: Dict[<type>, <type>] = ...") To disable, use ` # type: ignore[var-annotated]`

@property
def id(self) -> FunctionId:
return (self.domain, self.name, self.overload)

@property
def domain(self) -> str:
return self.original_function_proto.domain

@property
def name(self) -> str:
return self.original_function_proto.name

@property
def overload(self) -> str:
return getattr(self.original_function_proto, "overload", "")

def __str__(self) -> str:
return "\n".join(
[
"Function",
f"Nodes: {[str(n) for n in self.nodes]}",
f"Values: {[str(v) for v in self.values]}",
]
)


class RefAttr:
def __init__(self, name: str, ref_attr_name: str, type) -> None:
self.name = name
self.ref_attr_name = ref_attr_name
self.type = type

def to_proto(self) -> onnx.AttributeProto:
attr_proto = onnx.AttributeProto()
attr_proto.name = self.name
attr_proto.ref_attr_name = self.ref_attr_name
attr_proto.type = self.type
return attr_proto


class Node:
def __init__(self, node_proto: onnx.NodeProto) -> None:
self.original_node_proto = node_proto
self.domain: str = node_proto.domain
self.version: int | None = None
self.op_type: str = node_proto.op_type
self.inputs: list[Value | None] = []
self.outputs: list[Value | None] = []
self.attributes: dict[str, int | float | RefAttr | Graph | list[Graph]] = {}

def get_attribute(self, name: str) -> int | float | None:
return self.attributes.get(name, None)

Check failure

Code scanning / lintrunner

MYPY/return-value

Incompatible return value type (got "int | float | RefAttr | Graph | list[Graph] | None", expected "int | float | None") To disable, use ` # type: ignore[return-value]`

def __str__(self) -> str:
return "\n".join(
[
"Node",
f"OpType: {self.op_type}",
f"Inputs: {self.inputs}",
f"Outputs: {self.outputs}",
f"Attributes: {self.attributes}",
]
)
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

import onnx

from onnxscript import ir
from onnxscript.ir import visitor
import onnxscript._legacy_ir as ir
from onnxscript._legacy_ir import visitor
from onnxscript.utils import utils

""" NOTE: IRBuilder and function visiting
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import onnx.parser

from onnxscript.ir import irbuilder
from onnxscript._legacy_ir import irbuilder


class IRBuilderTest(unittest.TestCase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import onnx.helper
from onnx.helper import make_attribute

from onnxscript import ir
import onnxscript._legacy_ir as ir


class ModelProtoBuilder:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import onnx.checker
import onnx.parser

from onnxscript.ir import irbuilder, protobuilder
from onnxscript._legacy_ir import irbuilder, protobuilder
from onnxscript.rewriter import pattern
from onnxscript.rewriter.onnxruntime import instance_to_group_normalization

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np
import onnx

from onnxscript import ir
import onnxscript._legacy_ir as ir
from onnxscript.utils.utils import (
get_initializer_type,
is_control_flow_op,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import onnx

from onnxscript.ir import visitor
from onnxscript._legacy_ir import visitor


class FunctionCallsiteProtoTransformerTest(unittest.TestCase):
Expand Down
Loading