diff --git a/onnxscript/_internal/ast_utils.py b/onnxscript/_internal/ast_utils.py index c7250e1268..4146f38e2f 100644 --- a/onnxscript/_internal/ast_utils.py +++ b/onnxscript/_internal/ast_utils.py @@ -6,12 +6,9 @@ import ast import inspect -import sys import textwrap from typing import Callable -PY_VERSION_GE_39 = sys.version_info >= (3, 9) - def get_src_and_ast(func: Callable, /) -> tuple[str, ast.FunctionDef]: try: @@ -35,17 +32,10 @@ def normalize_subscript_expr(expr: ast.Subscript): # Returns a list of expressions, denoting the indices, after stripping the extraneous "Index" # wrapper present in python versions before 3.9 index_expr = expr.slice - if PY_VERSION_GE_39: - if isinstance(index_expr, ast.Tuple): - return index_expr.elts # multiple indices - else: - return [index_expr] # single index + if isinstance(index_expr, ast.Tuple): + return index_expr.elts # multiple indices else: - if isinstance(index_expr, ast.ExtSlice): - indices = index_expr.dims # type: ignore[attr-defined] - else: - indices = [index_expr] # single slice-index - return [x.value if isinstance(x, ast.Index) else x for x in indices] # type: ignore[attr-defined] + return [index_expr] # single index def is_print_call(stmt: ast.stmt) -> bool: diff --git a/onnxscript/converter.py b/onnxscript/converter.py index 2d10a73764..1ee6e0ecd0 100644 --- a/onnxscript/converter.py +++ b/onnxscript/converter.py @@ -23,9 +23,6 @@ from onnxscript import type_annotation as ta from onnxscript._internal import analysis, ast_utils, autocast, param_manipulation -PY_VERSION_GE_39 = ast_utils.PY_VERSION_GE_39 - - logger = logging.getLogger("onnxscript") @@ -435,14 +432,10 @@ def _is_constant_expr(self, node: ast.AST) -> None: ast.BinOp, ast.UnaryOp, ast.Compare, - ast.Num, - ast.Str, ast.Attribute, ast.List, ast.Load, - ast.NameConstant, ast.Constant, - ast.Str, ), ): return all(self._is_constant_expr(c) for c in ast.iter_child_nodes(node)) @@ -578,9 +571,9 @@ def _translate_expr( def _translate_opt_expr(self, node: ast.expr) -> Optional[Variable]: """Translation of an expression where "None" is permitted (eg., for an optional argument). - None is represented as a NameConstant in Python 3.7 and Constant in Python 3.9. + None is represented as a Constant in Python 3.9+. """ - if isinstance(node, (ast.NameConstant, ast.Constant)) and (node.value is None): + if isinstance(node, ast.Constant) and (node.value is None): return None return self._translate_expr(node) @@ -629,7 +622,7 @@ def _translate_subscript_expr( target = f"{var_name}_subscripted" target = self.generate_unique_name(target) indices = ast_utils.normalize_subscript_expr(node) - info = self._source_of(node.slice if PY_VERSION_GE_39 else node) + info = self._source_of(node.slice) # Create cached int constants: # TODO: Do this at a graph-scope level. diff --git a/onnxscript/converter_test.py b/onnxscript/converter_test.py index 46d88f9f12..6305bddf70 100644 --- a/onnxscript/converter_test.py +++ b/onnxscript/converter_test.py @@ -5,7 +5,6 @@ import inspect import os import pathlib -import sys import textwrap import types import typing @@ -437,8 +436,7 @@ def f1(A: FLOAT[...]) -> FLOAT[...]: r = A[index] return r - ast_name = "_ast" if sys.version_info[:2] < (3, 9) else "ast" - self.check_failure(f1, f"Left term must be a tuple not ''") + self.check_failure(f1, "Left term must be a tuple not ''") def check_run(self, onnxfn, inputs, expected_output): # Test by converting to model and running with ORT