Skip to content

More safety nets for resizing mutable dimensions #5817

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 3, 2022
Merged

Conversation

michaelosthege
Copy link
Member

  • Adds a Model.set_dim method for resizing dimensions that were created by add_coord(..., mutable=True).
  • Changes Model.set_data to anticipate that data-induced resizing can target dimensions that were created through add_coord(..., mutable=True) which are not symbolically linked to the data variables.

Closes #5812.

@codecov
Copy link

codecov bot commented May 28, 2022

Codecov Report

Merging #5817 (576e306) into main (57654dc) will increase coverage by 0.58%.
The diff coverage is 96.42%.

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #5817      +/-   ##
==========================================
+ Coverage   89.40%   89.98%   +0.58%     
==========================================
  Files          74       73       -1     
  Lines       13769    13221     -548     
==========================================
- Hits        12310    11897     -413     
+ Misses       1459     1324     -135     
Impacted Files Coverage Δ
pymc/model.py 87.26% <96.42%> (+0.73%) ⬆️
pymc/sampling.py 82.52% <0.00%> (-6.19%) ⬇️
pymc/variational/inference.py 85.78% <0.00%> (-1.46%) ⬇️
pymc/distributions/mixture.py 95.72% <0.00%> (-1.04%) ⬇️
pymc/variational/__init__.py 100.00% <0.00%> (ø)
pymc/distributions/logprob.py 97.65% <0.00%> (ø)
pymc/variational/flows.py
pymc/aesaraf.py 91.95% <0.00%> (+0.06%) ⬆️
pymc/smc/smc.py 96.45% <0.00%> (+0.07%) ⬆️
... and 4 more

@michaelosthege michaelosthege marked this pull request as ready for review May 28, 2022 19:54
@michaelosthege michaelosthege requested a review from lucianopaz May 28, 2022 19:54
Copy link
Contributor

@lucianopaz lucianopaz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, @michaelosthege. I think that we need to use coords values that are supplied to set_data to call set_dim internally, instead of expecting the users to set_dim before they set_data.

pymc/model.py Outdated
elif isinstance(length_tensor, ScalarSharedVariable):
# The dimension is mutable, but was defined without being linked
# to a shared variable. This is allowed, but slightly dangerous.
warnings.warn(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that the key thing that is missing here is that users can also pass coords when they set_data. Ideally, we should use the supplied coords internally to call set_dim and then update the values of the MutableData instance. This warning should only be raised if the users haven't passed coords along to set_data.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm.. You're welcome to try, but I think this could become more complicated than the current way with the warning.
The set_data method doesn't actually care if the ScalarSharedVariable corresponds to a dimension in Model.dim_lengths or not!

Most important will be to add test cases that showcase the intended call order

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would be the expected flow to change data and make predictions? Before we had to do:

with pm.Model(coords={"A": range(10)}) as m:
    x = pm.MutableData("x", x_values, dims="A")
    y = pm.MutableData("y", y_values, dims="A")
    a = pm.Normal("a", 0, 1)
    b = pm.Normal("b", 0, 1)
    c = pm.HalfNormal("c", 1)
    obs = pm.Normal("obs", mu=a + b * x, sigma=c, observed=y)
    idata = pm.sample()

with m:
    pm.set_data({"x": np.linspace(-2, 3, 100), "y": np.full(100, np.nan)})
    ppc = pm.sample_posterior_predictive(idata)

I would like us to be able to do this:

with pm.Model(coords={"A": range(10)}) as m:
    x = pm.Data("x", x_values, dims="A")
    y = pm.Data("y", y_values, dims="A")
    a = pm.Normal("a", 0, 1)
    b = pm.Normal("b", 0, 1)
    c = pm.HalfNormal("c", 1)
    obs = pm.Normal("obs", mu=a + b * x, sigma=c, observed=y)
    idata = pm.sample()

with m:
    pm.set_data({"x": np.linspace(-2, 3, 100), "y": np.full(100, np.nan)}, coords={"A": range(10, 110)})
    ppc = pm.sample_posterior_predictive(idata)

But due to the choice of default Mutable/Constant coords, we need to:

  1. Manually add_coord for "A" and say it is mutable
  2. Call set_dim before calling set_data

Or we have to create the coordinate values using MutableData, which I don't really know how to do.

# a warning shoudl be emitted.
with pytest.warns(ShapeWarning, match="update the dimension length"):
pmodel.set_data("mdata", [1, 2, 3, 4])
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add the test where pmodel.set_data("mdata", [1, 2, 3, 4], coords={"mdim": range(4)}).
Ideally, we want to catch the coords fed into set_data and call set_dim with them

@AlexAndorra
Copy link
Contributor

Thanks for the proposed fix @michaelosthege ! I agree with @lucianopaz that it'd be great if the user didn't have to set_dim before they set_data

@michaelosthege
Copy link
Member Author

Thanks for the proposed fix @michaelosthege ! I agree with @lucianopaz that it'd be great if the user didn't have to set_dim before they set_data

Since #5763 (7ec106c to be exact) one can pass coords to pm.Data containers.

So with this, you'd create the "A" dimension as mutable.

with pm.Model() as m:
    x = pm.MutableData("x", x_values, dims="A", coords={"A": range(10)})
    y = pm.MutableData("y", y_values, dims="A")

Then, if you add the coords to set_data, we could also call set_dim internally.

@michaelosthege
Copy link
Member Author

I'll add the test suggested above and look into adding the coords kwarg to set_data.

@michaelosthege michaelosthege force-pushed the issue-5812 branch 2 times, most recently from 331aab2 to 6ecc6d7 Compare May 30, 2022 19:23
@michaelosthege
Copy link
Member Author

@lucianopaz please review the new test cases - they should govern which usage patterns are now supported.

The only thing where I'm hestiant is introducing mutable dims through the pm.Model constructor.
For that we could introduce a new kwarg, but I didn't want to make that decision already.
Options:

  • pm.Model(coords_mutable=..., coords=...)
  • pm.Model(coords_mutable=..., coords_immutable=...)
  • more?

Please continue on this branch as you see fit. Let's get it over the finish line today

@ricardoV94
Copy link
Member

This is a vague hand waving comment, so apologies for that. I feel the internal code to manage dims is become a bit too convoluted for its own good. Suggestion:

  1. Have a coords and mutable_coords kwargs in pm.Model.
  2. Remove defining news dims / coords with variables
  3. Downsides?

@michaelosthege
Copy link
Member Author

The idea of defining coords through pm.Data came from the export_index_as_coords kwarg.

Thinking this further, with #5796 we could do this:

population = xarray.DataArray(
    "population",
    [10, 200],
    dims="city",
    coords=dict(city=["Atown", "Bvil"]),
)

with pm.Model() as m:
    pop = pm.MutableData("population", population)
    assert "city" in m.coords # True

Note that this would work for multi-dimensional xarray.DataArray objects but not for pd.DataFrame with multi-index.

For wroking with GPs this would be really useful, because one has to carry data and grid coordinates around quite a lot..


I agree that the code is a little convoluted - it would probably benefit from a little extraction refactoring.

But the downsides of removing coords from pm.Data would be that we'd have to break the export_index_as_coords feature and we couldn't do what the above example does.

If we add a kwarg to the Model signature, I would prefer coords_mutable because of how things are sorted by code completion tools.

@michaelosthege
Copy link
Member Author

Any decision regarding the default (im)mutable of pm.Model(coords) will build on top of the commits made in this branch.

Since the commits here don't affect the thing we're still in discussion about, can we go ahead and merge this in the interest of closing #5812 and increasing overall safety?

@twiecki twiecki requested review from ricardoV94 and lucianopaz June 1, 2022 13:16
Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some suggestions and questions :)

Copy link
Member Author

@michaelosthege michaelosthege left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ricardoV94 @twiecki I won't have time to do any commits here until Monday.
As in "I'm out, please take over the rest here and get this merged".

@twiecki
Copy link
Member

twiecki commented Jun 3, 2022

What's missing here, can we merge?

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Waiting for tests to pass to make sure I didn't break anything accidentally

@ricardoV94 ricardoV94 merged commit 5da32ed into main Jun 3, 2022
@twiecki twiecki deleted the issue-5812 branch June 3, 2022 14:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

set_data cannot deal with nodes without owners
5 participants