Skip to content

Commit 1588e1f

Browse files
code review
1 parent 7f6abd4 commit 1588e1f

File tree

3 files changed

+2
-28
lines changed

3 files changed

+2
-28
lines changed

pymc/backends/arviz.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -169,10 +169,6 @@ def __init__(
169169
self.nchains = self.ndraws = 0
170170

171171
self.prior = prior
172-
self.prior_predictions = None
173-
if self.prior is not None:
174-
self.prior_predictions = True
175-
176172
self.posterior_predictive = posterior_predictive
177173
self.log_likelihood = log_likelihood
178174
self.predictions = predictions
@@ -468,7 +464,6 @@ def observed_data_to_xarray(self):
468464
default_dims=[],
469465
)
470466

471-
@requires(["trace", "predictions", "prior_predictions"])
472467
@requires("model")
473468
def constant_data_to_xarray(self):
474469
"""Convert constant data to xarray."""

pymc/tests/test_idata_conversion.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -502,7 +502,7 @@ def test_no_trace(self):
502502

503503
@pytest.mark.parametrize("use_context", [True, False])
504504
def test_priors_separation(self, use_context):
505-
"""Test model is enough to get prior, prior predictive and observed_data."""
505+
"""Test model is enough to get prior, prior predictive, constant_data and observed_data."""
506506
with pm.Model() as model:
507507
x = pm.MutableData("x", [1.0, 2.0, 3.0])
508508
y = pm.ConstantData("y", [1.0, 2.0, 3.0])
@@ -514,6 +514,7 @@ def test_priors_separation(self, use_context):
514514
"prior": ["beta", "~obs"],
515515
"observed_data": ["obs"],
516516
"prior_predictive": ["obs"],
517+
"constant_data": ["x", "y"],
517518
}
518519
if use_context:
519520
with model:

pymc/tests/test_sampling.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1042,28 +1042,6 @@ def point_list_arg_bug_fixture() -> Tuple[pm.Model, pm.backends.base.MultiTrace]
10421042

10431043

10441044
class TestSamplePriorPredictive(SeededTest):
1045-
def test_idata_output(self):
1046-
"""This test controls that returned idata
1047-
contains all expected groups"""
1048-
1049-
with pm.Model() as model:
1050-
x = pm.MutableData("x", [1, 2, 3])
1051-
y = pm.MutableData("y", [1.1, 1.9, 3.1])
1052-
a = pm.Normal("a", mu=1, sigma=1)
1053-
b = pm.Normal("b", mu=0, sigma=1)
1054-
mu = pm.Deterministic("mu", var=a * x + b)
1055-
obs = pm.Normal("obs", mu=mu, sigma=1, observed=y)
1056-
idata = pm.sample_prior_predictive(samples=10)
1057-
1058-
test_dict = {
1059-
"prior": ["a", "b", "mu"],
1060-
"prior_predictive": ["obs"],
1061-
"observed_data": ["obs"],
1062-
"constant_data": ["x", "y"],
1063-
}
1064-
fails = check_multiple_attrs(test_dict, idata)
1065-
assert not fails
1066-
10671045
def test_ignores_observed(self):
10681046
observed = np.random.normal(10, 1, size=200)
10691047
with pm.Model():

0 commit comments

Comments
 (0)