diff --git a/onnxscript/analysis.py b/onnxscript/analysis.py index 0998663423..30ca18e055 100644 --- a/onnxscript/analysis.py +++ b/onnxscript/analysis.py @@ -37,10 +37,8 @@ def used_vars(expr: Optional[ast.expr]) -> Set[str]: return result -def local_defs(lhs: ast.expr) -> Set[str]: - """Utility function to return set of assigned/defined - variables in the lhs of an assignment statement. - """ +def lhs_vars(lhs: ast.expr) -> Set[str]: + """Return set of assigned variables in the lhs of an assignment statement.""" def get_id(e): assert isinstance(e, ast.Name), "Only simple assignments supported." @@ -51,27 +49,33 @@ def get_id(e): return {get_id(lhs)} -def defs(stmt: ast.stmt) -> Set[str]: - """Return the set of all variables that may be defined (assigned to) in an - execution of input stmt. +def assigned_vars( + stmt: ast.stmt | list[ast.stmt], formatter: sourceinfo.Formatter +) -> Set[str]: + """Return the set of all variables that may be assigned to in an execution of input stmt + or sequence of statements. """ - def block_defs(block: Sequence[ast.stmt]) -> Set[str]: + def assigned_in_block(block: Sequence[ast.stmt]) -> Set[str]: result: set[Any] = set() for s in block: - result = result | defs(s) + result = result | assigned_vars(s, formatter) return result if isinstance(stmt, ast.Assign): - return local_defs(stmt.targets[0]) + return lhs_vars(stmt.targets[0]) if isinstance(stmt, ast.AnnAssign): - return local_defs(stmt.target) + return lhs_vars(stmt.target) if isinstance(stmt, ast.Return): return set() if isinstance(stmt, ast.If): - return block_defs(stmt.body) | block_defs(stmt.orelse) + return assigned_in_block(stmt.body) | assigned_in_block(stmt.orelse) + if isinstance(stmt, ast.For): + return assigned_in_block(stmt.body) | {get_loop_var(stmt, formatter)} + if isinstance(stmt, ast.While): + return assigned_in_block(stmt.body) if isinstance(stmt, list): - return block_defs(stmt) + return assigned_in_block(stmt) if isinstance(stmt, ast.Break): return set() if ast_utils.is_print_call(stmt): @@ -98,9 +102,9 @@ def visitBlock(block: Sequence[ast.stmt], live_out: Set[str]) -> Set[str]: return live_out if isinstance(stmt, ast.Assign): - return live_out.difference(local_defs(stmt.targets[0])) | used_vars(stmt.value) + return live_out.difference(lhs_vars(stmt.targets[0])) | used_vars(stmt.value) if isinstance(stmt, ast.AnnAssign): - return live_out.difference(local_defs(stmt.target)) | used_vars(stmt.value) + return live_out.difference(lhs_vars(stmt.target)) | used_vars(stmt.value) if isinstance(stmt, ast.Return): return used_vars(stmt.value) if isinstance(stmt, ast.If): @@ -170,9 +174,9 @@ def visitBlock(block: Sequence[ast.stmt], live_out: Set[str]) -> Set[str]: def visit(stmt: ast.stmt, live_out: Set[str]) -> Set[str]: if isinstance(stmt, ast.Assign): - return live_out.difference(local_defs(stmt.targets[0])) | used_vars(stmt.value) + return live_out.difference(lhs_vars(stmt.targets[0])) | used_vars(stmt.value) if isinstance(stmt, ast.AnnAssign): - return live_out.difference(local_defs(stmt.target)) | used_vars(stmt.value) + return live_out.difference(lhs_vars(stmt.target)) | used_vars(stmt.value) if isinstance(stmt, ast.Return): return used_vars(stmt.value) if isinstance(stmt, ast.If): diff --git a/onnxscript/analysis_test.py b/onnxscript/analysis_test.py index 362d44dab9..3ec14ad628 100644 --- a/onnxscript/analysis_test.py +++ b/onnxscript/analysis_test.py @@ -163,5 +163,57 @@ def nested(): # pylint: disable=unused-variable self.assertUses(f, {"x"}) +class TestAssignedVarAnalysis(unittest.TestCase): + def assert_assigned_vars(self, f, expected: set[str]): + source, parse_tree = ast_utils.get_src_and_ast(f) + result = analysis.assigned_vars(parse_tree.body, formatter(source)) + self.assertEqual(result, expected) + + def test_basic_defs(self): + def f(x): + x = x + 1 + y = x + 2 + return y + + self.assert_assigned_vars(f, {"x", "y"}) + + def test_if_defs(self): + def f(x): + if x > 1: + y = x + 1 + z = 2 * y + else: + t = x + 2 + z = 3 * t + return z + + self.assert_assigned_vars(f, {"z", "y", "t"}) + + def test_loop_defs(self): + def f(x): + sum = 0 + while x > 0: + x = x - 1 + square = x * x + sum = sum + square + return sum + + self.assert_assigned_vars(f, {"sum", "x", "square"}) + + def test_if_loop_defs(self): + def f(x): + if x > 0: + sum = 0 + while x > 0: + x = x - 1 + square = x * x + sum = sum + square + else: + sum = 0 + return sum + + self.assert_assigned_vars(f, {"sum", "x", "square"}) + + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnxscript/converter.py b/onnxscript/converter.py index 9b121a5b38..4c2f9c859d 100644 --- a/onnxscript/converter.py +++ b/onnxscript/converter.py @@ -1102,9 +1102,11 @@ def ret(exp, i, suffix): def translate_if_stmt(self, stmt: ast.If) -> None: if hasattr(stmt, "live_out"): - live_defs = list(stmt.live_out.intersection(analysis.defs(stmt))) + live_defs = list( + stmt.live_out.intersection(analysis.assigned_vars(stmt, self.message)) + ) else: - live_defs = list(analysis.defs(stmt)) + live_defs = list(analysis.assigned_vars(stmt, self.message)) test = self.translate_expr(stmt.test, "cond").name lineno = self.source_of(stmt).lineno thenGraph, sub_fct_then = self.translate_block( @@ -1183,7 +1185,7 @@ def translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None: self.fail(loop_stmt, f"Unexpected loop type {type(loop_stmt)!r}.") # analyze loop body exposed_uses = analysis.exposed_uses(loop_stmt.body, self.message) - vars_def_in_loop = analysis.defs(loop_stmt.body) + vars_def_in_loop = analysis.assigned_vars(loop_stmt.body, self.message) loop_state_vars = vars_def_in_loop.intersection(exposed_uses | loop_stmt.live_out) scan_outputs = set() # TODO outputs = list(loop_state_vars | scan_outputs) diff --git a/onnxscript/converter_test.py b/onnxscript/converter_test.py index 74b68bf2a9..ce955b7d29 100644 --- a/onnxscript/converter_test.py +++ b/onnxscript/converter_test.py @@ -655,6 +655,18 @@ def float_list_as_tensor(): expected = np.array([13, 17], dtype=np.float32).reshape((2,)) self.check_run(float_list_as_tensor, [], expected) + def test_loop_inside_if(self): + @script(default_opset=op) + def sum(n: INT64) -> INT64: + sum = op.Constant(value=0) + if n > 0: + for i in range(n): + sum = sum + i + return sum + + self.check_run(sum, [np.array(5, dtype=np.int64)], np.array(10, dtype=np.int64)) + self.check_run(sum, [np.array(-5, dtype=np.int64)], np.array(0, dtype=np.int64)) + if __name__ == "__main__": unittest.main(verbosity=2)