19
19
20
20
from aesara import scan
21
21
from aesara .tensor .random .op import RandomVariable
22
+ from aesara .tensor .random .utils import normalize_size_param
22
23
23
24
from pymc .aesaraf import change_rv_size , floatX , intX
24
25
from pymc .distributions import distribution , logprob , multivariate
25
26
from pymc .distributions .continuous import Flat , Normal , get_tau_sigma
26
27
from pymc .distributions .dist_math import check_parameters
27
- from pymc .distributions .shape_utils import to_tuple
28
+ from pymc .distributions .shape_utils import rv_size_is_none , to_tuple
28
29
from pymc .util import check_dist_not_registered
29
30
30
31
__all__ = [
@@ -54,6 +55,16 @@ def make_node(self, rng, size, dtype, mu, sigma, init, steps):
54
55
if not steps .ndim == 0 or not steps .dtype .startswith ("int" ):
55
56
raise ValueError ("steps must be an integer scalar (ndim=0)." )
56
57
58
+ mu = at .as_tensor_variable (mu )
59
+ sigma = at .as_tensor_variable (sigma )
60
+ init = at .as_tensor_variable (init )
61
+
62
+ # Resize init distribution
63
+ size = normalize_size_param (size )
64
+ # If not explicit, size is determined by the shapes of mu, sigma, and init
65
+ init_size = size if not rv_size_is_none (size ) else at .broadcast_shape (mu , sigma , init )
66
+ init = change_rv_size (init , init_size )
67
+
57
68
return super ().make_node (rng , size , dtype , mu , sigma , init , steps )
58
69
59
70
def _supp_shape_from_params (self , dist_params , reop_param_idx = 0 , param_shapes = None ):
@@ -160,15 +171,9 @@ def dist(
160
171
raise ValueError ("Must specify steps parameter" )
161
172
steps = at .as_tensor_variable (intX (steps ))
162
173
163
- shape = kwargs .get ("shape" , None )
164
- if size is None and shape is None :
165
- init_size = None
166
- else :
167
- init_size = to_tuple (size ) if size is not None else to_tuple (shape )[:- 1 ]
168
-
169
174
# If no scalar distribution is passed then initialize with a Normal of same mu and sigma
170
175
if init is None :
171
- init = Normal .dist (mu , sigma , size = init_size )
176
+ init = Normal .dist (mu , sigma )
172
177
else :
173
178
if not (
174
179
isinstance (init , at .TensorVariable )
@@ -178,13 +183,6 @@ def dist(
178
183
):
179
184
raise TypeError ("init must be a univariate distribution variable" )
180
185
181
- if init_size is not None :
182
- init = change_rv_size (init , init_size )
183
- else :
184
- # If not explicit, size is determined by the shapes of mu, sigma, and init
185
- bcast_shape = at .broadcast_arrays (mu , sigma , init )[0 ].shape
186
- init = change_rv_size (init , bcast_shape )
187
-
188
186
# Ignores logprob of init var because that's accounted for in the logp method
189
187
init .tag .ignore_logprob = True
190
188
0 commit comments