Skip to content

Upgrade Aesara version pin and unpin SciPy upper limit #5474

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Feb 27, 2022
4 changes: 2 additions & 2 deletions conda-envs/environment-dev-py37.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ channels:
- defaults
dependencies:
- aeppl=0.0.26
- aesara=2.3.8
- aesara=2.4.0
- arviz>=0.11.4
- blas
- cachetools>=4.2.1
Expand All @@ -24,7 +24,7 @@ dependencies:
- pytest>=3.0
- python-graphviz
- python=3.7
- scipy>=1.4.1,<1.8.0
- scipy>=1.4.1
- sphinx-copybutton
- sphinx-notfound-page
- sphinx>=1.5
Expand Down
4 changes: 2 additions & 2 deletions conda-envs/environment-dev-py38.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ channels:
- defaults
dependencies:
- aeppl=0.0.26
- aesara=2.3.8
- aesara=2.4.0
- arviz>=0.11.4
- blas
- cachetools>=4.2.1
Expand All @@ -24,7 +24,7 @@ dependencies:
- pytest>=3.0
- python-graphviz
- python=3.8
- scipy>=1.4.1,<1.8.0
- scipy>=1.4.1
- sphinx-copybutton
- sphinx-notfound-page
- sphinx>=1.5
Expand Down
4 changes: 2 additions & 2 deletions conda-envs/environment-dev-py39.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ channels:
- defaults
dependencies:
- aeppl=0.0.26
- aesara=2.3.8
- aesara=2.4.0
- arviz>=0.11.4
- blas
- cachetools>=4.2.1
Expand All @@ -24,7 +24,7 @@ dependencies:
- pytest>=3.0
- python-graphviz
- python=3.9
- scipy>=1.4.1,<1.8.0
- scipy>=1.4.1
- sphinx-copybutton
- sphinx-notfound-page
- sphinx>=1.5
Expand Down
4 changes: 2 additions & 2 deletions conda-envs/environment-test-py37.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ channels:
- defaults
dependencies:
- aeppl=0.0.26
- aesara=2.3.8
- aesara=2.4.0
- arviz>=0.11.4
- blas
- cachetools>=4.2.1
Expand All @@ -22,5 +22,5 @@ dependencies:
- pytest>=3.0
- python-graphviz
- python=3.7
- scipy>=1.4.1,<1.8.0
- scipy>=1.4.1
- typing-extensions>=3.7.4
4 changes: 2 additions & 2 deletions conda-envs/environment-test-py38.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ channels:
- defaults
dependencies:
- aeppl=0.0.26
- aesara=2.3.8
- aesara=2.4.0
- arviz>=0.11.4
- blas
- cachetools>=4.2.1
Expand All @@ -22,5 +22,5 @@ dependencies:
- pytest>=3.0
- python-graphviz
- python=3.8
- scipy>=1.4.1,<1.8.0
- scipy>=1.4.1
- typing-extensions>=3.7.4
4 changes: 2 additions & 2 deletions conda-envs/environment-test-py39.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ channels:
- defaults
dependencies:
- aeppl=0.0.26
- aesara=2.3.8
- aesara=2.4.0
- arviz>=0.11.4
- blas
- cachetools>=4.2.1
Expand All @@ -22,5 +22,5 @@ dependencies:
- pytest>=3.0
- python-graphviz
- python=3.9
- scipy>=1.4.1,<1.8.0
- scipy>=1.4.1
- typing-extensions>=3.7.4
4 changes: 2 additions & 2 deletions conda-envs/windows-environment-dev-py38.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ channels:
dependencies:
# base dependencies (see install guide for Windows)
- aeppl=0.0.26
- aesara=2.3.8
- aesara=2.4.0
- arviz>=0.11.4
- blas
- cachetools>=4.2.1
Expand All @@ -17,7 +17,7 @@ dependencies:
- pip
- python=3.8
- python-graphviz
- scipy>=1.4.1,<1.8.0
- scipy>=1.4.1
- typing-extensions>=3.7.4
# Extra stuff for dev, testing and docs build
- ipython>=7.16
Expand Down
4 changes: 2 additions & 2 deletions conda-envs/windows-environment-test-py38.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ channels:
dependencies:
# base dependencies (see install guide for Windows)
- aeppl=0.0.26
- aesara=2.3.8
- aesara=2.4.0
- arviz>=0.11.4
- blas
- cachetools>=4.2.1
Expand All @@ -21,7 +21,7 @@ dependencies:
- pip
- python=3.8
- python-graphviz
- scipy>=1.4.1,<1.8.0
- scipy>=1.4.1
- typing-extensions>=3.7.4
# Extra stuff for testing
- ipython>=7.16
Expand Down
6 changes: 3 additions & 3 deletions pymc/aesaraf.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,9 @@ def change_rv_size(
tag = rv_var.tag

if expand:
if rv_node.op.ndim_supp == 0 and at.get_vector_length(size) == 0:
size = rv_node.op._infer_shape(size, dist_params)
new_size = tuple(new_size) + tuple(size)
old_shape = tuple(rv_node.op._infer_shape(size, dist_params))
old_size = old_shape[: len(old_shape) - rv_node.op.ndim_supp]
new_size = tuple(new_size) + tuple(old_size)

# Make sure the new size is a tensor. This dtype-aware conversion helps
# to not unnecessarily pick up a `Cast` in some cases (see #4652).
Expand Down
3 changes: 2 additions & 1 deletion pymc/distributions/censored.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,12 @@ def ndim_supp(cls, *dist_params):
return 0

@classmethod
def change_size(cls, rv, new_size):
def change_size(cls, rv, new_size, expand=False):
dist_node = rv.tag.dist.owner
lower = rv.tag.lower
upper = rv.tag.upper
rng, old_size, dtype, *dist_params = dist_node.inputs
new_size = new_size if not expand else tuple(new_size) + tuple(old_size)
new_dist = dist_node.op.make_node(rng, new_size, dtype, *dist_params).default_output()
return cls.rv_op(new_dist, lower, upper)

Expand Down
36 changes: 11 additions & 25 deletions pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from abc import ABCMeta
from functools import singledispatch
from typing import Callable, Iterable, Optional, Sequence, Tuple, Union
from typing import Callable, Iterable, Optional, Sequence, Tuple, Union, cast

import aesara
import numpy as np
Expand All @@ -45,7 +45,6 @@
convert_shape,
convert_size,
find_size,
maybe_resize,
resize_from_dims,
resize_from_observed,
)
Expand Down Expand Up @@ -353,17 +352,11 @@ def dist(
# Create the RV with a `size` right away.
# This is not necessarily the final result.
rv_out = cls.rv_op(*dist_params, size=create_size, **kwargs)
rv_out = maybe_resize(
rv_out,
cls.rv_op,
dist_params,
ndim_expected,
ndim_batch,
ndim_supp,
shape,
size,
**kwargs,
)

# Replicate dimensions may be prepended via a shape with Ellipsis as the last element:
if shape is not None and Ellipsis in shape:
replicate_shape = cast(StrongShape, shape[:-1])
rv_out = change_rv_size(rv_var=rv_out, new_size=replicate_shape, expand=True)

rng = kwargs.pop("rng", None)
if (
Expand Down Expand Up @@ -589,18 +582,11 @@ def dist(
# Create the RV with a `size` right away.
# This is not necessarily the final result.
graph = cls.rv_op(*dist_params, size=create_size, **kwargs)
graph = maybe_resize(
graph,
cls.rv_op,
dist_params,
ndim_expected,
ndim_batch,
ndim_supp,
shape,
size,
change_rv_size_fn=cls.change_size,
**kwargs,
)

# Replicate dimensions may be prepended via a shape with Ellipsis as the last element:
if shape is not None and Ellipsis in shape:
replicate_shape = cast(StrongShape, shape[:-1])
graph = cls.change_size(rv=graph, new_size=replicate_shape, expand=True)

rngs = kwargs.pop("rngs", None)
if rngs is not None:
Expand Down
Loading