Skip to content

Commit 5bc7070

Browse files
committed
Make scan helper return sequences to match old API
This was not possible prior to use of TypedListType for non TensorVariable sequences, as it would otherwise not be possible to represent indexing of last sequence state, which is needed e.g., for shared random generator updates.
1 parent db7068e commit 5bc7070

File tree

3 files changed

+53
-32
lines changed

3 files changed

+53
-32
lines changed

pytensor/loop/basic.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import functools
2-
from typing import List, Tuple
2+
from typing import List, Union
33

44
import numpy as np
55

@@ -18,7 +18,7 @@ def scan(
1818
non_sequences=None,
1919
n_steps=None,
2020
go_backwards=False,
21-
) -> Tuple[List[Variable], List[Variable]]:
21+
) -> Union[Variable, List[Variable]]:
2222
if sequences is None and n_steps is None:
2323
raise ValueError("Must provide n_steps when scanning without sequences")
2424

@@ -126,10 +126,11 @@ def scan(
126126
n_steps, idx, *prev_states, *sequences, *non_sequences, *extra_fgraph_inputs
127127
)
128128
assert isinstance(scan_outs, list)
129-
last_states = scan_outs[: scan_op.n_states]
130-
traces = scan_outs[scan_op.n_states :]
131-
# Don't return the inner index state
132-
return last_states[1:], traces[1:]
129+
# Don't return the last states or the trace for the inner index
130+
traces = scan_outs[scan_op.n_states + 1 :]
131+
if len(traces) == 1:
132+
return traces[0]
133+
return traces
133134

134135

135136
def map(
@@ -138,14 +139,12 @@ def map(
138139
non_sequences=None,
139140
go_backwards=False,
140141
):
141-
_, traces = scan(
142+
traces = scan(
142143
fn=fn,
143144
sequences=sequences,
144145
non_sequences=non_sequences,
145146
go_backwards=go_backwards,
146147
)
147-
if len(traces) == 1:
148-
return traces[0]
149148
return traces
150149

151150

@@ -156,16 +155,16 @@ def reduce(
156155
non_sequences=None,
157156
go_backwards=False,
158157
):
159-
final_states, _ = scan(
158+
traces = scan(
160159
fn=fn,
161160
init_states=init_states,
162161
sequences=sequences,
163162
non_sequences=non_sequences,
164163
go_backwards=go_backwards,
165164
)
166-
if len(final_states) == 1:
167-
return final_states[0]
168-
return final_states
165+
if not isinstance(traces, list):
166+
return traces[-1]
167+
return [trace[-1] for trace in traces]
169168

170169

171170
def filter(
@@ -177,21 +176,21 @@ def filter(
177176
if not isinstance(sequences, (tuple, list)):
178177
sequences = [sequences]
179178

180-
_, masks = scan(
179+
masks = scan(
181180
fn=fn,
182181
sequences=sequences,
183182
non_sequences=non_sequences,
184183
go_backwards=go_backwards,
185184
)
186185

187-
if not all(mask.dtype == "bool" for mask in masks):
188-
raise TypeError("The output of filter fn should be a boolean variable")
189-
if len(masks) == 1:
190-
masks = [masks[0]] * len(sequences)
186+
if not isinstance(masks, list):
187+
masks = [masks] * len(sequences)
191188
elif len(masks) != len(sequences):
192189
raise ValueError(
193190
"filter fn must return one variable or len(sequences), but it returned {len(masks)}"
194191
)
192+
if not all(mask.dtype == "bool" for mask in masks):
193+
raise TypeError("The output of filter fn should be a boolean variable")
195194

196195
filtered_sequences = [seq[mask] for seq, mask in zip(sequences, masks)]
197196

tests/link/jax/test_loop.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,15 @@
1313

1414
def test_scan_with_single_sequence():
1515
xs = vector("xs")
16-
_, [ys] = scan(lambda x: x * 100, sequences=[xs])
16+
ys = scan(lambda x: x * 100, sequences=[xs])
1717

1818
out_fg = FunctionGraph([xs], [ys])
1919
compare_jax_and_py(out_fg, [np.arange(10, dtype=config.floatX)])
2020

2121

2222
def test_scan_with_single_sequence_shortened_by_nsteps():
2323
xs = vector("xs", shape=(10,)) # JAX needs the length to be constant
24-
_, [ys] = scan(
24+
ys = scan(
2525
lambda x: x * 100,
2626
sequences=[xs],
2727
n_steps=9,
@@ -35,7 +35,7 @@ def test_scan_with_multiple_sequences():
3535
# JAX can only handle constant n_steps
3636
xs = vector("xs", shape=(10,))
3737
ys = vector("ys", shape=(10,))
38-
_, [zs] = scan(
38+
zs = scan(
3939
fn=lambda x, y: x * y,
4040
sequences=[xs, ys],
4141
)
@@ -48,7 +48,7 @@ def test_scan_with_multiple_sequences():
4848

4949
def test_scan_with_carried_and_non_carried_states():
5050
x = scalar("x")
51-
_, [ys1, ys2] = scan(
51+
[ys1, ys2] = scan(
5252
fn=lambda xtm1: (xtm1 + 1, (xtm1 + 1) * 2),
5353
init_states=[x, None],
5454
n_steps=10,
@@ -59,7 +59,7 @@ def test_scan_with_carried_and_non_carried_states():
5959

6060
def test_scan_with_sequence_and_carried_state():
6161
xs = vector("xs")
62-
_, [ys] = scan(
62+
ys = scan(
6363
fn=lambda x, ytm1: (ytm1 + 1) * x,
6464
init_states=[zeros(())],
6565
sequences=[xs],
@@ -71,11 +71,12 @@ def test_scan_with_sequence_and_carried_state():
7171
def test_scan_with_rvs():
7272
rng = shared(np.random.default_rng(123))
7373

74-
[final_rng, _], [rngs, xs] = scan(
74+
[rngs, xs] = scan(
7575
fn=lambda prev_rng: normal(rng=prev_rng).owner.outputs,
7676
init_states=[rng, None],
7777
n_steps=10,
7878
)
79+
final_rng = rngs[-1]
7980

8081
# First without updates
8182
fn = function([], xs, mode="JAX", updates=None)
@@ -99,7 +100,7 @@ def test_scan_with_rvs():
99100

100101

101102
def test_while_scan_fails():
102-
_, [xs] = scan(
103+
xs = scan(
103104
fn=lambda x: (x + 1, until((x + 1) >= 9)),
104105
init_states=[-1],
105106
n_steps=20,

tests/loop/test_basic.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
import numpy as np
22

33
import pytensor
4-
from pytensor import config, function, grad
4+
from pytensor import config, function, grad, shared
55
from pytensor.loop.basic import filter, map, reduce, scan
66
from pytensor.scan import until
77
from pytensor.tensor import arange, eq, scalar, vector, zeros
8+
from pytensor.tensor.random import normal
89

910

1011
def test_scan_with_sequences():
1112
xs = vector("xs")
1213
ys = vector("ys")
13-
_, [zs] = scan(
14+
zs = scan(
1415
fn=lambda x, y: x * y,
1516
sequences=[xs, ys],
1617
)
@@ -28,7 +29,7 @@ def test_scan_with_sequences():
2829

2930
def test_scan_with_carried_and_non_carried_states():
3031
x = scalar("x")
31-
_, [ys1, ys2] = scan(
32+
[ys1, ys2] = scan(
3233
fn=lambda xtm1: (xtm1 + 1, (xtm1 + 1) * 2),
3334
init_states=[x, None],
3435
n_steps=10,
@@ -41,7 +42,7 @@ def test_scan_with_carried_and_non_carried_states():
4142

4243
def test_scan_with_sequence_and_carried_state():
4344
xs = vector("xs")
44-
_, [ys] = scan(
45+
ys = scan(
4546
fn=lambda x, ytm1: (ytm1 + 1) * x,
4647
init_states=[zeros(())],
4748
sequences=[xs],
@@ -55,7 +56,7 @@ def test_scan_taking_grads_wrt_non_sequence():
5556
xs = vector("xs")
5657
ys = xs**2
5758

58-
_, [J] = scan(
59+
J = scan(
5960
lambda i, ys, xs: grad(ys[i], wrt=xs),
6061
sequences=arange(ys.shape[0]),
6162
non_sequences=[ys, xs],
@@ -70,7 +71,7 @@ def test_scan_taking_grads_wrt_sequence():
7071
xs = vector("xs")
7172
ys = xs**2
7273

73-
_, [J] = scan(
74+
J = scan(
7475
lambda y, xs: grad(y, wrt=xs),
7576
sequences=[ys],
7677
non_sequences=[xs],
@@ -81,7 +82,7 @@ def test_scan_taking_grads_wrt_sequence():
8182

8283

8384
def test_while_scan():
84-
_, [xs] = scan(
85+
xs = scan(
8586
fn=lambda x: (x + 1, until((x + 1) >= 9)),
8687
init_states=[-1],
8788
n_steps=20,
@@ -91,6 +92,26 @@ def test_while_scan():
9192
np.testing.assert_array_equal(f(), np.arange(10))
9293

9394

95+
def test_scan_rvs():
96+
rng = shared(np.random.default_rng(123))
97+
test_rng = np.random.default_rng(123)
98+
99+
def normal_fn(prev_rng):
100+
next_rng, x = normal(rng=prev_rng).owner.outputs
101+
return next_rng, x
102+
103+
[rngs, xs] = scan(
104+
fn=normal_fn,
105+
init_states=[rng, None],
106+
n_steps=5,
107+
)
108+
fn = function([], xs, updates={rng: rngs[-1]})
109+
110+
for i in range(3):
111+
res = fn()
112+
np.testing.assert_almost_equal(res, test_rng.normal(size=5))
113+
114+
94115
def test_map():
95116
xs = vector("xs")
96117
ys = map(

0 commit comments

Comments
 (0)