From eae0114a0f971f2c3e1e692d0fd1f6faeb75d413 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Wed, 26 Jul 2023 14:22:46 -0700 Subject: [PATCH 1/4] Handle loops in defs computation --- onnxscript/analysis.py | 10 +++++-- onnxscript/analysis_test.py | 52 +++++++++++++++++++++++++++++++++++++ onnxscript/converter.py | 6 ++--- 3 files changed, 63 insertions(+), 5 deletions(-) diff --git a/onnxscript/analysis.py b/onnxscript/analysis.py index 0998663423..85f50140a9 100644 --- a/onnxscript/analysis.py +++ b/onnxscript/analysis.py @@ -51,7 +51,7 @@ def get_id(e): return {get_id(lhs)} -def defs(stmt: ast.stmt) -> Set[str]: +def defs(stmt: ast.stmt | list[ast.stmt], formatter: sourceinfo.Formatter) -> Set[str]: """Return the set of all variables that may be defined (assigned to) in an execution of input stmt. """ @@ -59,7 +59,7 @@ def defs(stmt: ast.stmt) -> Set[str]: def block_defs(block: Sequence[ast.stmt]) -> Set[str]: result: set[Any] = set() for s in block: - result = result | defs(s) + result = result | defs(s, formatter) return result if isinstance(stmt, ast.Assign): @@ -70,12 +70,18 @@ def block_defs(block: Sequence[ast.stmt]) -> Set[str]: return set() if isinstance(stmt, ast.If): return block_defs(stmt.body) | block_defs(stmt.orelse) + if isinstance(stmt, ast.For): + return block_defs(stmt.body) | {get_loop_var(stmt, formatter)} + if isinstance(stmt, ast.While): + return block_defs(stmt.body) if isinstance(stmt, list): return block_defs(stmt) if isinstance(stmt, ast.Break): return set() if ast_utils.is_print_call(stmt): return set() + if isinstance(stmt, list): + return block_defs(stmt) raise ValueError(f"Unsupported statement type {type(stmt)!r}.") diff --git a/onnxscript/analysis_test.py b/onnxscript/analysis_test.py index 362d44dab9..61a00c82d2 100644 --- a/onnxscript/analysis_test.py +++ b/onnxscript/analysis_test.py @@ -163,5 +163,57 @@ def nested(): # pylint: disable=unused-variable self.assertUses(f, {"x"}) +class TestDefAnalysis(unittest.TestCase): + def assertDefs(self, f, expected: set[str]): + source, parse_tree = ast_utils.get_src_and_ast(f) + result = analysis.defs(parse_tree.body, formatter(source)) + self.assertEqual(result, expected) + + def test_basic_defs(self): + def f(x): + x = x + 1 + y = x + 2 + return y + + self.assertDefs(f, {"x", "y"}) + + def test_if_defs(self): + def f(x): + if x > 1: + y = x + 1 + z = 2 * y + else: + t = x + 2 + z = 3 * t + return z + + self.assertDefs(f, {"z", "y", "t"}) + + def test_loop_defs(self): + def f(x): + sum = 0 + while x > 0: + x = x - 1 + square = x * x + sum = sum + square + return sum + + self.assertDefs(f, {"sum", "x", "square"}) + + def test_if_loop_defs(self): + def f(x): + if x > 0: + sum = 0 + while x > 0: + x = x - 1 + square = x * x + sum = sum + square + else: + sum = 0 + return sum + + self.assertDefs(f, {"sum", "x", "square"}) + + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnxscript/converter.py b/onnxscript/converter.py index 9b121a5b38..e92c6b60dd 100644 --- a/onnxscript/converter.py +++ b/onnxscript/converter.py @@ -1102,9 +1102,9 @@ def ret(exp, i, suffix): def translate_if_stmt(self, stmt: ast.If) -> None: if hasattr(stmt, "live_out"): - live_defs = list(stmt.live_out.intersection(analysis.defs(stmt))) + live_defs = list(stmt.live_out.intersection(analysis.defs(stmt, self.message))) else: - live_defs = list(analysis.defs(stmt)) + live_defs = list(analysis.defs(stmt, self.message)) test = self.translate_expr(stmt.test, "cond").name lineno = self.source_of(stmt).lineno thenGraph, sub_fct_then = self.translate_block( @@ -1183,7 +1183,7 @@ def translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None: self.fail(loop_stmt, f"Unexpected loop type {type(loop_stmt)!r}.") # analyze loop body exposed_uses = analysis.exposed_uses(loop_stmt.body, self.message) - vars_def_in_loop = analysis.defs(loop_stmt.body) + vars_def_in_loop = analysis.defs(loop_stmt.body, self.message) loop_state_vars = vars_def_in_loop.intersection(exposed_uses | loop_stmt.live_out) scan_outputs = set() # TODO outputs = list(loop_state_vars | scan_outputs) From 3026408f6a72155627685017991bf8fd0839fc57 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Wed, 26 Jul 2023 15:05:21 -0700 Subject: [PATCH 2/4] Add integration test --- onnxscript/analysis.py | 40 ++++++++++++++++++------------------ onnxscript/analysis_test.py | 4 ++-- onnxscript/converter.py | 8 +++++--- onnxscript/converter_test.py | 12 +++++++++++ 4 files changed, 39 insertions(+), 25 deletions(-) diff --git a/onnxscript/analysis.py b/onnxscript/analysis.py index 85f50140a9..11d35723c8 100644 --- a/onnxscript/analysis.py +++ b/onnxscript/analysis.py @@ -37,10 +37,8 @@ def used_vars(expr: Optional[ast.expr]) -> Set[str]: return result -def local_defs(lhs: ast.expr) -> Set[str]: - """Utility function to return set of assigned/defined - variables in the lhs of an assignment statement. - """ +def lhs_vars(lhs: ast.expr) -> Set[str]: + """Return set of assigned variables in the lhs of an assignment statement.""" def get_id(e): assert isinstance(e, ast.Name), "Only simple assignments supported." @@ -51,37 +49,39 @@ def get_id(e): return {get_id(lhs)} -def defs(stmt: ast.stmt | list[ast.stmt], formatter: sourceinfo.Formatter) -> Set[str]: - """Return the set of all variables that may be defined (assigned to) in an - execution of input stmt. +def assigned_vars( + stmt: ast.stmt | list[ast.stmt], formatter: sourceinfo.Formatter +) -> Set[str]: + """Return the set of all variables that may be assigned to in an execution of input stmt + or sequence of statements. """ - def block_defs(block: Sequence[ast.stmt]) -> Set[str]: + def assigned_in_block(block: Sequence[ast.stmt]) -> Set[str]: result: set[Any] = set() for s in block: - result = result | defs(s, formatter) + result = result | assigned_vars(s, formatter) return result if isinstance(stmt, ast.Assign): - return local_defs(stmt.targets[0]) + return lhs_vars(stmt.targets[0]) if isinstance(stmt, ast.AnnAssign): - return local_defs(stmt.target) + return lhs_vars(stmt.target) if isinstance(stmt, ast.Return): return set() if isinstance(stmt, ast.If): - return block_defs(stmt.body) | block_defs(stmt.orelse) + return assigned_in_block(stmt.body) | assigned_in_block(stmt.orelse) if isinstance(stmt, ast.For): - return block_defs(stmt.body) | {get_loop_var(stmt, formatter)} + return assigned_in_block(stmt.body) | {get_loop_var(stmt, formatter)} if isinstance(stmt, ast.While): - return block_defs(stmt.body) + return assigned_in_block(stmt.body) if isinstance(stmt, list): - return block_defs(stmt) + return assigned_in_block(stmt) if isinstance(stmt, ast.Break): return set() if ast_utils.is_print_call(stmt): return set() if isinstance(stmt, list): - return block_defs(stmt) + return assigned_in_block(stmt) raise ValueError(f"Unsupported statement type {type(stmt)!r}.") @@ -104,9 +104,9 @@ def visitBlock(block: Sequence[ast.stmt], live_out: Set[str]) -> Set[str]: return live_out if isinstance(stmt, ast.Assign): - return live_out.difference(local_defs(stmt.targets[0])) | used_vars(stmt.value) + return live_out.difference(lhs_vars(stmt.targets[0])) | used_vars(stmt.value) if isinstance(stmt, ast.AnnAssign): - return live_out.difference(local_defs(stmt.target)) | used_vars(stmt.value) + return live_out.difference(lhs_vars(stmt.target)) | used_vars(stmt.value) if isinstance(stmt, ast.Return): return used_vars(stmt.value) if isinstance(stmt, ast.If): @@ -176,9 +176,9 @@ def visitBlock(block: Sequence[ast.stmt], live_out: Set[str]) -> Set[str]: def visit(stmt: ast.stmt, live_out: Set[str]) -> Set[str]: if isinstance(stmt, ast.Assign): - return live_out.difference(local_defs(stmt.targets[0])) | used_vars(stmt.value) + return live_out.difference(lhs_vars(stmt.targets[0])) | used_vars(stmt.value) if isinstance(stmt, ast.AnnAssign): - return live_out.difference(local_defs(stmt.target)) | used_vars(stmt.value) + return live_out.difference(lhs_vars(stmt.target)) | used_vars(stmt.value) if isinstance(stmt, ast.Return): return used_vars(stmt.value) if isinstance(stmt, ast.If): diff --git a/onnxscript/analysis_test.py b/onnxscript/analysis_test.py index 61a00c82d2..6496e4c670 100644 --- a/onnxscript/analysis_test.py +++ b/onnxscript/analysis_test.py @@ -163,10 +163,10 @@ def nested(): # pylint: disable=unused-variable self.assertUses(f, {"x"}) -class TestDefAnalysis(unittest.TestCase): +class TestAssignedVarAnalysis(unittest.TestCase): def assertDefs(self, f, expected: set[str]): source, parse_tree = ast_utils.get_src_and_ast(f) - result = analysis.defs(parse_tree.body, formatter(source)) + result = analysis.assigned_vars(parse_tree.body, formatter(source)) self.assertEqual(result, expected) def test_basic_defs(self): diff --git a/onnxscript/converter.py b/onnxscript/converter.py index e92c6b60dd..4c2f9c859d 100644 --- a/onnxscript/converter.py +++ b/onnxscript/converter.py @@ -1102,9 +1102,11 @@ def ret(exp, i, suffix): def translate_if_stmt(self, stmt: ast.If) -> None: if hasattr(stmt, "live_out"): - live_defs = list(stmt.live_out.intersection(analysis.defs(stmt, self.message))) + live_defs = list( + stmt.live_out.intersection(analysis.assigned_vars(stmt, self.message)) + ) else: - live_defs = list(analysis.defs(stmt, self.message)) + live_defs = list(analysis.assigned_vars(stmt, self.message)) test = self.translate_expr(stmt.test, "cond").name lineno = self.source_of(stmt).lineno thenGraph, sub_fct_then = self.translate_block( @@ -1183,7 +1185,7 @@ def translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None: self.fail(loop_stmt, f"Unexpected loop type {type(loop_stmt)!r}.") # analyze loop body exposed_uses = analysis.exposed_uses(loop_stmt.body, self.message) - vars_def_in_loop = analysis.defs(loop_stmt.body, self.message) + vars_def_in_loop = analysis.assigned_vars(loop_stmt.body, self.message) loop_state_vars = vars_def_in_loop.intersection(exposed_uses | loop_stmt.live_out) scan_outputs = set() # TODO outputs = list(loop_state_vars | scan_outputs) diff --git a/onnxscript/converter_test.py b/onnxscript/converter_test.py index 74b68bf2a9..ce955b7d29 100644 --- a/onnxscript/converter_test.py +++ b/onnxscript/converter_test.py @@ -655,6 +655,18 @@ def float_list_as_tensor(): expected = np.array([13, 17], dtype=np.float32).reshape((2,)) self.check_run(float_list_as_tensor, [], expected) + def test_loop_inside_if(self): + @script(default_opset=op) + def sum(n: INT64) -> INT64: + sum = op.Constant(value=0) + if n > 0: + for i in range(n): + sum = sum + i + return sum + + self.check_run(sum, [np.array(5, dtype=np.int64)], np.array(10, dtype=np.int64)) + self.check_run(sum, [np.array(-5, dtype=np.int64)], np.array(0, dtype=np.int64)) + if __name__ == "__main__": unittest.main(verbosity=2) From 4c8da73246eb32ce5f3864f09dafb9b9c4ba60d1 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Thu, 27 Jul 2023 11:16:54 -0700 Subject: [PATCH 3/4] Rename method --- onnxscript/analysis_test.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/onnxscript/analysis_test.py b/onnxscript/analysis_test.py index 6496e4c670..3ec14ad628 100644 --- a/onnxscript/analysis_test.py +++ b/onnxscript/analysis_test.py @@ -164,7 +164,7 @@ def nested(): # pylint: disable=unused-variable class TestAssignedVarAnalysis(unittest.TestCase): - def assertDefs(self, f, expected: set[str]): + def assert_assigned_vars(self, f, expected: set[str]): source, parse_tree = ast_utils.get_src_and_ast(f) result = analysis.assigned_vars(parse_tree.body, formatter(source)) self.assertEqual(result, expected) @@ -175,7 +175,7 @@ def f(x): y = x + 2 return y - self.assertDefs(f, {"x", "y"}) + self.assert_assigned_vars(f, {"x", "y"}) def test_if_defs(self): def f(x): @@ -187,7 +187,7 @@ def f(x): z = 3 * t return z - self.assertDefs(f, {"z", "y", "t"}) + self.assert_assigned_vars(f, {"z", "y", "t"}) def test_loop_defs(self): def f(x): @@ -198,7 +198,7 @@ def f(x): sum = sum + square return sum - self.assertDefs(f, {"sum", "x", "square"}) + self.assert_assigned_vars(f, {"sum", "x", "square"}) def test_if_loop_defs(self): def f(x): @@ -212,7 +212,7 @@ def f(x): sum = 0 return sum - self.assertDefs(f, {"sum", "x", "square"}) + self.assert_assigned_vars(f, {"sum", "x", "square"}) if __name__ == "__main__": From 45e9a3d7bc65c39cc3dafa09be0ad13e87786c89 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Thu, 27 Jul 2023 13:10:10 -0700 Subject: [PATCH 4/4] Remove duplicate line --- onnxscript/analysis.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/onnxscript/analysis.py b/onnxscript/analysis.py index 11d35723c8..30ca18e055 100644 --- a/onnxscript/analysis.py +++ b/onnxscript/analysis.py @@ -80,8 +80,6 @@ def assigned_in_block(block: Sequence[ast.stmt]) -> Set[str]: return set() if ast_utils.is_print_call(stmt): return set() - if isinstance(stmt, list): - return assigned_in_block(stmt) raise ValueError(f"Unsupported statement type {type(stmt)!r}.")