diff --git a/pymc/model.py b/pymc/model.py index c211636b2b..26b8ed6600 100644 --- a/pymc/model.py +++ b/pymc/model.py @@ -1734,7 +1734,7 @@ def point_logps(self, point=None, round_vals=2): Model._context_class = Model -def set_data(new_data, model=None): +def set_data(new_data, model=None, *, coords=None): """Sets the value of one or more data container variables. Parameters @@ -1771,7 +1771,7 @@ def set_data(new_data, model=None): model = modelcontext(model) for variable_name, new_value in new_data.items(): - model.set_data(variable_name, new_value) + model.set_data(variable_name, new_value, coords=coords) def compile_fn(outs, mode=None, point_fn=True, model=None, **kwargs): diff --git a/pymc/tests/test_data_container.py b/pymc/tests/test_data_container.py index 8c5ff71c79..2bcceaf5d7 100644 --- a/pymc/tests/test_data_container.py +++ b/pymc/tests/test_data_container.py @@ -94,6 +94,31 @@ def test_sample_posterior_predictive_after_set_data(self): x_test, y_test.posterior_predictive["obs"].mean(("chain", "draw")), atol=1e-1 ) + def test_sample_posterior_predictive_after_set_data_with_coords(self): + y = np.array([1.0, 2.0, 3.0]) + with pm.Model() as model: + x = pm.MutableData("x", [1.0, 2.0, 3.0], dims="obs_id") + beta = pm.Normal("beta", 0, 10.0) + pm.Normal("obs", beta * x, np.sqrt(1e-2), observed=y, dims="obs_id") + idata = pm.sample( + 10, + tune=100, + chains=1, + return_inferencedata=True, + compute_convergence_checks=False, + ) + # Predict on new data. + with model: + x_test = [5, 6] + pm.set_data(new_data={"x": x_test}, coords={"obs_id": ["a", "b"]}) + pm.sample_posterior_predictive(idata, extend_inferencedata=True, predictions=True) + + assert idata.predictions["obs"].shape == (1, 10, 2) + assert np.all(idata.predictions["obs_id"].values == np.array(["a", "b"])) + np.testing.assert_allclose( + x_test, idata.predictions["obs"].mean(("chain", "draw")), atol=1e-1 + ) + def test_sample_after_set_data(self): with pm.Model() as model: x = pm.MutableData("x", [1.0, 2.0, 3.0])