Skip to content

Apply scan memory save rewrite to while scans #178

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
ricardoV94 opened this issue Jan 6, 2023 · 0 comments · Fixed by #216
Closed

Apply scan memory save rewrite to while scans #178

ricardoV94 opened this issue Jan 6, 2023 · 0 comments · Fixed by #216
Labels

Comments

@ricardoV94
Copy link
Member

Description

The rewrite save_mem_new_scan

def save_mem_new_scan(fgraph, node):

Seems to work with both static and dynamic for scan loops, but not with while loops.

import pytensor
import pytensor.tensor as pt
from pytensor.scan import until

x = pt.scalar("x")
n_steps = pt.iscalar("n_steps")

y, _, = pytensor.scan(
    lambda xtm1: xtm1 + 1,  # for loop
    # lambda xtm1: (xtm1 + 1, {}, until(xtm1 >= 100)),  # while loop
    outputs_info=[x],
    n_steps=n_steps, # dynamic
    # n_steps=100, # static
    strict=True,
)
# Save memory is triggered by choosing only last value
y = y[-1]

pytensor.config.optimizer_verbose = True
f = pytensor.function([x, n_steps], y, on_unused_input="ignore")

This can make a big difference in memory as well as performance as it avoids allocating large arrays for the outputs when these are not of interest (see #174)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant