Skip to content

Refactoring _print_name for certain RVs and specifying rv_types in their distributions #6219

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 24, 2022

Conversation

larryshamalama
Copy link
Member

@larryshamalama larryshamalama commented Oct 14, 2022

What is this PR about?
This PR creates "new" PyMC RandomVariables for those that have an abbreviate _print_name such that we don't have to rely on the RV's class name for Latex and Graphviz displays.

Checklist

Major / Breaking Changes

  • ...

Bugfixes / New features

  • ...

Docs / Maintenance

Subclass several random variables to have a more readable _print_name.

Closes #6201

@larryshamalama larryshamalama force-pushed the new-rvs branch 5 times, most recently from dca5332 to e5fe176 Compare October 18, 2022 01:10
@codecov
Copy link

codecov bot commented Oct 18, 2022

Codecov Report

Merging #6219 (1205bd8) into main (7bad057) will decrease coverage by 0.00%.
The diff coverage is 100.00%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #6219      +/-   ##
==========================================
- Coverage   93.78%   93.78%   -0.01%     
==========================================
  Files         101      101              
  Lines       22187    22235      +48     
==========================================
+ Hits        20809    20854      +45     
- Misses       1378     1381       +3     
Impacted Files Coverage Δ
pymc/distributions/continuous.py 97.56% <100.00%> (+0.06%) ⬆️
pymc/distributions/discrete.py 99.25% <100.00%> (+0.03%) ⬆️
pymc/distributions/distribution.py 95.00% <100.00%> (ø)
pymc/distributions/multivariate.py 92.32% <100.00%> (+0.06%) ⬆️
pymc/tests/distributions/test_logprob.py 100.00% <100.00%> (ø)
pymc/parallel_sampling.py 85.52% <0.00%> (-0.99%) ⬇️

@twiecki
Copy link
Member

twiecki commented Oct 18, 2022

Can't we just set _print_name on the imported aesara rvs?

@ricardoV94
Copy link
Member

Can't we just set _print_name on the imported aesara rvs?

That would have side-effects on the Aesara side (if it even works)

@larryshamalama
Copy link
Member Author

Letting tests run again. Hopefully nothing fails.

Are there any additional tests to add in light of these new Ops?

@larryshamalama larryshamalama force-pushed the new-rvs branch 2 times, most recently from c6c02c7 to cd92813 Compare October 18, 2022 16:48
@ricardoV94
Copy link
Member

ricardoV94 commented Oct 20, 2022

I am not super happy with this (even though I suggested it). It's a bit worrisome that our dispatched methods will work for the subclasses but not the original classes, when all we are doing is changing a name.

We could do something more clever, where we still dispatch on the Aesara classes but return the PyMC ones from Distribution when those are specified. Then both objects will have logp/logcdf/moments/ etc...

That would have avoided the concerns about the tests where we create new classes.

We could use the rv_type property, and dispatch on that. So Normal would be

class Normal(Continuous):
  rv_op = pymc_normal
  rv_type = NormalRV  # Aesara
  ...

And then tweak the Metaclass of Distribution to use rv_type for dispatching when provided explicitly, and otherwise default to the old behavior of obtaining the rv_type from rv_op. WDYT?

@larryshamalama
Copy link
Member Author

I am not super happy with this (even though I suggested it). It's a bit worrisome that our dispatched methods will work for the subclasses but not the original classes, when all we are doing is changing a name.

I understand. I also had some doubts, but at least we have a clearer picture of how the changes in this PR would affect the codebase.

We could do something more clever, where we still dispatch on the Aesara classes but return the PyMC ones from Distribution when those are specified. Then both objects will have logp/logcdf/moments/ etc...

That would have avoided the concerns about the tests where we create new classes.

We could use the rv_type property, and dispatch on that. So Normal would be

class Normal(Continuous):
  rv_op = pymc_normal
  rv_type = NormalRV  # Aesara
  ...

And then tweak the Metaclass of Distribution to use rv_type for dispatching when provided explicitly, and otherwise default to the old behavior of obtaining the rv_type from rv_op. WDYT?

On top of my head, both solutions would work, but I'm not sure how I would dispatch the Aesara classes given rv_op to be the newly created ones (like PyMCNormalRV).

if isinstance(rv_op, RandomVariable):
rv_type = type(rv_op)
clsdict["rv_type"] = rv_type
new_cls = super().__new__(cls, name, bases, clsdict)
if rv_type is not None:
# Create dispatch functions

I believe that the dispatch is being done in these lines and perhaps there is a more clever way than checking if __class__.__name__.startswith("PyMC")... Is this what you meant?

Again, without trying much, I feel like using rv_type would be an easier solution. However, would this have any spillover complications in other areas of the codebase?

@ricardoV94
Copy link
Member

ricardoV94 commented Oct 21, 2022

@larryshamalama what I meant was something like:

        rv_type = clsdict.setdefault("rv_type", None)

        if rv_type is None and isinstance(rv_op, RandomVariable):
            rv_type = type(rv_op)
            clsdict["rv_type"] = rv_type

Then when we set:

class Normal(Continuous):
  rv_op = pymc_normal
  rv_type = NormalRV  # Aesara

pm.Normal will still return a PyMCNormalRV, but all the dispatching will be done on the original baseclass. Dispatching works for subclasses but not parent classes (which makes sense).

@larryshamalama
Copy link
Member Author

Okay, I can get to it

@larryshamalama larryshamalama changed the title Refactoring _print_name in random variables Refactoring _print_name for certain RVs and specifying rv_types in their distributions Oct 23, 2022
@ricardoV94
Copy link
Member

@larryshamalama Can you include the default changelist from the PR template? We use those for the auto-generated release notes

@larryshamalama
Copy link
Member Author

Done! Thanks for the suggestion, I will start doing this more often

@ricardoV94 ricardoV94 merged commit 652c9de into pymc-devs:main Oct 24, 2022
@ricardoV94
Copy link
Member

Thanks @larryshamalama, now the print_name_ for SymbolicRVs should be more useful

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
no releasenotes Skipped in automatic release notes generation
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Create PyMC-specific RV Ops that overwrite _print_name
3 participants