diff --git a/pytensor/graph/basic.py b/pytensor/graph/basic.py index f173b2fe96..65fe3750b9 100644 --- a/pytensor/graph/basic.py +++ b/pytensor/graph/basic.py @@ -22,6 +22,7 @@ TypeVar, Union, cast, + overload, ) import numpy as np @@ -1301,9 +1302,31 @@ def clone_get_equiv( return memo +@overload +def general_toposort( + outputs: Iterable[T], + deps: None, + compute_deps_cache: Callable[[T], Optional[Union[OrderedSet, list[T]]]], + deps_cache: Optional[dict[T, list[T]]], + clients: Optional[dict[T, list[T]]], +) -> list[T]: + ... + + +@overload def general_toposort( outputs: Iterable[T], deps: Callable[[T], Union[OrderedSet, list[T]]], + compute_deps_cache: None, + deps_cache: None, + clients: Optional[dict[T, list[T]]], +) -> list[T]: + ... + + +def general_toposort( + outputs: Iterable[T], + deps: Optional[Callable[[T], Union[OrderedSet, list[T]]]], compute_deps_cache: Optional[ Callable[[T], Optional[Union[OrderedSet, list[T]]]] ] = None, @@ -1345,7 +1368,7 @@ def general_toposort( if deps_cache is None: deps_cache = {} - def _compute_deps_cache(io): + def _compute_deps_cache_(io): if io not in deps_cache: d = deps(io) @@ -1363,6 +1386,8 @@ def _compute_deps_cache(io): else: return deps_cache[io] + _compute_deps_cache = _compute_deps_cache_ + else: _compute_deps_cache = compute_deps_cache @@ -1451,15 +1476,14 @@ def io_toposort( ) return order - compute_deps = None - compute_deps_cache = None iset = set(inputs) - deps_cache: dict = {} if not orderings: # ordering can be None or empty dict # Specialized function that is faster when no ordering. # Also include the cache in the function itself for speed up. + deps_cache: dict = {} + def compute_deps_cache(obj): if obj in deps_cache: return deps_cache[obj] @@ -1478,6 +1502,14 @@ def compute_deps_cache(obj): deps_cache[obj] = rval return rval + topo = general_toposort( + outputs, + deps=None, + compute_deps_cache=compute_deps_cache, + deps_cache=deps_cache, + clients=clients, + ) + else: # the inputs are used only here in the function that decides what # 'predecessors' to explore @@ -1494,13 +1526,13 @@ def compute_deps(obj): assert not orderings.get(obj, None) return rval - topo = general_toposort( - outputs, - deps=compute_deps, - compute_deps_cache=compute_deps_cache, - deps_cache=deps_cache, - clients=clients, - ) + topo = general_toposort( + outputs, + deps=compute_deps, + compute_deps_cache=None, + deps_cache=None, + clients=clients, + ) return [o for o in topo if isinstance(o, Apply)] diff --git a/pytensor/graph/rewriting/basic.py b/pytensor/graph/rewriting/basic.py index c24cd99358..d334b9894d 100644 --- a/pytensor/graph/rewriting/basic.py +++ b/pytensor/graph/rewriting/basic.py @@ -2405,13 +2405,15 @@ def importer(node): if node is not current_node: q.append(node) - chin = None + chin: Optional[Callable] = None if self.tracks_on_change_inputs: - def chin(node, i, r, new_r, reason): + def chin_(node, i, r, new_r, reason): if node is not current_node and not isinstance(node, str): q.append(node) + chin = chin_ + u = self.attach_updater( fgraph, importer, None, chin=chin, name=getattr(self, "name", None) ) diff --git a/tests/tensor/test_math.py b/tests/tensor/test_math.py index 045a776081..337ef20685 100644 --- a/tests/tensor/test_math.py +++ b/tests/tensor/test_math.py @@ -1403,7 +1403,7 @@ def test_bool(self): rng = np.random.default_rng(seed=utt.fetch_seed()) -TestClip = makeTester( +TestClip1 = makeTester( name="ClipTester", op=clip, expected=lambda x, y, z: np.clip(x, y, z), @@ -1470,7 +1470,7 @@ def test_bool(self): ) -class TestClip: +class TestClip2: def test_complex_value(self): for dtype in ["complex64", "complex128"]: a = vector(dtype=dtype)