Skip to content

Commit 1a06d50

Browse files
committed
Raise NotImplementedError for multivariate CustomDists
1 parent 2303bf9 commit 1a06d50

File tree

2 files changed

+21
-0
lines changed

2 files changed

+21
-0
lines changed

pymc/distributions/distribution.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -488,6 +488,11 @@ def dist(
488488
class_name: str = "CustomDist",
489489
**kwargs,
490490
):
491+
if ndim_supp > 0:
492+
raise NotImplementedError(
493+
"CustomDist with ndim_supp > 0 and without a `dist` function are not supported."
494+
)
495+
491496
dist_params = [as_tensor_variable(param) for param in dist_params]
492497

493498
# Assume scalar ndims_params

tests/distributions/test_distribution.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,10 @@ def test_custom_dist_without_random(self):
217217
with pytest.raises(NotImplementedError):
218218
pm.sample_posterior_predictive(idata, model=model)
219219

220+
@pytest.mark.xfail(
221+
NotImplementedError,
222+
reason="Support shape of multivariate CustomDist cannot be inferred. See https://github.com/pymc-devs/pytensor/pull/388",
223+
)
220224
@pytest.mark.parametrize("size", [(), (3,), (3, 2)], ids=str)
221225
def test_custom_dist_with_random_multivariate(self, size):
222226
supp_shape = 5
@@ -264,6 +268,10 @@ def test_custom_dist_old_api_error(self):
264268
):
265269
CustomDist("a", lambda x: x)
266270

271+
@pytest.mark.xfail(
272+
NotImplementedError,
273+
reason="Support shape of multivariate CustomDist cannot be inferred. See https://github.com/pymc-devs/pytensor/pull/388",
274+
)
267275
@pytest.mark.parametrize("size", [None, (), (2,)], ids=str)
268276
def test_custom_dist_multivariate_logp(self, size):
269277
supp_shape = 5
@@ -314,6 +322,10 @@ def density_moment(rv, size, mu):
314322
assert evaled_moment.shape == to_tuple(size)
315323
assert np.all(evaled_moment == mu_val)
316324

325+
@pytest.mark.xfail(
326+
NotImplementedError,
327+
reason="Support shape of multivariate CustomDist cannot be inferred. See https://github.com/pymc-devs/pytensor/pull/388",
328+
)
317329
@pytest.mark.parametrize("size", [(), (2,), (3, 2)], ids=str)
318330
def test_custom_dist_custom_moment_multivariate(self, size):
319331
def density_moment(rv, size, mu):
@@ -328,6 +340,10 @@ def density_moment(rv, size, mu):
328340
assert evaled_moment.shape == to_tuple(size) + (5,)
329341
assert np.all(evaled_moment == mu_val)
330342

343+
@pytest.mark.xfail(
344+
NotImplementedError,
345+
reason="Support shape of multivariate CustomDist cannot be inferred. See https://github.com/pymc-devs/pytensor/pull/388",
346+
)
331347
@pytest.mark.parametrize(
332348
"with_random, size",
333349
[

0 commit comments

Comments
 (0)