Skip to content

Commit c38eea0

Browse files
committed
Fix shape inference for multivariate random Ops
When size is not provided, the batch shapes of the parameters were being broadcasted twice, and the second time, wrongly, due to mixing static shape of the original parameters and the potentially larger shape of the just broadcasted parameters.
1 parent 205da7f commit c38eea0

File tree

2 files changed

+52
-28
lines changed

2 files changed

+52
-28
lines changed

pytensor/tensor/random/op.py

Lines changed: 16 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,7 @@
2020
infer_static_shape,
2121
)
2222
from pytensor.tensor.random.type import RandomGeneratorType, RandomStateType, RandomType
23-
from pytensor.tensor.random.utils import (
24-
broadcast_params,
25-
normalize_size_param,
26-
params_broadcast_shapes,
27-
)
23+
from pytensor.tensor.random.utils import broadcast_params, normalize_size_param
2824
from pytensor.tensor.shape import shape_tuple
2925
from pytensor.tensor.type import TensorType, all_dtypes
3026
from pytensor.tensor.type_other import NoneConst
@@ -156,6 +152,13 @@ def _infer_shape(
156152

157153
from pytensor.tensor.extra_ops import broadcast_shape_iter
158154

155+
if self.ndim_supp == 0:
156+
supp_shape = ()
157+
else:
158+
supp_shape = tuple(
159+
self._supp_shape_from_params(dist_params, param_shapes=param_shapes)
160+
)
161+
159162
size_len = get_vector_length(size)
160163

161164
if size_len > 0:
@@ -171,30 +174,22 @@ def _infer_shape(
171174
f"Size length must be 0 or >= {param_batched_dims}"
172175
)
173176

174-
if self.ndim_supp == 0:
175-
return size
176-
else:
177-
supp_shape = self._supp_shape_from_params(
178-
dist_params, param_shapes=param_shapes
179-
)
180-
return tuple(size) + tuple(supp_shape)
181-
182-
# Broadcast the parameters
183-
param_shapes = params_broadcast_shapes(
184-
param_shapes or [shape_tuple(p) for p in dist_params],
185-
self.ndims_params,
186-
)
177+
return tuple(size) + supp_shape
178+
179+
# Size was not provided, we must infer it from the shape of the parameters
180+
if param_shapes is None:
181+
param_shapes = [shape_tuple(p) for p in dist_params]
187182

188183
def extract_batch_shape(p, ps, n):
189184
shape = tuple(ps)
190185

191186
if n == 0:
192187
return shape
193188

194-
batch_shape = [
189+
batch_shape = tuple(
195190
s if not b else constant(1, "int64")
196191
for s, b in zip(shape[:-n], p.type.broadcastable[:-n])
197-
]
192+
)
198193
return batch_shape
199194

200195
# These are versions of our actual parameters with the anticipated
@@ -218,15 +213,8 @@ def extract_batch_shape(p, ps, n):
218213
# Distribution has no parameters
219214
batch_shape = ()
220215

221-
if self.ndim_supp == 0:
222-
supp_shape = ()
223-
else:
224-
supp_shape = self._supp_shape_from_params(
225-
dist_params,
226-
param_shapes=param_shapes,
227-
)
216+
shape = batch_shape + supp_shape
228217

229-
shape = tuple(batch_shape) + tuple(supp_shape)
230218
if not shape:
231219
shape = constant([], dtype="int64")
232220

tests/tensor/random/test_op.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,42 @@ def test_RandomVariable_incompatible_size():
206206
rv_op(np.zeros((2, 4, 3)), 1, size=(4,))
207207

208208

209+
class MultivariateRandomVariable(RandomVariable):
210+
name = "MultivariateRandomVariable"
211+
ndim_supp = 1
212+
ndims_params = (1, 2)
213+
dtype = "floatX"
214+
215+
def _supp_shape_from_params(self, dist_params, param_shapes=None):
216+
return [dist_params[0].shape[-1]]
217+
218+
219+
@config.change_flags(compute_test_value="off")
220+
def test_multivariate_rv_infer_static_shape():
221+
"""Test that infer shape for multivariate random variable works when a parameter must be broadcasted."""
222+
mv_op = MultivariateRandomVariable()
223+
224+
param1 = tensor(shape=(10, 2, 3))
225+
param2 = tensor(shape=(10, 2, 3, 3))
226+
assert mv_op(param1, param2).type.shape == (10, 2, 3)
227+
228+
param1 = tensor(shape=(2, 3))
229+
param2 = tensor(shape=(10, 2, 3, 3))
230+
assert mv_op(param1, param2).type.shape == (10, 2, 3)
231+
232+
param1 = tensor(shape=(10, 2, 3))
233+
param2 = tensor(shape=(2, 3, 3))
234+
assert mv_op(param1, param2).type.shape == (10, 2, 3)
235+
236+
param1 = tensor(shape=(10, 1, 3))
237+
param2 = tensor(shape=(2, 3, 3))
238+
assert mv_op(param1, param2).type.shape == (10, 2, 3)
239+
240+
param1 = tensor(shape=(2, 3))
241+
param2 = tensor(shape=(2, 3, 3))
242+
assert mv_op(param1, param2, size=(10, 2)).type.shape == (10, 2, 3)
243+
244+
209245
def test_vectorize_node():
210246
vec = tensor(shape=(None,))
211247
vec.tag.test_value = [0, 0, 0]

0 commit comments

Comments
 (0)