Skip to content

Commit 6b4b71f

Browse files
committed
Add more type hints to distribution parameter
1 parent dc11a99 commit 6b4b71f

File tree

1 file changed

+13
-4
lines changed

1 file changed

+13
-4
lines changed

pymc/distributions/continuous.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -959,7 +959,7 @@ def dist(
959959
mu: Optional[DIST_PARAMETER_TYPES] = None,
960960
lam: Optional[DIST_PARAMETER_TYPES] = None,
961961
phi: Optional[DIST_PARAMETER_TYPES] = None,
962-
alpha=0.0,
962+
alpha: Optional[DIST_PARAMETER_TYPES] = 0.0,
963963
**kwargs,
964964
):
965965
mu, lam, phi = cls.get_mu_lam_phi(mu, lam, phi)
@@ -1128,7 +1128,16 @@ class Beta(UnitContinuous):
11281128
rv_op = pytensor.tensor.random.beta
11291129

11301130
@classmethod
1131-
def dist(cls, alpha=None, beta=None, mu=None, sigma=None, nu=None, *args, **kwargs):
1131+
def dist(
1132+
cls,
1133+
alpha: Optional[DIST_PARAMETER_TYPES] = None,
1134+
beta: Optional[DIST_PARAMETER_TYPES] = None,
1135+
mu: Optional[DIST_PARAMETER_TYPES] = None,
1136+
sigma: Optional[DIST_PARAMETER_TYPES] = None,
1137+
nu: Optional[DIST_PARAMETER_TYPES] = None,
1138+
*args,
1139+
**kwargs,
1140+
):
11321141
alpha, beta = cls.get_alpha_beta(alpha, beta, mu, sigma, nu)
11331142
alpha = pt.as_tensor_variable(floatX(alpha))
11341143
beta = pt.as_tensor_variable(floatX(beta))
@@ -1256,7 +1265,7 @@ class Kumaraswamy(UnitContinuous):
12561265
rv_op = kumaraswamy
12571266

12581267
@classmethod
1259-
def dist(cls, a, b, *args, **kwargs):
1268+
def dist(cls, a: DIST_PARAMETER_TYPES, b: DIST_PARAMETER_TYPES, *args, **kwargs):
12601269
a = pt.as_tensor_variable(floatX(a))
12611270
b = pt.as_tensor_variable(floatX(b))
12621271

@@ -1342,7 +1351,7 @@ class Exponential(PositiveContinuous):
13421351
rv_op = exponential
13431352

13441353
@classmethod
1345-
def dist(cls, lam, *args, **kwargs):
1354+
def dist(cls, lam: DIST_PARAMETER_TYPES, *args, **kwargs):
13461355
lam = pt.as_tensor_variable(floatX(lam))
13471356

13481357
# PyTensor exponential op is parametrized in terms of mu (1/lam)

0 commit comments

Comments
 (0)