Skip to content

Commit 8f96dc9

Browse files
authored
Replace ast.NameConstant with ast.Constant and remove duplicates (#2188)
This pull request includes several changes to the `onnxscript/converter.py` and `onnxscript/converter_test.py` files to improve compatibility with different Python versions and simplify the code. The most important changes include removing deprecated AST node types and updating test cases to reflect these changes.
1 parent 4633a3a commit 8f96dc9

File tree

3 files changed

+7
-26
lines changed

3 files changed

+7
-26
lines changed

onnxscript/_internal/ast_utils.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,9 @@
66

77
import ast
88
import inspect
9-
import sys
109
import textwrap
1110
from typing import Callable
1211

13-
PY_VERSION_GE_39 = sys.version_info >= (3, 9)
14-
1512

1613
def get_src_and_ast(func: Callable, /) -> tuple[str, ast.FunctionDef]:
1714
try:
@@ -35,17 +32,10 @@ def normalize_subscript_expr(expr: ast.Subscript):
3532
# Returns a list of expressions, denoting the indices, after stripping the extraneous "Index"
3633
# wrapper present in python versions before 3.9
3734
index_expr = expr.slice
38-
if PY_VERSION_GE_39:
39-
if isinstance(index_expr, ast.Tuple):
40-
return index_expr.elts # multiple indices
41-
else:
42-
return [index_expr] # single index
35+
if isinstance(index_expr, ast.Tuple):
36+
return index_expr.elts # multiple indices
4337
else:
44-
if isinstance(index_expr, ast.ExtSlice):
45-
indices = index_expr.dims # type: ignore[attr-defined]
46-
else:
47-
indices = [index_expr] # single slice-index
48-
return [x.value if isinstance(x, ast.Index) else x for x in indices] # type: ignore[attr-defined]
38+
return [index_expr] # single index
4939

5040

5141
def is_print_call(stmt: ast.stmt) -> bool:

onnxscript/converter.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,6 @@
2323
from onnxscript import type_annotation as ta
2424
from onnxscript._internal import analysis, ast_utils, autocast, param_manipulation
2525

26-
PY_VERSION_GE_39 = ast_utils.PY_VERSION_GE_39
27-
28-
2926
logger = logging.getLogger("onnxscript")
3027

3128

@@ -435,14 +432,10 @@ def _is_constant_expr(self, node: ast.AST) -> None:
435432
ast.BinOp,
436433
ast.UnaryOp,
437434
ast.Compare,
438-
ast.Num,
439-
ast.Str,
440435
ast.Attribute,
441436
ast.List,
442437
ast.Load,
443-
ast.NameConstant,
444438
ast.Constant,
445-
ast.Str,
446439
),
447440
):
448441
return all(self._is_constant_expr(c) for c in ast.iter_child_nodes(node))
@@ -578,9 +571,9 @@ def _translate_expr(
578571

579572
def _translate_opt_expr(self, node: ast.expr) -> Optional[Variable]:
580573
"""Translation of an expression where "None" is permitted (eg., for an optional argument).
581-
None is represented as a NameConstant in Python 3.7 and Constant in Python 3.9.
574+
None is represented as a Constant in Python 3.9+.
582575
"""
583-
if isinstance(node, (ast.NameConstant, ast.Constant)) and (node.value is None):
576+
if isinstance(node, ast.Constant) and (node.value is None):
584577
return None
585578
return self._translate_expr(node)
586579

@@ -629,7 +622,7 @@ def _translate_subscript_expr(
629622
target = f"{var_name}_subscripted"
630623
target = self.generate_unique_name(target)
631624
indices = ast_utils.normalize_subscript_expr(node)
632-
info = self._source_of(node.slice if PY_VERSION_GE_39 else node)
625+
info = self._source_of(node.slice)
633626

634627
# Create cached int constants:
635628
# TODO: Do this at a graph-scope level.

onnxscript/converter_test.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import inspect
66
import os
77
import pathlib
8-
import sys
98
import textwrap
109
import types
1110
import typing
@@ -437,8 +436,7 @@ def f1(A: FLOAT[...]) -> FLOAT[...]:
437436
r = A[index]
438437
return r
439438

440-
ast_name = "_ast" if sys.version_info[:2] < (3, 9) else "ast"
441-
self.check_failure(f1, f"Left term must be a tuple not '<class '{ast_name}.Name'>'")
439+
self.check_failure(f1, "Left term must be a tuple not '<class 'ast.Name'>'")
442440

443441
def check_run(self, onnxfn, inputs, expected_output):
444442
# Test by converting to model and running with ORT

0 commit comments

Comments
 (0)