Skip to content

Update more samplers for v4 compatibility #4559

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 33 commits into from
Mar 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
21ead5e
Make Metropolis, Slice, PGBART, MetropolisMLDA use point values
brandonwillard Mar 23, 2021
8e3f4f8
Re-enable disabled tests
brandonwillard Mar 23, 2021
7b417d5
Use value vars to determine steps
brandonwillard Mar 23, 2021
e7d9051
Use size instead of shape in pymc3.tests.sampler_fixtures
brandonwillard Mar 23, 2021
87ff067
Set model-level RandomVariable seeds during sampling
brandonwillard Mar 23, 2021
0374530
Check shapes by evaluating graph with start values
brandonwillard Mar 24, 2021
4c0f9e8
Fix logpt so that transforms are always applied, when enabled
brandonwillard Mar 24, 2021
d5bcd33
Set default transform for Dirichlet
brandonwillard Mar 24, 2021
3963278
Normalize Multinomial argument
brandonwillard Mar 24, 2021
2f68e4a
Use no_transform_object in Distribution.__new__
brandonwillard Mar 24, 2021
fc123ed
Fix Interval.jacobian_det
brandonwillard Mar 24, 2021
22d36df
Fix Stickbreaking scalar condition
brandonwillard Mar 24, 2021
5ca03b9
Make Model.test_point generation transform existing RV test value
brandonwillard Mar 24, 2021
4bc94ec
Add transformed value variables to Model.named_vars
brandonwillard Mar 24, 2021
9106d94
Remove DeterministicWrapper from Deterministic
brandonwillard Mar 24, 2021
04dade5
Make sure sample_posterior_predictive doesn't use trace values for sa…
brandonwillard Mar 24, 2021
fe384ec
Set seed after loading trace in TestSaveLoad
brandonwillard Mar 24, 2021
1899b35
Make v4 compatibility changes to pymc3.tests.test_sampling
brandonwillard Mar 24, 2021
ffe978a
Make pymc3.tests.test_transforms work with None RV variables
brandonwillard Mar 24, 2021
2cbd6dc
Make find_MAP work with RaveledVars
brandonwillard Mar 24, 2021
fef06a2
Make sure start values are NumPy arrays
brandonwillard Mar 24, 2021
c65baf7
Raise NotImplementedError in Group.__init__
brandonwillard Mar 24, 2021
1c68482
Add type hints to astep methods
brandonwillard Mar 24, 2021
9f8d459
Use untransformed samples and xfail Arviz tests in BaseSampler
brandonwillard Mar 24, 2021
8499a7d
Make sure forward transformed input is a TensorVariable in TestMatche…
brandonwillard Mar 24, 2021
6ceebdf
Adjust flaky last-digit numerical requirements in TestMatchesScipy
brandonwillard Mar 24, 2021
e65aad2
Fix MvNormal quaddist_matrix parameter order
brandonwillard Mar 25, 2021
bf6cce0
Enable MvNormal tests in test_distributions
brandonwillard Mar 25, 2021
05c40b7
Factor out parameter pre-processing in TestMatchesScipy
brandonwillard Mar 25, 2021
1414e97
Apply recent xfail updates from master branch
brandonwillard Mar 25, 2021
8654f7e
Prevent dtype conversion in Aesara during testing to avoid a bug
brandonwillard Mar 25, 2021
a61c937
Fix NegativeBinomial parameterization and enable its tests
brandonwillard Mar 25, 2021
336be93
Prevent SciPy error by using float64 point in test_dirichlet_with_bat…
brandonwillard Mar 25, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 0 additions & 9 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,41 +27,32 @@ jobs:
# 6th block: These have some XFAILs
- |
--ignore=pymc3/tests/test_distribution_defaults.py
--ignore=pymc3/tests/test_distributions.py
--ignore=pymc3/tests/test_distributions_random.py
--ignore=pymc3/tests/test_distributions_timeseries.py
--ignore=pymc3/tests/test_missing.py
--ignore=pymc3/tests/test_mixture.py
--ignore=pymc3/tests/test_model_graph.py
--ignore=pymc3/tests/test_modelcontext.py
--ignore=pymc3/tests/test_models_linear.py
--ignore=pymc3/tests/test_ndarray_backend.py
--ignore=pymc3/tests/test_parallel_sampling.py
--ignore=pymc3/tests/test_posterior_predictive.py
--ignore=pymc3/tests/test_posteriors.py
--ignore=pymc3/tests/test_profile.py
--ignore=pymc3/tests/test_random.py
--ignore=pymc3/tests/test_sampling.py
--ignore=pymc3/tests/test_shared.py
--ignore=pymc3/tests/test_smc.py
--ignore=pymc3/tests/test_starting.py
--ignore=pymc3/tests/test_step.py
--ignore=pymc3/tests/test_tracetab.py
--ignore=pymc3/tests/test_transforms.py
--ignore=pymc3/tests/test_tuning.py
--ignore=pymc3/tests/test_types.py
--ignore=pymc3/tests/test_util.py
--ignore=pymc3/tests/test_variational_inference.py

--ignore=pymc3/tests/test_sampling_jax.py

--ignore=pymc3/tests/test_dist_math.py
--ignore=pymc3/tests/test_minibatches.py
--ignore=pymc3/tests/test_pickling.py
--ignore=pymc3/tests/test_plots.py
--ignore=pymc3/tests/test_special_functions.py
--ignore=pymc3/tests/test_updates.py

--ignore=pymc3/tests/test_dist_math.py
--ignore=pymc3/tests/test_examples.py
--ignore=pymc3/tests/test_glm.py
Expand Down
30 changes: 21 additions & 9 deletions pymc3/aesaraf.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List

import aesara
import numpy as np
Expand Down Expand Up @@ -222,7 +223,7 @@ def __hash__(self):
return hash(type(self))


def make_shared_replacements(vars, model):
def make_shared_replacements(point, vars, model):
"""
Makes shared replacements for all *other* variables than the ones passed.

Expand All @@ -231,6 +232,7 @@ def make_shared_replacements(vars, model):

Parameters
----------
point: dictionary mapping variable names to sample values
vars: list of variables not to make shared
model: model

Expand All @@ -239,15 +241,22 @@ def make_shared_replacements(vars, model):
Dict of variable -> new shared variable
"""
othervars = set(model.vars) - set(vars)
return {var: aesara.shared(var.tag.test_value, var.name + "_shared") for var in othervars}
return {var: aesara.shared(point[var.name], var.name + "_shared") for var in othervars}


def join_nonshared_inputs(xs, vars, shared, make_shared=False):
def join_nonshared_inputs(
point: Dict[str, np.ndarray],
xs: List[TensorVariable],
vars: List[TensorVariable],
shared,
make_shared: bool = False,
):
"""
Takes a list of aesara Variables and joins their non shared inputs into a single input.

Parameters
----------
point: a sample point
xs: list of aesara tensors
vars: list of variables to join

Expand All @@ -266,17 +275,20 @@ def join_nonshared_inputs(xs, vars, shared, make_shared=False):
tensor_type = joined.type
inarray = tensor_type("inarray")
else:
inarray = aesara.shared(joined.tag.test_value, "inarray")
if point is None:
raise ValueError("A point is required when `make_shared` is True")
joined_values = np.concatenate([point[var.name].ravel() for var in vars])
inarray = aesara.shared(joined_values, "inarray")

inarray.tag.test_value = joined.tag.test_value
if aesara.config.compute_test_value != "off":
inarray.tag.test_value = joined.tag.test_value

replace = {}
last_idx = 0
for var in vars:
arr_len = aet.prod(var.shape)
replace[var] = reshape_t(inarray[last_idx : last_idx + arr_len], var.shape).astype(
var.dtype
)
shape = point[var.name].shape
arr_len = np.prod(shape, dtype=int)
replace[var] = reshape_t(inarray[last_idx : last_idx + arr_len], shape).astype(var.dtype)
last_idx += arr_len

replace.update(shared)
Expand Down
20 changes: 11 additions & 9 deletions pymc3/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
int, np.ndarray, Tuple[Union[int, Variable], ...], List[Union[int, Variable]], Variable
]

no_transform_object = object()


@singledispatch
def logp_transform(op: Op):
Expand Down Expand Up @@ -340,17 +342,17 @@ def logpt(
else:
logp_var = _logcdf(rv_node.op, rv_var, *dist_params, **kwargs)

transform = getattr(rv_value_var.tag, "transform", None) if rv_value_var else None

if transform and transformed and not cdf:
if transformed and not cdf:
(logp_var,), _ = apply_transforms((logp_var,))

if jacobian:
transformed_jacobian = transform.jacobian_det(rv_var, rv_value)
if transformed_jacobian:
if logp_var.ndim > transformed_jacobian.ndim:
logp_var = logp_var.sum(axis=-1)
logp_var += transformed_jacobian
transform = getattr(rv_value_var.tag, "transform", None) if rv_value_var else None

if transform and transformed and not cdf and jacobian:
transformed_jacobian = transform.jacobian_det(rv_var, rv_value)
if transformed_jacobian:
if logp_var.ndim > transformed_jacobian.ndim:
logp_var = logp_var.sum(axis=-1)
logp_var += transformed_jacobian

# Replace random variables with their value variables
(logp_var,), replaced = rvs_to_value_vars((logp_var,), {rv_var: rv_value})
Expand Down
48 changes: 21 additions & 27 deletions pymc3/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,35 +727,33 @@ def NegBinom(a, m, x):

@classmethod
def dist(cls, mu=None, alpha=None, p=None, n=None, *args, **kwargs):
mu, alpha = cls.get_mu_alpha(mu, alpha, p, n)
mu = aet.as_tensor_variable(floatX(mu))
alpha = aet.as_tensor_variable(floatX(alpha))
# mode = intX(aet.floor(mu))
return super().dist([mu, alpha], *args, **kwargs)
n, p = cls.get_mu_alpha(mu, alpha, p, n)
n = aet.as_tensor_variable(floatX(n))
p = aet.as_tensor_variable(floatX(p))
return super().dist([n, p], *args, **kwargs)

@classmethod
def get_mu_alpha(cls, mu=None, alpha=None, p=None, n=None):
if alpha is None:
if n is not None:
n = aet.as_tensor_variable(intX(n))
alpha = n
if n is None:
if alpha is not None:
n = aet.as_tensor_variable(floatX(alpha))
else:
raise ValueError("Incompatible parametrization. Must specify either alpha or n.")
elif n is not None:
elif alpha is not None:
raise ValueError("Incompatible parametrization. Can't specify both alpha and n.")

if mu is None:
if p is not None:
p = aet.as_tensor_variable(floatX(p))
mu = alpha * (1 - p) / p
if p is None:
if mu is not None:
mu = aet.as_tensor_variable(floatX(mu))
p = n / (mu + n)
else:
raise ValueError("Incompatible parametrization. Must specify either mu or p.")
elif p is not None:
elif mu is not None:
raise ValueError("Incompatible parametrization. Can't specify both mu and p.")

return mu, alpha
return n, p

def logp(value, mu, alpha):
def logp(value, n, p):
r"""
Calculate log-probability of NegativeBinomial distribution at specified value.

Expand All @@ -769,6 +767,8 @@ def logp(value, mu, alpha):
-------
TensorVariable
"""
alpha = n
mu = alpha * (1 - p) / p
negbinom = bound(
binomln(value + alpha - 1, value)
+ logpow(mu / (mu + alpha), value)
Expand All @@ -779,9 +779,9 @@ def logp(value, mu, alpha):
)

# Return Poisson when alpha gets very large.
return aet.switch(aet.gt(alpha, 1e10), Poisson.dist(mu).logp(value), negbinom)
return aet.switch(aet.gt(alpha, 1e10), Poisson.logp(value, mu), negbinom)

def logcdf(value, mu, alpha):
def logcdf(value, n, p):
"""
Compute the log of the cumulative distribution function for NegativeBinomial distribution
at the specified value.
Expand All @@ -801,20 +801,14 @@ def logcdf(value, mu, alpha):
f"NegativeBinomial.logcdf expects a scalar value but received a {np.ndim(value)}-dimensional object."
)

# TODO: avoid `p` recomputation if distribution was defined in terms of `p`
p = alpha / (mu + alpha)

return bound(
aet.log(incomplete_beta(alpha, aet.floor(value) + 1, p)),
aet.log(incomplete_beta(n, aet.floor(value) + 1, p)),
0 <= value,
0 < alpha,
0 < n,
0 <= p,
p <= 1,
)

def _distr_parameters_for_repr(self):
return self._param_type


class Geometric(Discrete):
R"""
Expand Down
4 changes: 2 additions & 2 deletions pymc3/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from aesara.tensor.random.op import RandomVariable

from pymc3.distributions import _logcdf, _logp
from pymc3.distributions import _logcdf, _logp, no_transform_object

if TYPE_CHECKING:
from typing import Optional, Callable
Expand Down Expand Up @@ -161,7 +161,7 @@ def __new__(cls, name, *args, **kwargs):
if "shape" in kwargs:
raise DeprecationWarning("The `shape` keyword is deprecated; use `size`.")

transform = kwargs.pop("transform", None)
transform = kwargs.pop("transform", no_transform_object)

rv_out = cls.dist(*args, rng=rng, **kwargs)

Expand Down
10 changes: 7 additions & 3 deletions pymc3/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ class MvNormal(Continuous):
@classmethod
def dist(cls, mu, cov=None, tau=None, chol=None, lower=True, **kwargs):
mu = aet.as_tensor_variable(mu)
cov = quaddist_matrix(cov, tau, chol, lower)
cov = quaddist_matrix(cov, chol, tau, lower)
return super().dist([mu, cov], **kwargs)

def logp(value, mu, cov):
Expand Down Expand Up @@ -386,8 +386,12 @@ class Dirichlet(Continuous):

rv_op = dirichlet

def __new__(cls, name, *args, **kwargs):
kwargs.setdefault("transform", transforms.stick_breaking)
return super().__new__(cls, name, *args, **kwargs)

@classmethod
def dist(cls, a, transform=transforms.stick_breaking, **kwargs):
def dist(cls, a, **kwargs):

a = aet.as_tensor_variable(a)
# mean = a / aet.sum(a)
Expand Down Expand Up @@ -483,7 +487,7 @@ class Multinomial(Discrete):
@classmethod
def dist(cls, n, p, *args, **kwargs):

# p = p / aet.sum(p, axis=-1, keepdims=True)
p = p / aet.sum(p, axis=-1, keepdims=True)
n = aet.as_tensor_variable(n)
p = aet.as_tensor_variable(p)

Expand Down
8 changes: 4 additions & 4 deletions pymc3/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def jacobian_det(self, rv_var, rv_value):
s = aet.nnet.softplus(-rv_value)
return aet.log(b - a) - 2 * s - rv_value
else:
return aet.ones_like(rv_value)
return rv_value


interval = Interval
Expand Down Expand Up @@ -286,7 +286,7 @@ class StickBreaking(Transform):
name = "stickbreaking"

def forward(self, rv_var, rv_value):
if rv_var.ndim == 1 or rv_var.broadcastable[-1]:
if rv_var.broadcastable[-1]:
# If this variable is just a bunch of scalars/degenerate
# Dirichlets, we can't transform it
return rv_value
Expand All @@ -299,7 +299,7 @@ def forward(self, rv_var, rv_value):
return floatX(y.T)

def backward(self, rv_var, rv_value):
if rv_var.ndim == 1 or rv_var.broadcastable[-1]:
if rv_var.broadcastable[-1]:
# If this variable is just a bunch of scalars/degenerate
# Dirichlets, we can't transform it
return rv_value
Expand All @@ -312,7 +312,7 @@ def backward(self, rv_var, rv_value):
return floatX(x.T)

def jacobian_det(self, rv_var, rv_value):
if rv_var.ndim == 1 or rv_var.broadcastable[-1]:
if rv_var.broadcastable[-1]:
# If this variable is just a bunch of scalars/degenerate
# Dirichlets, we can't transform it
return aet.ones_like(rv_value)
Expand Down
Loading