Skip to content

Commit 5bdfdde

Browse files
Revert "Allow parametrization through either shape, dims or size"
This reverts commit ed29203.
1 parent 366ff1b commit 5bdfdde

14 files changed

+154
-548
lines changed

RELEASE-NOTES.md

-5
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,6 @@
99

1010
### New Features
1111
- The `CAR` distribution has been added to allow for use of conditional autoregressions which often are used in spatial and network models.
12-
- The dimensionality of model variables can now be parametrized through either of `shape`, `dims` or `size` (see [#4625](https://github.com/pymc-devs/pymc3/pull/4625)):
13-
- With `shape` the length of dimensions must be given numerically or as scalar Aesara `Variables`. Using `shape` restricts the model variable to the exact length and re-sizing is no longer possible.
14-
- `dims` keeps model variables re-sizeable (for example through `pm.Data`) and leads to well defined coordinates in `InferenceData` objects.
15-
- The `size` kwarg creates new dimensions in addition to what is implied by RV parameters.
16-
- An `Ellipsis` (`...`) in the last position of `shape` or `dims` can be used as short-hand notation for implied dimensions.
1712
- ...
1813

1914
### Maintenance

pymc3/distributions/distribution.py

+26-209
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,21 @@
2020

2121
from abc import ABCMeta
2222
from copy import copy
23-
from typing import Any, Optional, Sequence, Tuple, Union
23+
from typing import TYPE_CHECKING
2424

25-
import aesara
26-
import aesara.tensor as at
2725
import dill
2826

29-
from aesara.graph.basic import Variable
3027
from aesara.tensor.random.op import RandomVariable
3128

32-
from pymc3.aesaraf import change_rv_size, pandas_to_array
3329
from pymc3.distributions import _logcdf, _logp
30+
31+
if TYPE_CHECKING:
32+
from typing import Optional, Callable
33+
34+
import aesara
35+
import aesara.graph.basic
36+
import aesara.tensor as at
37+
3438
from pymc3.util import UNSET, get_repr_for_variable
3539
from pymc3.vartypes import string_types
3640

@@ -48,10 +52,6 @@
4852

4953
PLATFORM = sys.platform
5054

51-
Shape = Union[int, Sequence[Union[str, type(Ellipsis)]], Variable]
52-
Dims = Union[str, Sequence[Union[str, None, type(Ellipsis)]]]
53-
Size = Union[int, Tuple[int, ...]]
54-
5555

5656
class _Unpickling:
5757
pass
@@ -115,111 +115,13 @@ def logcdf(op, var, rvs_to_values, *dist_params, **kwargs):
115115
return new_cls
116116

117117

118-
def _valid_ellipsis_position(items: Union[None, Shape, Dims, Size]) -> bool:
119-
if items is not None and not isinstance(items, Variable) and Ellipsis in items:
120-
if any(i == Ellipsis for i in items[:-1]):
121-
return False
122-
return True
123-
124-
125-
def _validate_shape_dims_size(
126-
shape: Any = None, dims: Any = None, size: Any = None
127-
) -> Tuple[Optional[Shape], Optional[Dims], Optional[Size]]:
128-
# Raise on unsupported parametrization
129-
if shape is not None and dims is not None:
130-
raise ValueError(f"Passing both `shape` ({shape}) and `dims` ({dims}) is not supported!")
131-
if dims is not None and size is not None:
132-
raise ValueError(f"Passing both `dims` ({dims}) and `size` ({size}) is not supported!")
133-
if shape is not None and size is not None:
134-
raise ValueError(f"Passing both `shape` ({shape}) and `size` ({size}) is not supported!")
135-
136-
# Raise on invalid types
137-
if not isinstance(shape, (type(None), int, list, tuple, Variable)):
138-
raise ValueError("The `shape` parameter must be an int, list or tuple.")
139-
if not isinstance(dims, (type(None), str, list, tuple)):
140-
raise ValueError("The `dims` parameter must be a str, list or tuple.")
141-
if not isinstance(size, (type(None), int, list, tuple)):
142-
raise ValueError("The `size` parameter must be an int, list or tuple.")
143-
144-
# Auto-convert non-tupled parameters
145-
if isinstance(shape, int):
146-
shape = (shape,)
147-
if isinstance(dims, str):
148-
dims = (dims,)
149-
if isinstance(size, int):
150-
size = (size,)
151-
152-
# Convert to actual tuples
153-
if not isinstance(shape, (type(None), tuple, Variable)):
154-
shape = tuple(shape)
155-
if not isinstance(dims, (type(None), tuple)):
156-
dims = tuple(dims)
157-
if not isinstance(size, (type(None), tuple)):
158-
size = tuple(size)
159-
160-
if not _valid_ellipsis_position(shape):
161-
raise ValueError(
162-
f"Ellipsis in `shape` may only appear in the last position. Actual: {shape}"
163-
)
164-
if not _valid_ellipsis_position(dims):
165-
raise ValueError(f"Ellipsis in `dims` may only appear in the last position. Actual: {dims}")
166-
if size is not None and Ellipsis in size:
167-
raise ValueError(f"The `size` parameter cannot contain an Ellipsis. Actual: {size}")
168-
return shape, dims, size
169-
170-
171118
class Distribution(metaclass=DistributionMeta):
172119
"""Statistical distribution"""
173120

174121
rv_class = None
175122
rv_op = None
176123

177-
def __new__(
178-
cls,
179-
name: str,
180-
*args,
181-
rng=None,
182-
dims: Optional[Dims] = None,
183-
testval=None,
184-
observed=None,
185-
total_size=None,
186-
transform=UNSET,
187-
**kwargs,
188-
) -> RandomVariable:
189-
"""Adds a RandomVariable corresponding to a PyMC3 distribution to the current model.
190-
191-
Note that all remaining kwargs must be compatible with ``.dist()``
192-
193-
Parameters
194-
----------
195-
cls : type
196-
A PyMC3 distribution.
197-
name : str
198-
Name for the new model variable.
199-
rng : optional
200-
Random number generator to use with the RandomVariable.
201-
dims : tuple, optional
202-
A tuple of dimension names known to the model.
203-
testval : optional
204-
Test value to be attached to the output RV.
205-
Must match its shape exactly.
206-
observed : optional
207-
Observed data to be passed when registering the random variable in the model.
208-
See ``Model.register_rv``.
209-
total_size : float, optional
210-
See ``Model.register_rv``.
211-
transform : optional
212-
See ``Model.register_rv``.
213-
**kwargs
214-
Keyword arguments that will be forwarded to ``.dist()``.
215-
Most prominently: ``shape`` and ``size``
216-
217-
Returns
218-
-------
219-
rv : RandomVariable
220-
The created RV, registered in the Model.
221-
"""
222-
124+
def __new__(cls, name, *args, **kwargs):
223125
try:
224126
from pymc3.model import Model
225127

@@ -232,125 +134,40 @@ def __new__(
232134
"for a standalone distribution."
233135
)
234136

235-
if not isinstance(name, string_types):
236-
raise TypeError(f"Name needs to be a string but got: {name}")
137+
rng = kwargs.pop("rng", None)
237138

238139
if rng is None:
239140
rng = model.default_rng
240141

241-
_, dims, _ = _validate_shape_dims_size(dims=dims)
242-
resize = None
142+
if not isinstance(name, string_types):
143+
raise TypeError(f"Name needs to be a string but got: {name}")
243144

244-
# Create the RV without specifying testval, because the testval may have a shape
245-
# that only matches after replicating with a size implied by dims (see below).
246-
rv_out = cls.dist(*args, rng=rng, testval=None, **kwargs)
247-
n_implied = rv_out.ndim
145+
data = kwargs.pop("observed", None)
248146

249-
# `dims` are only available with this API, because `.dist()` can be used
250-
# without a modelcontext and dims are not tracked at the Aesara level.
251-
if dims is not None:
252-
if Ellipsis in dims:
253-
# Auto-complete the dims tuple to the full length
254-
dims = (*dims[:-1], *[None] * rv_out.ndim)
147+
total_size = kwargs.pop("total_size", None)
255148

256-
n_resize = len(dims) - n_implied
149+
dims = kwargs.pop("dims", None)
257150

258-
# All resize dims must be known already (numerically or symbolically).
259-
unknown_resize_dims = set(dims[:n_resize]) - set(model.dim_lengths)
260-
if unknown_resize_dims:
261-
raise KeyError(
262-
f"Dimensions {unknown_resize_dims} are unknown to the model and cannot be used to specify a `size`."
263-
)
151+
if "shape" in kwargs:
152+
raise DeprecationWarning("The `shape` keyword is deprecated; use `size`.")
264153

265-
# The numeric/symbolic resize tuple can be created using model.RV_dim_lengths
266-
resize = tuple(model.dim_lengths[dname] for dname in dims[:n_resize])
267-
elif observed is not None:
268-
if not hasattr(observed, "shape"):
269-
observed = pandas_to_array(observed)
270-
n_resize = observed.ndim - n_implied
271-
resize = tuple(observed.shape[d] for d in range(n_resize))
272-
273-
if resize:
274-
# A batch size was specified through `dims`, or implied by `observed`.
275-
rv_out = change_rv_size(rv_var=rv_out, new_size=resize, expand=True)
276-
277-
if dims is not None:
278-
# Now that we have a handle on the output RV, we can register named implied dimensions that
279-
# were not yet known to the model, such that they can be used for size further downstream.
280-
for di, dname in enumerate(dims[n_resize:]):
281-
if not dname in model.dim_lengths:
282-
model.add_coord(dname, values=None, length=rv_out.shape[n_resize + di])
154+
transform = kwargs.pop("transform", UNSET)
283155

284-
if testval is not None:
285-
# Assigning the testval earlier causes trouble because the RV may not be created with the final shape already.
286-
rv_out.tag.test_value = testval
156+
rv_out = cls.dist(*args, rng=rng, **kwargs)
287157

288-
return model.register_rv(rv_out, name, observed, total_size, dims=dims, transform=transform)
158+
return model.register_rv(rv_out, name, data, total_size, dims=dims, transform=transform)
289159

290160
@classmethod
291-
def dist(
292-
cls,
293-
dist_params,
294-
*,
295-
shape: Optional[Shape] = None,
296-
size: Optional[Size] = None,
297-
testval=None,
298-
**kwargs,
299-
) -> RandomVariable:
300-
"""Creates a RandomVariable corresponding to the `cls` distribution.
161+
def dist(cls, dist_params, **kwargs):
301162

302-
Parameters
303-
----------
304-
dist_params
305-
shape : tuple, optional
306-
A tuple of sizes for each dimension of the new RV.
307-
308-
Ellipsis (...) may be used in the last position of the tuple,
309-
and automatically expand to the shape implied by RV inputs.
310-
size : int, tuple, Variable, optional
311-
A scalar or tuple for replicating the RV in addition
312-
to its implied shape/dimensionality.
313-
testval : optional
314-
Test value to be attached to the output RV.
315-
Must match its shape exactly.
316-
317-
Returns
318-
-------
319-
rv : RandomVariable
320-
The created RV.
321-
"""
322-
if "dims" in kwargs:
323-
raise NotImplementedError("The use of a `.dist(dims=...)` API is not yet supported.")
324-
325-
shape, _, size = _validate_shape_dims_size(shape=shape, size=size)
326-
327-
# Create the RV without specifying size or testval.
328-
# The size will be expanded later (if necessary) and only then the testval fits.
329-
rv_native = cls.rv_op(*dist_params, size=None, **kwargs)
163+
testval = kwargs.pop("testval", None)
330164

331-
if shape is None and size is None:
332-
size = ()
333-
elif shape is not None:
334-
if isinstance(shape, Variable):
335-
size = ()
336-
else:
337-
if Ellipsis in shape:
338-
size = tuple(shape[:-1])
339-
else:
340-
size = tuple(shape[: len(shape) - rv_native.ndim])
341-
# no-op conditions:
342-
# `elif size is not None` (User already specified how to expand the RV)
343-
# `else` (Unreachable)
344-
345-
if size:
346-
rv_out = change_rv_size(rv_var=rv_native, new_size=size, expand=True)
347-
else:
348-
rv_out = rv_native
165+
rv_var = cls.rv_op(*dist_params, **kwargs)
349166

350167
if testval is not None:
351-
rv_out.tag.test_value = testval
168+
rv_var.tag.test_value = testval
352169

353-
return rv_out
170+
return rv_var
354171

355172
def _distr_parameters_for_repr(self):
356173
"""Return the names of the parameters for this distribution (e.g. "mu"

0 commit comments

Comments
 (0)