Skip to content

Refactor several distributions #4640

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
May 11, 2021
Merged

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Apr 13, 2021

Couple more of distributions refactored to v4. This should cover all the ones that are to be implemented on Aesara, unless I missed something.

Depending on what your PR does, here are a few things you might want to address in the description:

@ricardoV94 ricardoV94 force-pushed the refactor_more_dists branch from fe913df to d370008 Compare April 13, 2021 18:47
@ricardoV94
Copy link
Member Author

ricardoV94 commented Apr 16, 2021

@brandonwillard What do you think of my last commit as a solution for the BoundContinuous issue?

I moved the transform_logp inside pm.Distribution.__new__() instead of being inside model.create_value_var(). This allows to pass class attributes as optional arguments to transform_logp without cluttering the rv_op.

BTW I see this was before in the metaclass: https://github.com/pymc-devs/pymc3/blob/1e4163339f963d9728d0160e74d594f685d76b92/pymc3/distributions/distribution.py#L112-L120

Should we move it there or remove those commented lines?

@ricardoV94 ricardoV94 force-pushed the refactor_more_dists branch 2 times, most recently from 1433432 to 2d6ef7c Compare April 16, 2021 13:38
Copy link
Contributor

@brandonwillard brandonwillard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't want to start adding variables at the class level unnecessarily; instead, each class that wants to use an interval transform should just create its own in Distribution.__new__. That approach is class-level enough to accomplish the same thing, and it makes the transform-setup logic a little less spread out.

@ricardoV94 ricardoV94 force-pushed the refactor_more_dists branch 2 times, most recently from d75d17c to d0ee690 Compare April 26, 2021 08:30
@ricardoV94 ricardoV94 added the v4 label Apr 26, 2021
@ricardoV94 ricardoV94 force-pushed the refactor_more_dists branch from d0ee690 to c5db0e5 Compare April 26, 2021 09:27
twiecki
twiecki previously approved these changes May 3, 2021
@twiecki
Copy link
Member

twiecki commented May 3, 2021


=================================== FAILURES ===================================
_______________ TestMatchesScipy.test_beta_binomial_distribution _______________

self = <pymc3.tests.test_distributions.TestMatchesScipy object at 0x12affbeb0>

    @pytest.mark.xfail(condition=(aesara.config.floatX == "float32"), reason="Fails on float32")
    def test_beta_binomial_distribution(self):
>       self.checkd(
            BetaBinomial,
            Nat,
            {"alpha": Rplus, "beta": Rplus, "n": NatSmall},
        )

pymc3/tests/test_distributions.py:1501: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <pymc3.tests.test_distributions.TestMatchesScipy object at 0x12affbeb0>
distfam = <class 'pymc3.distributions.discrete.BetaBinomial'>
valuedomain = <pymc3.tests.test_distributions.Domain object at 0x12511d880>
vardomains = {'alpha': <pymc3.tests.test_distributions.Domain object at 0x12560c7f0>, 'beta': <pymc3.tests.test_distributions.Domain object at 0x12560c7f0>, 'n': <pymc3.tests.test_distributions.Domain object at 0x12511dfd0>}
checks = (<bound method TestMatchesScipy.check_int_to_1 of <pymc3.tests.test_distributions.TestMatchesScipy object at 0x12affbeb0>>,)
extra_args = {}

    def checkd(self, distfam, valuedomain, vardomains, checks=None, extra_args=None):
        if checks is None:
            checks = (self.check_int_to_1,)
    
        if extra_args is None:
            extra_args = {}
        m = build_model(distfam, valuedomain, vardomains, extra_args=extra_args)
        for check in checks:
>           check(m, m.named_vars["value"], valuedomain, vardomains)
E           AttributeError: 'tuple' object has no attribute 'named_vars'

pymc3/tests/test_distributions.py:863: AttributeError
=============================== warnings summary ===============================

@ricardoV94
Copy link
Member Author

ricardoV94 commented May 3, 2021

Yeah, it's using this checkd test that has not been refactored for V4. Just marked as xfail for now. Will fix later when refactoring the ZeroInflated distributions that make use of it as well.

@ricardoV94 ricardoV94 force-pushed the refactor_more_dists branch from 6940e74 to d52ae50 Compare May 5, 2021 07:40
@ricardoV94 ricardoV94 requested a review from twiecki May 5, 2021 10:05
@ricardoV94
Copy link
Member Author

I added a skipif for the betabinomial test in case we are using scipy < 1.4.0 (which happens in the pytest workflow). The tests should still be covered in the arviz-compat workflow so it should be fine for now.

Copy link
Contributor

@brandonwillard brandonwillard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aside from renaming that transform field, it looks like test_distributions_random could be introducing a significant amount of testing redundancy. Regardless, we can merge this once the transform field name is updated.

class BoundedContinuous(Continuous):
"""Base class for bounded continuous distributions"""

transform_args = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this field is very specific to upper and lower bounds, its name should reflect that. Currently, it seems like this field specifies which arguments are to be transformed, which is very misleading. Also, the values in this field are argument indices, so that should be made clearer somehow (e.g. upper_lower_bound_indices, trans_ul_bound_indices, etc.).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed to bound_args_indices

@@ -2646,6 +2641,18 @@ def __init__(self, nu, *args, **kwargs):
super().__init__(alpha=nu / 2.0, beta=0.5, *args, **kwargs)


# TODO: Remove this once logpt for multiplication is working!
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's create an issue specifically for this before/after merging.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done: #4683

"""
Draw random values from Triangular distribution.
rv_op = triangular
bound_args_indices = [0, 2] # lower, upper
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why 2? And shouldn't these be tuples instead of lists?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to tuples.

The arguments of the triangular RandomOp are in this order (lower, mode, upper), and so we need to specify that the first and last are the ones relevant for the transformation, hence (0, 2) (zero-based index)


def logp(self, value):
def logp(value, mu, s):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't we need self here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because these are used only as a sort of static-method by the _logp dispatcher here: https://github.com/pymc-devs/pymc3/blob/6c247638c506e4b3d6eff86ff24c5e98d34ae055/pymc3/distributions/distribution.py#L96-L102

We are not really using distribution instances anymore IIUC.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What a curious pattern, thanks for explaining.

twiecki
twiecki previously approved these changes May 11, 2021
@ricardoV94 ricardoV94 force-pushed the refactor_more_dists branch from 104106b to d9f0775 Compare May 11, 2021 09:48
@ricardoV94 ricardoV94 requested a review from twiecki May 11, 2021 09:56
@twiecki twiecki merged commit 2c372ef into pymc-devs:v4 May 11, 2021
@twiecki
Copy link
Member

twiecki commented May 11, 2021

This gets us a big step closer to merging v4, thanks @ricardoV94!

@ricardoV94 ricardoV94 deleted the refactor_more_dists branch May 12, 2021 11:45
twiecki pushed a commit that referenced this pull request Jun 5, 2021
* Refactor several distributions

* Fix continuous bounded default transform

* Add 32bit xfail to Weibull logp

* Add TODO reminder for Weibull

* Refactor random tests

* Remove tests covered by Aesara

* Refactor BetaBinomial

* Add skipif for betabinom depending on scipy version

* Rename transform_args to bound_args_indices
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants