From 9f96ec997f60990dc9c7ad36f8122420f8ec78b6 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 31 May 2024 14:56:07 +0200 Subject: [PATCH] Harmonize Scan rewrite and tag names --- pytensor/link/numba/dispatch/scan.py | 2 +- pytensor/scan/rewriting.py | 36 +++++++++++++++------------- tests/scan/test_rewriting.py | 6 ++--- 3 files changed, 24 insertions(+), 20 deletions(-) diff --git a/pytensor/link/numba/dispatch/scan.py b/pytensor/link/numba/dispatch/scan.py index c60c4c546f..34f088fd54 100644 --- a/pytensor/link/numba/dispatch/scan.py +++ b/pytensor/link/numba/dispatch/scan.py @@ -184,7 +184,7 @@ def add_inner_in_expr( # rotation for initially truncated storage. output_storage_post_proc_stmts: list[str] = [] - # In truncated storage situations (e.g. created by `save_mem_new_scan`), + # In truncated storage situations (e.g. created by `scan_save_mem`), # the taps and output storage overlap, instead of the standard situation in # which the output storage is large enough to contain both the initial taps # values and the output storage. In this truncated case, we use the diff --git a/pytensor/scan/rewriting.py b/pytensor/scan/rewriting.py index ae128c608f..4b4c632d8d 100644 --- a/pytensor/scan/rewriting.py +++ b/pytensor/scan/rewriting.py @@ -209,7 +209,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node): @node_rewriter([Scan]) -def push_out_non_seq_scan(fgraph, node): +def scan_push_out_non_seq(fgraph, node): r"""Push out the variables inside the `Scan` that depend only on non-sequences. This optimizations pushes, out of `Scan`'s inner function and into the outer @@ -417,10 +417,10 @@ def add_to_replace(y): @node_rewriter([Scan]) -def push_out_seq_scan(fgraph, node): +def scan_push_out_seq(fgraph, node): r"""Push out the variables inside the `Scan` that depend only on constants and sequences. - This optimization resembles `push_out_non_seq_scan` but it tries to push--out of + This optimization resembles `scan_push_out_non_seq` but it tries to push--out of the inner function--the computation that only relies on sequence and non-sequence inputs. The idea behind this optimization is that, when it is possible to do so, it is generally more computationally efficient to perform @@ -822,10 +822,10 @@ def add_nitsot_outputs( @node_rewriter([Scan]) -def push_out_add_scan(fgraph, node): +def scan_push_out_add(fgraph, node): r"""Push `Add` operations performed at the end of the inner graph to the outside. - Like `push_out_seq_scan`, this optimization aims to replace many operations + Like `scan_push_out_seq`, this optimization aims to replace many operations on small tensors by few operations on large tensors. It can also lead to increased memory usage. """ @@ -1185,7 +1185,7 @@ def while_scan_merge_subtensor_last_element(fgraph, scan_node): @node_rewriter([Scan]) -def save_mem_new_scan(fgraph, node): +def scan_save_mem(fgraph, node): r"""Graph optimizer that reduces scan memory consumption. This optimizations attempts to determine if a `Scan` node, during its execution, @@ -2282,7 +2282,7 @@ def map_out(outer_i, inner_o, outer_o, seen): @node_rewriter([Scan]) -def push_out_dot1_scan(fgraph, node): +def scan_push_out_dot1(fgraph, node): r""" This is another optimization that attempts to detect certain patterns of computation in a `Scan` `Op`'s inner function and move this computation to the @@ -2483,7 +2483,7 @@ def push_out_dot1_scan(fgraph, node): # ScanSaveMem should execute only once per node. optdb.register( "scan_save_mem", - in2out(save_mem_new_scan, ignore_newtrees=True), + in2out(scan_save_mem, ignore_newtrees=True), "fast_run", "scan", position=1.61, @@ -2511,8 +2511,9 @@ def push_out_dot1_scan(fgraph, node): scan_seqopt1.register( - "scan_pushout_nonseqs_ops", - in2out(push_out_non_seq_scan, ignore_newtrees=True), + "scan_push_out_non_seq", + in2out(scan_push_out_non_seq, ignore_newtrees=True), + "scan_pushout_nonseqs_ops", # For backcompat: so it can be tagged with old name "fast_run", "scan", "scan_pushout", @@ -2521,8 +2522,9 @@ def push_out_dot1_scan(fgraph, node): scan_seqopt1.register( - "scan_pushout_seqs_ops", - in2out(push_out_seq_scan, ignore_newtrees=True), + "scan_push_out_seq", + in2out(scan_push_out_seq, ignore_newtrees=True), + "scan_pushout_seqs_ops", # For backcompat: so it can be tagged with old name "fast_run", "scan", "scan_pushout", @@ -2531,8 +2533,9 @@ def push_out_dot1_scan(fgraph, node): scan_seqopt1.register( - "scan_pushout_dot1", - in2out(push_out_dot1_scan, ignore_newtrees=True), + "scan_push_out_dot1", + in2out(scan_push_out_dot1, ignore_newtrees=True), + "scan_pushout_dot1", # For backcompat: so it can be tagged with old name "fast_run", "more_mem", "scan", @@ -2542,9 +2545,10 @@ def push_out_dot1_scan(fgraph, node): scan_seqopt1.register( - "scan_pushout_add", + "scan_push_out_add", # TODO: Perhaps this should be an `EquilibriumGraphRewriter`? - in2out(push_out_add_scan, ignore_newtrees=False), + in2out(scan_push_out_add, ignore_newtrees=False), + "scan_pushout_add", # For backcompat: so it can be tagged with old name "fast_run", "more_mem", "scan", diff --git a/tests/scan/test_rewriting.py b/tests/scan/test_rewriting.py index c9f11e891d..aebba785a5 100644 --- a/tests/scan/test_rewriting.py +++ b/tests/scan/test_rewriting.py @@ -304,7 +304,7 @@ def fn(i, i_tm1): class TestPushOutNonSeqScan: """ - Tests for the `push_out_non_seq_scan` optimization in the case where the inner + Tests for the `scan_push_out_non_seq` optimization in the case where the inner function of a `Scan` `Op` has an output which is the result of a `Dot` product on a non-sequence matrix input to `Scan` and a vector that is the result of computation in the inner function. @@ -595,7 +595,7 @@ def inner_func(x): class TestPushOutAddScan: """ - Test case for the `push_out_add_scan` optimization in the case where the `Scan` + Test case for the `scan_push_out_add` optimization in the case where the `Scan` is used to compute the sum over the dot products between the corresponding elements of two list of matrices. @@ -1208,7 +1208,7 @@ def test_inplace3(self): class TestSaveMem: - mode = get_default_mode().including("scan_save_mem", "save_mem_new_scan") + mode = get_default_mode().including("scan_save_mem", "scan_save_mem") def test_save_mem(self): rng = np.random.default_rng(utt.fetch_seed())