20
20
21
21
from abc import ABCMeta
22
22
from copy import copy
23
- from typing import Any , Optional , Sequence , Tuple , Union
23
+ from typing import TYPE_CHECKING
24
24
25
- import aesara
26
- import aesara .tensor as at
27
25
import dill
28
26
29
- from aesara .graph .basic import Variable
30
27
from aesara .tensor .random .op import RandomVariable
31
28
32
- from pymc3 .aesaraf import change_rv_size , pandas_to_array
33
29
from pymc3 .distributions import _logcdf , _logp
30
+
31
+ if TYPE_CHECKING :
32
+ from typing import Optional , Callable
33
+
34
+ import aesara
35
+ import aesara .graph .basic
36
+ import aesara .tensor as at
37
+
34
38
from pymc3 .util import UNSET , get_repr_for_variable
35
39
from pymc3 .vartypes import string_types
36
40
48
52
49
53
PLATFORM = sys .platform
50
54
51
- Shape = Union [int , Sequence [Union [str , type (Ellipsis )]], Variable ]
52
- Dims = Union [str , Sequence [Union [str , None , type (Ellipsis )]]]
53
- Size = Union [int , Tuple [int , ...]]
54
-
55
55
56
56
class _Unpickling :
57
57
pass
@@ -115,111 +115,13 @@ def logcdf(op, var, rvs_to_values, *dist_params, **kwargs):
115
115
return new_cls
116
116
117
117
118
- def _valid_ellipsis_position (items : Union [None , Shape , Dims , Size ]) -> bool :
119
- if items is not None and not isinstance (items , Variable ) and Ellipsis in items :
120
- if any (i == Ellipsis for i in items [:- 1 ]):
121
- return False
122
- return True
123
-
124
-
125
- def _validate_shape_dims_size (
126
- shape : Any = None , dims : Any = None , size : Any = None
127
- ) -> Tuple [Optional [Shape ], Optional [Dims ], Optional [Size ]]:
128
- # Raise on unsupported parametrization
129
- if shape is not None and dims is not None :
130
- raise ValueError (f"Passing both `shape` ({ shape } ) and `dims` ({ dims } ) is not supported!" )
131
- if dims is not None and size is not None :
132
- raise ValueError (f"Passing both `dims` ({ dims } ) and `size` ({ size } ) is not supported!" )
133
- if shape is not None and size is not None :
134
- raise ValueError (f"Passing both `shape` ({ shape } ) and `size` ({ size } ) is not supported!" )
135
-
136
- # Raise on invalid types
137
- if not isinstance (shape , (type (None ), int , list , tuple , Variable )):
138
- raise ValueError ("The `shape` parameter must be an int, list or tuple." )
139
- if not isinstance (dims , (type (None ), str , list , tuple )):
140
- raise ValueError ("The `dims` parameter must be a str, list or tuple." )
141
- if not isinstance (size , (type (None ), int , list , tuple )):
142
- raise ValueError ("The `size` parameter must be an int, list or tuple." )
143
-
144
- # Auto-convert non-tupled parameters
145
- if isinstance (shape , int ):
146
- shape = (shape ,)
147
- if isinstance (dims , str ):
148
- dims = (dims ,)
149
- if isinstance (size , int ):
150
- size = (size ,)
151
-
152
- # Convert to actual tuples
153
- if not isinstance (shape , (type (None ), tuple , Variable )):
154
- shape = tuple (shape )
155
- if not isinstance (dims , (type (None ), tuple )):
156
- dims = tuple (dims )
157
- if not isinstance (size , (type (None ), tuple )):
158
- size = tuple (size )
159
-
160
- if not _valid_ellipsis_position (shape ):
161
- raise ValueError (
162
- f"Ellipsis in `shape` may only appear in the last position. Actual: { shape } "
163
- )
164
- if not _valid_ellipsis_position (dims ):
165
- raise ValueError (f"Ellipsis in `dims` may only appear in the last position. Actual: { dims } " )
166
- if size is not None and Ellipsis in size :
167
- raise ValueError (f"The `size` parameter cannot contain an Ellipsis. Actual: { size } " )
168
- return shape , dims , size
169
-
170
-
171
118
class Distribution (metaclass = DistributionMeta ):
172
119
"""Statistical distribution"""
173
120
174
121
rv_class = None
175
122
rv_op = None
176
123
177
- def __new__ (
178
- cls ,
179
- name : str ,
180
- * args ,
181
- rng = None ,
182
- dims : Optional [Dims ] = None ,
183
- testval = None ,
184
- observed = None ,
185
- total_size = None ,
186
- transform = UNSET ,
187
- ** kwargs ,
188
- ) -> RandomVariable :
189
- """Adds a RandomVariable corresponding to a PyMC3 distribution to the current model.
190
-
191
- Note that all remaining kwargs must be compatible with ``.dist()``
192
-
193
- Parameters
194
- ----------
195
- cls : type
196
- A PyMC3 distribution.
197
- name : str
198
- Name for the new model variable.
199
- rng : optional
200
- Random number generator to use with the RandomVariable.
201
- dims : tuple, optional
202
- A tuple of dimension names known to the model.
203
- testval : optional
204
- Test value to be attached to the output RV.
205
- Must match its shape exactly.
206
- observed : optional
207
- Observed data to be passed when registering the random variable in the model.
208
- See ``Model.register_rv``.
209
- total_size : float, optional
210
- See ``Model.register_rv``.
211
- transform : optional
212
- See ``Model.register_rv``.
213
- **kwargs
214
- Keyword arguments that will be forwarded to ``.dist()``.
215
- Most prominently: ``shape`` and ``size``
216
-
217
- Returns
218
- -------
219
- rv : RandomVariable
220
- The created RV, registered in the Model.
221
- """
222
-
124
+ def __new__ (cls , name , * args , ** kwargs ):
223
125
try :
224
126
from pymc3 .model import Model
225
127
@@ -232,125 +134,40 @@ def __new__(
232
134
"for a standalone distribution."
233
135
)
234
136
235
- if not isinstance (name , string_types ):
236
- raise TypeError (f"Name needs to be a string but got: { name } " )
137
+ rng = kwargs .pop ("rng" , None )
237
138
238
139
if rng is None :
239
140
rng = model .default_rng
240
141
241
- _ , dims , _ = _validate_shape_dims_size ( dims = dims )
242
- resize = None
142
+ if not isinstance ( name , string_types ):
143
+ raise TypeError ( f"Name needs to be a string but got: { name } " )
243
144
244
- # Create the RV without specifying testval, because the testval may have a shape
245
- # that only matches after replicating with a size implied by dims (see below).
246
- rv_out = cls .dist (* args , rng = rng , testval = None , ** kwargs )
247
- n_implied = rv_out .ndim
145
+ data = kwargs .pop ("observed" , None )
248
146
249
- # `dims` are only available with this API, because `.dist()` can be used
250
- # without a modelcontext and dims are not tracked at the Aesara level.
251
- if dims is not None :
252
- if Ellipsis in dims :
253
- # Auto-complete the dims tuple to the full length
254
- dims = (* dims [:- 1 ], * [None ] * rv_out .ndim )
147
+ total_size = kwargs .pop ("total_size" , None )
255
148
256
- n_resize = len ( dims ) - n_implied
149
+ dims = kwargs . pop ( " dims" , None )
257
150
258
- # All resize dims must be known already (numerically or symbolically).
259
- unknown_resize_dims = set (dims [:n_resize ]) - set (model .dim_lengths )
260
- if unknown_resize_dims :
261
- raise KeyError (
262
- f"Dimensions { unknown_resize_dims } are unknown to the model and cannot be used to specify a `size`."
263
- )
151
+ if "shape" in kwargs :
152
+ raise DeprecationWarning ("The `shape` keyword is deprecated; use `size`." )
264
153
265
- # The numeric/symbolic resize tuple can be created using model.RV_dim_lengths
266
- resize = tuple (model .dim_lengths [dname ] for dname in dims [:n_resize ])
267
- elif observed is not None :
268
- if not hasattr (observed , "shape" ):
269
- observed = pandas_to_array (observed )
270
- n_resize = observed .ndim - n_implied
271
- resize = tuple (observed .shape [d ] for d in range (n_resize ))
272
-
273
- if resize :
274
- # A batch size was specified through `dims`, or implied by `observed`.
275
- rv_out = change_rv_size (rv_var = rv_out , new_size = resize , expand = True )
276
-
277
- if dims is not None :
278
- # Now that we have a handle on the output RV, we can register named implied dimensions that
279
- # were not yet known to the model, such that they can be used for size further downstream.
280
- for di , dname in enumerate (dims [n_resize :]):
281
- if not dname in model .dim_lengths :
282
- model .add_coord (dname , values = None , length = rv_out .shape [n_resize + di ])
154
+ transform = kwargs .pop ("transform" , UNSET )
283
155
284
- if testval is not None :
285
- # Assigning the testval earlier causes trouble because the RV may not be created with the final shape already.
286
- rv_out .tag .test_value = testval
156
+ rv_out = cls .dist (* args , rng = rng , ** kwargs )
287
157
288
- return model .register_rv (rv_out , name , observed , total_size , dims = dims , transform = transform )
158
+ return model .register_rv (rv_out , name , data , total_size , dims = dims , transform = transform )
289
159
290
160
@classmethod
291
- def dist (
292
- cls ,
293
- dist_params ,
294
- * ,
295
- shape : Optional [Shape ] = None ,
296
- size : Optional [Size ] = None ,
297
- testval = None ,
298
- ** kwargs ,
299
- ) -> RandomVariable :
300
- """Creates a RandomVariable corresponding to the `cls` distribution.
161
+ def dist (cls , dist_params , ** kwargs ):
301
162
302
- Parameters
303
- ----------
304
- dist_params
305
- shape : tuple, optional
306
- A tuple of sizes for each dimension of the new RV.
307
-
308
- Ellipsis (...) may be used in the last position of the tuple,
309
- and automatically expand to the shape implied by RV inputs.
310
- size : int, tuple, Variable, optional
311
- A scalar or tuple for replicating the RV in addition
312
- to its implied shape/dimensionality.
313
- testval : optional
314
- Test value to be attached to the output RV.
315
- Must match its shape exactly.
316
-
317
- Returns
318
- -------
319
- rv : RandomVariable
320
- The created RV.
321
- """
322
- if "dims" in kwargs :
323
- raise NotImplementedError ("The use of a `.dist(dims=...)` API is not yet supported." )
324
-
325
- shape , _ , size = _validate_shape_dims_size (shape = shape , size = size )
326
-
327
- # Create the RV without specifying size or testval.
328
- # The size will be expanded later (if necessary) and only then the testval fits.
329
- rv_native = cls .rv_op (* dist_params , size = None , ** kwargs )
163
+ testval = kwargs .pop ("testval" , None )
330
164
331
- if shape is None and size is None :
332
- size = ()
333
- elif shape is not None :
334
- if isinstance (shape , Variable ):
335
- size = ()
336
- else :
337
- if Ellipsis in shape :
338
- size = tuple (shape [:- 1 ])
339
- else :
340
- size = tuple (shape [: len (shape ) - rv_native .ndim ])
341
- # no-op conditions:
342
- # `elif size is not None` (User already specified how to expand the RV)
343
- # `else` (Unreachable)
344
-
345
- if size :
346
- rv_out = change_rv_size (rv_var = rv_native , new_size = size , expand = True )
347
- else :
348
- rv_out = rv_native
165
+ rv_var = cls .rv_op (* dist_params , ** kwargs )
349
166
350
167
if testval is not None :
351
- rv_out .tag .test_value = testval
168
+ rv_var .tag .test_value = testval
352
169
353
- return rv_out
170
+ return rv_var
354
171
355
172
def _distr_parameters_for_repr (self ):
356
173
"""Return the names of the parameters for this distribution (e.g. "mu"
0 commit comments