Skip to content

Commit f3c29b3

Browse files
authored
Merge pull request #797 from stan-dev/2.37-tests
Update tests for new cmdstan defaults
2 parents 0b08591 + 900299a commit f3c29b3

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

test/test_log_prob.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
["-5.5395901199", "-1.4903938392"],
3131
),
3232
(3, ["-7.02", "-1.19"], ["-5.54", "-1.49"]),
33-
(None, ["-7.02147", "-1.18847"], ["-5.53959", "-1.49039"]),
33+
(None, ["-7.0214668", "-1.1884726"], ["-5.5395901", "-1.4903938"]),
3434
],
3535
)
3636
def test_lp_good(

test/test_pathfinder.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def test_pathfinder_outputs():
3636

3737
assert pathfinder.is_resampled
3838

39-
assert pathfinder.draws().shape == (draws, 3)
39+
assert pathfinder.draws().shape == (draws, 4)
4040

4141

4242
def test_pathfinder_from_csv():
@@ -159,7 +159,7 @@ def test_pathfinder_no_psis():
159159
pathfinder = bern_model.pathfinder(data=jdata, psis_resample=False)
160160

161161
assert not pathfinder.is_resampled
162-
assert pathfinder.draws().shape == (4000, 3)
162+
assert pathfinder.draws().shape == (4000, 4)
163163

164164

165165
def test_pathfinder_no_lp_calc():
@@ -170,7 +170,7 @@ def test_pathfinder_no_lp_calc():
170170
pathfinder = bern_model.pathfinder(data=jdata, calculate_lp=False)
171171

172172
assert not pathfinder.is_resampled
173-
assert pathfinder.draws().shape == (4000, 3)
173+
assert pathfinder.draws().shape == (4000, 4)
174174
n_lp_nan = np.sum(np.isnan(pathfinder.method_variables()['lp__']))
175175
assert n_lp_nan < 4000 # some lp still calculated during pathfinder
176176
assert n_lp_nan > 3000 # but most are not
@@ -190,4 +190,4 @@ def test_pathfinder_threads():
190190
stan_file=stan, cpp_options={'STAN_THREADS': True}, force_compile=True
191191
)
192192
pathfinder = bern_model.pathfinder(data=jdata, num_threads=4)
193-
assert pathfinder.draws().shape == (1000, 3)
193+
assert pathfinder.draws().shape == (1000, 4)

0 commit comments

Comments
 (0)