Skip to content

Commit 4a731c0

Browse files
committed
fix(numba): non-contiguous shared variable
Shared variables (ie pm.Data) that were non-contiguous could lead to incorrect results in the pymc numba backend. We now ensure that they are always c-contiguous by copying if they are not.
1 parent 724e620 commit 4a731c0

File tree

2 files changed

+38
-2
lines changed

2 files changed

+38
-2
lines changed

python/nutpie/compile_pymc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def with_data(self, **updates):
131131
if name not in shared_data:
132132
raise KeyError(f"Unknown shared variable: {name}")
133133
old_val = shared_data[name]
134-
new_val = np.asarray(new_val, dtype=old_val.dtype).copy()
134+
new_val = np.copy(new_val.astype(old_val.dtype), order="C")
135135
new_val.flags.writeable = False
136136
if old_val.ndim != new_val.ndim:
137137
raise ValueError(
@@ -256,7 +256,7 @@ def _compile_pymc_model_numba(
256256
for val in [*logp_fn_pt.get_shared(), *expand_fn_pt.get_shared()]:
257257
if val.name in shared_data and val not in seen:
258258
raise ValueError(f"Shared variables must have unique names: {val.name}")
259-
shared_data[val.name] = val.get_value()
259+
shared_data[val.name] = np.copy(val.get_value(), order="C")
260260
shared_vars[val.name] = val
261261
seen.add(val)
262262

tests/test_pymc.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,42 @@ def test_pymc_model(backend, gradient_backend):
3131
trace.posterior.a # noqa: B018
3232

3333

34+
@pytest.mark.pymc
35+
def test_order_shared():
36+
a_val = np.array([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])
37+
with pm.Model() as model:
38+
a = pm.Data("a", np.copy(a_val, order="C"))
39+
b = pm.Normal("b", shape=(2, 5))
40+
pm.Deterministic("c", (a[:, None, :] * b[:, :, None]).sum(-1))
41+
42+
compiled = nutpie.compile_pymc_model(model, backend="numba")
43+
trace = nutpie.sample(compiled)
44+
np.testing.assert_allclose(
45+
(
46+
trace.posterior.b.values[:, :, :, :, None]
47+
* a_val[None, None, :, None, :]
48+
).sum(-1),
49+
trace.posterior.c.values,
50+
)
51+
52+
with pm.Model() as model:
53+
a = pm.Data(
54+
"a", np.copy(a_val, order="F")
55+
)
56+
b = pm.Normal("b", shape=(2, 5))
57+
pm.Deterministic("c", (a[:, None, :] * b[:, :, None]).sum(-1))
58+
59+
compiled = nutpie.compile_pymc_model(model, backend="numba")
60+
trace = nutpie.sample(compiled)
61+
np.testing.assert_allclose(
62+
(
63+
trace.posterior.b.values[:, :, :, :, None]
64+
* a_val[None, None, :, None, :]
65+
).sum(-1),
66+
trace.posterior.c.values,
67+
)
68+
69+
3470
@pytest.mark.pymc
3571
@parameterize_backends
3672
def test_low_rank(backend, gradient_backend):

0 commit comments

Comments
 (0)