Skip to content

Handle loops in analysis of assigned vars #925

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 21 additions & 17 deletions onnxscript/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,8 @@ def used_vars(expr: Optional[ast.expr]) -> Set[str]:
return result


def local_defs(lhs: ast.expr) -> Set[str]:
"""Utility function to return set of assigned/defined
variables in the lhs of an assignment statement.
"""
def lhs_vars(lhs: ast.expr) -> Set[str]:
"""Return set of assigned variables in the lhs of an assignment statement."""

def get_id(e):
assert isinstance(e, ast.Name), "Only simple assignments supported."
Expand All @@ -51,27 +49,33 @@ def get_id(e):
return {get_id(lhs)}


def defs(stmt: ast.stmt) -> Set[str]:
"""Return the set of all variables that may be defined (assigned to) in an
execution of input stmt.
def assigned_vars(
stmt: ast.stmt | list[ast.stmt], formatter: sourceinfo.Formatter
) -> Set[str]:
"""Return the set of all variables that may be assigned to in an execution of input stmt
or sequence of statements.
"""

def block_defs(block: Sequence[ast.stmt]) -> Set[str]:
def assigned_in_block(block: Sequence[ast.stmt]) -> Set[str]:
result: set[Any] = set()
for s in block:
result = result | defs(s)
result = result | assigned_vars(s, formatter)
return result

if isinstance(stmt, ast.Assign):
return local_defs(stmt.targets[0])
return lhs_vars(stmt.targets[0])
if isinstance(stmt, ast.AnnAssign):
return local_defs(stmt.target)
return lhs_vars(stmt.target)
if isinstance(stmt, ast.Return):
return set()
if isinstance(stmt, ast.If):
return block_defs(stmt.body) | block_defs(stmt.orelse)
return assigned_in_block(stmt.body) | assigned_in_block(stmt.orelse)
if isinstance(stmt, ast.For):
return assigned_in_block(stmt.body) | {get_loop_var(stmt, formatter)}
if isinstance(stmt, ast.While):
return assigned_in_block(stmt.body)
if isinstance(stmt, list):
return block_defs(stmt)
return assigned_in_block(stmt)
if isinstance(stmt, ast.Break):
return set()
if ast_utils.is_print_call(stmt):
Expand All @@ -98,9 +102,9 @@ def visitBlock(block: Sequence[ast.stmt], live_out: Set[str]) -> Set[str]:
return live_out

if isinstance(stmt, ast.Assign):
return live_out.difference(local_defs(stmt.targets[0])) | used_vars(stmt.value)
return live_out.difference(lhs_vars(stmt.targets[0])) | used_vars(stmt.value)
if isinstance(stmt, ast.AnnAssign):
return live_out.difference(local_defs(stmt.target)) | used_vars(stmt.value)
return live_out.difference(lhs_vars(stmt.target)) | used_vars(stmt.value)
if isinstance(stmt, ast.Return):
return used_vars(stmt.value)
if isinstance(stmt, ast.If):
Expand Down Expand Up @@ -170,9 +174,9 @@ def visitBlock(block: Sequence[ast.stmt], live_out: Set[str]) -> Set[str]:

def visit(stmt: ast.stmt, live_out: Set[str]) -> Set[str]:
if isinstance(stmt, ast.Assign):
return live_out.difference(local_defs(stmt.targets[0])) | used_vars(stmt.value)
return live_out.difference(lhs_vars(stmt.targets[0])) | used_vars(stmt.value)
if isinstance(stmt, ast.AnnAssign):
return live_out.difference(local_defs(stmt.target)) | used_vars(stmt.value)
return live_out.difference(lhs_vars(stmt.target)) | used_vars(stmt.value)
if isinstance(stmt, ast.Return):
return used_vars(stmt.value)
if isinstance(stmt, ast.If):
Expand Down
52 changes: 52 additions & 0 deletions onnxscript/analysis_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,5 +163,57 @@ def nested(): # pylint: disable=unused-variable
self.assertUses(f, {"x"})


class TestAssignedVarAnalysis(unittest.TestCase):
def assert_assigned_vars(self, f, expected: set[str]):
source, parse_tree = ast_utils.get_src_and_ast(f)
result = analysis.assigned_vars(parse_tree.body, formatter(source))
self.assertEqual(result, expected)

def test_basic_defs(self):
def f(x):
x = x + 1
y = x + 2
return y

self.assert_assigned_vars(f, {"x", "y"})

def test_if_defs(self):
def f(x):
if x > 1:
y = x + 1
z = 2 * y
else:
t = x + 2
z = 3 * t
return z

self.assert_assigned_vars(f, {"z", "y", "t"})

def test_loop_defs(self):
def f(x):
sum = 0
while x > 0:
x = x - 1
square = x * x
sum = sum + square
return sum

self.assert_assigned_vars(f, {"sum", "x", "square"})

def test_if_loop_defs(self):
def f(x):
if x > 0:
sum = 0
while x > 0:
x = x - 1
square = x * x
sum = sum + square
else:
sum = 0
return sum

self.assert_assigned_vars(f, {"sum", "x", "square"})


if __name__ == "__main__":
unittest.main(verbosity=2)
8 changes: 5 additions & 3 deletions onnxscript/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1102,9 +1102,11 @@ def ret(exp, i, suffix):

def translate_if_stmt(self, stmt: ast.If) -> None:
if hasattr(stmt, "live_out"):
live_defs = list(stmt.live_out.intersection(analysis.defs(stmt)))
live_defs = list(
stmt.live_out.intersection(analysis.assigned_vars(stmt, self.message))
)
else:
live_defs = list(analysis.defs(stmt))
live_defs = list(analysis.assigned_vars(stmt, self.message))
Comment on lines +1105 to +1109
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be good to explain this logic with a comment?

Copy link
Collaborator Author

@gramalingam gramalingam Jul 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. I didn't add it, and it seems questionable to me. I need to check it out, I suspect it was added in some case where the liveness analysis had not run for some reason. but the fix should probably be something else.

test = self.translate_expr(stmt.test, "cond").name
lineno = self.source_of(stmt).lineno
thenGraph, sub_fct_then = self.translate_block(
Expand Down Expand Up @@ -1183,7 +1185,7 @@ def translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None:
self.fail(loop_stmt, f"Unexpected loop type {type(loop_stmt)!r}.")
# analyze loop body
exposed_uses = analysis.exposed_uses(loop_stmt.body, self.message)
vars_def_in_loop = analysis.defs(loop_stmt.body)
vars_def_in_loop = analysis.assigned_vars(loop_stmt.body, self.message)
loop_state_vars = vars_def_in_loop.intersection(exposed_uses | loop_stmt.live_out)
scan_outputs = set() # TODO
outputs = list(loop_state_vars | scan_outputs)
Expand Down
12 changes: 12 additions & 0 deletions onnxscript/converter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,18 @@ def float_list_as_tensor():
expected = np.array([13, 17], dtype=np.float32).reshape((2,))
self.check_run(float_list_as_tensor, [], expected)

def test_loop_inside_if(self):
@script(default_opset=op)
def sum(n: INT64) -> INT64:
sum = op.Constant(value=0)
if n > 0:
for i in range(n):
sum = sum + i
return sum

self.check_run(sum, [np.array(5, dtype=np.int64)], np.array(10, dtype=np.int64))
self.check_run(sum, [np.array(-5, dtype=np.int64)], np.array(0, dtype=np.int64))


if __name__ == "__main__":
unittest.main(verbosity=2)