@@ -36,7 +36,7 @@ def test_pathfinder_outputs():
36
36
37
37
assert pathfinder .is_resampled
38
38
39
- assert pathfinder .draws ().shape == (draws , 3 )
39
+ assert pathfinder .draws ().shape == (draws , 4 )
40
40
41
41
42
42
def test_pathfinder_from_csv ():
@@ -159,7 +159,7 @@ def test_pathfinder_no_psis():
159
159
pathfinder = bern_model .pathfinder (data = jdata , psis_resample = False )
160
160
161
161
assert not pathfinder .is_resampled
162
- assert pathfinder .draws ().shape == (4000 , 3 )
162
+ assert pathfinder .draws ().shape == (4000 , 4 )
163
163
164
164
165
165
def test_pathfinder_no_lp_calc ():
@@ -170,7 +170,7 @@ def test_pathfinder_no_lp_calc():
170
170
pathfinder = bern_model .pathfinder (data = jdata , calculate_lp = False )
171
171
172
172
assert not pathfinder .is_resampled
173
- assert pathfinder .draws ().shape == (4000 , 3 )
173
+ assert pathfinder .draws ().shape == (4000 , 4 )
174
174
n_lp_nan = np .sum (np .isnan (pathfinder .method_variables ()['lp__' ]))
175
175
assert n_lp_nan < 4000 # some lp still calculated during pathfinder
176
176
assert n_lp_nan > 3000 # but most are not
@@ -190,4 +190,4 @@ def test_pathfinder_threads():
190
190
stan_file = stan , cpp_options = {'STAN_THREADS' : True }, force_compile = True
191
191
)
192
192
pathfinder = bern_model .pathfinder (data = jdata , num_threads = 4 )
193
- assert pathfinder .draws ().shape == (1000 , 3 )
193
+ assert pathfinder .draws ().shape == (1000 , 4 )
0 commit comments