Skip to content

Commit c937001

Browse files
committed
👌 use broadcast_dist_samples_shape
1 parent 8c12894 commit c937001

File tree

4 files changed

+114
-33
lines changed

4 files changed

+114
-33
lines changed

docs/source/api/shape_utils.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,6 @@ This module introduces functions that are made aware of the requested `size_tupl
1414
:toctree: generated/
1515

1616
to_tuple
17+
broadcast_dist_samples_shape
1718
rv_size_is_none
1819
change_dist_size

pymc/distributions/multivariate.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
)
6060
from pymc.distributions.shape_utils import (
6161
_change_dist_size,
62+
broadcast_dist_samples_shape,
6263
change_dist_size,
6364
get_support_shape,
6465
rv_size_is_none,
@@ -1651,8 +1652,8 @@ def rng_fn(cls, rng, mu, rowchol, colchol, size=None):
16511652

16521653
# Broadcasting all parameters
16531654
shapes = [mu.shape, output_shape]
1654-
sp_shapes = [s[len(size) :] if size == s[: min([len(size), len(s)])] else s for s in shapes]
1655-
mu = np.broadcast_to(mu, shape=np.broadcast_shapes(*sp_shapes))
1655+
broadcastable_shape = broadcast_dist_samples_shape(shapes, size=size)
1656+
mu = np.broadcast_to(mu, shape=broadcastable_shape)
16561657
rowchol = np.broadcast_to(rowchol, shape=size + rowchol.shape[-2:])
16571658

16581659
colchol = np.broadcast_to(colchol, shape=size + colchol.shape[-2:])

pymc/distributions/shape_utils.py

Lines changed: 72 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from pymc.pytensorf import convert_observed_data
3939

4040
__all__ = [
41+
"broadcast_dist_samples_shape",
4142
"to_tuple",
4243
"rv_size_is_none",
4344
"change_dist_size",
@@ -86,45 +87,85 @@ def _check_shape_type(shape):
8687
return tuple(out)
8788

8889

89-
def shapes_broadcasting(*args, raise_exception=False):
90-
"""Return the shape resulting from broadcasting multiple shapes.
91-
Represents numpy's broadcasting rules.
90+
def broadcast_dist_samples_shape(shapes, size=None):
91+
"""Apply shape broadcasting to shape tuples but assuming that the shapes
92+
correspond to draws from random variables, with the `size` tuple possibly
93+
prepended to it. The `size` prepend is ignored to consider if the supplied
94+
`shapes` can broadcast or not. It is prepended to the resulting broadcasted
95+
`shapes`, if any of the shape tuples had the `size` prepend.
9296
9397
Parameters
9498
----------
95-
*args: array-like of int
96-
Tuples or arrays or lists representing the shapes of arrays to be
97-
broadcast.
98-
raise_exception: bool (optional)
99-
Controls whether to raise an exception or simply return `None` if
100-
the broadcasting fails.
99+
shapes: Iterable of tuples holding the distribution samples shapes
100+
size: None, int or tuple (optional)
101+
size of the sample set requested.
101102
102103
Returns
103104
-------
104-
Resulting shape. If broadcasting is not possible and `raise_exception` is
105-
False, then `None` is returned. If `raise_exception` is `True`, a
106-
`ValueError` is raised.
105+
tuple of the resulting shape
106+
107+
Examples
108+
--------
109+
.. code-block:: python
110+
size = 100
111+
shape0 = (size,)
112+
shape1 = (size, 5)
113+
shape2 = (size, 4, 5)
114+
out = broadcast_dist_samples_shape([shape0, shape1, shape2],
115+
size=size)
116+
assert out == (size, 4, 5)
117+
.. code-block:: python
118+
size = 100
119+
shape0 = (size,)
120+
shape1 = (5,)
121+
shape2 = (4, 5)
122+
out = broadcast_dist_samples_shape([shape0, shape1, shape2],
123+
size=size)
124+
assert out == (size, 4, 5)
125+
.. code-block:: python
126+
size = 100
127+
shape0 = (1,)
128+
shape1 = (5,)
129+
shape2 = (4, 5)
130+
out = broadcast_dist_samples_shape([shape0, shape1, shape2],
131+
size=size)
132+
assert out == (4, 5)
107133
"""
108-
x = list(_check_shape_type(args[0])) if args else ()
109-
for arg in args[1:]:
110-
y = list(_check_shape_type(arg))
111-
if len(x) < len(y):
112-
x, y = y, x
113-
if len(y) > 0:
114-
x[-len(y) :] = [
115-
j if i == 1 else i if j == 1 else i if i == j else 0
116-
for i, j in zip(x[-len(y) :], y)
117-
]
118-
if not all(x):
119-
if raise_exception:
120-
raise ValueError(
121-
"Supplied shapes {} do not broadcast together".format(
122-
", ".join([f"{a}" for a in args])
123-
)
134+
if size is None:
135+
broadcasted_shape = np.broadcast_shapes(*shapes)
136+
if broadcasted_shape is None:
137+
raise ValueError(
138+
"Cannot broadcast provided shapes {} given size: {}".format(
139+
", ".join([f"{s}" for s in shapes]), size
124140
)
125-
else:
126-
return None
127-
return tuple(x)
141+
)
142+
return broadcasted_shape
143+
shapes = [_check_shape_type(s) for s in shapes]
144+
_size = to_tuple(size)
145+
# samples shapes without the size prepend
146+
sp_shapes = [s[len(_size) :] if _size == s[: min([len(_size), len(s)])] else s for s in shapes]
147+
try:
148+
broadcast_shape = np.broadcast_shapes(*sp_shapes)
149+
except ValueError:
150+
raise ValueError(
151+
"Cannot broadcast provided shapes {} given size: {}".format(
152+
", ".join([f"{s}" for s in shapes]), size
153+
)
154+
)
155+
broadcastable_shapes = []
156+
for shape, sp_shape in zip(shapes, sp_shapes):
157+
if _size == shape[: len(_size)]:
158+
# If size prepends the shape, then we have to add broadcasting axis
159+
# in the middle
160+
p_shape = (
161+
shape[: len(_size)]
162+
+ (1,) * (len(broadcast_shape) - len(sp_shape))
163+
+ shape[len(_size) :]
164+
)
165+
else:
166+
p_shape = shape
167+
broadcastable_shapes.append(p_shape)
168+
return np.broadcast_shapes(*broadcastable_shapes)
128169

129170

130171
# User-provided can be lazily specified as scalars

tests/distributions/test_shape_utils.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030
from pymc import ShapeError
3131
from pymc.distributions.shape_utils import (
32+
broadcast_dist_samples_shape,
3233
change_dist_size,
3334
convert_dims,
3435
convert_shape,
@@ -85,6 +86,43 @@ def fixture_exception_handling(request):
8586
return request.param
8687

8788

89+
class TestShapesBroadcasting:
90+
def test_broadcasting(self, fixture_shapes):
91+
shapes = fixture_shapes
92+
try:
93+
expected_out = np.broadcast(*(np.empty(s) for s in shapes)).shape
94+
except ValueError:
95+
expected_out = None
96+
if expected_out is None:
97+
with pytest.raises(ValueError):
98+
np.broadcast_shapes(*shapes)
99+
else:
100+
out = np.broadcast_shapes(*shapes)
101+
assert out == expected_out
102+
103+
def test_broadcast_dist_samples_shape(self, fixture_sizes, fixture_shapes):
104+
size = fixture_sizes
105+
shapes = fixture_shapes
106+
size_ = to_tuple(size)
107+
shapes_ = [
108+
s if s[: min([len(size_), len(s)])] != size_ else s[len(size_) :] for s in shapes
109+
]
110+
try:
111+
expected_out = np.broadcast(*(np.empty(s) for s in shapes_)).shape
112+
except ValueError:
113+
expected_out = None
114+
if expected_out is not None and any(
115+
s[: min([len(size_), len(s)])] == size_ for s in shapes
116+
):
117+
expected_out = size_ + expected_out
118+
if expected_out is None:
119+
with pytest.raises(ValueError):
120+
broadcast_dist_samples_shape(shapes, size=size)
121+
else:
122+
out = broadcast_dist_samples_shape(shapes, size=size)
123+
assert out == expected_out
124+
125+
88126
class TestSizeShapeDimsObserved:
89127
@pytest.mark.parametrize("param_shape", [(), (2,)])
90128
@pytest.mark.parametrize("batch_shape", [(), (3,)])

0 commit comments

Comments
 (0)