diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 272dfcc8fd..d925cad991 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,7 +7,7 @@ exclude: | )$ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 + rev: v4.5.0 hooks: - id: debug-statements exclude: | @@ -20,23 +20,23 @@ repos: )$ - id: check-merge-conflict - repo: https://github.com/asottile/pyupgrade - rev: v3.3.1 + rev: v3.15.0 hooks: - id: pyupgrade args: [--py39-plus] - repo: https://github.com/psf/black - rev: 23.1.0 + rev: 23.12.1 hooks: - id: black language_version: python3 - repo: https://github.com/pycqa/flake8 - rev: 6.0.0 + rev: 7.0.0 hooks: - id: flake8 additional_dependencies: - flake8-comprehensions - repo: https://github.com/pycqa/isort - rev: 5.12.0 + rev: 5.13.2 hooks: - id: isort - repo: https://github.com/humitos/mirrors-autoflake.git @@ -54,7 +54,7 @@ repos: )$ args: ['--in-place', '--remove-all-unused-imports', '--remove-unused-variable'] - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.0.0 + rev: v1.8.0 hooks: - id: mypy language: python diff --git a/pytensor/compile/debugmode.py b/pytensor/compile/debugmode.py index 4c2d1296d3..9dce2d7e8c 100644 --- a/pytensor/compile/debugmode.py +++ b/pytensor/compile/debugmode.py @@ -689,7 +689,7 @@ def _lessbroken_deepcopy(a): else: rval = copy.deepcopy(a) - assert type(rval) == type(a), (type(rval), type(a)) + assert type(rval) is type(a), (type(rval), type(a)) if isinstance(rval, np.ndarray): assert rval.dtype == a.dtype @@ -1156,7 +1156,7 @@ def __str__(self): return str(self.__dict__) def __eq__(self, other): - rval = type(self) == type(other) + rval = type(self) is type(other) if rval: # nodes are not compared because this comparison is # supposed to be true for corresponding events that happen diff --git a/pytensor/compile/ops.py b/pytensor/compile/ops.py index 1cb5893a08..1e276aedbe 100644 --- a/pytensor/compile/ops.py +++ b/pytensor/compile/ops.py @@ -246,7 +246,7 @@ def __init__(self, fn, itypes, otypes, infer_shape): self.infer_shape = self._infer_shape def __eq__(self, other): - return type(self) == type(other) and self.__fn == other.__fn + return type(self) is type(other) and self.__fn == other.__fn def __hash__(self): return hash(type(self)) ^ hash(self.__fn) diff --git a/pytensor/compile/profiling.py b/pytensor/compile/profiling.py index 986f6dc108..db46c3d5b4 100644 --- a/pytensor/compile/profiling.py +++ b/pytensor/compile/profiling.py @@ -1084,8 +1084,8 @@ def min_memory_generator(executable_nodes, viewed_by, view_of): viewof_change = [] # Use to track view_of changes - viewedby_add = defaultdict(lambda: []) - viewedby_remove = defaultdict(lambda: []) + viewedby_add = defaultdict(list) + viewedby_remove = defaultdict(list) # Use to track viewed_by changes for var in node.outputs: diff --git a/pytensor/graph/basic.py b/pytensor/graph/basic.py index 185caa77b6..14cb1c221a 100644 --- a/pytensor/graph/basic.py +++ b/pytensor/graph/basic.py @@ -23,6 +23,7 @@ TypeVar, Union, cast, + overload, ) import numpy as np @@ -718,7 +719,7 @@ def __eq__(self, other): return True return ( - type(self) == type(other) + type(self) is type(other) and self.id == other.id and self.type == other.type ) @@ -1301,9 +1302,31 @@ def clone_get_equiv( return memo +@overload +def general_toposort( + outputs: Iterable[T], + deps: None, + compute_deps_cache: Callable[[T], Optional[Union[OrderedSet, list[T]]]], + deps_cache: Optional[dict[T, list[T]]], + clients: Optional[dict[T, list[T]]], +) -> list[T]: + ... + + +@overload def general_toposort( outputs: Iterable[T], deps: Callable[[T], Union[OrderedSet, list[T]]], + compute_deps_cache: None, + deps_cache: None, + clients: Optional[dict[T, list[T]]], +) -> list[T]: + ... + + +def general_toposort( + outputs: Iterable[T], + deps: Optional[Callable[[T], Union[OrderedSet, list[T]]]], compute_deps_cache: Optional[ Callable[[T], Optional[Union[OrderedSet, list[T]]]] ] = None, @@ -1345,7 +1368,7 @@ def general_toposort( if deps_cache is None: deps_cache = {} - def _compute_deps_cache(io): + def _compute_deps_cache_(io): if io not in deps_cache: d = deps(io) @@ -1363,6 +1386,8 @@ def _compute_deps_cache(io): else: return deps_cache[io] + _compute_deps_cache = _compute_deps_cache_ + else: _compute_deps_cache = compute_deps_cache @@ -1451,15 +1476,14 @@ def io_toposort( ) return order - compute_deps = None - compute_deps_cache = None iset = set(inputs) - deps_cache: dict = {} if not orderings: # ordering can be None or empty dict # Specialized function that is faster when no ordering. # Also include the cache in the function itself for speed up. + deps_cache: dict = {} + def compute_deps_cache(obj): if obj in deps_cache: return deps_cache[obj] @@ -1478,6 +1502,14 @@ def compute_deps_cache(obj): deps_cache[obj] = rval return rval + topo = general_toposort( + outputs, + deps=None, + compute_deps_cache=compute_deps_cache, + deps_cache=deps_cache, + clients=clients, + ) + else: # the inputs are used only here in the function that decides what # 'predecessors' to explore @@ -1494,13 +1526,13 @@ def compute_deps(obj): assert not orderings.get(obj, None) return rval - topo = general_toposort( - outputs, - deps=compute_deps, - compute_deps_cache=compute_deps_cache, - deps_cache=deps_cache, - clients=clients, - ) + topo = general_toposort( + outputs, + deps=compute_deps, + compute_deps_cache=None, + deps_cache=None, + clients=clients, + ) return [o for o in topo if isinstance(o, Apply)] diff --git a/pytensor/graph/null_type.py b/pytensor/graph/null_type.py index d2a77c67df..66f5c18fd1 100644 --- a/pytensor/graph/null_type.py +++ b/pytensor/graph/null_type.py @@ -33,7 +33,7 @@ def values_eq(self, a, b, force_same_dtype=True): raise ValueError("NullType has no values to compare") def __eq__(self, other): - return type(self) == type(other) + return type(self) is type(other) def __hash__(self): return hash(type(self)) diff --git a/pytensor/graph/rewriting/basic.py b/pytensor/graph/rewriting/basic.py index 4d4cfe02f5..72fc48ed87 100644 --- a/pytensor/graph/rewriting/basic.py +++ b/pytensor/graph/rewriting/basic.py @@ -951,8 +951,8 @@ class MetaNodeRewriter(NodeRewriter): def __init__(self): self.verbose = config.metaopt__verbose - self.track_dict = defaultdict(lambda: []) - self.tag_dict = defaultdict(lambda: []) + self.track_dict = defaultdict(list) + self.tag_dict = defaultdict(list) self._tracks = [] self.rewriters = [] @@ -2406,13 +2406,15 @@ def importer(node): if node is not current_node: q.append(node) - chin = None + chin: Optional[Callable] = None if self.tracks_on_change_inputs: - def chin(node, i, r, new_r, reason): + def chin_(node, i, r, new_r, reason): if node is not current_node and not isinstance(node, str): q.append(node) + chin = chin_ + u = self.attach_updater( fgraph, importer, None, chin=chin, name=getattr(self, "name", None) ) diff --git a/pytensor/graph/rewriting/unify.py b/pytensor/graph/rewriting/unify.py index 9e6b7494b1..463ee3138b 100644 --- a/pytensor/graph/rewriting/unify.py +++ b/pytensor/graph/rewriting/unify.py @@ -58,8 +58,8 @@ def __new__(cls, constraint, token=None, prefix=""): return obj def __eq__(self, other): - if type(self) == type(other): - return self.token == other.token and self.constraint == other.constraint + if type(self) is type(other): + return self.token is other.token and self.constraint == other.constraint return NotImplemented def __hash__(self): diff --git a/pytensor/graph/utils.py b/pytensor/graph/utils.py index 580678672e..2baec0e5ff 100644 --- a/pytensor/graph/utils.py +++ b/pytensor/graph/utils.py @@ -229,7 +229,7 @@ def __hash__(self): if "__eq__" not in dct: def __eq__(self, other): - return type(self) == type(other) and tuple( + return type(self) is type(other) and tuple( getattr(self, a) for a in props ) == tuple(getattr(other, a) for a in props) diff --git a/pytensor/ifelse.py b/pytensor/ifelse.py index 7858f51eba..e56f404826 100644 --- a/pytensor/ifelse.py +++ b/pytensor/ifelse.py @@ -78,7 +78,7 @@ def __init__(self, n_outs, as_view=False, name=None): self.name = name def __eq__(self, other): - if type(self) != type(other): + if type(self) is not type(other): return False if self.as_view != other.as_view: return False diff --git a/pytensor/link/c/params_type.py b/pytensor/link/c/params_type.py index ffa57b0949..a62888dd04 100644 --- a/pytensor/link/c/params_type.py +++ b/pytensor/link/c/params_type.py @@ -297,7 +297,7 @@ def __hash__(self): def __eq__(self, other): return ( - type(self) == type(other) + type(self) is type(other) and self.__params_type__ == other.__params_type__ and all( # NB: Params object should have been already filtered. @@ -432,7 +432,7 @@ def __repr__(self): def __eq__(self, other): return ( - type(self) == type(other) + type(self) is type(other) and self.fields == other.fields and self.types == other.types ) diff --git a/pytensor/link/c/type.py b/pytensor/link/c/type.py index 24ced701ed..6cb4d95b8c 100644 --- a/pytensor/link/c/type.py +++ b/pytensor/link/c/type.py @@ -519,7 +519,7 @@ def __hash__(self): def __eq__(self, other): return ( - type(self) == type(other) + type(self) is type(other) and self.ctype == other.ctype and len(self) == len(other) and len(self.aliases) == len(other.aliases) diff --git a/pytensor/link/utils.py b/pytensor/link/utils.py index 444b8f28f8..ec0c03a7be 100644 --- a/pytensor/link/utils.py +++ b/pytensor/link/utils.py @@ -818,7 +818,7 @@ def get_destroy_dependencies(fgraph: FunctionGraph) -> dict[Apply, list[Variable in destroy_dependencies. """ order = fgraph.orderings() - destroy_dependencies = defaultdict(lambda: []) + destroy_dependencies = defaultdict(list) for node in fgraph.apply_nodes: for prereq in order.get(node, []): destroy_dependencies[node].extend(prereq.outputs) diff --git a/pytensor/raise_op.py b/pytensor/raise_op.py index 52d674b801..737223b53a 100644 --- a/pytensor/raise_op.py +++ b/pytensor/raise_op.py @@ -16,7 +16,7 @@ class ExceptionType(Generic): def __eq__(self, other): - return type(self) == type(other) + return type(self) is type(other) def __hash__(self): return hash(type(self)) @@ -51,7 +51,7 @@ def __str__(self): return f"CheckAndRaise{{{self.exc_type}({self.msg})}}" def __eq__(self, other): - if type(self) != type(other): + if type(self) is not type(other): return False if self.msg == other.msg and self.exc_type == other.exc_type: diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index 5b5ae05838..d3503e6148 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -1074,7 +1074,7 @@ def __call__(self, *types): return [rval] def __eq__(self, other): - return type(self) == type(other) and self.tbl == other.tbl + return type(self) is type(other) and self.tbl == other.tbl def __hash__(self): return hash(type(self)) # ignore hash of table @@ -1160,7 +1160,7 @@ def L_op(self, inputs, outputs, output_gradients): return self.grad(inputs, output_gradients) def __eq__(self, other): - test = type(self) == type(other) and getattr( + test = type(self) is type(other) and getattr( self, "output_types_preference", None ) == getattr(other, "output_types_preference", None) return test @@ -4132,7 +4132,7 @@ def __eq__(self, other): if self is other: return True if ( - type(self) != type(other) + type(self) is not type(other) or self.nin != other.nin or self.nout != other.nout ): diff --git a/pytensor/scalar/math.py b/pytensor/scalar/math.py index 7eba128100..9f73b67859 100644 --- a/pytensor/scalar/math.py +++ b/pytensor/scalar/math.py @@ -626,7 +626,7 @@ def c_code(self, node, name, inp, out, sub): raise NotImplementedError("only floatingpoint is implemented") def __eq__(self, other): - return type(self) == type(other) + return type(self) is type(other) def __hash__(self): return hash(type(self)) @@ -675,7 +675,7 @@ def c_code(self, node, name, inp, out, sub): raise NotImplementedError("only floatingpoint is implemented") def __eq__(self, other): - return type(self) == type(other) + return type(self) is type(other) def __hash__(self): return hash(type(self)) @@ -724,7 +724,7 @@ def c_code(self, node, name, inp, out, sub): raise NotImplementedError("only floatingpoint is implemented") def __eq__(self, other): - return type(self) == type(other) + return type(self) is type(other) def __hash__(self): return hash(type(self)) @@ -1033,7 +1033,7 @@ def c_code(self, node, name, inp, out, sub): raise NotImplementedError("only floatingpoint is implemented") def __eq__(self, other): - return type(self) == type(other) + return type(self) is type(other) def __hash__(self): return hash(type(self)) @@ -1074,7 +1074,7 @@ def c_code(self, node, name, inp, out, sub): raise NotImplementedError("only floatingpoint is implemented") def __eq__(self, other): - return type(self) == type(other) + return type(self) is type(other) def __hash__(self): return hash(type(self)) diff --git a/pytensor/scan/op.py b/pytensor/scan/op.py index b7cc7fa276..14999f8bfb 100644 --- a/pytensor/scan/op.py +++ b/pytensor/scan/op.py @@ -1244,7 +1244,7 @@ def is_cpu_vector(s): return apply_node def __eq__(self, other): - if type(self) != type(other): + if type(self) is not type(other): return False if self.info != other.info: diff --git a/pytensor/sparse/basic.py b/pytensor/sparse/basic.py index 96105adc5c..56b0400307 100644 --- a/pytensor/sparse/basic.py +++ b/pytensor/sparse/basic.py @@ -451,7 +451,7 @@ def __eq__(self, other): return ( a == x and (b.dtype == y.dtype) - and (type(b) == type(y)) + and (type(b) is type(y)) and (b.shape == y.shape) and (abs(b - y).sum() < 1e-6 * b.nnz) ) diff --git a/pytensor/tensor/random/type.py b/pytensor/tensor/random/type.py index 527d3f3d6b..fccea2a241 100644 --- a/pytensor/tensor/random/type.py +++ b/pytensor/tensor/random/type.py @@ -102,7 +102,7 @@ def _eq(sa, sb): return _eq(sa, sb) def __eq__(self, other): - return type(self) == type(other) + return type(self) is type(other) def __hash__(self): return hash(type(self)) @@ -198,7 +198,7 @@ def _eq(sa, sb): return _eq(sa, sb) def __eq__(self, other): - return type(self) == type(other) + return type(self) is type(other) def __hash__(self): return hash(type(self)) diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index 43aecb8816..07651d5056 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -126,7 +126,7 @@ def apply(self, fgraph): "nb_call_replace": 0, "nb_call_validate": 0, "nb_inconsistent": 0, - "ndim": defaultdict(lambda: 0), + "ndim": defaultdict(int), } check_each_change = config.tensor__insert_inplace_optimizer_validate_nb diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 670fec4211..6d3431b010 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -1770,7 +1770,7 @@ def local_reduce_broadcastable(fgraph, node): ii += 1 new_reduced = reduced.dimshuffle(*pattern) if new_axis: - if type(node.op) == CAReduce: + if type(node.op) is CAReduce: # This case handles `CAReduce` instances # (e.g. generated by `scalar_elemwise`), and not the # scalar `Op`-specific subclasses diff --git a/pytensor/tensor/type.py b/pytensor/tensor/type.py index 153879b77e..d3fd02ab44 100644 --- a/pytensor/tensor/type.py +++ b/pytensor/tensor/type.py @@ -370,7 +370,7 @@ def values_eq_approx( return values_eq_approx(a, b, allow_remove_inf, allow_remove_nan, rtol, atol) def __eq__(self, other): - if type(self) != type(other): + if type(self) is not type(other): return NotImplemented return other.dtype == self.dtype and other.shape == self.shape @@ -639,7 +639,7 @@ def c_code_cache_version(self): class DenseTypeMeta(MetaType): def __instancecheck__(self, o): - if type(o) == TensorType or isinstance(o, DenseTypeMeta): + if type(o) is TensorType or isinstance(o, DenseTypeMeta): return True return False diff --git a/pytensor/tensor/type_other.py b/pytensor/tensor/type_other.py index 0ac1fd4415..dc46810321 100644 --- a/pytensor/tensor/type_other.py +++ b/pytensor/tensor/type_other.py @@ -64,7 +64,7 @@ def __str__(self): return "slice" def __eq__(self, other): - return type(self) == type(other) + return type(self) is type(other) def __hash__(self): return hash(type(self)) diff --git a/pytensor/tensor/variable.py b/pytensor/tensor/variable.py index d4b3df6975..9a2f8d3b1f 100644 --- a/pytensor/tensor/variable.py +++ b/pytensor/tensor/variable.py @@ -931,7 +931,7 @@ class TensorConstantSignature(tuple): """ def __eq__(self, other): - if type(self) != type(other): + if type(self) is not type(other): return False try: (t0, d0), (t1, d1) = self, other @@ -1091,7 +1091,7 @@ def __deepcopy__(self, memo): class DenseVariableMeta(MetaType): def __instancecheck__(self, o): - if type(o) == TensorVariable or isinstance(o, DenseVariableMeta): + if type(o) is TensorVariable or isinstance(o, DenseVariableMeta): return True return False @@ -1106,7 +1106,7 @@ class DenseTensorVariable(TensorType, metaclass=DenseVariableMeta): class DenseConstantMeta(MetaType): def __instancecheck__(self, o): - if type(o) == TensorConstant or isinstance(o, DenseConstantMeta): + if type(o) is TensorConstant or isinstance(o, DenseConstantMeta): return True return False diff --git a/pytensor/typed_list/type.py b/pytensor/typed_list/type.py index 7b43252658..e71b7e2800 100644 --- a/pytensor/typed_list/type.py +++ b/pytensor/typed_list/type.py @@ -55,7 +55,7 @@ def __eq__(self, other): Two lists are equal if they contain the same type. """ - return type(self) == type(other) and self.ttype == other.ttype + return type(self) is type(other) and self.ttype == other.ttype def __hash__(self): return hash((type(self), self.ttype)) diff --git a/tests/graph/rewriting/test_unify.py b/tests/graph/rewriting/test_unify.py index a152fcee17..da430a1587 100644 --- a/tests/graph/rewriting/test_unify.py +++ b/tests/graph/rewriting/test_unify.py @@ -42,7 +42,7 @@ def perform(self, node, inputs, outputs): class CustomOpNoProps(CustomOpNoPropsNoEq): def __eq__(self, other): - return type(self) == type(other) and self.a == other.a + return type(self) is type(other) and self.a == other.a def __hash__(self): return hash((type(self), self.a)) diff --git a/tests/graph/test_fg.py b/tests/graph/test_fg.py index 77ca574773..068f19b946 100644 --- a/tests/graph/test_fg.py +++ b/tests/graph/test_fg.py @@ -30,8 +30,8 @@ def test_pickle(self): s = pickle.dumps(func) new_func = pickle.loads(s) - assert all(type(a) == type(b) for a, b in zip(func.inputs, new_func.inputs)) - assert all(type(a) == type(b) for a, b in zip(func.outputs, new_func.outputs)) + assert all(type(a) is type(b) for a, b in zip(func.inputs, new_func.inputs)) + assert all(type(a) is type(b) for a, b in zip(func.outputs, new_func.outputs)) assert all( type(a.op) is type(b.op) # noqa: E721 for a, b in zip(func.apply_nodes, new_func.apply_nodes) diff --git a/tests/graph/test_op.py b/tests/graph/test_op.py index 59d81ad59e..c2f81649d8 100644 --- a/tests/graph/test_op.py +++ b/tests/graph/test_op.py @@ -25,7 +25,7 @@ def __init__(self, thingy): self.thingy = thingy def __eq__(self, other): - return type(other) == type(self) and other.thingy == self.thingy + return type(other) is type(self) and other.thingy == self.thingy def __str__(self): return str(self.thingy) diff --git a/tests/link/c/test_basic.py b/tests/link/c/test_basic.py index b8ea0d3400..68da2dd86a 100644 --- a/tests/link/c/test_basic.py +++ b/tests/link/c/test_basic.py @@ -73,7 +73,7 @@ def c_code_cache_version(self): return (1,) def __eq__(self, other): - return type(self) == type(other) + return type(self) is type(other) def __hash__(self): return hash(type(self)) diff --git a/tests/sparse/test_basic.py b/tests/sparse/test_basic.py index 590c76e008..3ceb2bfa17 100644 --- a/tests/sparse/test_basic.py +++ b/tests/sparse/test_basic.py @@ -348,7 +348,7 @@ def __init__(self, structured): self.structured = structured def __eq__(self, other): - return (type(self) == type(other)) and self.structured == other.structured + return (type(self) is type(other)) and self.structured == other.structured def __hash__(self): return hash(type(self)) ^ hash(self.structured) diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index ba5e1cf648..3bd1b31b6f 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -3156,7 +3156,7 @@ def test_stack(): sx, sy = dscalar(), dscalar() rval = inplace_func([sx, sy], stack([sx, sy]))(-4.0, -2.0) - assert type(rval) == np.ndarray + assert type(rval) is np.ndarray assert [-4, -2] == list(rval) diff --git a/tests/tensor/test_math.py b/tests/tensor/test_math.py index a3a6be4235..172204f250 100644 --- a/tests/tensor/test_math.py +++ b/tests/tensor/test_math.py @@ -1398,7 +1398,7 @@ def test_bool(self): rng = np.random.default_rng(seed=utt.fetch_seed()) -TestClip = makeTester( +TestClip1 = makeTester( name="ClipTester", op=clip, expected=lambda x, y, z: np.clip(x, y, z), @@ -1465,7 +1465,7 @@ def test_bool(self): ) -class TestClip: +class TestClip2: def test_complex_value(self): for dtype in ["complex64", "complex128"]: a = vector(dtype=dtype) diff --git a/tests/tensor/test_subtensor.py b/tests/tensor/test_subtensor.py index 63acbabb29..bfee070820 100644 --- a/tests/tensor/test_subtensor.py +++ b/tests/tensor/test_subtensor.py @@ -819,7 +819,7 @@ def test_ok_list(self): assert np.allclose(val, good), (val, good) # Test reuse of output memory - if type(AdvancedSubtensor1) == AdvancedSubtensor1: + if type(AdvancedSubtensor1) is AdvancedSubtensor1: op = AdvancedSubtensor1() # When idx is a TensorConstant. if hasattr(idx, "data"):