Skip to content

Update pre-commit hooks #581

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ exclude: |
)$
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
rev: v4.5.0
hooks:
- id: debug-statements
exclude: |
Expand All @@ -20,23 +20,23 @@ repos:
)$
- id: check-merge-conflict
- repo: https://github.com/asottile/pyupgrade
rev: v3.3.1
rev: v3.15.0
hooks:
- id: pyupgrade
args: [--py39-plus]
- repo: https://github.com/psf/black
rev: 23.1.0
rev: 23.12.1
hooks:
- id: black
language_version: python3
- repo: https://github.com/pycqa/flake8
rev: 6.0.0
rev: 7.0.0
hooks:
- id: flake8
additional_dependencies:
- flake8-comprehensions
- repo: https://github.com/pycqa/isort
rev: 5.12.0
rev: 5.13.2
hooks:
- id: isort
- repo: https://github.com/humitos/mirrors-autoflake.git
Expand All @@ -54,7 +54,7 @@ repos:
)$
args: ['--in-place', '--remove-all-unused-imports', '--remove-unused-variable']
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.0.0
rev: v1.8.0
hooks:
- id: mypy
language: python
Expand Down
4 changes: 2 additions & 2 deletions pytensor/compile/debugmode.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,7 +689,7 @@ def _lessbroken_deepcopy(a):
else:
rval = copy.deepcopy(a)

assert type(rval) == type(a), (type(rval), type(a))
assert type(rval) is type(a), (type(rval), type(a))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this surely equivalent? Could the classes override type equality somehow?

The many places where we had type(...) == type(...) make me weary of these changes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ya, I am also weary of this commit. You should be very skeptical.

On my first attempt I also changed this line and it caused a test failure. (Note that this line is not like the others: .type instead of type(), and doesn't lead to a warning; I was just replacing == with is wherever I saw it next to type.)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that's different. Is there a reference for this rule that we can look at? Maybe that will confirm it is indeed always safe to do so.


if isinstance(rval, np.ndarray):
assert rval.dtype == a.dtype
Expand Down Expand Up @@ -1156,7 +1156,7 @@ def __str__(self):
return str(self.__dict__)

def __eq__(self, other):
rval = type(self) == type(other)
rval = type(self) is type(other)
if rval:
# nodes are not compared because this comparison is
# supposed to be true for corresponding events that happen
Expand Down
2 changes: 1 addition & 1 deletion pytensor/compile/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def __init__(self, fn, itypes, otypes, infer_shape):
self.infer_shape = self._infer_shape

def __eq__(self, other):
return type(self) == type(other) and self.__fn == other.__fn
return type(self) is type(other) and self.__fn == other.__fn

def __hash__(self):
return hash(type(self)) ^ hash(self.__fn)
Expand Down
4 changes: 2 additions & 2 deletions pytensor/compile/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1084,8 +1084,8 @@ def min_memory_generator(executable_nodes, viewed_by, view_of):
viewof_change = []
# Use to track view_of changes

viewedby_add = defaultdict(lambda: [])
viewedby_remove = defaultdict(lambda: [])
viewedby_add = defaultdict(list)
viewedby_remove = defaultdict(list)
# Use to track viewed_by changes

for var in node.outputs:
Expand Down
56 changes: 44 additions & 12 deletions pytensor/graph/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
TypeVar,
Union,
cast,
overload,
)

import numpy as np
Expand Down Expand Up @@ -718,7 +719,7 @@ def __eq__(self, other):
return True

return (
type(self) == type(other)
type(self) is type(other)
and self.id == other.id
and self.type == other.type
)
Expand Down Expand Up @@ -1301,9 +1302,31 @@ def clone_get_equiv(
return memo


@overload
def general_toposort(
outputs: Iterable[T],
deps: None,
compute_deps_cache: Callable[[T], Optional[Union[OrderedSet, list[T]]]],
deps_cache: Optional[dict[T, list[T]]],
clients: Optional[dict[T, list[T]]],
) -> list[T]:
...


@overload
def general_toposort(
outputs: Iterable[T],
deps: Callable[[T], Union[OrderedSet, list[T]]],
compute_deps_cache: None,
deps_cache: None,
clients: Optional[dict[T, list[T]]],
) -> list[T]:
...


def general_toposort(
outputs: Iterable[T],
deps: Optional[Callable[[T], Union[OrderedSet, list[T]]]],
compute_deps_cache: Optional[
Callable[[T], Optional[Union[OrderedSet, list[T]]]]
] = None,
Expand Down Expand Up @@ -1345,7 +1368,7 @@ def general_toposort(
if deps_cache is None:
deps_cache = {}

def _compute_deps_cache(io):
def _compute_deps_cache_(io):
if io not in deps_cache:
d = deps(io)

Expand All @@ -1363,6 +1386,8 @@ def _compute_deps_cache(io):
else:
return deps_cache[io]

_compute_deps_cache = _compute_deps_cache_

else:
_compute_deps_cache = compute_deps_cache

Expand Down Expand Up @@ -1451,15 +1476,14 @@ def io_toposort(
)
return order

compute_deps = None
compute_deps_cache = None
iset = set(inputs)
deps_cache: dict = {}

if not orderings: # ordering can be None or empty dict
# Specialized function that is faster when no ordering.
# Also include the cache in the function itself for speed up.

deps_cache: dict = {}

def compute_deps_cache(obj):
if obj in deps_cache:
return deps_cache[obj]
Expand All @@ -1478,6 +1502,14 @@ def compute_deps_cache(obj):
deps_cache[obj] = rval
return rval

topo = general_toposort(
outputs,
deps=None,
compute_deps_cache=compute_deps_cache,
deps_cache=deps_cache,
clients=clients,
)

else:
# the inputs are used only here in the function that decides what
# 'predecessors' to explore
Expand All @@ -1494,13 +1526,13 @@ def compute_deps(obj):
assert not orderings.get(obj, None)
return rval

topo = general_toposort(
outputs,
deps=compute_deps,
compute_deps_cache=compute_deps_cache,
deps_cache=deps_cache,
clients=clients,
)
topo = general_toposort(
outputs,
deps=compute_deps,
compute_deps_cache=None,
deps_cache=None,
clients=clients,
)
return [o for o in topo if isinstance(o, Apply)]


Expand Down
2 changes: 1 addition & 1 deletion pytensor/graph/null_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def values_eq(self, a, b, force_same_dtype=True):
raise ValueError("NullType has no values to compare")

def __eq__(self, other):
return type(self) == type(other)
return type(self) is type(other)

def __hash__(self):
return hash(type(self))
Expand Down
10 changes: 6 additions & 4 deletions pytensor/graph/rewriting/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,8 +951,8 @@ class MetaNodeRewriter(NodeRewriter):

def __init__(self):
self.verbose = config.metaopt__verbose
self.track_dict = defaultdict(lambda: [])
self.tag_dict = defaultdict(lambda: [])
self.track_dict = defaultdict(list)
self.tag_dict = defaultdict(list)
self._tracks = []
self.rewriters = []

Expand Down Expand Up @@ -2406,13 +2406,15 @@ def importer(node):
if node is not current_node:
q.append(node)

chin = None
chin: Optional[Callable] = None
if self.tracks_on_change_inputs:

def chin(node, i, r, new_r, reason):
def chin_(node, i, r, new_r, reason):
if node is not current_node and not isinstance(node, str):
q.append(node)

chin = chin_

u = self.attach_updater(
fgraph, importer, None, chin=chin, name=getattr(self, "name", None)
)
Expand Down
4 changes: 2 additions & 2 deletions pytensor/graph/rewriting/unify.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ def __new__(cls, constraint, token=None, prefix=""):
return obj

def __eq__(self, other):
if type(self) == type(other):
return self.token == other.token and self.constraint == other.constraint
if type(self) is type(other):
return self.token is other.token and self.constraint == other.constraint
return NotImplemented

def __hash__(self):
Expand Down
2 changes: 1 addition & 1 deletion pytensor/graph/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def __hash__(self):
if "__eq__" not in dct:

def __eq__(self, other):
return type(self) == type(other) and tuple(
return type(self) is type(other) and tuple(
getattr(self, a) for a in props
) == tuple(getattr(other, a) for a in props)

Expand Down
2 changes: 1 addition & 1 deletion pytensor/ifelse.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def __init__(self, n_outs, as_view=False, name=None):
self.name = name

def __eq__(self, other):
if type(self) != type(other):
if type(self) is not type(other):
return False
if self.as_view != other.as_view:
return False
Expand Down
4 changes: 2 additions & 2 deletions pytensor/link/c/params_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def __hash__(self):

def __eq__(self, other):
return (
type(self) == type(other)
type(self) is type(other)
and self.__params_type__ == other.__params_type__
and all(
# NB: Params object should have been already filtered.
Expand Down Expand Up @@ -432,7 +432,7 @@ def __repr__(self):

def __eq__(self, other):
return (
type(self) == type(other)
type(self) is type(other)
and self.fields == other.fields
and self.types == other.types
)
Expand Down
2 changes: 1 addition & 1 deletion pytensor/link/c/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@ def __hash__(self):

def __eq__(self, other):
return (
type(self) == type(other)
type(self) is type(other)
and self.ctype == other.ctype
and len(self) == len(other)
and len(self.aliases) == len(other.aliases)
Expand Down
2 changes: 1 addition & 1 deletion pytensor/link/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,7 +818,7 @@ def get_destroy_dependencies(fgraph: FunctionGraph) -> dict[Apply, list[Variable
in destroy_dependencies.
"""
order = fgraph.orderings()
destroy_dependencies = defaultdict(lambda: [])
destroy_dependencies = defaultdict(list)
for node in fgraph.apply_nodes:
for prereq in order.get(node, []):
destroy_dependencies[node].extend(prereq.outputs)
Expand Down
4 changes: 2 additions & 2 deletions pytensor/raise_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

class ExceptionType(Generic):
def __eq__(self, other):
return type(self) == type(other)
return type(self) is type(other)

def __hash__(self):
return hash(type(self))
Expand Down Expand Up @@ -51,7 +51,7 @@ def __str__(self):
return f"CheckAndRaise{{{self.exc_type}({self.msg})}}"

def __eq__(self, other):
if type(self) != type(other):
if type(self) is not type(other):
return False

if self.msg == other.msg and self.exc_type == other.exc_type:
Expand Down
6 changes: 3 additions & 3 deletions pytensor/scalar/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1074,7 +1074,7 @@ def __call__(self, *types):
return [rval]

def __eq__(self, other):
return type(self) == type(other) and self.tbl == other.tbl
return type(self) is type(other) and self.tbl == other.tbl

def __hash__(self):
return hash(type(self)) # ignore hash of table
Expand Down Expand Up @@ -1160,7 +1160,7 @@ def L_op(self, inputs, outputs, output_gradients):
return self.grad(inputs, output_gradients)

def __eq__(self, other):
test = type(self) == type(other) and getattr(
test = type(self) is type(other) and getattr(
self, "output_types_preference", None
) == getattr(other, "output_types_preference", None)
return test
Expand Down Expand Up @@ -4132,7 +4132,7 @@ def __eq__(self, other):
if self is other:
return True
if (
type(self) != type(other)
type(self) is not type(other)
or self.nin != other.nin
or self.nout != other.nout
):
Expand Down
10 changes: 5 additions & 5 deletions pytensor/scalar/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,7 +626,7 @@ def c_code(self, node, name, inp, out, sub):
raise NotImplementedError("only floatingpoint is implemented")

def __eq__(self, other):
return type(self) == type(other)
return type(self) is type(other)

def __hash__(self):
return hash(type(self))
Expand Down Expand Up @@ -675,7 +675,7 @@ def c_code(self, node, name, inp, out, sub):
raise NotImplementedError("only floatingpoint is implemented")

def __eq__(self, other):
return type(self) == type(other)
return type(self) is type(other)

def __hash__(self):
return hash(type(self))
Expand Down Expand Up @@ -724,7 +724,7 @@ def c_code(self, node, name, inp, out, sub):
raise NotImplementedError("only floatingpoint is implemented")

def __eq__(self, other):
return type(self) == type(other)
return type(self) is type(other)

def __hash__(self):
return hash(type(self))
Expand Down Expand Up @@ -1033,7 +1033,7 @@ def c_code(self, node, name, inp, out, sub):
raise NotImplementedError("only floatingpoint is implemented")

def __eq__(self, other):
return type(self) == type(other)
return type(self) is type(other)

def __hash__(self):
return hash(type(self))
Expand Down Expand Up @@ -1074,7 +1074,7 @@ def c_code(self, node, name, inp, out, sub):
raise NotImplementedError("only floatingpoint is implemented")

def __eq__(self, other):
return type(self) == type(other)
return type(self) is type(other)

def __hash__(self):
return hash(type(self))
Expand Down
Loading