|
51 | 51 | )
|
52 | 52 | from pytensor.tensor.random.op import RandomVariable
|
53 | 53 | from pytensor.tensor.random.utils import normalize_size_param
|
54 |
| -from pytensor.tensor.variable import TensorConstant |
| 54 | +from pytensor.tensor.variable import TensorConstant, TensorVariable |
55 | 55 |
|
56 | 56 | from pymc.logprob.abstract import _logprob_helper
|
57 |
| -from pymc.logprob.basic import icdf |
| 57 | +from pymc.logprob.basic import TensorLike, icdf |
58 | 58 | from pymc.pytensorf import normalize_rng_param
|
59 | 59 |
|
60 | 60 | try:
|
@@ -148,7 +148,7 @@ class BoundedContinuous(Continuous):
|
148 | 148 | """Base class for bounded continuous distributions."""
|
149 | 149 |
|
150 | 150 | # Indices of the arguments that define the lower and upper bounds of the distribution
|
151 |
| - bound_args_indices: list[int] | None = None |
| 151 | + bound_args_indices: tuple[int | None, int | None] | None = None |
152 | 152 |
|
153 | 153 |
|
154 | 154 | @_default_transform.register(PositiveContinuous)
|
@@ -210,7 +210,9 @@ def assert_negative_support(var, label, distname, value=-1e-6):
|
210 | 210 | return Assert(msg)(var, pt.all(pt.ge(var, 0.0)))
|
211 | 211 |
|
212 | 212 |
|
213 |
| -def get_tau_sigma(tau=None, sigma=None): |
| 213 | +def get_tau_sigma( |
| 214 | + tau: TensorLike | None = None, sigma: TensorLike | None = None |
| 215 | +) -> tuple[TensorVariable, TensorVariable]: |
214 | 216 | r"""
|
215 | 217 | Find precision and standard deviation.
|
216 | 218 |
|
@@ -239,13 +241,14 @@ def get_tau_sigma(tau=None, sigma=None):
|
239 | 241 | sigma = pt.as_tensor_variable(1.0)
|
240 | 242 | tau = pt.as_tensor_variable(1.0)
|
241 | 243 | elif tau is None:
|
| 244 | + assert sigma is not None # Just for type checker |
242 | 245 | sigma = pt.as_tensor_variable(sigma)
|
243 | 246 | # Keep tau negative, if sigma was negative, so that it will
|
244 | 247 | # fail when used
|
245 | 248 | tau = (sigma**-2.0) * pt.sign(sigma)
|
246 | 249 | else:
|
247 | 250 | tau = pt.as_tensor_variable(tau)
|
248 |
| - # Keep tau negative, if sigma was negative, so that it will |
| 251 | + # Keep sigma negative, if tau was negative, so that it will |
249 | 252 | # fail when used
|
250 | 253 | sigma = pt.abs(tau) ** -0.5 * pt.sign(tau)
|
251 | 254 |
|
|
0 commit comments