From e9bf9f9e21c0a8a98c06e70e6ce51fe828b97494 Mon Sep 17 00:00:00 2001 From: Somasree Majumder Date: Sat, 12 Mar 2022 19:54:16 +0530 Subject: [PATCH 1/5] Aadd coords argument to pymc.set_data --- pymc/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymc/model.py b/pymc/model.py index c211636b2b..ac9aed60f7 100644 --- a/pymc/model.py +++ b/pymc/model.py @@ -1734,7 +1734,7 @@ def point_logps(self, point=None, round_vals=2): Model._context_class = Model -def set_data(new_data, model=None): +def set_data(new_data, coords,model=None): """Sets the value of one or more data container variables. Parameters @@ -1772,7 +1772,7 @@ def set_data(new_data, model=None): for variable_name, new_value in new_data.items(): model.set_data(variable_name, new_value) - + return model.set_data(new_data, coords) def compile_fn(outs, mode=None, point_fn=True, model=None, **kwargs): """Compiles an Aesara function which returns ``outs`` and takes values of model From 7a1c71d7407a2a14287a8a011b16e120a395ca8f Mon Sep 17 00:00:00 2001 From: Somasree Majumder Date: Sun, 13 Mar 2022 01:56:07 +0530 Subject: [PATCH 2/5] Adding --- pymc/model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pymc/model.py b/pymc/model.py index ac9aed60f7..98de66ba85 100644 --- a/pymc/model.py +++ b/pymc/model.py @@ -1734,7 +1734,7 @@ def point_logps(self, point=None, round_vals=2): Model._context_class = Model -def set_data(new_data, coords,model=None): +def set_data(new_data, model=None, *, coords=None): """Sets the value of one or more data container variables. Parameters @@ -1771,8 +1771,8 @@ def set_data(new_data, coords,model=None): model = modelcontext(model) for variable_name, new_value in new_data.items(): - model.set_data(variable_name, new_value) - return model.set_data(new_data, coords) + model.set_data(variable_name, new_value,coords) + """return model.set_data(new_data, coords)""" def compile_fn(outs, mode=None, point_fn=True, model=None, **kwargs): """Compiles an Aesara function which returns ``outs`` and takes values of model From e8fe4a0c815f5a3240d2408c28728b488eea279c Mon Sep 17 00:00:00 2001 From: Somasree Majumder <56045049+soma2000-lang@users.noreply.github.com> Date: Fri, 18 Mar 2022 03:17:17 +0530 Subject: [PATCH 3/5] Update model.py --- pymc/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/model.py b/pymc/model.py index 98de66ba85..f3ad020b01 100644 --- a/pymc/model.py +++ b/pymc/model.py @@ -1772,7 +1772,7 @@ def set_data(new_data, model=None, *, coords=None): for variable_name, new_value in new_data.items(): model.set_data(variable_name, new_value,coords) - """return model.set_data(new_data, coords)""" + def compile_fn(outs, mode=None, point_fn=True, model=None, **kwargs): """Compiles an Aesara function which returns ``outs`` and takes values of model From 855f55c75cbb353b986c2c41042127130d30cc23 Mon Sep 17 00:00:00 2001 From: Somasree Majumder <56045049+soma2000-lang@users.noreply.github.com> Date: Fri, 18 Mar 2022 03:18:35 +0530 Subject: [PATCH 4/5] Update model.py --- pymc/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/model.py b/pymc/model.py index f3ad020b01..0b6efe9529 100644 --- a/pymc/model.py +++ b/pymc/model.py @@ -1771,7 +1771,7 @@ def set_data(new_data, model=None, *, coords=None): model = modelcontext(model) for variable_name, new_value in new_data.items(): - model.set_data(variable_name, new_value,coords) + model.set_data(variable_name, new_value,coords=coords) def compile_fn(outs, mode=None, point_fn=True, model=None, **kwargs): From 654418b27e4ab2a357cb9163212eda25eb4dae0b Mon Sep 17 00:00:00 2001 From: "Oriol (ZBook)" Date: Sun, 24 Apr 2022 01:53:10 +0300 Subject: [PATCH 5/5] add test and run black --- pymc/model.py | 4 ++-- pymc/tests/test_data_container.py | 25 +++++++++++++++++++++++++ 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/pymc/model.py b/pymc/model.py index 0b6efe9529..26b8ed6600 100644 --- a/pymc/model.py +++ b/pymc/model.py @@ -1771,8 +1771,8 @@ def set_data(new_data, model=None, *, coords=None): model = modelcontext(model) for variable_name, new_value in new_data.items(): - model.set_data(variable_name, new_value,coords=coords) - + model.set_data(variable_name, new_value, coords=coords) + def compile_fn(outs, mode=None, point_fn=True, model=None, **kwargs): """Compiles an Aesara function which returns ``outs`` and takes values of model diff --git a/pymc/tests/test_data_container.py b/pymc/tests/test_data_container.py index 8c5ff71c79..2bcceaf5d7 100644 --- a/pymc/tests/test_data_container.py +++ b/pymc/tests/test_data_container.py @@ -94,6 +94,31 @@ def test_sample_posterior_predictive_after_set_data(self): x_test, y_test.posterior_predictive["obs"].mean(("chain", "draw")), atol=1e-1 ) + def test_sample_posterior_predictive_after_set_data_with_coords(self): + y = np.array([1.0, 2.0, 3.0]) + with pm.Model() as model: + x = pm.MutableData("x", [1.0, 2.0, 3.0], dims="obs_id") + beta = pm.Normal("beta", 0, 10.0) + pm.Normal("obs", beta * x, np.sqrt(1e-2), observed=y, dims="obs_id") + idata = pm.sample( + 10, + tune=100, + chains=1, + return_inferencedata=True, + compute_convergence_checks=False, + ) + # Predict on new data. + with model: + x_test = [5, 6] + pm.set_data(new_data={"x": x_test}, coords={"obs_id": ["a", "b"]}) + pm.sample_posterior_predictive(idata, extend_inferencedata=True, predictions=True) + + assert idata.predictions["obs"].shape == (1, 10, 2) + assert np.all(idata.predictions["obs_id"].values == np.array(["a", "b"])) + np.testing.assert_allclose( + x_test, idata.predictions["obs"].mean(("chain", "draw")), atol=1e-1 + ) + def test_sample_after_set_data(self): with pm.Model() as model: x = pm.MutableData("x", [1.0, 2.0, 3.0])