diff --git a/python/nutpie/compile_pymc.py b/python/nutpie/compile_pymc.py index 7cfe121..c714146 100644 --- a/python/nutpie/compile_pymc.py +++ b/python/nutpie/compile_pymc.py @@ -131,7 +131,7 @@ def with_data(self, **updates): if name not in shared_data: raise KeyError(f"Unknown shared variable: {name}") old_val = shared_data[name] - new_val = np.asarray(new_val, dtype=old_val.dtype).copy() + new_val = np.array(new_val, dtype=old_val.dtype, order="C", copy=True) new_val.flags.writeable = False if old_val.ndim != new_val.ndim: raise ValueError( @@ -256,7 +256,7 @@ def _compile_pymc_model_numba( for val in [*logp_fn_pt.get_shared(), *expand_fn_pt.get_shared()]: if val.name in shared_data and val not in seen: raise ValueError(f"Shared variables must have unique names: {val.name}") - shared_data[val.name] = val.get_value() + shared_data[val.name] = np.array(val.get_value(), order="C", copy=True) shared_vars[val.name] = val seen.add(val) diff --git a/tests/test_pymc.py b/tests/test_pymc.py index 51e2159..938f36c 100644 --- a/tests/test_pymc.py +++ b/tests/test_pymc.py @@ -31,6 +31,38 @@ def test_pymc_model(backend, gradient_backend): trace.posterior.a # noqa: B018 +@pytest.mark.pymc +def test_order_shared(): + a_val = np.array([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]) + with pm.Model() as model: + a = pm.Data("a", np.copy(a_val, order="C")) + b = pm.Normal("b", shape=(2, 5)) + pm.Deterministic("c", (a[:, None, :] * b[:, :, None]).sum(-1)) + + compiled = nutpie.compile_pymc_model(model, backend="numba") + trace = nutpie.sample(compiled) + np.testing.assert_allclose( + ( + trace.posterior.b.values[:, :, :, :, None] * a_val[None, None, :, None, :] + ).sum(-1), + trace.posterior.c.values, + ) + + with pm.Model() as model: + a = pm.Data("a", np.copy(a_val, order="F")) + b = pm.Normal("b", shape=(2, 5)) + pm.Deterministic("c", (a[:, None, :] * b[:, :, None]).sum(-1)) + + compiled = nutpie.compile_pymc_model(model, backend="numba") + trace = nutpie.sample(compiled) + np.testing.assert_allclose( + ( + trace.posterior.b.values[:, :, :, :, None] * a_val[None, None, :, None, :] + ).sum(-1), + trace.posterior.c.values, + ) + + @pytest.mark.pymc @parameterize_backends def test_low_rank(backend, gradient_backend):