Skip to content

Commit 938aff4

Browse files
Fix some type hinting to help with migrating Distribution (#7484)
1 parent 1457626 commit 938aff4

File tree

2 files changed

+15
-8
lines changed

2 files changed

+15
-8
lines changed

pymc/distributions/continuous.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,10 @@
5151
)
5252
from pytensor.tensor.random.op import RandomVariable
5353
from pytensor.tensor.random.utils import normalize_size_param
54-
from pytensor.tensor.variable import TensorConstant
54+
from pytensor.tensor.variable import TensorConstant, TensorVariable
5555

5656
from pymc.logprob.abstract import _logprob_helper
57-
from pymc.logprob.basic import icdf
57+
from pymc.logprob.basic import TensorLike, icdf
5858
from pymc.pytensorf import normalize_rng_param
5959

6060
try:
@@ -148,7 +148,7 @@ class BoundedContinuous(Continuous):
148148
"""Base class for bounded continuous distributions."""
149149

150150
# Indices of the arguments that define the lower and upper bounds of the distribution
151-
bound_args_indices: list[int] | None = None
151+
bound_args_indices: tuple[int | None, int | None] | None = None
152152

153153

154154
@_default_transform.register(PositiveContinuous)
@@ -210,7 +210,9 @@ def assert_negative_support(var, label, distname, value=-1e-6):
210210
return Assert(msg)(var, pt.all(pt.ge(var, 0.0)))
211211

212212

213-
def get_tau_sigma(tau=None, sigma=None):
213+
def get_tau_sigma(
214+
tau: TensorLike | None = None, sigma: TensorLike | None = None
215+
) -> tuple[TensorVariable, TensorVariable]:
214216
r"""
215217
Find precision and standard deviation.
216218
@@ -239,13 +241,14 @@ def get_tau_sigma(tau=None, sigma=None):
239241
sigma = pt.as_tensor_variable(1.0)
240242
tau = pt.as_tensor_variable(1.0)
241243
elif tau is None:
244+
assert sigma is not None # Just for type checker
242245
sigma = pt.as_tensor_variable(sigma)
243246
# Keep tau negative, if sigma was negative, so that it will
244247
# fail when used
245248
tau = (sigma**-2.0) * pt.sign(sigma)
246249
else:
247250
tau = pt.as_tensor_variable(tau)
248-
# Keep tau negative, if sigma was negative, so that it will
251+
# Keep sigma negative, if tau was negative, so that it will
249252
# fail when used
250253
sigma = pt.abs(tau) ** -0.5 * pt.sign(tau)
251254

pymc/distributions/distribution.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from abc import ABCMeta
2222
from collections.abc import Callable, Sequence
2323
from functools import singledispatch
24-
from typing import TypeAlias
24+
from typing import Any, TypeAlias
2525

2626
import numpy as np
2727

@@ -423,8 +423,12 @@ def change_symbolic_rv_size(op: SymbolicRandomVariable, rv, new_size, expand) ->
423423
class Distribution(metaclass=DistributionMeta):
424424
"""Statistical distribution."""
425425

426-
rv_op: [RandomVariable, SymbolicRandomVariable] = None
427-
rv_type: MetaType = None
426+
# rv_op and _type are set to None via the DistributionMeta.__new__
427+
# if not specified as class attributes in subclasses of Distribution.
428+
# rv_op can either be a class (see the Normal class) or a method
429+
# (see the Censored class), both callable to return a TensorVariable.
430+
rv_op: Any = None
431+
rv_type: MetaType | None = None
428432

429433
def __new__(
430434
cls,

0 commit comments

Comments
 (0)