Skip to content

Commit cf95a78

Browse files
Update posterior predictive sampling and notebook (#5268)
* improve idata-sample_posterior_predictive integration and add thinning example to docstring * remove plot_posterior_predictive_glm (arviz.plot_lm should be used now) Co-authored-by: Michael Osthege <[email protected]>
1 parent c200802 commit cf95a78

15 files changed

+4092
-510
lines changed

.github/workflows/pytest.yml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ jobs:
4646
--ignore=pymc/tests/test_dist_math.py
4747
--ignore=pymc/tests/test_minibatches.py
4848
--ignore=pymc/tests/test_pickling.py
49-
--ignore=pymc/tests/test_plots.py
5049
--ignore=pymc/tests/test_updates.py
5150
--ignore=pymc/tests/test_gp.py
5251
--ignore=pymc/tests/test_model.py
@@ -68,7 +67,6 @@ jobs:
6867
pymc/tests/test_dist_math.py
6968
pymc/tests/test_minibatches.py
7069
pymc/tests/test_pickling.py
71-
pymc/tests/test_plots.py
7270
pymc/tests/test_updates.py
7371
pymc/tests/test_transforms.py
7472

docs/source/learn/examples/posterior_predictive.ipynb

Lines changed: 3939 additions & 222 deletions
Large diffs are not rendered by default.

pymc/backends/arviz.py

Lines changed: 29 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,6 @@ def __init__(
138138
dims: Optional[DimSpec] = None,
139139
model=None,
140140
save_warmup: Optional[bool] = None,
141-
density_dist_obs: bool = True,
142141
):
143142

144143
self.save_warmup = rcParams["data.save_warmup"] if save_warmup is None else save_warmup
@@ -175,29 +174,11 @@ def __init__(
175174
self.log_likelihood = log_likelihood
176175
self.predictions = predictions
177176

178-
def arbitrary_element(dct: Dict[Any, np.ndarray]) -> np.ndarray:
179-
return next(iter(dct.values()))
180-
181-
if trace is None:
182-
# if you have a posterior_predictive built with keep_dims,
183-
# you'll lose here, but there's nothing I can do about that.
184-
self.nchains = 1
185-
get_from = None
186-
if predictions is not None:
187-
get_from = predictions
188-
elif posterior_predictive is not None:
189-
get_from = posterior_predictive
190-
elif prior is not None:
191-
get_from = prior
192-
if get_from is None:
193-
# pylint: disable=line-too-long
194-
raise ValueError(
195-
"When constructing InferenceData must have at least"
196-
" one of trace, prior, posterior_predictive or predictions."
197-
)
198-
199-
aelem = arbitrary_element(get_from)
200-
self.ndraws = aelem.shape[0]
177+
if all(elem is None for elem in (trace, predictions, posterior_predictive, prior)):
178+
raise ValueError(
179+
"When constructing InferenceData you must pass at least"
180+
" one of trace, prior, posterior_predictive or predictions."
181+
)
201182

202183
self.coords = {**self.model.coords, **(coords or {})}
203184
self.coords = {
@@ -214,7 +195,6 @@ def arbitrary_element(dct: Dict[Any, np.ndarray]) -> np.ndarray:
214195
}
215196
self.dims = {**model_dims, **self.dims}
216197

217-
self.density_dist_obs = density_dist_obs
218198
self.observations = find_observations(self.model)
219199

220200
def split_trace(self) -> Tuple[Union[None, "MultiTrace"], Union[None, "MultiTrace"]]:
@@ -396,34 +376,35 @@ def log_likelihood_to_xarray(self):
396376
),
397377
)
398378

399-
def translate_posterior_predictive_dict_to_xarray(self, dct) -> xr.Dataset:
379+
def translate_posterior_predictive_dict_to_xarray(self, dct, kind) -> xr.Dataset:
400380
"""Take Dict of variables to numpy ndarrays (samples) and translate into dataset."""
401381
data = {}
382+
warning_vars = []
402383
for k, ary in dct.items():
403-
shape = ary.shape
404-
if shape[0] == self.nchains and shape[1] == self.ndraws:
384+
if (ary.shape[0] == self.nchains) and (ary.shape[1] == self.ndraws):
405385
data[k] = ary
406-
elif shape[0] == self.nchains * self.ndraws:
407-
data[k] = ary.reshape((self.nchains, self.ndraws, *shape[1:]))
408386
else:
409387
data[k] = np.expand_dims(ary, 0)
410-
# pylint: disable=line-too-long
411-
_log.warning(
412-
"posterior predictive variable %s's shape not compatible with number of chains and draws. "
413-
"This can mean that some draws or even whole chains are not represented.",
414-
k,
415-
)
388+
warning_vars.append(k)
389+
warnings.warn(
390+
f"The shape of variables {', '.join(warning_vars)} in {kind} group is not compatible "
391+
"with number of chains and draws. The automatic dimension naming might not have worked. "
392+
"This can also mean that some draws or even whole chains are not represented.",
393+
UserWarning,
394+
)
416395
return dict_to_dataset(data, library=pymc, coords=self.coords, dims=self.dims)
417396

418397
@requires(["posterior_predictive"])
419398
def posterior_predictive_to_xarray(self):
420399
"""Convert posterior_predictive samples to xarray."""
421-
return self.translate_posterior_predictive_dict_to_xarray(self.posterior_predictive)
400+
return self.translate_posterior_predictive_dict_to_xarray(
401+
self.posterior_predictive, "posterior_predictive"
402+
)
422403

423404
@requires(["predictions"])
424405
def predictions_to_xarray(self):
425406
"""Convert predictions (out of sample predictions) to xarray."""
426-
return self.translate_posterior_predictive_dict_to_xarray(self.predictions)
407+
return self.translate_posterior_predictive_dict_to_xarray(self.predictions, "predictions")
427408

428409
def priors_to_xarray(self):
429410
"""Convert prior samples (and if possible prior predictive too) to xarray."""
@@ -545,7 +526,6 @@ def to_inference_data(
545526
dims: Optional[DimSpec] = None,
546527
model: Optional["Model"] = None,
547528
save_warmup: Optional[bool] = None,
548-
density_dist_obs: bool = True,
549529
) -> InferenceData:
550530
"""Convert pymc data into an InferenceData object.
551531
@@ -578,9 +558,6 @@ def to_inference_data(
578558
save_warmup : bool, optional
579559
Save warmup iterations InferenceData object. If not defined, use default
580560
defined by the rcParams.
581-
density_dist_obs : bool, default True
582-
Store variables passed with ``observed`` arg to
583-
:class:`~pymc.distributions.DensityDist` in the generated InferenceData.
584561
585562
Returns
586563
-------
@@ -598,7 +575,6 @@ def to_inference_data(
598575
dims=dims,
599576
model=model,
600577
save_warmup=save_warmup,
601-
density_dist_obs=density_dist_obs,
602578
).to_inference_data()
603579

604580

@@ -620,6 +596,7 @@ def predictions_to_inference_data(
620596
predictions: Dict[str, np.ndarray]
621597
The predictions are the return value of :func:`~pymc.sample_posterior_predictive`,
622598
a dictionary of strings (variable names) to numpy ndarrays (draws).
599+
Requires the arrays to follow the convention ``chain, draw, *shape``.
623600
posterior_trace: MultiTrace
624601
This should be a trace that has been thinned appropriately for
625602
``pymc.sample_posterior_predictive``. Specifically, any variable whose shape is
@@ -648,14 +625,21 @@ def predictions_to_inference_data(
648625
raise ValueError(
649626
"Do not pass True for inplace unless passing" "an existing InferenceData as idata_orig"
650627
)
651-
new_idata = InferenceDataConverter(
628+
converter = InferenceDataConverter(
652629
trace=posterior_trace,
653630
predictions=predictions,
654631
model=model,
655632
coords=coords,
656633
dims=dims,
657634
log_likelihood=False,
658-
).to_inference_data()
635+
)
636+
if hasattr(idata_orig, "posterior"):
637+
converter.nchains = idata_orig.posterior.dims["chain"]
638+
converter.ndraws = idata_orig.posterior.dims["draw"]
639+
else:
640+
aelem = next(iter(predictions.values()))
641+
converter.nchains, converter.ndraws = aelem.shape[:2]
642+
new_idata = converter.to_inference_data()
659643
if idata_orig is None:
660644
return new_idata
661645
elif inplace:

pymc/plots/__init__.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,6 @@ def wrapped(*args, **kwargs):
5656
compareplot = alias_deprecation(az.plot_compare, alias="compareplot")
5757

5858

59-
from pymc.plots.posteriorplot import plot_posterior_predictive_glm
60-
6159
__all__ = tuple(az.plots.__all__) + (
6260
"autocorrplot",
6361
"compareplot",
@@ -67,5 +65,4 @@ def wrapped(*args, **kwargs):
6765
"energyplot",
6866
"densityplot",
6967
"pairplot",
70-
"plot_posterior_predictive_glm",
7168
)

pymc/plots/posteriorplot.py

Lines changed: 0 additions & 86 deletions
This file was deleted.

0 commit comments

Comments
 (0)