Skip to content

Commit f33fc00

Browse files
committed
Merge branch 'distsigs' of https://github.com/cluhmann/pymc into distsigs
2 parents b83e0bb + 705aff1 commit f33fc00

24 files changed

+490
-265
lines changed

.github/workflows/autoupdate-pre-commit-config.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ jobs:
1414
- name: Set up Python
1515
uses: actions/setup-python@v2
1616
- name: Cache multiple paths
17-
uses: actions/cache@v2
17+
uses: actions/cache@v3
1818
with:
1919
path: |
2020
~/.cache/pre-commit

.github/workflows/dispatched_pre-commit.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ jobs:
1313
ref: ${{github.event.client_payload.pull_request.head.ref}}
1414
token: ${{ secrets.ACTION_TRIGGER_TOKEN }}
1515
- name: Cache multiple paths
16-
uses: actions/cache@v2
16+
uses: actions/cache@v3
1717
with:
1818
path: |
1919
~/.cache/pre-commit

.github/workflows/tests.yml

+10-10
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ jobs:
8686
steps:
8787
- uses: actions/checkout@v2
8888
- name: Cache conda
89-
uses: actions/cache@v1
89+
uses: actions/cache@v3
9090
env:
9191
# Increase this value to reset cache if environment-test-py37.yml has not changed
9292
CACHE_NUMBER: 0
@@ -95,7 +95,7 @@ jobs:
9595
key: ${{ runner.os }}-conda-${{ env.CACHE_NUMBER }}-${{
9696
hashFiles('conda-envs/environment-test-py37.yml') }}
9797
- name: Cache multiple paths
98-
uses: actions/cache@v2
98+
uses: actions/cache@v3
9999
env:
100100
# Increase this value to reset cache if requirements.txt has not changed
101101
CACHE_NUMBER: 0
@@ -154,7 +154,7 @@ jobs:
154154
steps:
155155
- uses: actions/checkout@v2
156156
- name: Cache conda
157-
uses: actions/cache@v1
157+
uses: actions/cache@v3
158158
env:
159159
# Increase this value to reset cache if conda-envs/environment-test-py38.yml has not changed
160160
CACHE_NUMBER: 0
@@ -163,7 +163,7 @@ jobs:
163163
key: ${{ runner.os }}-conda-${{ env.CACHE_NUMBER }}-${{
164164
hashFiles('conda-envs/windows-environment-test-py38.yml') }}
165165
- name: Cache multiple paths
166-
uses: actions/cache@v2
166+
uses: actions/cache@v3
167167
env:
168168
# Increase this value to reset cache if requirements.txt has not changed
169169
CACHE_NUMBER: 0
@@ -230,7 +230,7 @@ jobs:
230230
steps:
231231
- uses: actions/checkout@v2
232232
- name: Cache conda
233-
uses: actions/cache@v1
233+
uses: actions/cache@v3
234234
env:
235235
# Increase this value to reset cache if environment-test-py39.yml has not changed
236236
CACHE_NUMBER: 0
@@ -239,7 +239,7 @@ jobs:
239239
key: ${{ runner.os }}-conda-${{ env.CACHE_NUMBER }}-${{
240240
hashFiles('conda-envs/environment-test-py39.yml') }}
241241
- name: Cache multiple paths
242-
uses: actions/cache@v2
242+
uses: actions/cache@v3
243243
env:
244244
# Increase this value to reset cache if requirements.txt has not changed
245245
CACHE_NUMBER: 0
@@ -292,7 +292,7 @@ jobs:
292292
steps:
293293
- uses: actions/checkout@v2
294294
- name: Cache conda
295-
uses: actions/cache@v1
295+
uses: actions/cache@v3
296296
env:
297297
# Increase this value to reset cache if environment-test-py39.yml has not changed
298298
CACHE_NUMBER: 0
@@ -301,7 +301,7 @@ jobs:
301301
key: ${{ runner.os }}-conda-${{ env.CACHE_NUMBER }}-${{
302302
hashFiles('conda-envs/environment-test-py39.yml') }}
303303
- name: Cache multiple paths
304-
uses: actions/cache@v2
304+
uses: actions/cache@v3
305305
env:
306306
# Increase this value to reset cache if requirements.txt has not changed
307307
CACHE_NUMBER: 0
@@ -359,7 +359,7 @@ jobs:
359359
steps:
360360
- uses: actions/checkout@v2
361361
- name: Cache conda
362-
uses: actions/cache@v1
362+
uses: actions/cache@v3
363363
env:
364364
# Increase this value to reset cache if conda-envs/environment-test-py38.yml has not changed
365365
CACHE_NUMBER: 0
@@ -368,7 +368,7 @@ jobs:
368368
key: ${{ runner.os }}-conda-${{ env.CACHE_NUMBER }}-${{
369369
hashFiles('conda-envs/windows-environment-test-py38.yml') }}
370370
- name: Cache multiple paths
371-
uses: actions/cache@v2
371+
uses: actions/cache@v3
372372
env:
373373
# Increase this value to reset cache if requirements.txt has not changed
374374
CACHE_NUMBER: 0

docs/source/contributing/developer_guide_implementing_distribution.md

+15-16
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ This guide provides an overview on how to implement a distribution for version 4
44
It is designed for developers who wish to add a new distribution to the library.
55
Users will not be aware of all this complexity and should instead make use of helper methods such as `~pymc.distributions.DensityDist`.
66

7-
PyMC {class}`~pymc.distributions.Distribution` builds on top of Aesara's {class}`~aesara.tensor.random.op.RandomVariable`, and implements `logp`, `logcdf` and `get_moment` methods as well as other initialization and validation helpers.
7+
PyMC {class}`~pymc.distributions.Distribution` builds on top of Aesara's {class}`~aesara.tensor.random.op.RandomVariable`, and implements `logp`, `logcdf` and `moment` methods as well as other initialization and validation helpers.
88
Most notably `shape/dims` kwargs, alternative parametrizations, and default `transforms`.
99

1010
Here is a summary check-list of the steps needed to implement a new distribution.
@@ -13,7 +13,7 @@ Each section will be expanded below:
1313
1. Creating a new `RandomVariable` `Op`
1414
1. Implementing the corresponding `Distribution` class
1515
1. Adding tests for the new `RandomVariable`
16-
1. Adding tests for `logp` / `logcdf` and `get_moment` methods
16+
1. Adding tests for `logp` / `logcdf` and `moment` methods
1717
1. Documenting the new `Distribution`.
1818

1919
This guide does not attempt to explain the rationale behind the `Distributions` current implementation, and details are provided only insofar as they help to implement new "standard" distributions.
@@ -119,7 +119,7 @@ After implementing the new `RandomVariable` `Op`, it's time to make use of it in
119119
PyMC 4.x works in a very {term}`functional <Functional Programming>` way, and the `distribution` classes are there mostly to facilitate porting the `PyMC3` v3.x code to the new `PyMC` v4.x version, add PyMC API features and keep related methods organized together.
120120
In practice, they take care of:
121121

122-
1. Linking ({term}`Dispatching`) a rv_op class with the corresponding `get_moment`, `logp` and `logcdf` methods.
122+
1. Linking ({term}`Dispatching`) a rv_op class with the corresponding `moment`, `logp` and `logcdf` methods.
123123
1. Defining a standard transformation (for continuous distributions) that converts a bounded variable domain (e.g., positive line) to an unbounded domain (i.e., the real line), which many samplers prefer.
124124
1. Validating the parametrization of a distribution and converting non-symbolic inputs (i.e., numeric literals or numpy arrays) to symbolic variables.
125125
1. Converting multiple alternative parametrizations to the standard parametrization that the `RandomVariable` is defined in terms of.
@@ -154,9 +154,9 @@ class Blah(PositiveContinuous):
154154
# the rv_op needs in order to be instantiated
155155
return super().dist([param1, param2], **kwargs)
156156

157-
# get_moment returns a symbolic expression for the stable moment from which to start sampling
157+
# moment returns a symbolic expression for the stable moment from which to start sampling
158158
# the variable, given the implicit `rv`, `size` and `param1` ... `paramN`
159-
def get_moment(rv, size, param1, param2):
159+
def moment(rv, size, param1, param2):
160160
moment, _ = at.broadcast_arrays(param1, param2)
161161
if not rv_size_is_none(size):
162162
moment = at.full(size, moment)
@@ -193,30 +193,29 @@ class Blah(PositiveContinuous):
193193

194194
Some notes:
195195

196-
1. A distribution should at the very least inherit from {class}`~pymc.distributions.Discrete` or {class}`~pymc.distributions.Continuous`. For the latter, more specific subclasses exist: `PositiveContinuous`, `UnitContinuous`, `BoundedContinuous`, `CircularContinuous`, which specify default transformations for the variables. If you need to specify a one-time custom transform you can also override the `__new__` method, as is done for the {class}`~pymc.distributions.multivariate.Dirichlet`.
197-
1. If a distribution does not have a corresponding `random` implementation, a `RandomVariable` should still be created that raises a `NotImplementedError`. This is the case for the {class}`~pymc.distributions.continuous.Flat`. In this case it will be necessary to provide a standard `initval` by
198-
overriding `__new__`.
196+
1. A distribution should at the very least inherit from {class}`~pymc.distributions.Discrete` or {class}`~pymc.distributions.Continuous`. For the latter, more specific subclasses exist: `PositiveContinuous`, `UnitContinuous`, `BoundedContinuous`, `CircularContinuous`, `SimplexContinuous`, which specify default transformations for the variables. If you need to specify a one-time custom transform you can also create a `_default_transform` dispatch function as is done for the {class}`~pymc.distributions.multivariate.LKJCholeskyCov`.
197+
1. If a distribution does not have a corresponding `random` implementation, a `RandomVariable` should still be created that raises a `NotImplementedError`. This is the case for the {class}`~pymc.distributions.continuous.Flat`. In this case it will be necessary to provide a `moment` method.
199198
1. As mentioned above, `PyMC` v4.x works in a very {term}`functional <Functional Programming>` way, and all the information that is needed in the `logp` and `logcdf` methods is expected to be "carried" via the `RandomVariable` inputs. You may pass numerical arguments that are not strictly needed for the `rng_fn` method but are used in the `logp` and `logcdf` methods. Just keep in mind whether this affects the correct shape inference behavior of the `RandomVariable`. If specialized non-numeric information is needed you might need to define your custom`_logp` and `_logcdf` {term}`Dispatching` functions, but this should be done as a last resort.
200199
1. The `logcdf` method is not a requirement, but it's a nice plus!
201-
1. Currently only one moment is supported in the `get_moment` method, and probably the "higher-order" one is the most useful (that is `mean` > `median` > `mode`)... You might need to truncate the moment if you are dealing with a discrete distribution.
202-
1. When creating the `get_moment` method, we have to be careful with `size != None` and broadcast properly when some parameters that are not used in the moment may nevertheless inform about the shape of the distribution. E.g. `pm.Normal.dist(mu=0, sigma=np.arange(1, 6))` returns a moment of `[mu, mu, mu, mu, mu]`.
200+
1. Currently only one moment is supported in the `moment` method, and probably the "higher-order" one is the most useful (that is `mean` > `median` > `mode`)... You might need to truncate the moment if you are dealing with a discrete distribution.
201+
1. When creating the `moment` method, we have to be careful with `size != None` and broadcast properly when some parameters that are not used in the moment may nevertheless inform about the shape of the distribution. E.g. `pm.Normal.dist(mu=0, sigma=np.arange(1, 6))` returns a moment of `[mu, mu, mu, mu, mu]`.
203202

204203
For a quick check that things are working you can try the following:
205204

206205
```python
207206

208207
import pymc as pm
209-
from pymc.distributions.distribution import get_moment
208+
from pymc.distributions.distribution import moment
210209

211210
# pm.blah = pm.Normal in this example
212-
blah = pm.blah.dist(mu = 0, sigma = 1)
211+
blah = pm.blah.dist(mu=0, sigma=1)
213212

214213
# Test that the returned blah_op is still working fine
215214
blah.eval()
216215
# array(-1.01397228)
217216

218-
# Test the get_moment method
219-
get_moment(blah).eval()
217+
# Test the moment method
218+
moment(blah).eval()
220219
# array(0.)
221220

222221
# Test the logp method
@@ -367,9 +366,9 @@ def test_blah_logcdf(self):
367366

368367
```
369368

370-
## 5. Adding tests for the `get_moment` method
369+
## 5. Adding tests for the `moment` method
371370

372-
Tests for the `get_moment` method are contained in `pymc/tests/test_distributions_moments.py`, and make use of the function `assert_moment_is_expected`
371+
Tests for the `moment` method are contained in `pymc/tests/test_distributions_moments.py`, and make use of the function `assert_moment_is_expected`
373372
which checks if:
374373
1. Moments return the `expected` values
375374
1. Moments have the expected size and shape

pymc/distributions/bound.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,12 @@
2020
from aesara.tensor.var import TensorVariable
2121

2222
from pymc.aesaraf import floatX, intX
23-
from pymc.distributions.continuous import BoundedContinuous
23+
from pymc.distributions.continuous import BoundedContinuous, bounded_cont_transform
2424
from pymc.distributions.dist_math import check_parameters
2525
from pymc.distributions.distribution import Continuous, Discrete
2626
from pymc.distributions.logprob import logp
2727
from pymc.distributions.shape_utils import to_tuple
28+
from pymc.distributions.transforms import _default_transform
2829
from pymc.model import modelcontext
2930
from pymc.util import check_dist_not_registered
3031

@@ -82,6 +83,11 @@ def logp(value, distribution, lower, upper):
8283
)
8384

8485

86+
@_default_transform.register(BoundRV)
87+
def bound_default_transform(op, rv):
88+
return bounded_cont_transform(op, rv, _ContinuousBounded.bound_args_indices)
89+
90+
8591
class DiscreteBoundRV(BoundRV):
8692
name = "discrete_bound"
8793
dtype = "int64"
@@ -94,8 +100,8 @@ class _DiscreteBounded(Discrete):
94100
rv_op = discrete_boundrv
95101

96102
def __new__(cls, *args, **kwargs):
97-
transform = kwargs.get("transform", None)
98-
if transform is not None:
103+
kwargs.setdefault("transform", None)
104+
if kwargs.get("transform") is not None:
99105
raise ValueError("Cannot transform discrete variable.")
100106
return super().__new__(cls, *args, **kwargs)
101107

pymc/distributions/censored.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from aesara.tensor import TensorVariable
1919
from aesara.tensor.random.op import RandomVariable
2020

21-
from pymc.distributions.distribution import SymbolicDistribution, _get_moment
21+
from pymc.distributions.distribution import SymbolicDistribution, _moment
2222
from pymc.util import check_dist_not_registered
2323

2424

@@ -124,8 +124,8 @@ def graph_rvs(cls, rv):
124124
return (rv.tag.dist,)
125125

126126

127-
@_get_moment.register(Clip)
128-
def get_moment_censored(op, rv, dist, lower, upper):
127+
@_moment.register(Clip)
128+
def moment_censored(op, rv, dist, lower, upper):
129129
moment = at.switch(
130130
at.eq(lower, -np.inf),
131131
at.switch(

0 commit comments

Comments
 (0)