Skip to content

Commit abad0a4

Browse files
Luke LBmichaelosthege
Luke LB
andcommitted
Require pm.MutableData or add_coord(mutable=True) for mutable dims
This increases the safety of working with resizable dimensions. Dims created via `pm.Model(coords=...)` coordinates are now immutable. Only dimension lengths that are created anew from `pm.MutableData`, `pm.Data(mutable=True)`, or the underlying `add_coord(mutable=True)` will become shared variables. Co-authored-by: Michael Osthege <[email protected]>
1 parent e97ad4e commit abad0a4

File tree

4 files changed

+51
-11
lines changed

4 files changed

+51
-11
lines changed

pymc/data.py

+1
Original file line numberDiff line numberDiff line change
@@ -704,6 +704,7 @@ def Data(
704704
# Note: Coordinate values can't be taken from
705705
# the value, because it could be N-dimensional.
706706
values=coords.get(dname, None),
707+
mutable=mutable,
707708
length=xshape[d],
708709
)
709710

pymc/model.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -1071,6 +1071,7 @@ def add_coord(
10711071
self,
10721072
name: str,
10731073
values: Optional[Sequence] = None,
1074+
mutable: bool = False,
10741075
*,
10751076
length: Optional[Union[int, Variable]] = None,
10761077
):
@@ -1084,9 +1085,12 @@ def add_coord(
10841085
values : optional, array-like
10851086
Coordinate values or ``None`` (for auto-numbering).
10861087
If ``None`` is passed, a ``length`` must be specified.
1088+
mutable : bool
1089+
Whether the created dimension should be resizable.
1090+
Default is False.
10871091
length : optional, scalar
10881092
A scalar of the dimensions length.
1089-
Defaults to ``aesara.shared(len(values))``.
1093+
Defaults to ``aesara.tensor.constant(len(values))``.
10901094
"""
10911095
if name in {"draw", "chain", "__sample__"}:
10921096
raise ValueError(
@@ -1111,8 +1115,11 @@ def add_coord(
11111115
if not np.array_equal(values, self.coords[name]):
11121116
raise ValueError(f"Duplicate and incompatible coordinate: {name}.")
11131117
else:
1118+
if mutable:
1119+
self._dim_lengths[name] = length or aesara.shared(len(values))
1120+
else:
1121+
self._dim_lengths[name] = length or aesara.tensor.constant(len(values))
11141122
self._coords[name] = values
1115-
self._dim_lengths[name] = length or aesara.shared(len(values))
11161123

11171124
def add_coords(
11181125
self,
@@ -1192,9 +1199,10 @@ def set_data(
11921199
if isinstance(length_tensor, TensorConstant):
11931200
raise ShapeError(
11941201
f"Resizing dimension '{dname}' is impossible, because "
1195-
f"a 'TensorConstant' stores its length. To be able "
1196-
f"to change the dimension length, 'fixed' in "
1197-
f"'model.add_coord' must be set to `False`."
1202+
"a 'TensorConstant' stores its length. To be able "
1203+
"to change the dimension length, pass `mutable=True` when "
1204+
"registering the dimension via `model.add_coord`, "
1205+
"or define it via a `pm.MutableData` variable."
11981206
)
11991207
else:
12001208
length_belongs_to = length_tensor.owner.inputs[0].owner.inputs[0]

pymc/tests/test_data_container.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from aesara import shared
2121
from aesara.compile.sharedvalue import SharedVariable
22-
from aesara.tensor.sharedvar import ScalarSharedVariable
22+
from aesara.tensor import TensorConstant
2323
from aesara.tensor.var import TensorVariable
2424

2525
import pymc as pm
@@ -315,21 +315,22 @@ def test_explicit_coords(self):
315315
}
316316
# pass coordinates explicitly, use numpy array in Data container
317317
with pm.Model(coords=coords) as pmodel:
318+
# Dims created from coords are constant by default
319+
assert isinstance(pmodel.dim_lengths["rows"], TensorConstant)
320+
assert isinstance(pmodel.dim_lengths["columns"], TensorConstant)
318321
pm.MutableData("observations", data, dims=("rows", "columns"))
319-
# new data with same shape
322+
# new data with same (!) shape
320323
pm.set_data({"observations": data + 1})
321-
# new data with same shape and coords
324+
# new data with same (!) shape and coords
322325
pm.set_data({"observations": data}, coords=coords)
323326
assert "rows" in pmodel.coords
324327
assert pmodel.coords["rows"] == ("R1", "R2", "R3", "R4", "R5")
325328
assert "rows" in pmodel.dim_lengths
326-
assert isinstance(pmodel.dim_lengths["rows"], ScalarSharedVariable)
327329
assert pmodel.dim_lengths["rows"].eval() == 5
328330
assert "columns" in pmodel.coords
329331
assert pmodel.coords["columns"] == ("C1", "C2", "C3", "C4", "C5", "C6", "C7")
330332
assert pmodel.RV_dims == {"observations": ("rows", "columns")}
331333
assert "columns" in pmodel.dim_lengths
332-
assert isinstance(pmodel.dim_lengths["columns"], ScalarSharedVariable)
333334
assert pmodel.dim_lengths["columns"].eval() == 7
334335

335336
def test_set_coords_through_pmdata(self):
@@ -354,6 +355,7 @@ def test_symbolic_coords(self):
354355
Their lengths are then automatically linked to the corresponding Tensor dimension.
355356
"""
356357
with pm.Model() as pmodel:
358+
# Dims created from MutableData are TensorVariables linked to the SharedVariable.shape
357359
intensity = pm.MutableData("intensity", np.ones((2, 3)), dims=("row", "column"))
358360
assert "row" in pmodel.dim_lengths
359361
assert "column" in pmodel.dim_lengths

pymc/tests/test_model.py

+30-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@
2727
import scipy.sparse as sps
2828
import scipy.stats as st
2929

30+
from aesara.tensor import TensorVariable
3031
from aesara.tensor.random.op import RandomVariable
32+
from aesara.tensor.sharedvar import ScalarSharedVariable
3133
from aesara.tensor.var import TensorConstant
3234

3335
import pymc as pm
@@ -729,13 +731,40 @@ def test_valueerror_from_resize_without_coords_update():
729731
without passing new coords raises a ValueError.
730732
"""
731733
with pm.Model() as pmodel:
732-
pmodel.add_coord("shared", [1, 2, 3])
734+
pmodel.add_coord("shared", [1, 2, 3], mutable=True)
733735
pm.MutableData("m", [1, 2, 3], dims=("shared"))
734736
with pytest.raises(ValueError, match="'m' variable already had 3"):
735737
# tries to resize m but without passing coords so raise ValueError
736738
pm.set_data({"m": [1, 2, 3, 4]})
737739

738740

741+
def test_coords_and_constantdata_create_immutable_dims():
742+
"""
743+
When created from `pm.Model(coords=...)` or `pm.ConstantData`
744+
a dimension should be resizable.
745+
"""
746+
with pm.Model(coords={"group": ["A", "B"]}) as m:
747+
x = pm.ConstantData("x", [0], dims="feature")
748+
y = pm.Normal("y", x, 1, dims=("group", "feature"))
749+
assert isinstance(m._dim_lengths["feature"], TensorConstant)
750+
assert isinstance(m._dim_lengths["group"], TensorConstant)
751+
assert x.eval().shape == (1,)
752+
assert y.eval().shape == (2, 1)
753+
754+
755+
def test_add_coord_mutable_kwarg():
756+
"""
757+
Checks resulting tensor type depending on mutable kwarg in add_coord.
758+
"""
759+
with pm.Model() as m:
760+
m.add_coord("fixed", values=[1], mutable=False)
761+
m.add_coord("mutable1", values=[1, 2], mutable=True)
762+
assert isinstance(m._dim_lengths["fixed"], TensorConstant)
763+
assert isinstance(m._dim_lengths["mutable1"], ScalarSharedVariable)
764+
pm.MutableData("mdata", np.ones((1, 2, 3)), dims=("fixed", "mutable1", "mutable2"))
765+
assert isinstance(m._dim_lengths["mutable2"], TensorVariable)
766+
767+
739768
@pytest.mark.parametrize("jacobian", [True, False])
740769
def test_model_logp(jacobian):
741770
with pm.Model() as m:

0 commit comments

Comments
 (0)