Skip to content

Commit 4acd98e

Browse files
committed
Remove Model auto_deterministics
This property was initially added just to handle deterministics created by automatic imputation, in order to ensure the combined tensor of missing and observed components showed up in prior and posterior predictive sampling. At the same time, it allowed hiding the deterministic during mcmc sampling, saving memory use for large datasets. This last benefit is lost for the sake of simplicity. If a user is concerned, they can manually split the observed and missing components of a dataset when defining their model.
1 parent 48870ad commit 4acd98e

File tree

4 files changed

+88
-27
lines changed

4 files changed

+88
-27
lines changed

pymc/model.py

+4-8
Original file line numberDiff line numberDiff line change
@@ -560,7 +560,6 @@ def __init__(
560560
self.rvs_to_initial_values = treedict(parent=self.parent.rvs_to_initial_values)
561561
self.free_RVs = treelist(parent=self.parent.free_RVs)
562562
self.observed_RVs = treelist(parent=self.parent.observed_RVs)
563-
self.auto_deterministics = treelist(parent=self.parent.auto_deterministics)
564563
self.deterministics = treelist(parent=self.parent.deterministics)
565564
self.potentials = treelist(parent=self.parent.potentials)
566565
self._coords = self.parent._coords
@@ -575,7 +574,6 @@ def __init__(
575574
self.rvs_to_initial_values = treedict()
576575
self.free_RVs = treelist()
577576
self.observed_RVs = treelist()
578-
self.auto_deterministics = treelist()
579577
self.deterministics = treelist()
580578
self.potentials = treelist()
581579
self._coords = {}
@@ -1435,10 +1433,11 @@ def make_obs_var(
14351433
self.observed_RVs.append(observed_rv_var)
14361434

14371435
# Create deterministic that combines observed and missing
1436+
# Note: This can widely increase memory consumption during sampling for large datasets
14381437
rv_var = at.zeros(data.shape)
14391438
rv_var = at.set_subtensor(rv_var[mask.nonzero()], missing_rv_var)
14401439
rv_var = at.set_subtensor(rv_var[antimask_idx], observed_rv_var)
1441-
rv_var = Deterministic(name, rv_var, self, dims, auto=True)
1440+
rv_var = Deterministic(name, rv_var, self, dims)
14421441

14431442
else:
14441443
if sps.issparse(data):
@@ -1911,7 +1910,7 @@ def Point(*args, filter_model_vars=False, **kwargs) -> Dict[str, np.ndarray]:
19111910
}
19121911

19131912

1914-
def Deterministic(name, var, model=None, dims=None, auto=False):
1913+
def Deterministic(name, var, model=None, dims=None):
19151914
"""Create a named deterministic variable.
19161915
19171916
Deterministic nodes are only deterministic given all of their inputs, i.e.
@@ -1974,10 +1973,7 @@ def Deterministic(name, var, model=None, dims=None, auto=False):
19741973
"""
19751974
model = modelcontext(model)
19761975
var = var.copy(model.name_for(name))
1977-
if auto:
1978-
model.auto_deterministics.append(var)
1979-
else:
1980-
model.deterministics.append(var)
1976+
model.deterministics.append(var)
19811977
model.add_named_variable(var, dims)
19821978

19831979
from pymc.printing import str_for_potential_or_deterministic

pymc/sampling/forward.py

+22-6
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,14 @@
3535
import xarray
3636

3737
from aesara import tensor as at
38-
from aesara.graph.basic import Apply, Constant, Variable, general_toposort, walk
38+
from aesara.graph.basic import (
39+
Apply,
40+
Constant,
41+
Variable,
42+
ancestors,
43+
general_toposort,
44+
walk,
45+
)
3946
from aesara.graph.fg import FunctionGraph
4047
from aesara.tensor.random.var import (
4148
RandomGeneratorSharedVariable,
@@ -324,6 +331,18 @@ def draw(
324331
return [np.stack(v) for v in drawn_values]
325332

326333

334+
def observed_dependent_deterministics(model: Model):
335+
"""Find deterministics that depend directly on observed variables"""
336+
deterministics = model.deterministics
337+
observed_rvs = set(model.observed_RVs)
338+
blockers = model.basic_RVs
339+
return [
340+
deterministic
341+
for deterministic in deterministics
342+
if observed_rvs & set(ancestors([deterministic], blockers=blockers))
343+
]
344+
345+
327346
def sample_prior_predictive(
328347
samples: int = 500,
329348
model: Optional[Model] = None,
@@ -371,10 +390,7 @@ def sample_prior_predictive(
371390
)
372391

373392
if var_names is None:
374-
vars_: Set[str] = {
375-
var.name
376-
for var in model.basic_RVs + model.deterministics + model.auto_deterministics
377-
}
393+
vars_: Set[str] = {var.name for var in model.basic_RVs + model.deterministics}
378394
else:
379395
vars_ = set(var_names)
380396

@@ -570,7 +586,7 @@ def sample_posterior_predictive(
570586
if var_names is not None:
571587
vars_ = [model[x] for x in var_names]
572588
else:
573-
vars_ = model.observed_RVs + model.auto_deterministics
589+
vars_ = model.observed_RVs + observed_dependent_deterministics(model)
574590

575591
indices = np.arange(samples)
576592
if progressbar:

pymc/tests/sampling/test_forward.py

+17
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from pymc.sampling.forward import (
3939
compile_forward_sampling_function,
4040
get_vars_in_point_list,
41+
observed_dependent_deterministics,
4142
)
4243
from pymc.tests.helpers import SeededTest, fast_unstable_sampling_mode
4344

@@ -1621,3 +1622,19 @@ def test_get_vars_in_point_list():
16211622
trace = MultiTrace([strace])
16221623
vars_in_trace = get_vars_in_point_list(trace, modelB)
16231624
assert set(vars_in_trace) == {a}
1625+
1626+
1627+
def test_observed_dependent_deterministics():
1628+
with pm.Model() as m:
1629+
free = pm.Normal("free")
1630+
obs = pm.Normal("obs", observed=1)
1631+
1632+
det_free = pm.Deterministic("det_free", free + 1)
1633+
det_free2 = pm.Deterministic("det_free2", det_free + 1)
1634+
1635+
det_obs = pm.Deterministic("det_obs", obs + 1)
1636+
det_obs2 = pm.Deterministic("det_obs2", det_obs + 1)
1637+
1638+
det_mixed = pm.Deterministic("det_mixed", free + obs)
1639+
1640+
assert set(observed_dependent_deterministics(m)) == {det_obs, det_obs2, det_mixed}

pymc/tests/test_model.py

+45-13
Original file line numberDiff line numberDiff line change
@@ -1195,21 +1195,29 @@ def test_missing_dual_observations(self):
11951195
trace = pm.sample(chains=1, tune=5, draws=50)
11961196

11971197
def test_interval_missing_observations(self):
1198+
rng = np.random.default_rng(1198)
1199+
11981200
with pm.Model() as model:
11991201
obs1 = np.ma.masked_values([1, 2, -1, 4, -1], value=-1)
12001202
obs2 = np.ma.masked_values([-1, -1, 6, -1, 8], value=-1)
12011203

1202-
rng = aesara.shared(np.random.RandomState(2323), borrow=True)
1203-
12041204
with pytest.warns(ImputationWarning):
1205-
theta1 = pm.Uniform("theta1", 0, 5, observed=obs1, rng=rng)
1205+
theta1 = pm.Uniform("theta1", 0, 5, observed=obs1)
12061206
with pytest.warns(ImputationWarning):
1207-
theta2 = pm.Normal("theta2", mu=theta1, observed=obs2, rng=rng)
1207+
theta2 = pm.Normal("theta2", mu=theta1, observed=obs2)
12081208

12091209
assert isinstance(model.rvs_to_transforms[model["theta1_missing"]], IntervalTransform)
12101210
assert model.rvs_to_transforms[model["theta1_observed"]] is None
12111211

1212-
prior_trace = pm.sample_prior_predictive(return_inferencedata=False)
1212+
prior_trace = pm.sample_prior_predictive(random_seed=rng, return_inferencedata=False)
1213+
assert set(prior_trace.keys()) == {
1214+
"theta1",
1215+
"theta1_observed",
1216+
"theta1_missing",
1217+
"theta2",
1218+
"theta2_observed",
1219+
"theta2_missing",
1220+
}
12131221

12141222
# Make sure the observed + missing combined deterministics have the
12151223
# same shape as the original observations vectors
@@ -1237,23 +1245,47 @@ def test_interval_missing_observations(self):
12371245
== 0.0
12381246
)
12391247

1240-
assert {"theta1", "theta2"} <= set(prior_trace.keys())
1241-
12421248
trace = pm.sample(
1243-
chains=1, draws=50, compute_convergence_checks=False, return_inferencedata=False
1249+
chains=1,
1250+
draws=50,
1251+
compute_convergence_checks=False,
1252+
return_inferencedata=False,
1253+
random_seed=rng,
12441254
)
1255+
assert set(trace.varnames) == {
1256+
"theta1",
1257+
"theta1_missing",
1258+
"theta1_missing_interval__",
1259+
"theta2",
1260+
"theta2_missing",
1261+
}
12451262

1263+
# Make sure that the missing values are newly generated samples and that
1264+
# the observed and deterministic match
12461265
assert np.all(0 < trace["theta1_missing"].mean(0))
12471266
assert np.all(0 < trace["theta2_missing"].mean(0))
1248-
assert "theta1" not in trace.varnames
1249-
assert "theta2" not in trace.varnames
1267+
assert np.isclose(np.mean(trace["theta1"][:, obs1.mask] - trace["theta1_missing"]), 0)
1268+
assert np.isclose(np.mean(trace["theta2"][:, obs2.mask] - trace["theta2_missing"]), 0)
12501269

1251-
# Make sure that the observed values are newly generated samples and that
1252-
# the observed and deterministic matche
1253-
pp_idata = pm.sample_posterior_predictive(trace)
1270+
# Make sure that the observed values are unchanged
1271+
assert np.allclose(np.var(trace["theta1"][:, ~obs1.mask], 0), 0.0)
1272+
assert np.allclose(np.var(trace["theta2"][:, ~obs2.mask], 0), 0.0)
1273+
np.testing.assert_array_equal(trace["theta1"][0][~obs1.mask], obs1[~obs1.mask])
1274+
np.testing.assert_array_equal(trace["theta2"][0][~obs2.mask], obs1[~obs2.mask])
1275+
1276+
pp_idata = pm.sample_posterior_predictive(trace, random_seed=rng)
12541277
pp_trace = pp_idata.posterior_predictive.stack(sample=["chain", "draw"]).transpose(
12551278
"sample", ...
12561279
)
1280+
assert set(pp_trace.keys()) == {
1281+
"theta1",
1282+
"theta1_observed",
1283+
"theta2",
1284+
"theta2_observed",
1285+
}
1286+
1287+
# Make sure that the observed values are newly generated samples and that
1288+
# the observed and deterministic match
12571289
assert np.all(np.var(pp_trace["theta1"], 0) > 0.0)
12581290
assert np.all(np.var(pp_trace["theta2"], 0) > 0.0)
12591291
assert np.isclose(

0 commit comments

Comments
 (0)