Skip to content

Commit 9105d74

Browse files
authored
Deprecate sample_posterior_predictive_w (#6254)
Closes #4807
1 parent bcffce2 commit 9105d74

File tree

2 files changed

+4
-186
lines changed

2 files changed

+4
-186
lines changed

pymc/sampling.py

+4-121
Original file line numberDiff line numberDiff line change
@@ -2065,127 +2065,10 @@ def sample_posterior_predictive_w(
20652065
weighted models (default), or a dictionary with variable names as keys, and samples as
20662066
numpy arrays.
20672067
"""
2068-
raise NotImplementedError(f"sample_posterior_predictive_w has not yet been ported to PyMC 4.0.")
2069-
2070-
if isinstance(traces[0], InferenceData):
2071-
n_samples = [
2072-
trace.posterior.sizes["chain"] * trace.posterior.sizes["draw"] for trace in traces
2073-
]
2074-
traces = [dataset_to_point_list(trace.posterior) for trace in traces]
2075-
elif isinstance(traces[0], xarray.Dataset):
2076-
n_samples = [trace.sizes["chain"] * trace.sizes["draw"] for trace in traces]
2077-
traces = [dataset_to_point_list(trace) for trace in traces]
2078-
else:
2079-
n_samples = [len(i) * i.nchains for i in traces]
2080-
2081-
if models is None:
2082-
models = [modelcontext(models)] * len(traces)
2083-
2084-
if random_seed is not None:
2085-
(random_seed,) = _get_seeds_per_chain(random_seed, 1)
2086-
2087-
for model in models:
2088-
if model.potentials:
2089-
warnings.warn(
2090-
"The effect of Potentials on other parameters is ignored during posterior predictive sampling. "
2091-
"This is likely to lead to invalid or biased predictive samples.",
2092-
UserWarning,
2093-
stacklevel=2,
2094-
)
2095-
break
2096-
2097-
if weights is None:
2098-
weights = [1] * len(traces)
2099-
2100-
if len(traces) != len(weights):
2101-
raise ValueError("The number of traces and weights should be the same")
2102-
2103-
if len(models) != len(weights):
2104-
raise ValueError("The number of models and weights should be the same")
2105-
2106-
length_morv = len(models[0].observed_RVs)
2107-
if any(len(i.observed_RVs) != length_morv for i in models):
2108-
raise ValueError("The number of observed RVs should be the same for all models")
2109-
2110-
weights = np.asarray(weights)
2111-
p = weights / np.sum(weights)
2112-
2113-
min_tr = min(n_samples)
2114-
2115-
n = (min_tr * p).astype("int")
2116-
# ensure n sum up to min_tr
2117-
idx = np.argmax(n)
2118-
n[idx] = n[idx] + min_tr - np.sum(n)
2119-
trace = []
2120-
for i, j in enumerate(n):
2121-
tr = traces[i]
2122-
len_trace = len(tr)
2123-
try:
2124-
nchain = tr.nchains
2125-
except AttributeError:
2126-
nchain = 1
2127-
2128-
indices = np.random.randint(0, nchain * len_trace, j)
2129-
if nchain > 1:
2130-
chain_idx, point_idx = np.divmod(indices, len_trace)
2131-
for cidx, pidx in zip(chain_idx, point_idx):
2132-
trace.append(tr._straces[cidx].point(pidx))
2133-
else:
2134-
for idx in indices:
2135-
trace.append(tr[idx])
2136-
2137-
obs = [x for m in models for x in m.observed_RVs]
2138-
variables = np.repeat(obs, n)
2139-
2140-
lengths = list({np.atleast_1d(observed).shape for observed in obs})
2141-
2142-
size: List[Optional[Tuple[int, ...]]] = []
2143-
if len(lengths) == 1:
2144-
size = [None] * len(variables)
2145-
elif len(lengths) > 2:
2146-
raise ValueError("Observed variables could not be broadcast together")
2147-
else:
2148-
x = np.zeros(shape=lengths[0])
2149-
y = np.zeros(shape=lengths[1])
2150-
b = np.broadcast(x, y)
2151-
for var in variables:
2152-
# XXX: This needs to be refactored
2153-
shape = None # np.shape(np.atleast_1d(var.distribution.default()))
2154-
if shape != b.shape:
2155-
size.append(b.shape)
2156-
else:
2157-
size.append(None)
2158-
len_trace = len(trace)
2159-
2160-
if samples is None:
2161-
samples = len_trace
2162-
2163-
indices = np.random.randint(0, len_trace, samples)
2164-
2165-
if progressbar:
2166-
indices = progress_bar(indices, total=samples, display=progressbar)
2167-
2168-
try:
2169-
ppcl: Dict[str, list] = defaultdict(list)
2170-
for idx in indices:
2171-
param = trace[idx]
2172-
var = variables[idx]
2173-
# TODO sample_posterior_predictive_w is currently only work for model with
2174-
# one observed.
2175-
# XXX: This needs to be refactored
2176-
# ppc[var.name].append(draw_values([var], point=param, size=size[idx])[0])
2177-
raise NotImplementedError()
2178-
2179-
except KeyboardInterrupt:
2180-
pass
2181-
else:
2182-
ppcd = {k: np.asarray(v) for k, v in ppcl.items()}
2183-
if not return_inferencedata:
2184-
return ppcd
2185-
ikwargs: Dict[str, Any] = dict(model=models)
2186-
if idata_kwargs:
2187-
ikwargs.update(idata_kwargs)
2188-
return pm.to_inference_data(posterior_predictive=ppcd, **ikwargs)
2068+
raise FutureWarning(
2069+
"The function `sample_posterior_predictive_w` has been removed in PyMC 4.3.0. "
2070+
"Switch to `arviz.stats.weight_predictions`"
2071+
)
21892072

21902073

21912074
def sample_prior_predictive(

pymc/tests/test_sampling.py

-65
Original file line numberDiff line numberDiff line change
@@ -1177,71 +1177,6 @@ def test_logging_sampled_basic_rvs_posterior_mutable(self, mock_sample_results,
11771177
caplog.clear()
11781178

11791179

1180-
@pytest.mark.xfail(
1181-
reason="sample_posterior_predictive_w not refactored for v4", raises=NotImplementedError
1182-
)
1183-
class TestSamplePPCW(SeededTest):
1184-
def test_sample_posterior_predictive_w(self):
1185-
data0 = np.random.normal(0, 1, size=50)
1186-
warning_msg = "The number of samples is too small to check convergence reliably"
1187-
1188-
with pm.Model() as model_0:
1189-
mu = pm.Normal("mu", mu=0, sigma=1)
1190-
y = pm.Normal("y", mu=mu, sigma=1, observed=data0)
1191-
with pytest.warns(UserWarning, match=warning_msg):
1192-
trace_0 = pm.sample(10, tune=0, chains=2, return_inferencedata=False)
1193-
idata_0 = pm.to_inference_data(trace_0, log_likelihood=False)
1194-
1195-
with pm.Model() as model_1:
1196-
mu = pm.Normal("mu", mu=0, sigma=1, size=len(data0))
1197-
y = pm.Normal("y", mu=mu, sigma=1, observed=data0)
1198-
with pytest.warns(UserWarning, match=warning_msg):
1199-
trace_1 = pm.sample(10, tune=0, chains=2, return_inferencedata=False)
1200-
idata_1 = pm.to_inference_data(trace_1, log_likelihood=False)
1201-
1202-
with pm.Model() as model_2:
1203-
# Model with no observed RVs.
1204-
mu = pm.Normal("mu", mu=0, sigma=1)
1205-
with pytest.warns(UserWarning, match=warning_msg):
1206-
trace_2 = pm.sample(10, tune=0, return_inferencedata=False)
1207-
1208-
traces = [trace_0, trace_1]
1209-
idatas = [idata_0, idata_1]
1210-
models = [model_0, model_1]
1211-
1212-
ppc = pm.sample_posterior_predictive_w(traces, 100, models)
1213-
assert ppc["y"].shape == (100, 50)
1214-
1215-
ppc = pm.sample_posterior_predictive_w(idatas, 100, models)
1216-
assert ppc["y"].shape == (100, 50)
1217-
1218-
with model_0:
1219-
ppc = pm.sample_posterior_predictive_w([idata_0.posterior], None)
1220-
assert ppc["y"].shape == (20, 50)
1221-
1222-
with pytest.raises(ValueError, match="The number of traces and weights should be the same"):
1223-
pm.sample_posterior_predictive_w([idata_0.posterior], 100, models, weights=[0.5, 0.5])
1224-
1225-
with pytest.raises(ValueError, match="The number of models and weights should be the same"):
1226-
pm.sample_posterior_predictive_w([idata_0.posterior], 100, models)
1227-
1228-
with pytest.raises(
1229-
ValueError, match="The number of observed RVs should be the same for all models"
1230-
):
1231-
pm.sample_posterior_predictive_w([trace_0, trace_2], 100, [model_0, model_2])
1232-
1233-
def test_potentials_warning(self):
1234-
warning_msg = "The effect of Potentials on other parameters is ignored during"
1235-
with pm.Model() as m:
1236-
a = pm.Normal("a", 0, 1)
1237-
p = pm.Potential("p", a + 1)
1238-
obs = pm.Normal("obs", a, 1, observed=5)
1239-
1240-
trace = az_from_dict({"a": np.random.rand(10)})
1241-
with pytest.warns(UserWarning, match=warning_msg):
1242-
pm.sample_posterior_predictive_w(samples=5, traces=[trace, trace], models=[m, m])
1243-
1244-
12451180
def check_exec_nuts_init(method):
12461181
with pm.Model() as model:
12471182
pm.Normal("a", mu=0, sigma=1, size=2)

0 commit comments

Comments
 (0)