diff --git a/conda-envs/environment-dev-py37.yml b/conda-envs/environment-dev-py37.yml index 1cf01e6cca..1d1e936ab9 100644 --- a/conda-envs/environment-dev-py37.yml +++ b/conda-envs/environment-dev-py37.yml @@ -5,7 +5,7 @@ channels: - defaults dependencies: - aeppl=0.0.26 -- aesara=2.3.8 +- aesara=2.4.0 - arviz>=0.11.4 - blas - cachetools>=4.2.1 @@ -24,7 +24,7 @@ dependencies: - pytest>=3.0 - python-graphviz - python=3.7 -- scipy>=1.4.1,<1.8.0 +- scipy>=1.4.1 - sphinx-copybutton - sphinx-notfound-page - sphinx>=1.5 diff --git a/conda-envs/environment-dev-py38.yml b/conda-envs/environment-dev-py38.yml index 1cb5abf637..d4426963e1 100644 --- a/conda-envs/environment-dev-py38.yml +++ b/conda-envs/environment-dev-py38.yml @@ -5,7 +5,7 @@ channels: - defaults dependencies: - aeppl=0.0.26 -- aesara=2.3.8 +- aesara=2.4.0 - arviz>=0.11.4 - blas - cachetools>=4.2.1 @@ -24,7 +24,7 @@ dependencies: - pytest>=3.0 - python-graphviz - python=3.8 -- scipy>=1.4.1,<1.8.0 +- scipy>=1.4.1 - sphinx-copybutton - sphinx-notfound-page - sphinx>=1.5 diff --git a/conda-envs/environment-dev-py39.yml b/conda-envs/environment-dev-py39.yml index 4c683a70a7..6e50880151 100644 --- a/conda-envs/environment-dev-py39.yml +++ b/conda-envs/environment-dev-py39.yml @@ -5,7 +5,7 @@ channels: - defaults dependencies: - aeppl=0.0.26 -- aesara=2.3.8 +- aesara=2.4.0 - arviz>=0.11.4 - blas - cachetools>=4.2.1 @@ -24,7 +24,7 @@ dependencies: - pytest>=3.0 - python-graphviz - python=3.9 -- scipy>=1.4.1,<1.8.0 +- scipy>=1.4.1 - sphinx-copybutton - sphinx-notfound-page - sphinx>=1.5 diff --git a/conda-envs/environment-test-py37.yml b/conda-envs/environment-test-py37.yml index 2c93125602..ab8bdbe386 100644 --- a/conda-envs/environment-test-py37.yml +++ b/conda-envs/environment-test-py37.yml @@ -5,7 +5,7 @@ channels: - defaults dependencies: - aeppl=0.0.26 -- aesara=2.3.8 +- aesara=2.4.0 - arviz>=0.11.4 - blas - cachetools>=4.2.1 @@ -22,5 +22,5 @@ dependencies: - pytest>=3.0 - python-graphviz - python=3.7 -- scipy>=1.4.1,<1.8.0 +- scipy>=1.4.1 - typing-extensions>=3.7.4 diff --git a/conda-envs/environment-test-py38.yml b/conda-envs/environment-test-py38.yml index 5d87f33945..9c51bd0d67 100644 --- a/conda-envs/environment-test-py38.yml +++ b/conda-envs/environment-test-py38.yml @@ -5,7 +5,7 @@ channels: - defaults dependencies: - aeppl=0.0.26 -- aesara=2.3.8 +- aesara=2.4.0 - arviz>=0.11.4 - blas - cachetools>=4.2.1 @@ -22,5 +22,5 @@ dependencies: - pytest>=3.0 - python-graphviz - python=3.8 -- scipy>=1.4.1,<1.8.0 +- scipy>=1.4.1 - typing-extensions>=3.7.4 diff --git a/conda-envs/environment-test-py39.yml b/conda-envs/environment-test-py39.yml index 02d16cb22d..cc869f9fe6 100644 --- a/conda-envs/environment-test-py39.yml +++ b/conda-envs/environment-test-py39.yml @@ -5,7 +5,7 @@ channels: - defaults dependencies: - aeppl=0.0.26 -- aesara=2.3.8 +- aesara=2.4.0 - arviz>=0.11.4 - blas - cachetools>=4.2.1 @@ -22,5 +22,5 @@ dependencies: - pytest>=3.0 - python-graphviz - python=3.9 -- scipy>=1.4.1,<1.8.0 +- scipy>=1.4.1 - typing-extensions>=3.7.4 diff --git a/conda-envs/windows-environment-dev-py38.yml b/conda-envs/windows-environment-dev-py38.yml index f165a7a5b6..dfacd28c09 100644 --- a/conda-envs/windows-environment-dev-py38.yml +++ b/conda-envs/windows-environment-dev-py38.yml @@ -5,7 +5,7 @@ channels: dependencies: # base dependencies (see install guide for Windows) - aeppl=0.0.26 -- aesara=2.3.8 +- aesara=2.4.0 - arviz>=0.11.4 - blas - cachetools>=4.2.1 @@ -17,7 +17,7 @@ dependencies: - pip - python=3.8 - python-graphviz -- scipy>=1.4.1,<1.8.0 +- scipy>=1.4.1 - typing-extensions>=3.7.4 # Extra stuff for dev, testing and docs build - ipython>=7.16 diff --git a/conda-envs/windows-environment-test-py38.yml b/conda-envs/windows-environment-test-py38.yml index 3adac2cd4a..11beba6bcc 100644 --- a/conda-envs/windows-environment-test-py38.yml +++ b/conda-envs/windows-environment-test-py38.yml @@ -5,7 +5,7 @@ channels: dependencies: # base dependencies (see install guide for Windows) - aeppl=0.0.26 -- aesara=2.3.8 +- aesara=2.4.0 - arviz>=0.11.4 - blas - cachetools>=4.2.1 @@ -21,7 +21,7 @@ dependencies: - pip - python=3.8 - python-graphviz -- scipy>=1.4.1,<1.8.0 +- scipy>=1.4.1 - typing-extensions>=3.7.4 # Extra stuff for testing - ipython>=7.16 diff --git a/pymc/aesaraf.py b/pymc/aesaraf.py index 8616f6ded6..1b04575e8c 100644 --- a/pymc/aesaraf.py +++ b/pymc/aesaraf.py @@ -173,9 +173,9 @@ def change_rv_size( tag = rv_var.tag if expand: - if rv_node.op.ndim_supp == 0 and at.get_vector_length(size) == 0: - size = rv_node.op._infer_shape(size, dist_params) - new_size = tuple(new_size) + tuple(size) + old_shape = tuple(rv_node.op._infer_shape(size, dist_params)) + old_size = old_shape[: len(old_shape) - rv_node.op.ndim_supp] + new_size = tuple(new_size) + tuple(old_size) # Make sure the new size is a tensor. This dtype-aware conversion helps # to not unnecessarily pick up a `Cast` in some cases (see #4652). diff --git a/pymc/distributions/censored.py b/pymc/distributions/censored.py index 384dcb4c22..9457039db0 100644 --- a/pymc/distributions/censored.py +++ b/pymc/distributions/censored.py @@ -100,11 +100,12 @@ def ndim_supp(cls, *dist_params): return 0 @classmethod - def change_size(cls, rv, new_size): + def change_size(cls, rv, new_size, expand=False): dist_node = rv.tag.dist.owner lower = rv.tag.lower upper = rv.tag.upper rng, old_size, dtype, *dist_params = dist_node.inputs + new_size = new_size if not expand else tuple(new_size) + tuple(old_size) new_dist = dist_node.op.make_node(rng, new_size, dtype, *dist_params).default_output() return cls.rv_op(new_dist, lower, upper) diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index 34bb0a7172..13eb9c60a7 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -19,7 +19,7 @@ from abc import ABCMeta from functools import singledispatch -from typing import Callable, Iterable, Optional, Sequence, Tuple, Union +from typing import Callable, Iterable, Optional, Sequence, Tuple, Union, cast import aesara import numpy as np @@ -45,7 +45,6 @@ convert_shape, convert_size, find_size, - maybe_resize, resize_from_dims, resize_from_observed, ) @@ -353,17 +352,11 @@ def dist( # Create the RV with a `size` right away. # This is not necessarily the final result. rv_out = cls.rv_op(*dist_params, size=create_size, **kwargs) - rv_out = maybe_resize( - rv_out, - cls.rv_op, - dist_params, - ndim_expected, - ndim_batch, - ndim_supp, - shape, - size, - **kwargs, - ) + + # Replicate dimensions may be prepended via a shape with Ellipsis as the last element: + if shape is not None and Ellipsis in shape: + replicate_shape = cast(StrongShape, shape[:-1]) + rv_out = change_rv_size(rv_var=rv_out, new_size=replicate_shape, expand=True) rng = kwargs.pop("rng", None) if ( @@ -589,18 +582,11 @@ def dist( # Create the RV with a `size` right away. # This is not necessarily the final result. graph = cls.rv_op(*dist_params, size=create_size, **kwargs) - graph = maybe_resize( - graph, - cls.rv_op, - dist_params, - ndim_expected, - ndim_batch, - ndim_supp, - shape, - size, - change_rv_size_fn=cls.change_size, - **kwargs, - ) + + # Replicate dimensions may be prepended via a shape with Ellipsis as the last element: + if shape is not None and Ellipsis in shape: + replicate_shape = cast(StrongShape, shape[:-1]) + graph = cls.change_size(rv=graph, new_size=replicate_shape, expand=True) rngs = kwargs.pop("rngs", None) if rngs is not None: diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index ff4e34d68d..27be614a67 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -247,7 +247,7 @@ def dist(cls, mu, cov=None, tau=None, chol=None, lower=True, **kwargs): def get_moment(rv, size, mu, cov): moment = mu if not rv_size_is_none(size): - moment_size = at.concatenate([size, mu.shape]) + moment_size = at.concatenate([size, [mu.shape[-1]]]) moment = at.full(moment_size, mu) return moment @@ -301,18 +301,13 @@ def _shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None): @classmethod def rng_fn(cls, rng, nu, mu, cov, size): - # Don't reassign broadcasted cov, since MvNormal expects two dimensional cov only. - mu, _ = broadcast_params([mu, cov], cls.ndims_params[1:]) - - chi2_samples = np.sqrt(rng.chisquare(nu, size=size) / nu) - # Add distribution shape to chi2 samples - chi2_samples = chi2_samples.reshape(chi2_samples.shape + (1,) * len(mu.shape)) - mv_samples = multivariate_normal.rng_fn(rng=rng, mean=np.zeros_like(mu), cov=cov, size=size) - size = tuple(size or ()) + # Take chi2 draws and add an axis of length 1 to the right for correct broadcasting below + chi2_samples = np.sqrt(rng.chisquare(nu, size=size) / nu)[..., None] + if size: - mu = np.broadcast_to(mu, size + mu.shape) + mu = np.broadcast_to(mu, size + (mu.shape[-1],)) return (mv_samples / chi2_samples) + mu @@ -379,7 +374,7 @@ def dist(cls, nu, Sigma=None, mu=None, cov=None, tau=None, chol=None, lower=True def get_moment(rv, size, nu, mu, cov): moment = mu if not rv_size_is_none(size): - moment_size = at.concatenate([size, moment.shape]) + moment_size = at.concatenate([size, [mu.shape[-1]]]) moment = at.full(moment_size, moment) return moment @@ -453,7 +448,7 @@ def get_moment(rv, size, a): norm_constant = at.sum(a, axis=-1)[..., None] moment = a / norm_constant if not rv_size_is_none(size): - moment = at.full(at.concatenate([size, a.shape]), moment) + moment = at.full(at.concatenate([size, [a.shape[-1]]]), moment) return moment def logp(value, a): @@ -497,8 +492,8 @@ def rng_fn(cls, rng, n, p, size): size = tuple(size or ()) if size: - n = np.broadcast_to(n, size + n.shape) - p = np.broadcast_to(p, size + p.shape) + n = np.broadcast_to(n, size) + p = np.broadcast_to(p, size + (p.shape[-1],)) res = np.empty(p.shape) for idx in np.ndindex(p.shape[:-1]): @@ -571,7 +566,7 @@ def get_moment(rv, size, n, p): inc_bool_arr = at.abs_(diff) > 0 mode = at.inc_subtensor(mode[inc_bool_arr.nonzero()], diff[inc_bool_arr.nonzero()]) if not rv_size_is_none(size): - output_size = at.concatenate([size, p.shape]) + output_size = at.concatenate([size, [p.shape[-1]]]) mode = at.full(output_size, mode) return mode @@ -623,8 +618,8 @@ def rng_fn(cls, rng, n, a, size): size = tuple(size or ()) if size: - n = np.broadcast_to(n, size + n.shape) - a = np.broadcast_to(a, size + a.shape) + n = np.broadcast_to(n, size) + a = np.broadcast_to(a, size + (a.shape[-1],)) res = np.empty(a.shape) for idx in np.ndindex(a.shape[:-1]): @@ -688,10 +683,14 @@ def get_moment(rv, size, n, a): diff = n - at.sum(mode, axis=-1, keepdims=True) inc_bool_arr = at.abs_(diff) > 0 mode = at.inc_subtensor(mode[inc_bool_arr.nonzero()], diff[inc_bool_arr.nonzero()]) - # Reshape mode according to base shape (ignoring size) - mode = at.reshape(mode, rv.shape[size.size :]) + + # Reshape mode according to dimensions implied by the parameters + # This can include axes of length 1 + _, p_bcast = broadcast_params([n, p], ndims_params=[0, 1]) + mode = at.reshape(mode, p_bcast.shape) + if not rv_size_is_none(size): - output_size = at.concatenate([size, mode.shape]) + output_size = at.concatenate([size, [p.shape[-1]]]) mode = at.full(output_size, mode) return mode @@ -906,7 +905,7 @@ class WishartRV(RandomVariable): def _shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None): # The shape of second parameter `V` defines the shape of the output. - return dist_params[1].shape + return dist_params[1].shape[-2:] @classmethod def rng_fn(cls, rng, nu, V, size): @@ -1617,7 +1616,7 @@ class MatrixNormalRV(RandomVariable): _print_name = ("MatrixNormal", "\\operatorname{MatrixNormal}") def _infer_shape(self, size, dist_params, param_shapes=None): - shape = tuple(size) + tuple(dist_params[0].shape) + shape = tuple(size) + tuple(dist_params[0].shape[-2:]) return shape @classmethod @@ -1747,18 +1746,6 @@ def dist( cholesky = Cholesky(lower=True, on_error="raise") - if kwargs.get("size", None) is not None: - raise NotImplementedError("MatrixNormal doesn't support size argument") - - if "shape" in kwargs: - kwargs.pop("shape") - warnings.warn( - "The shape argument in MatrixNormal is deprecated and will be ignored." - "MatrixNormal automatically derives the shape" - "from row and column matrix dimensions.", - FutureWarning, - ) - # Among-row matrices if len([i for i in [rowcov, rowchol] if i is not None]) != 1: raise ValueError( @@ -1788,22 +1775,16 @@ def dist( raise ValueError("colchol must be two dimensional.") colchol_cov = at.as_tensor_variable(colchol) - dist_shape = (rowchol_cov.shape[0], colchol_cov.shape[0]) + dist_shape = (rowchol_cov.shape[-1], colchol_cov.shape[-1]) # Broadcasting mu mu = at.extra_ops.broadcast_to(mu, shape=dist_shape) - mu = at.as_tensor_variable(floatX(mu)) - # mean = median = mode = mu return super().dist([mu, rowchol_cov, colchol_cov], **kwargs) def get_moment(rv, size, mu, rowchol, colchol): - output_shape = (rowchol.shape[0], colchol.shape[0]) - if not rv_size_is_none(size): - output_shape = at.concatenate([size, output_shape]) - moment = at.full(output_shape, mu) - return moment + return at.full_like(rv, mu) def logp(value, mu, rowchol, colchol): """ @@ -1848,14 +1829,11 @@ def logp(value, mu, rowchol, colchol): class KroneckerNormalRV(RandomVariable): name = "kroneckernormal" - ndim_supp = 2 + ndim_supp = 1 ndims_params = [1, 0, 2] dtype = "floatX" _print_name = ("KroneckerNormal", "\\operatorname{KroneckerNormal}") - def _shape_from_params(self, dist_params, rep_param_idx=0, param_shapes=None): - return default_shape_from_params(1, dist_params, rep_param_idx, param_shapes) - def rng_fn(self, rng, mu, sigma, *covs, size=None): size = size if size else covs[-1] covs = covs[:-1] if covs[-1] == size else covs @@ -1984,7 +1962,6 @@ def dist(cls, mu, covs=None, chols=None, evds=None, sigma=None, *args, **kwargs) mu = at.as_tensor_variable(mu) - # mean = median = mode = mu return super().dist([mu, sigma, *covs], **kwargs) def get_moment(rv, size, mu, covs, chols, evds): @@ -2070,7 +2047,7 @@ def make_node(self, rng, size, dtype, mu, W, alpha, tau): return super().make_node(rng, size, dtype, mu, W, alpha, tau) def _infer_shape(self, size, dist_params, param_shapes=None): - shape = tuple(size) + tuple(dist_params[0].shape) + shape = tuple(size) + (dist_params[0].shape[-1],) return shape @classmethod @@ -2105,7 +2082,7 @@ def rng_fn(cls, rng: np.random.RandomState, mu, W, alpha, tau, size): size = tuple(size or ()) if size: - mu = np.broadcast_to(mu, size + mu.shape) + mu = np.broadcast_to(mu, size + (mu.shape[-1],)) z = rng.normal(size=mu.shape) samples = np.empty(z.shape) for idx in np.ndindex(mu.shape[:-1]): @@ -2165,11 +2142,7 @@ def dist(cls, mu, W, alpha, tau, *args, **kwargs): return super().dist([mu, W, alpha, tau], **kwargs) def get_moment(rv, size, mu, W, alpha, tau): - moment = mu - if not rv_size_is_none(size): - moment_size = at.concatenate([size, moment.shape]) - moment = at.full(moment_size, mu) - return moment + return at.full_like(rv, mu) def logp(value, mu, W, alpha, tau): """ diff --git a/pymc/distributions/shape_utils.py b/pymc/distributions/shape_utils.py index 7f8ea91cbe..9bcb31a331 100644 --- a/pymc/distributions/shape_utils.py +++ b/pymc/distributions/shape_utils.py @@ -18,10 +18,7 @@ samples from probability distributions for stochastic nodes in PyMC. """ -import warnings - -from functools import partial -from typing import TYPE_CHECKING, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Optional, Sequence, Tuple, Union, cast import numpy as np @@ -29,8 +26,7 @@ from aesara.tensor.var import TensorVariable from typing_extensions import TypeAlias -from pymc.aesaraf import change_rv_size, pandas_to_array -from pymc.exceptions import ShapeError, ShapeWarning +from pymc.aesaraf import pandas_to_array __all__ = [ "to_tuple", @@ -524,19 +520,22 @@ def resize_from_dims(dims: WeakDims, ndim_implied: int, model) -> Tuple[StrongSi # We don't have a way to know the names of implied # dimensions, so they will be `None`. dims = (*dims[:-1], *[None] * ndim_implied) + sdims = cast(StrongDims, dims) - ndim_resize = len(dims) - ndim_implied + ndim_resize = len(sdims) - ndim_implied # All resize dims must be known already (numerically or symbolically). - unknowndim_resize_dims = set(dims[:ndim_resize]) - set(model.dim_lengths) + unknowndim_resize_dims = set(sdims[:ndim_resize]) - set(model.dim_lengths) if unknowndim_resize_dims: raise KeyError( f"Dimensions {unknowndim_resize_dims} are unknown to the model and cannot be used to specify a `size`." ) # The numeric/symbolic resize tuple can be created using model.RV_dim_lengths - resize_shape = tuple(model.dim_lengths[dname] for dname in dims[:ndim_resize]) - return resize_shape, dims + resize_shape: Tuple[Variable, ...] = tuple( + model.dim_lengths[dname] for dname in sdims[:ndim_resize] + ) + return resize_shape, sdims def resize_from_observed( @@ -565,26 +564,30 @@ def resize_from_observed( return resize_shape, observed -def find_size(shape=None, size=None, ndim_supp=None): +def find_size( + shape: Optional[WeakShape], + size: Optional[StrongSize], + ndim_supp: int, +) -> Tuple[Optional[StrongSize], Optional[int], Optional[int], int]: """Determines the size keyword argument for creating a Distribution. Parameters ---------- - shape : tuple + shape A tuple specifying the final shape of a distribution - size : tuple + size A tuple specifying the size of a distribution ndim_supp : int The support dimension of the distribution. - 0 if a univariate distribution, 1 if a multivariate distribution. + 0 if a univariate distribution, 1 or higher for multivariate distributions. Returns ------- - create_size : int + create_size : int, optional The size argument to be passed to the distribution - ndim_expected : int + ndim_expected : int, optional Number of dimensions expected after distribution was created - ndim_batch : int + ndim_batch : int, optional Number of batch dimensions ndim_supp : int Number of support dimensions @@ -613,84 +616,6 @@ def find_size(shape=None, size=None, ndim_supp=None): return create_size, ndim_expected, ndim_batch, ndim_supp -def maybe_resize( - rv_out, - rv_op, - dist_params, - ndim_expected, - ndim_batch, - ndim_supp, - shape, - size, - *, - change_rv_size_fn=partial(change_rv_size, expand=True), - **kwargs, -): - """Resize a distribution if necessary. - - Parameters - ---------- - rv_out : RandomVariable - The RandomVariable to be resized if necessary - rv_op : RandomVariable.__class__ - The RandomVariable class to recreate it - dist_params : dict - Input parameters to recreate the RandomVariable - ndim_expected : int - Number of dimensions expected after distribution was created - ndim_batch : int - Number of batch dimensions - ndim_supp : int - The support dimension of the distribution. - 0 if a univariate distribution, 1 if a multivariate distribution. - shape : tuple - A tuple specifying the final shape of a distribution - size : tuple - A tuple specifying the size of a distribution - change_rv_size_fn: callable - A function that returns an equivalent RV with a different size - - Returns - ------- - rv_out : int - The size argument to be passed to the distribution - """ - ndim_actual = rv_out.ndim - ndims_unexpected = ndim_actual != ndim_expected - - if shape is not None and ndims_unexpected: - if Ellipsis in shape: - # Resize and we're done! - rv_out = change_rv_size_fn(rv_var=rv_out, new_size=shape[:-1]) - else: - # This is rare, but happens, for example, with MvNormal(np.ones((2, 3)), np.eye(3), shape=(2, 3)). - # Recreate the RV without passing `size` to created it with just the implied dimensions. - rv_out = rv_op(*dist_params, size=None, **kwargs) - - # Now resize by any remaining "extra" dimensions that were not implied from support and parameters - if rv_out.ndim < ndim_expected: - expand_shape = shape[: ndim_expected - rv_out.ndim] - rv_out = change_rv_size_fn(rv_var=rv_out, new_size=expand_shape) - if not rv_out.ndim == ndim_expected: - raise ShapeError( - f"Failed to create the RV with the expected dimensionality. " - f"This indicates a severe problem. Please open an issue.", - actual=ndim_actual, - expected=ndim_batch + ndim_supp, - ) - - # Warn about the edge cases where the RV Op creates more dimensions than - # it should based on `size` and `RVOp.ndim_supp`. - if size is not None and ndims_unexpected: - warnings.warn( - f"You may have expected a ({len(tuple(size))}+{ndim_supp})-dimensional RV, but the resulting RV will be {ndim_actual}-dimensional." - ' To silence this warning use `warnings.simplefilter("ignore", pm.ShapeWarning)`.', - ShapeWarning, - ) - - return rv_out - - def rv_size_is_none(size: Variable) -> bool: """Check wether an rv size is None (ie., at.Constant([]))""" return isinstance(size, Constant) and size.data.size == 0 diff --git a/pymc/tests/test_distributions.py b/pymc/tests/test_distributions.py index bad5f444e9..67a9137620 100644 --- a/pymc/tests/test_distributions.py +++ b/pymc/tests/test_distributions.py @@ -21,6 +21,7 @@ import numpy.random as nr from aeppl.logprob import ParameterValueError +from aesara.tensor.random.utils import broadcast_params from pymc.distributions.continuous import get_tau_sigma from pymc.util import UNSET @@ -2122,40 +2123,6 @@ def test_dirichlet_invalid(self): valid_dist = Dirichlet.dist(a=[1, 1, 1]) assert np.all(np.isfinite(pm.logp(valid_dist, value).eval()) == np.array([True, False])) - @pytest.mark.parametrize( - "value,alpha,K,logp", - [ - (np.array([5, 4, 3, 2, 1]) / 15, 0.5, 4, 1.5126301307277439), - (np.tile(1, 13) / 13, 2, 12, 13.980045245672827), - (np.array([0.001] * 10 + [0.99]), 0.1, 10, -22.971662448814723), - (np.append(0.5 ** np.arange(1, 20), 0.5**20), 5, 19, 94.20462772778092), - ( - (np.array([[7, 5, 3, 2], [19, 17, 13, 11]]) / np.array([[17], [60]])), - 2.5, - 3, - np.array([1.29317672, 1.50126157]), - ), - ], - ) - def test_stickbreakingweights_logp(self, value, alpha, K, logp): - with Model() as model: - sbw = StickBreakingWeights("sbw", alpha=alpha, K=K, transform=None) - pt = {"sbw": value} - assert_almost_equal( - pm.logp(sbw, value).eval(), - logp, - decimal=select_by_precision(float64=6, float32=2), - err_msg=str(pt), - ) - - def test_stickbreakingweights_invalid(self): - sbw = pm.StickBreakingWeights.dist(3.0, 3) - sbw_wrong_K = pm.StickBreakingWeights.dist(3.0, 7) - assert pm.logp(sbw, np.array([0.4, 0.3, 0.2, 0.15])).eval() == -np.inf - assert pm.logp(sbw, np.array([1.1, 0.3, 0.2, 0.1])).eval() == -np.inf - assert pm.logp(sbw, np.array([0.4, 0.3, 0.2, -0.1])).eval() == -np.inf - assert pm.logp(sbw_wrong_K, np.array([0.4, 0.3, 0.2, 0.1])).eval() == -np.inf - @pytest.mark.parametrize( "a", [ @@ -2164,9 +2131,10 @@ def test_stickbreakingweights_invalid(self): (np.abs(np.random.randn(2, 2, 4)) + 1), ], ) - @pytest.mark.parametrize("size", [2, (1, 2), (2, 4, 3)]) - def test_dirichlet_vectorized(self, a, size): + @pytest.mark.parametrize("extra_size", [(2,), (1, 2), (2, 4, 3)]) + def test_dirichlet_vectorized(self, a, extra_size): a = floatX(np.array(a)) + size = extra_size + a.shape[:-1] dir = pm.Dirichlet.dist(a=a, size=size) vals = dir.eval() @@ -2234,12 +2202,15 @@ def test_multinomial_p_not_normalized_symbolic(self): (np.abs(np.random.randn(2, 2, 4))), ], ) - @pytest.mark.parametrize("size", [1, 2, (2, 3)]) - def test_multinomial_vectorized(self, n, p, size): + @pytest.mark.parametrize("extra_size", [(1,), (2,), (2, 3)]) + def test_multinomial_vectorized(self, n, p, extra_size): n = intX(np.array(n)) p = floatX(np.array(p)) p /= p.sum(axis=-1, keepdims=True) + _, bcast_p = broadcast_params([n, p], ndims_params=[0, 1]) + size = extra_size + bcast_p.shape[:-1] + mn = pm.Multinomial.dist(n=n, p=p, size=size) vals = mn.eval() @@ -2303,11 +2274,14 @@ def test_dirichlet_multinomial_matches_beta_binomial(self): (np.abs(np.random.randn(2, 2, 4))), ], ) - @pytest.mark.parametrize("size", [1, 2, (2, 3)]) - def test_dirichlet_multinomial_vectorized(self, n, a, size): + @pytest.mark.parametrize("extra_size", [(1,), (2,), (2, 3)]) + def test_dirichlet_multinomial_vectorized(self, n, a, extra_size): n = intX(np.array(n)) a = floatX(np.array(a)) + _, bcast_a = broadcast_params([n, a], ndims_params=[0, 1]) + size = extra_size + bcast_a.shape[:-1] + dm = pm.DirichletMultinomial.dist(n=n, a=a, size=size) vals = dm.eval() @@ -2318,6 +2292,40 @@ def test_dirichlet_multinomial_vectorized(self, n, a, size): err_msg=f"vals={vals}", ) + @pytest.mark.parametrize( + "value,alpha,K,logp", + [ + (np.array([5, 4, 3, 2, 1]) / 15, 0.5, 4, 1.5126301307277439), + (np.tile(1, 13) / 13, 2, 12, 13.980045245672827), + (np.array([0.001] * 10 + [0.99]), 0.1, 10, -22.971662448814723), + (np.append(0.5 ** np.arange(1, 20), 0.5**20), 5, 19, 94.20462772778092), + ( + (np.array([[7, 5, 3, 2], [19, 17, 13, 11]]) / np.array([[17], [60]])), + 2.5, + 3, + np.array([1.29317672, 1.50126157]), + ), + ], + ) + def test_stickbreakingweights_logp(self, value, alpha, K, logp): + with Model() as model: + sbw = StickBreakingWeights("sbw", alpha=alpha, K=K, transform=None) + pt = {"sbw": value} + assert_almost_equal( + pm.logp(sbw, value).eval(), + logp, + decimal=select_by_precision(float64=6, float32=2), + err_msg=str(pt), + ) + + def test_stickbreakingweights_invalid(self): + sbw = pm.StickBreakingWeights.dist(3.0, 3) + sbw_wrong_K = pm.StickBreakingWeights.dist(3.0, 7) + assert pm.logp(sbw, np.array([0.4, 0.3, 0.2, 0.15])).eval() == -np.inf + assert pm.logp(sbw, np.array([1.1, 0.3, 0.2, 0.1])).eval() == -np.inf + assert pm.logp(sbw, np.array([0.4, 0.3, 0.2, -0.1])).eval() == -np.inf + assert pm.logp(sbw_wrong_K, np.array([0.4, 0.3, 0.2, 0.1])).eval() == -np.inf + @aesara.config.change_flags(compute_test_value="raise") def test_categorical_bounds(self): with Model(): @@ -3349,6 +3357,15 @@ def test_censored_invalid_dist(self): ): x = pm.Censored("x", registered_dist, lower=None, upper=None) + def test_change_size(self): + base_dist = pm.Censored.dist(pm.Normal.dist(), -1, 1, size=(3, 2)) + + new_dist = pm.Censored.change_size(base_dist, (4,)) + assert new_dist.eval().shape == (4,) + + new_dist = pm.Censored.change_size(base_dist, (4,), expand=True) + assert new_dist.eval().shape == (4, 3, 2) + class TestLKJCholeskCov: def test_dist(self): diff --git a/pymc/tests/test_distributions_moments.py b/pymc/tests/test_distributions_moments.py index fd38e43b0f..f9def30e9c 100644 --- a/pymc/tests/test_distributions_moments.py +++ b/pymc/tests/test_distributions_moments.py @@ -161,7 +161,8 @@ def assert_moment_is_expected(model, expected, check_finite_logp=True): except NotImplementedError: random_draw = moment - assert moment.shape == expected.shape == random_draw.shape + assert moment.shape == expected.shape + assert expected.shape == random_draw.shape assert np.allclose(moment, expected) if check_finite_logp: @@ -788,7 +789,7 @@ def test_discrete_weibull_moment(q, beta, size, expected): ), ( np.array([[1, 2, 3], [5, 6, 7]]), - 7, + (7, 2), np.apply_along_axis( lambda x: np.divide(x, np.array([6, 18])), 1, @@ -797,7 +798,7 @@ def test_discrete_weibull_moment(q, beta, size, expected): ), ( np.full(shape=np.array([7, 3]), fill_value=np.array([13, 17, 19])), - (11, 5), + (11, 5, 7), np.broadcast_to([13, 17, 19], shape=[11, 5, 7, 3]) / 49, ), ], @@ -939,7 +940,7 @@ def test_interpolated_moment(x_points, pdf_points, size, expected): ( np.array([[3.0, 5], [1, 4]]), np.identity(2), - (4, 5), + (4, 5, 2), np.full((4, 5, 2, 2), [[3.0, 5], [1, 4]]), ), ], @@ -964,7 +965,7 @@ def test_mv_normal_moment(mu, cov, size, expected): (np.array([1, 0, 3.0, 4]), (5, 3), np.full((5, 3, 4), [1, 0, 3.0, 4])), ( np.array([[3.0, 5, 2, 1], [1, 4, 0.5, 9]]), - (4, 5), + (4, 5, 2), np.full((4, 5, 2, 4), [[3.0, 5, 2, 1], [1, 4, 0.5, 9]]), ), ], @@ -1007,8 +1008,8 @@ def test_moyal_moment(mu, sigma, size, expected): (2, rand1d, np.eye(2), 2, np.full((2, 2), rand1d)), (2, rand1d, np.eye(2), (2, 5), np.full((2, 5, 2), rand1d)), (2, rand2d, np.eye(3), None, rand2d), - (2, rand2d, np.eye(3), 2, np.full((2, 2, 3), rand2d)), - (2, rand2d, np.eye(3), (2, 5), np.full((2, 5, 2, 3), rand2d)), + (2, rand2d, np.eye(3), (2, 2), np.full((2, 2, 3), rand2d)), + (2, rand2d, np.eye(3), (2, 5, 2), np.full((2, 5, 2, 3), rand2d)), ], ) def test_mvstudentt_moment(nu, mu, cov, size, expected): @@ -1019,11 +1020,6 @@ def test_mvstudentt_moment(nu, mu, cov, size, expected): assert_moment_is_expected(model, expected, check_finite_logp=x.ndim < 3) -def check_matrixnormal_moment(mu, rowchol, colchol, size, expected): - with Model() as model: - MatrixNormal("x", mu=mu, rowchol=rowchol, colchol=colchol, size=size) - - @pytest.mark.parametrize( "alpha, mu, sigma, size, expected", [ @@ -1091,11 +1087,12 @@ def test_asymmetriclaplace_moment(b, kappa, mu, size, expected): ], ) def test_matrixnormal_moment(mu, rowchol, colchol, size, expected): - if size is None: - check_matrixnormal_moment(mu, rowchol, colchol, size, expected) - else: - with pytest.raises(NotImplementedError): - check_matrixnormal_moment(mu, rowchol, colchol, size, expected) + with Model() as model: + x = MatrixNormal("x", mu=mu, rowchol=rowchol, colchol=colchol, size=size) + + # MatrixNormal logp is only implemented for 2d values + check_logp = x.ndim == 2 + assert_moment_is_expected(model, expected, check_finite_logp=check_logp) @pytest.mark.parametrize( @@ -1325,7 +1322,7 @@ def test_polyagamma_moment(h, z, size, expected): ( np.array([[0.25, 0.25, 0.25, 0.25], [0.26, 0.26, 0.26, 0.22]]), np.array([1, 10]), - 2, + (2, 2), np.full((2, 2, 4), [[1, 0, 0, 0], [2, 3, 3, 2]]), ), ], @@ -1418,7 +1415,7 @@ def normal_sim(rng, mu, sigma, size): ), ], ) -def test_kronecker_normal_moments(mu, covs, size, expected): +def test_kronecker_normal_moment(mu, covs, size, expected): with Model() as model: KroneckerNormal("x", mu=mu, covs=covs, size=size) assert_moment_is_expected(model, expected) @@ -1475,7 +1472,7 @@ def test_lkjcholeskycov_moment(n, eta, size, expected): ( np.array([[26, 26, 26, 22]]), # Dim: 1 x 4 np.array([[1], [10]]), # Dim: 2 x 1 - (2, 1), + (2, 1, 2, 1), np.full( (2, 1, 2, 1, 4), np.array([[[1, 0, 0, 0]], [[2, 3, 3, 2]]]), # Dim: 2 x 1 x 4 diff --git a/pymc/tests/test_distributions_random.py b/pymc/tests/test_distributions_random.py index 51207ae0c8..0d9177f8f3 100644 --- a/pymc/tests/test_distributions_random.py +++ b/pymc/tests/test_distributions_random.py @@ -375,7 +375,8 @@ def check_rv_size(self): pymc_rv = self.pymc_dist.dist(**self.pymc_dist_params, size=size) expected_symbolic = tuple(pymc_rv.shape.eval()) actual = pymc_rv.eval().shape - assert actual == expected_symbolic == expected + assert actual == expected_symbolic + assert expected_symbolic == expected # test multi-parameters sampling for univariate distributions (with univariate inputs) if ( @@ -1196,8 +1197,8 @@ def test_issue_3706(self): class TestMvStudentTCov(BaseTestDistributionRandom): def mvstudentt_rng_fn(self, size, nu, mu, cov, rng): - chi2_samples = rng.chisquare(nu, size=size) mv_samples = rng.multivariate_normal(np.zeros_like(mu), cov, size=size) + chi2_samples = rng.chisquare(nu, size=size) return (mv_samples / np.sqrt(chi2_samples[:, None] / nu)) + mu pymc_dist = pm.MvStudentT @@ -1309,41 +1310,6 @@ class TestDirichlet(BaseTestDistributionRandom): ] -class TestStickBreakingWeights(BaseTestDistributionRandom): - pymc_dist = pm.StickBreakingWeights - pymc_dist_params = {"alpha": 2.0, "K": 19} - expected_rv_op_params = {"alpha": 2.0, "K": 19} - sizes_to_check = [None, 17, (5,), (11, 5), (3, 13, 5)] - sizes_expected = [ - (20,), - (17, 20), - ( - 5, - 20, - ), - (11, 5, 20), - (3, 13, 5, 20), - ] - checks_to_run = [ - "check_pymc_params_match_rv_op", - "check_rv_size", - "check_basic_properties", - ] - - def check_basic_properties(self): - default_rng = aesara.shared(np.random.default_rng(1234)) - draws = pm.StickBreakingWeights.dist( - alpha=3.5, - K=19, - size=(2, 3, 5), - rng=default_rng, - ).eval() - - assert np.allclose(draws.sum(-1), 1) - assert np.all(draws >= 0) - assert np.all(draws <= 1) - - class TestMultinomial(BaseTestDistributionRandom): pymc_dist = pm.Multinomial pymc_dist_params = {"n": 85, "p": np.array([0.28, 0.62, 0.10])} @@ -1379,7 +1345,7 @@ def check_random_draws(self): draws = pm.DirichletMultinomial.dist( n=np.array([5, 100]), a=np.array([[0.001, 0.001, 0.001, 1000], [1000, 1000, 0.001, 0.001]]), - size=(2, 3), + size=(2, 3, 2), rng=default_rng, ).eval() assert np.all(draws.sum(-1) == np.array([5, 100])) @@ -1395,11 +1361,46 @@ class TestDirichletMultinomial_1D_n_2D_a(BaseTestDistributionRandom): "n": np.array([23, 29]), "a": np.array([[0.25, 0.25, 0.25, 0.25], [0.25, 0.25, 0.25, 0.25]]), } - sizes_to_check = [None, 1, (4,), (3, 4)] + sizes_to_check = [None, (1, 2), (4, 2), (3, 4, 2)] sizes_expected = [(2, 4), (1, 2, 4), (4, 2, 4), (3, 4, 2, 4)] checks_to_run = ["check_rv_size"] +class TestStickBreakingWeights(BaseTestDistributionRandom): + pymc_dist = pm.StickBreakingWeights + pymc_dist_params = {"alpha": 2.0, "K": 19} + expected_rv_op_params = {"alpha": 2.0, "K": 19} + sizes_to_check = [None, 17, (5,), (11, 5), (3, 13, 5)] + sizes_expected = [ + (20,), + (17, 20), + ( + 5, + 20, + ), + (11, 5, 20), + (3, 13, 5, 20), + ] + checks_to_run = [ + "check_pymc_params_match_rv_op", + "check_rv_size", + "check_basic_properties", + ] + + def check_basic_properties(self): + default_rng = aesara.shared(np.random.default_rng(1234)) + draws = pm.StickBreakingWeights.dist( + alpha=3.5, + K=19, + size=(2, 3, 5), + rng=default_rng, + ).eval() + + assert np.allclose(draws.sum(-1), 1) + assert np.all(draws >= 0) + assert np.all(draws <= 1) + + class TestCategorical(BaseTestDistributionRandom): pymc_dist = pm.Categorical pymc_dist_params = {"p": np.array([0.28, 0.62, 0.10])} @@ -1783,8 +1784,25 @@ def wishart_rng_fn(self, size, nu, V, rng): "check_rv_size", "check_pymc_params_match_rv_op", "check_pymc_draws_match_reference", + "check_rv_size_batched_params", ] + def check_rv_size_batched_params(self): + for size in (None, (2,), (1, 2), (4, 3, 2)): + x = pm.Wishart.dist(nu=4, V=np.stack([np.eye(3), np.eye(3)]), size=size) + + if size is None: + expected_shape = (2, 3, 3) + else: + expected_shape = size + (3, 3) + + assert tuple(x.shape.eval()) == expected_shape + + # RNG does not currently support batched parameters, whet it does this test + # should be updated to check that draws also have the expected shape + with pytest.raises(ValueError): + x.eval() + class TestMatrixNormal(BaseTestDistributionRandom): @@ -1793,13 +1811,15 @@ class TestMatrixNormal(BaseTestDistributionRandom): mu = np.random.random((3, 3)) row_cov = np.eye(3) col_cov = np.eye(3) - shape = None - size = None pymc_dist_params = {"mu": mu, "rowcov": row_cov, "colcov": col_cov} expected_rv_op_params = {"mu": mu, "rowcov": row_cov, "colcov": col_cov} + sizes_to_check = (None, (1,), (2, 4)) + sizes_expected = [(3, 3), (1, 3, 3), (2, 4, 3, 3)] + checks_to_run = [ "check_pymc_params_match_rv_op", + "check_rv_size", "check_draws", "check_errors", "check_random_variable_prior", @@ -1840,17 +1860,6 @@ def ref_rand(mu, rowcov, colcov): assert p > delta def check_errors(self): - msg = "MatrixNormal doesn't support size argument" - with pm.Model(): - with pytest.raises(NotImplementedError, match=msg): - matrixnormal = pm.MatrixNormal( - "matnormal", - mu=np.random.random((3, 3)), - rowcov=np.eye(3), - colcov=np.eye(3), - size=15, - ) - with pm.Model(): matrixnormal = pm.MatrixNormal( "matnormal", @@ -1861,16 +1870,6 @@ def check_errors(self): with pytest.raises(ValueError): logp(matrixnormal, aesara.tensor.ones((3, 3, 3))) - with pm.Model(): - with pytest.warns(FutureWarning): - matrixnormal = pm.MatrixNormal( - "matnormal", - mu=np.random.random((3, 3)), - rowcov=np.eye(3), - colcov=np.eye(3), - shape=15, - ) - def check_random_variable_prior(self): """ This test checks for shape correctness when using MatrixNormal distribution diff --git a/pymc/tests/test_shape_handling.py b/pymc/tests/test_shape_handling.py index edcdb7af08..5cfb481f20 100644 --- a/pymc/tests/test_shape_handling.py +++ b/pymc/tests/test_shape_handling.py @@ -31,7 +31,6 @@ shapes_broadcasting, to_tuple, ) -from pymc.exceptions import ShapeWarning test_shapes = [ (tuple(), (1,), (4,), (5, 4)), @@ -323,7 +322,7 @@ def test_simultaneous_size_and_dims(self, with_dims_ellipsis): assert "ddata" in pmodel.dim_lengths # Size does not include support dims, so this test must use a dist with support dims. - kwargs = dict(name="y", size=2, mu=at.ones((3, 4)), cov=at.eye(4)) + kwargs = dict(name="y", size=(2, 3), mu=at.ones((3, 4)), cov=at.eye(4)) if with_dims_ellipsis: y = pm.MvNormal(**kwargs, dims=("dsize", ...)) assert pmodel.RV_dims["y"] == ("dsize", None, None) @@ -434,17 +433,11 @@ def test_mvnormal_shape_size_difference(self): assert rv.ndim == 5 assert tuple(rv.shape.eval()) == (6, 5, 4, 3, 2) - with pytest.warns(None): - rv = pm.MvNormal.dist(mu=[1, 2, 3], cov=np.eye(3), size=(5, 4)) - assert tuple(rv.shape.eval()) == (5, 4, 3) + rv = pm.MvNormal.dist(mu=[1, 2, 3], cov=np.eye(3), size=(5, 4)) + assert tuple(rv.shape.eval()) == (5, 4, 3) - # When using `size` the API behaves like Aesara/NumPy - with pytest.warns( - ShapeWarning, - match=r"You may have expected a \(2\+1\)-dimensional RV, but the resulting RV will be 5-dimensional", - ): - rv = pm.MvNormal.dist(mu=np.ones((5, 4, 3)), cov=np.eye(3), size=(5, 4)) - assert tuple(rv.shape.eval()) == (5, 4, 5, 4, 3) + rv = pm.MvNormal.dist(mu=np.ones((5, 4, 3)), cov=np.eye(3), size=(5, 4)) + assert tuple(rv.shape.eval()) == (5, 4, 3) def test_convert_dims(self): assert convert_dims(dims="town") == ("town",) diff --git a/pymc/util.py b/pymc/util.py index f218368854..0e2765b61d 100644 --- a/pymc/util.py +++ b/pymc/util.py @@ -14,7 +14,7 @@ import functools -from typing import Dict, List, Tuple, Union +from typing import Dict, List, Tuple, Union, cast import arviz import cloudpickle diff --git a/requirements-dev.txt b/requirements-dev.txt index 55568d4a27..904ecaf37e 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -2,7 +2,7 @@ # See that file for comments about the need/usage of each dependency. aeppl==0.0.26 -aesara==2.3.8 +aesara==2.4.0 arviz>=0.11.4 cachetools>=4.2.1 cloudpickle @@ -18,7 +18,7 @@ pre-commit>=2.8.0 pydata-sphinx-theme pytest-cov>=2.5 pytest>=3.0 -scipy>=1.4.1,<1.8.0 +scipy>=1.4.1 sphinx-copybutton sphinx-design sphinx-notfound-page diff --git a/requirements.txt b/requirements.txt index 534d3270a6..59ffb7a331 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,10 @@ aeppl==0.0.26 -aesara==2.3.8 +aesara==2.4.0 arviz>=0.11.4 cachetools>=4.2.1 cloudpickle fastprogress>=0.2.0 numpy>=1.15.0 pandas>=0.24.0 -scipy>=1.4.1,<1.8.0 +scipy>=1.4.1 typing-extensions>=3.7.4