Skip to content

Commit f6f1a8e

Browse files
ricardoV94twiecki
authored andcommitted
Convert point_logps values to arrays
When compiling to JAX the returned object is a DeviceArray, which causes problems downstream in check_start_vals
1 parent 01ed358 commit f6f1a8e

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

pymc/model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1628,7 +1628,9 @@ def point_logps(self, point=None, round_vals=2):
16281628
return Series(
16291629
{
16301630
rv.name: np.round(
1631-
self.fn(logpt_sum(rv, getattr(rv.tag, "observations", None)))(point),
1631+
np.asarray(
1632+
self.fn(logpt_sum(rv, getattr(rv.tag, "observations", None)))(point)
1633+
),
16321634
round_vals,
16331635
)
16341636
for rv in self.basic_RVs

0 commit comments

Comments
 (0)