@@ -94,6 +94,31 @@ def test_sample_posterior_predictive_after_set_data(self):
94
94
x_test , y_test .posterior_predictive ["obs" ].mean (("chain" , "draw" )), atol = 1e-1
95
95
)
96
96
97
+ def test_sample_posterior_predictive_after_set_data_with_coords (self ):
98
+ y = np .array ([1.0 , 2.0 , 3.0 ])
99
+ with pm .Model () as model :
100
+ x = pm .MutableData ("x" , [1.0 , 2.0 , 3.0 ], dims = "obs_id" )
101
+ beta = pm .Normal ("beta" , 0 , 10.0 )
102
+ pm .Normal ("obs" , beta * x , np .sqrt (1e-2 ), observed = y , dims = "obs_id" )
103
+ idata = pm .sample (
104
+ 10 ,
105
+ tune = 100 ,
106
+ chains = 1 ,
107
+ return_inferencedata = True ,
108
+ compute_convergence_checks = False ,
109
+ )
110
+ # Predict on new data.
111
+ with model :
112
+ x_test = [5 , 6 ]
113
+ pm .set_data (new_data = {"x" : x_test }, coords = {"obs_id" : ["a" , "b" ]})
114
+ pm .sample_posterior_predictive (idata , extend_inferencedata = True , predictions = True )
115
+
116
+ assert idata .predictions ["obs" ].shape == (1 , 10 , 2 )
117
+ assert np .all (idata .predictions ["obs_id" ].values == np .array (["a" , "b" ]))
118
+ np .testing .assert_allclose (
119
+ x_test , idata .predictions ["obs" ].mean (("chain" , "draw" )), atol = 1e-1
120
+ )
121
+
97
122
def test_sample_after_set_data (self ):
98
123
with pm .Model () as model :
99
124
x = pm .MutableData ("x" , [1.0 , 2.0 , 3.0 ])
0 commit comments