Skip to content

Commit ea1617f

Browse files
committed
Fix redefinitions
1 parent e8b04b7 commit ea1617f

File tree

3 files changed

+49
-15
lines changed

3 files changed

+49
-15
lines changed

pytensor/graph/basic.py

+43-11
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
TypeVar,
2424
Union,
2525
cast,
26+
overload,
2627
)
2728

2829
import numpy as np
@@ -1301,9 +1302,31 @@ def clone_get_equiv(
13011302
return memo
13021303

13031304

1305+
@overload
1306+
def general_toposort(
1307+
outputs: Iterable[T],
1308+
deps: None,
1309+
compute_deps_cache: Callable[[T], Optional[Union[OrderedSet, list[T]]]],
1310+
deps_cache: Optional[dict[T, list[T]]],
1311+
clients: Optional[dict[T, list[T]]],
1312+
) -> list[T]:
1313+
...
1314+
1315+
1316+
@overload
13041317
def general_toposort(
13051318
outputs: Iterable[T],
13061319
deps: Callable[[T], Union[OrderedSet, list[T]]],
1320+
compute_deps_cache: None,
1321+
deps_cache: None,
1322+
clients: Optional[dict[T, list[T]]],
1323+
) -> list[T]:
1324+
...
1325+
1326+
1327+
def general_toposort(
1328+
outputs: Iterable[T],
1329+
deps: Optional[Callable[[T], Union[OrderedSet, list[T]]]],
13071330
compute_deps_cache: Optional[
13081331
Callable[[T], Optional[Union[OrderedSet, list[T]]]]
13091332
] = None,
@@ -1345,7 +1368,7 @@ def general_toposort(
13451368
if deps_cache is None:
13461369
deps_cache = {}
13471370

1348-
def _compute_deps_cache(io):
1371+
def _compute_deps_cache_(io):
13491372
if io not in deps_cache:
13501373
d = deps(io)
13511374

@@ -1363,6 +1386,8 @@ def _compute_deps_cache(io):
13631386
else:
13641387
return deps_cache[io]
13651388

1389+
_compute_deps_cache = _compute_deps_cache_
1390+
13661391
else:
13671392
_compute_deps_cache = compute_deps_cache
13681393

@@ -1451,15 +1476,14 @@ def io_toposort(
14511476
)
14521477
return order
14531478

1454-
compute_deps = None
1455-
compute_deps_cache = None
14561479
iset = set(inputs)
1457-
deps_cache: dict = {}
14581480

14591481
if not orderings: # ordering can be None or empty dict
14601482
# Specialized function that is faster when no ordering.
14611483
# Also include the cache in the function itself for speed up.
14621484

1485+
deps_cache: dict = {}
1486+
14631487
def compute_deps_cache(obj):
14641488
if obj in deps_cache:
14651489
return deps_cache[obj]
@@ -1478,6 +1502,14 @@ def compute_deps_cache(obj):
14781502
deps_cache[obj] = rval
14791503
return rval
14801504

1505+
topo = general_toposort(
1506+
outputs,
1507+
deps=None,
1508+
compute_deps_cache=compute_deps_cache,
1509+
deps_cache=deps_cache,
1510+
clients=clients,
1511+
)
1512+
14811513
else:
14821514
# the inputs are used only here in the function that decides what
14831515
# 'predecessors' to explore
@@ -1494,13 +1526,13 @@ def compute_deps(obj):
14941526
assert not orderings.get(obj, None)
14951527
return rval
14961528

1497-
topo = general_toposort(
1498-
outputs,
1499-
deps=compute_deps,
1500-
compute_deps_cache=compute_deps_cache,
1501-
deps_cache=deps_cache,
1502-
clients=clients,
1503-
)
1529+
topo = general_toposort(
1530+
outputs,
1531+
deps=compute_deps,
1532+
compute_deps_cache=None,
1533+
deps_cache=None,
1534+
clients=clients,
1535+
)
15041536
return [o for o in topo if isinstance(o, Apply)]
15051537

15061538

pytensor/graph/rewriting/basic.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -2406,13 +2406,15 @@ def importer(node):
24062406
if node is not current_node:
24072407
q.append(node)
24082408

2409-
chin = None
2409+
chin: Optional[Callable] = None
24102410
if self.tracks_on_change_inputs:
24112411

2412-
def chin(node, i, r, new_r, reason):
2412+
def chin_(node, i, r, new_r, reason):
24132413
if node is not current_node and not isinstance(node, str):
24142414
q.append(node)
24152415

2416+
chin = chin_
2417+
24162418
u = self.attach_updater(
24172419
fgraph, importer, None, chin=chin, name=getattr(self, "name", None)
24182420
)

tests/tensor/test_math.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1398,7 +1398,7 @@ def test_bool(self):
13981398

13991399

14001400
rng = np.random.default_rng(seed=utt.fetch_seed())
1401-
TestClip = makeTester(
1401+
TestClip1 = makeTester(
14021402
name="ClipTester",
14031403
op=clip,
14041404
expected=lambda x, y, z: np.clip(x, y, z),
@@ -1465,7 +1465,7 @@ def test_bool(self):
14651465
)
14661466

14671467

1468-
class TestClip:
1468+
class TestClip2:
14691469
def test_complex_value(self):
14701470
for dtype in ["complex64", "complex128"]:
14711471
a = vector(dtype=dtype)

0 commit comments

Comments
 (0)