Skip to content

Commit bae5087

Browse files
Add coords argument to pymc.set_data (#5588)
Co-authored-by: Oriol (ZBook) <[email protected]>
1 parent e4ec363 commit bae5087

File tree

2 files changed

+27
-2
lines changed

2 files changed

+27
-2
lines changed

pymc/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1734,7 +1734,7 @@ def point_logps(self, point=None, round_vals=2):
17341734
Model._context_class = Model
17351735

17361736

1737-
def set_data(new_data, model=None):
1737+
def set_data(new_data, model=None, *, coords=None):
17381738
"""Sets the value of one or more data container variables.
17391739
17401740
Parameters
@@ -1771,7 +1771,7 @@ def set_data(new_data, model=None):
17711771
model = modelcontext(model)
17721772

17731773
for variable_name, new_value in new_data.items():
1774-
model.set_data(variable_name, new_value)
1774+
model.set_data(variable_name, new_value, coords=coords)
17751775

17761776

17771777
def compile_fn(outs, mode=None, point_fn=True, model=None, **kwargs):

pymc/tests/test_data_container.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,31 @@ def test_sample_posterior_predictive_after_set_data(self):
9494
x_test, y_test.posterior_predictive["obs"].mean(("chain", "draw")), atol=1e-1
9595
)
9696

97+
def test_sample_posterior_predictive_after_set_data_with_coords(self):
98+
y = np.array([1.0, 2.0, 3.0])
99+
with pm.Model() as model:
100+
x = pm.MutableData("x", [1.0, 2.0, 3.0], dims="obs_id")
101+
beta = pm.Normal("beta", 0, 10.0)
102+
pm.Normal("obs", beta * x, np.sqrt(1e-2), observed=y, dims="obs_id")
103+
idata = pm.sample(
104+
10,
105+
tune=100,
106+
chains=1,
107+
return_inferencedata=True,
108+
compute_convergence_checks=False,
109+
)
110+
# Predict on new data.
111+
with model:
112+
x_test = [5, 6]
113+
pm.set_data(new_data={"x": x_test}, coords={"obs_id": ["a", "b"]})
114+
pm.sample_posterior_predictive(idata, extend_inferencedata=True, predictions=True)
115+
116+
assert idata.predictions["obs"].shape == (1, 10, 2)
117+
assert np.all(idata.predictions["obs_id"].values == np.array(["a", "b"]))
118+
np.testing.assert_allclose(
119+
x_test, idata.predictions["obs"].mean(("chain", "draw")), atol=1e-1
120+
)
121+
97122
def test_sample_after_set_data(self):
98123
with pm.Model() as model:
99124
x = pm.MutableData("x", [1.0, 2.0, 3.0])

0 commit comments

Comments
 (0)