Skip to content

Fix redefinitions #598

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 43 additions & 11 deletions pytensor/graph/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
TypeVar,
Union,
cast,
overload,
)

import numpy as np
Expand Down Expand Up @@ -1301,9 +1302,31 @@ def clone_get_equiv(
return memo


@overload
def general_toposort(
outputs: Iterable[T],
deps: None,
compute_deps_cache: Callable[[T], Optional[Union[OrderedSet, list[T]]]],
deps_cache: Optional[dict[T, list[T]]],
clients: Optional[dict[T, list[T]]],
) -> list[T]:
...


@overload
def general_toposort(
outputs: Iterable[T],
deps: Callable[[T], Union[OrderedSet, list[T]]],
compute_deps_cache: None,
deps_cache: None,
clients: Optional[dict[T, list[T]]],
) -> list[T]:
...


def general_toposort(
outputs: Iterable[T],
deps: Optional[Callable[[T], Union[OrderedSet, list[T]]]],
compute_deps_cache: Optional[
Callable[[T], Optional[Union[OrderedSet, list[T]]]]
] = None,
Expand Down Expand Up @@ -1345,7 +1368,7 @@ def general_toposort(
if deps_cache is None:
deps_cache = {}

def _compute_deps_cache(io):
def _compute_deps_cache_(io):
if io not in deps_cache:
d = deps(io)

Expand All @@ -1363,6 +1386,8 @@ def _compute_deps_cache(io):
else:
return deps_cache[io]

_compute_deps_cache = _compute_deps_cache_

else:
_compute_deps_cache = compute_deps_cache

Expand Down Expand Up @@ -1451,15 +1476,14 @@ def io_toposort(
)
return order

compute_deps = None
compute_deps_cache = None
iset = set(inputs)
deps_cache: dict = {}

if not orderings: # ordering can be None or empty dict
# Specialized function that is faster when no ordering.
# Also include the cache in the function itself for speed up.

deps_cache: dict = {}

def compute_deps_cache(obj):
if obj in deps_cache:
return deps_cache[obj]
Expand All @@ -1478,6 +1502,14 @@ def compute_deps_cache(obj):
deps_cache[obj] = rval
return rval

topo = general_toposort(
outputs,
deps=None,
compute_deps_cache=compute_deps_cache,
deps_cache=deps_cache,
clients=clients,
)

else:
# the inputs are used only here in the function that decides what
# 'predecessors' to explore
Expand All @@ -1494,13 +1526,13 @@ def compute_deps(obj):
assert not orderings.get(obj, None)
return rval

topo = general_toposort(
outputs,
deps=compute_deps,
compute_deps_cache=compute_deps_cache,
deps_cache=deps_cache,
clients=clients,
)
topo = general_toposort(
outputs,
deps=compute_deps,
compute_deps_cache=None,
deps_cache=None,
clients=clients,
)
return [o for o in topo if isinstance(o, Apply)]


Expand Down
6 changes: 4 additions & 2 deletions pytensor/graph/rewriting/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2405,13 +2405,15 @@ def importer(node):
if node is not current_node:
q.append(node)

chin = None
chin: Optional[Callable] = None
if self.tracks_on_change_inputs:

def chin(node, i, r, new_r, reason):
def chin_(node, i, r, new_r, reason):
if node is not current_node and not isinstance(node, str):
q.append(node)

chin = chin_

u = self.attach_updater(
fgraph, importer, None, chin=chin, name=getattr(self, "name", None)
)
Expand Down
4 changes: 2 additions & 2 deletions tests/tensor/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -1403,7 +1403,7 @@ def test_bool(self):


rng = np.random.default_rng(seed=utt.fetch_seed())
TestClip = makeTester(
TestClip1 = makeTester(
name="ClipTester",
op=clip,
expected=lambda x, y, z: np.clip(x, y, z),
Expand Down Expand Up @@ -1470,7 +1470,7 @@ def test_bool(self):
)


class TestClip:
class TestClip2:
def test_complex_value(self):
for dtype in ["complex64", "complex128"]:
a = vector(dtype=dtype)
Expand Down