Skip to content

Commit 6ecc6d7

Browse files
Automatically resize independently mutable dims via set_data
1 parent 102522e commit 6ecc6d7

File tree

2 files changed

+98
-32
lines changed

2 files changed

+98
-32
lines changed

pymc/model.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1199,28 +1199,29 @@ def set_data(
11991199
# definitely lead to shape problems.
12001200
raise ShapeError(
12011201
f"Resizing dimension '{dname}' is impossible, because "
1202-
"a 'TensorConstant' stores its length. To be able "
1202+
"a `TensorConstant` stores its length. To be able "
12031203
"to change the dimension length, pass `mutable=True` when "
12041204
"registering the dimension via `model.add_coord`, "
12051205
"or define it via a `pm.MutableData` variable."
12061206
)
1207-
elif isinstance(length_tensor, ScalarSharedVariable):
1208-
# The dimension is mutable, but was defined without being linked
1209-
# to a shared variable. This is allowed, but slightly dangerous.
1210-
warnings.warn(
1211-
f"You are resizing a variable with dimension '{dname}' which was initialized"
1212-
" as a mutable dimension and is not linked to the `MutableData` variable."
1213-
" Remember to update the dimension length by calling "
1214-
f"`Model.set_dim({dname}, new_length={new_length})` manually,"
1215-
" preferably _before_ updating `MutableData` variables that use this dimension.",
1216-
ShapeWarning,
1217-
stacklevel=2,
1218-
)
1219-
else:
1220-
# The dimension was created from another model variable.
1221-
# If that was a non-mutable variable, there will definitely be shape problems.
1207+
elif length_tensor.owner is not None:
1208+
# The dimension was created from a model variable.
12221209
length_belongs_to = length_tensor.owner.inputs[0].owner.inputs[0]
1223-
if not isinstance(length_belongs_to, SharedVariable):
1210+
if length_belongs_to is shared_object:
1211+
# No surprise it's changing.
1212+
pass
1213+
elif isinstance(length_belongs_to, SharedVariable):
1214+
# The dimension is mutable through a SharedVariable other than the one being modified.
1215+
# But the other variable was not yet re-sized! Warn the user to do that!
1216+
warnings.warn(
1217+
f"You are resizing a variable with dimension '{dname}' which was initialized "
1218+
f"as a mutable dimension by another variable ('{length_belongs_to}')."
1219+
" Remember to update that variable with the correct shape to avoid shape issues.",
1220+
ShapeWarning,
1221+
stacklevel=2,
1222+
)
1223+
else:
1224+
# The dimension is immutable.
12241225
raise ShapeError(
12251226
f"Resizing dimension '{dname}' with values of length {new_length} would lead to incompatibilities, "
12261227
f"because the dimension was initialized from '{length_belongs_to}' which is not a shared variable. "
@@ -1230,8 +1231,9 @@ def set_data(
12301231
expected=old_length,
12311232
)
12321233
if isinstance(length_tensor, ScalarSharedVariable):
1233-
# Updating the shared variable resizes dependent nodes that use this dimension for their `size`.
1234-
length_tensor.set_value(new_length)
1234+
# The dimension is mutable, but was defined without being linked
1235+
# to a shared variable. This is allowed, but a little less robust.
1236+
self.set_dim(dname, new_length, coord_values=new_coords)
12351237

12361238
if new_coords is not None:
12371239
# Update the registered coord values (also if they were None)

pymc/tests/test_model.py

Lines changed: 77 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import unittest
15+
import warnings
1516

1617
from functools import reduce
1718

@@ -704,25 +705,45 @@ def test_nested_model_coords():
704705
assert set(m2.RV_dims) < set(m1.RV_dims)
705706

706707

707-
def test_shapeerror_from_resize_immutable_dims():
708+
def test_shapeerror_from_set_data_dimensionality():
709+
with pm.Model() as pmodel:
710+
pm.MutableData("m", np.ones((3,)), dims="one")
711+
with pytest.raises(ValueError, match="must have 1 dimensions"):
712+
pmodel.set_data("m", np.ones((3, 4)))
713+
714+
715+
def test_shapeerror_from_resize_immutable_dim_from_RV():
708716
"""
709717
Trying to resize an immutable dimension should raise a ShapeError.
710718
Even if the variable being updated is a SharedVariable and has other
711719
dimensions that are mutable.
712720
"""
713721
with pm.Model() as pmodel:
714-
a = pm.Normal("a", mu=[1, 2, 3], dims="fixed")
722+
pm.Normal("a", mu=[1, 2, 3], dims="fixed")
723+
assert isinstance(pmodel.dim_lengths["fixed"], TensorVariable)
715724

716-
m = pm.MutableData("m", [[1, 2, 3]], dims=("one", "fixed"))
725+
pm.MutableData("m", [[1, 2, 3]], dims=("one", "fixed"))
717726

718727
# This is fine because the "fixed" dim is not resized
719-
pm.set_data({"m": [[1, 2, 3], [3, 4, 5]]})
728+
pmodel.set_data("m", [[1, 2, 3], [3, 4, 5]])
720729

721730
with pytest.raises(ShapeError, match="was initialized from 'a'"):
722-
# Can't work because the "fixed" dimension is linked to a constant shape:
731+
# Can't work because the "fixed" dimension is linked to a
732+
# TensorVariable with constant shape.
723733
# Note that the new data tries to change both dimensions
724-
with pmodel:
725-
pm.set_data({"m": [[1, 2], [3, 4]]})
734+
pmodel.set_data("m", [[1, 2], [3, 4]])
735+
736+
737+
def test_shapeerror_from_resize_immutable_dim_from_coords():
738+
with pm.Model(coords={"immutable": [1, 2]}) as pmodel:
739+
assert isinstance(pmodel.dim_lengths["immutable"], TensorConstant)
740+
pm.MutableData("m", [1, 2], dims="immutable")
741+
# Data can be changed
742+
pmodel.set_data("m", [3, 4])
743+
744+
with pytest.raises(ShapeError, match="`TensorConstant` stores its length"):
745+
# But the length is linked to a TensorConstant
746+
pmodel.set_data("m", [1, 2, 3], coords=dict(immutable=[1, 2, 3]))
726747

727748

728749
def test_valueerror_from_resize_without_coords_update():
@@ -798,22 +819,65 @@ def test_set_dim_with_coords():
798819
assert pmodel.coords["mdim"] == ("A", "B", "C")
799820

800821

801-
def test_set_data_warns_resize_mutable_dim():
822+
def test_set_data_indirect_resize():
802823
with pm.Model() as pmodel:
803824
pmodel.add_coord("mdim", mutable=True, length=2)
804825
pm.MutableData("mdata", [1, 2], dims="mdim")
805826

806827
# First resize the dimension.
807828
pmodel.dim_lengths["mdim"].set_value(3)
808829
# Then change the data.
809-
pmodel.set_data("mdata", [1, 2, 3])
830+
with warnings.catch_warnings():
831+
warnings.simplefilter("error")
832+
pmodel.set_data("mdata", [1, 2, 3])
810833

811834
# Now the other way around.
812-
# Because the dimension doesn't depend on the data variable,
813-
# a warning shoudl be emitted.
814-
with pytest.warns(ShapeWarning, match="update the dimension length"):
835+
with warnings.catch_warnings():
836+
warnings.simplefilter("error")
815837
pmodel.set_data("mdata", [1, 2, 3, 4])
816-
pass
838+
839+
840+
def test_set_data_warns_on_resize_of_dims_defined_by_other_mutabledata():
841+
with pm.Model() as pmodel:
842+
pm.MutableData("m1", [1, 2], dims="mutable")
843+
pm.MutableData("m2", [3, 4], dims="mutable")
844+
845+
# Resizing the non-defining variable first gives a warning
846+
with pytest.warns(ShapeWarning, match="by another variable"):
847+
pmodel.set_data("m2", [4, 5, 6])
848+
pmodel.set_data("m1", [1, 2, 3])
849+
850+
# Resizing the definint variable first is silent
851+
with warnings.catch_warnings():
852+
warnings.simplefilter("error")
853+
pmodel.set_data("m1", [1, 2])
854+
pmodel.set_data("m2", [3, 4])
855+
856+
857+
def test_set_data_indirect_resize_with_coords():
858+
with pm.Model() as pmodel:
859+
pmodel.add_coord("mdim", ["A", "B"], mutable=True, length=2)
860+
pm.MutableData("mdata", [1, 2], dims="mdim")
861+
862+
assert pmodel.coords["mdim"] == ("A", "B")
863+
864+
# First resize the dimension.
865+
pmodel.set_dim("mdim", 3, ["A", "B", "C"])
866+
assert pmodel.coords["mdim"] == ("A", "B", "C")
867+
# Then change the data.
868+
with warnings.catch_warnings():
869+
warnings.simplefilter("error")
870+
pmodel.set_data("mdata", [1, 2, 3])
871+
872+
# Now the other way around.
873+
with warnings.catch_warnings():
874+
warnings.simplefilter("error")
875+
pmodel.set_data("mdata", [1, 2, 3, 4], coords=dict(mdim=["A", "B", "C", "D"]))
876+
assert pmodel.coords["mdim"] == ("A", "B", "C", "D")
877+
878+
# This time with incorrectly sized coord values
879+
with pytest.raises(ShapeError, match="new coordinate values"):
880+
pmodel.set_data("mdata", [1, 2], coords=dict(mdim=[1, 2, 3]))
817881

818882

819883
@pytest.mark.parametrize("jacobian", [True, False])

0 commit comments

Comments
 (0)