Skip to content

Commit 985d3cd

Browse files
Prevent SciPy error by using float64 point in test_dirichlet_with_batch_shapes
1 parent 955370c commit 985d3cd

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

pymc3/tests/test_distributions.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -1920,8 +1920,11 @@ def test_dirichlet_with_batch_shapes(self, dist_shape):
19201920
with pm.Model() as model:
19211921
d = pm.Dirichlet("d", a=a)
19221922

1923+
# Generate sample points to test
19231924
d_value = d.tag.value_var
1924-
d_point = d.eval()
1925+
d_point = d.eval().astype("float64")
1926+
d_point /= d_point.sum(axis=-1)[..., None]
1927+
19251928
if hasattr(d_value.tag, "transform"):
19261929
d_point_trans = d_value.tag.transform.forward(d, aet.as_tensor(d_point)).eval()
19271930
else:

0 commit comments

Comments
 (0)