3232from aesara .tensor .nlinalg import det , eigh , matrix_inverse , trace
3333from aesara .tensor .random .basic import MultinomialRV , dirichlet , multivariate_normal
3434from aesara .tensor .random .op import RandomVariable , default_supp_shape_from_params
35- from aesara .tensor .random .utils import broadcast_params
35+ from aesara .tensor .random .utils import broadcast_params , normalize_size_param
3636from aesara .tensor .slinalg import Cholesky
3737from aesara .tensor .slinalg import solve_lower_triangular as solve_lower
3838from aesara .tensor .slinalg import solve_upper_triangular as solve_upper
@@ -1134,6 +1134,19 @@ def make_node(self, rng, size, dtype, n, eta, D):
11341134
11351135 D = at .as_tensor_variable (D )
11361136
1137+ # We resize the sd_dist `D` automatically so that it has (size x n) independent
1138+ # draws which is what the `_LKJCholeskyCovRV.rng_fn` expects. This makes the
1139+ # random and logp methods equivalent, as the latter also assumes a unique value
1140+ # for each diagonal element.
1141+ # Since `eta` and `n` are forced to be scalars we don't need to worry about
1142+ # implied batched dimensions for the time being.
1143+ size = normalize_size_param (size )
1144+ if D .owner .op .ndim_supp == 0 :
1145+ D = change_rv_size (D , at .concatenate ((size , (n ,))))
1146+ else :
1147+ # The support shape must be `n` but we have no way of controlling it
1148+ D = change_rv_size (D , size )
1149+
11371150 return super ().make_node (rng , size , dtype , n , eta , D )
11381151
11391152 def _infer_shape (self , size , dist_params , param_shapes = None ):
@@ -1179,7 +1192,7 @@ def __new__(cls, name, eta, n, sd_dist, **kwargs):
11791192 return super ().__new__ (cls , name , eta , n , sd_dist , ** kwargs )
11801193
11811194 @classmethod
1182- def dist (cls , eta , n , sd_dist , size = None , ** kwargs ):
1195+ def dist (cls , eta , n , sd_dist , ** kwargs ):
11831196 eta = at .as_tensor_variable (floatX (eta ))
11841197 n = at .as_tensor_variable (intX (n ))
11851198
@@ -1191,18 +1204,6 @@ def dist(cls, eta, n, sd_dist, size=None, **kwargs):
11911204 ):
11921205 raise TypeError ("sd_dist must be a scalar or vector distribution variable" )
11931206
1194- # We resize the sd_dist automatically so that it has (size x n) independent draws
1195- # which is what the `_LKJCholeskyCovRV.rng_fn` expects. This makes the random
1196- # and logp methods equivalent, as the latter also assumes a unique value for each
1197- # diagonal element.
1198- # Since `eta` and `n` are forced to be scalars we don't need to worry about
1199- # implied batched dimensions for the time being.
1200- if sd_dist .owner .op .ndim_supp == 0 :
1201- sd_dist = change_rv_size (sd_dist , to_tuple (size ) + (n ,))
1202- else :
1203- # The support shape must be `n` but we have no way of controlling it
1204- sd_dist = change_rv_size (sd_dist , to_tuple (size ))
1205-
12061207 # sd_dist is part of the generative graph, but should be completely ignored
12071208 # by the logp graph, since the LKJ logp explicitly includes these terms.
12081209 # Setting sd_dist.tag.ignore_logprob to True, will prevent Aeppl warning about
@@ -1211,7 +1212,7 @@ def dist(cls, eta, n, sd_dist, size=None, **kwargs):
12111212 # sd_dist prior components from the logp expression.
12121213 sd_dist .tag .ignore_logprob = True
12131214
1214- return super ().dist ([n , eta , sd_dist ], size = size , ** kwargs )
1215+ return super ().dist ([n , eta , sd_dist ], ** kwargs )
12151216
12161217 def moment (rv , size , n , eta , sd_dists ):
12171218 diag_idxs = (at .cumsum (at .arange (1 , n + 1 )) - 1 ).astype ("int32" )
0 commit comments