-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Default moment
for CustomDist
provided with a dist
function
#6873
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Default moment
for CustomDist
provided with a dist
function
#6873
Conversation
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #6873 +/- ##
===========================================
- Coverage 92.19% 57.13% -35.06%
===========================================
Files 101 101
Lines 16921 16964 +43
===========================================
- Hits 15600 9693 -5907
- Misses 1321 7271 +5950
|
@ricardoV94 I added implementation which use graph rewriting to replace distributions by corresponding moments. Could you please take a look? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Neat approach!
Just need to think about Ops with InnerGraphs
@ricardoV94 Do you know som simple example of symbolic def test_custom_dist_custom_moment_inner_graph(self):
def dist(mu, size):
ys, _ = pytensor.scan(
fn=lambda x: pt.exp(pm.Normal.dist(x, 1, size=size)),
sequences=[mu],
outputs_info=[None],
name="ys",
)
return pt.sum(ys)
with Model() as model:
CustomDist("x", pt.ones(2), dist=dist)
assert_moment_is_expected(model, 2) I get |
You need to return random updates from the scan. Check the utility https://github.com/pymc-devs/pymc/blob/main/pymc/pytensorf.py#L1000 |
@ricardoV94 yeah, it work, thank you! |
@ricardoV94 I added replacements inside inner graph, could you please take a look? |
@ricardoV94 could you please to check my comment above when you will have some time? |
@ricardoV94 just friendly reminder about example for |
@aerubanov I couldn't come up with a working example of an OpFromGraph with a RandomVariable inside (should be fine outside), except other SymbolicRandomVariables. So I think we ignore for now . However, this is a good reminder. Do you have a test showing that the moment works for a nested Something like: import pymc as pm
def dist_fn(size):
return pm.Truncated.dist(pm.Normal.dist(), -1, 1, size=size) + 5
x = pm.CustomDist.dist(dist=dist_fn) |
@ricardoV94 Yeah, I need to add test case for |
29a8065
to
5e2c148
Compare
pymc/distributions/distribution.py
Outdated
new_node = rewrite_moment_scan_node(node) | ||
for out1, out2 in zip(node.outputs, new_node.outputs): | ||
fgraph.replace(out1, out2) | ||
elif isinstance(node.op, (RandomVariable, SymbolicRandomVariable)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this branch should come first
pymc/distributions/distribution.py
Outdated
@@ -687,6 +747,7 @@ def custom_dist_logcdf(op, value, size, *params, **kwargs): | |||
|
|||
@_moment.register(rv_type) | |||
def custom_dist_get_moment(op, rv, size, *params): | |||
params = [i for i in params if not isinstance(i, RandomGeneratorSharedVariable)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
params
consists RandomGeneratorSharedVariable
s, which are not match with dist
function signature. So I filter out it here, but may be there is a better way to do it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I think below we do something like params[:len(dist_params)]
? Does that work? Would be nice to add a comment to say we are excluding the shared RNGs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think params[:len(dist_params)]
should work here, but do we have access to dist_params
from this function? Looks like no.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In that case let's keep like you did, but perhaps use a helper function with a readable name? Could that helper be used below?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added helper function to filter out shared RNGs
5e2c148
to
29b9c61
Compare
29b9c61
to
d54a858
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CustomDist
is getting more complex, we might want to move it into its own file later down the road.
pymc/distributions/distribution.py
Outdated
fgraph = get_rv_fgraph(dist, dist_params, size) | ||
replace_moments = MomentRewrite() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am still worried that this will rewrite dist_params
as well, and not just the CustomDist graph between dist
and dist_params
.
Would the following work? Clone the inner fgraph from rv.owner.op.fgraph
. This graph doesn't have dist_params
directly but dummy placeholders called NominalVariable
s. Apply the rewrite on this inner graph, and once you are done, replace the NominalVariable
s by the respective dist_params
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually what I am suggesting is pretty similar to what you did with Scan
, so you could perhaps reuse some of the same logic?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ricardoV94 When I trying to create FunctionGraph
with NominalVariable
s I get error about missing input values - I think FunctionGraph
constructor do not support dummy placeholders as input. But I`m try another approach:
def dist_moment(rv, size, *dist_params, dist):
rv = dist(*dist_params, size=size)
inputs, outputs = dist_params, [rv.owner.out]
fgraph_topo = io_toposort(inputs, outputs)
replace_with_moment = []
to_replace_set = set()
for nd in fgraph_topo:
if nd not in to_replace_set and isinstance(
nd.op, (RandomVariable, SymbolicRandomVariable)
):
replace_with_moment.append(nd.out)
to_replace_set.add(nd)
givens = {}
for item in replace_with_moment:
givens[item] = moment(item)
[out] = clone_replace(outputs, replace=givens)
return out
Looks like it work but do not support Scan
for now (but I can add it). Do you think such approach will be better?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You don't have to recreate the FunctionGraph
, you can just do fgraph.clone()
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ricardoV94 I try this approach, but I need to create CustomDist
op first because just output of dist
function do not have fgraph
attribute.
def dist_moment(rv, size, *dist_params):
size = normalize_size_param(size)
dummy_size_param = size.type()
dummy_dist_params = [dist_param.type() for dist_param in dist_params]
dummy_rv = dist(*dummy_dist_params, dummy_size_param)
dummy_params = [dummy_size_param] + dummy_dist_params
rv_type = type(
class_name,
(CustomSymbolicDistRV,),
# If logp is not provided, we try to infer it from the dist graph
dict(
inline_logprob=logp is None,
),
)
rv_op = rv_type(
inputs=dummy_params,
outputs=[dummy_rv],
ndim_supp=ndim_supp,
)
fgraph = rv_op.fgraph.clone()
replace_moments = MomentRewrite()
replace_moments.rewrite(fgraph)
for i, par in enumerate([size] + list(dist_params)):
fgraph.replace(fgraph.inputs[i], par)
[moment] = fgraph.outputs
return moment
Do you think this way better? It is also do not work with Scan
yet. I need to figure out why and I will move graph creation logic to helper function if we will decide to keep this way.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For scan I got Assertion Error:
# allocate storage for intermediate computation
for node in order:
for r in node.inputs:
if r not in storage_map:
> assert isinstance(r, Constant)
E AssertionError
../../miniconda3/envs/pymc-dev/lib/python3.11/site-packages/pytensor/link/utils.py:135: AssertionError
It is hapens when I run TestCustomSymbolicDist::test_custom_dist_default_moment_inner_graph
test case
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ricardoV94 If I create rv like this it`s do not have fgraph atribute:
rv = dist(*dist_params, size=size)
>>> rv
Exp.0
>>> rv.owner
Exp(normal_rv{0, (0, 0), floatX, False}.out)
>>> rv.owner.op
Elemwise(scalar_op=exp,inplace_pattern=<frozendict {}>)
>>> rv.owner.op.fgraph
*** AttributeError: 'Elemwise' object has no attribute 'fgraph'
I use lambda mu, sigma, size: pt.exp(pm.Normal.dist(mu, sigma, size=size))
for dist
function here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, use the rv
that is passed to the moment function directly (first argument), don't recreate it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ohm yeah, my bad. Will try it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ricardoV94 No, I can not use rv
that is passes as parameter:
>>> rv
*** Not yet returned!
So I tried to re-create it
b1e81f5
to
9db17f6
Compare
@ricardoV94 Could you please take a look on recent changes? Just friendly reminder) |
pymc/distributions/distribution.py
Outdated
@@ -622,13 +683,28 @@ def dist( | |||
if logcdf is None: | |||
logcdf = default_not_implemented(class_name, "logcdf") | |||
|
|||
def dist_moment(rv, size, *dist_params): | |||
fgraph = rv.owner.op.fgraph.clone() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This method can now be used for any OpFromGraph not just CustomDist ones, so I would move it to the MomentRewriter
Co-authored-by: Ricardo Vieira <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks really great. Just some minor cleaning up suggestions left
pymc/distributions/distribution.py
Outdated
for out1, out2 in zip(node.outputs, new_node.outputs): | ||
fgraph.replace(out1, out2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use replace_all?
pymc/distributions/distribution.py
Outdated
has_fallback=True, | ||
ndim_supp=ndim_supp, | ||
) | ||
moment = dist_moment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove this if statement (see other comment about default)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ricardoV94 Hm, I think need this condition to avoid overriding moment provided by user, how can we avoid it without condition?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
More specifically, If I try to remove if statement TestCustomSymbolicDist::test_custom_methods
fails
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are two levels. We dispatch the general moment to the parent base class. Whenever a new subclass is created here and the user provided a moment we register it on the subclass (so it has preference). If the user didn't provide anything, we don't register and the parent class one will be used.
pymc/distributions/distribution.py
Outdated
def filter_RNGs(params): | ||
return [p for p in params if not isinstance(p.type, (RandomType, RandomGeneratorType))] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this ended up not being needed let's just inline it in the custom moment function
Co-authored-by: Ricardo Vieira <[email protected]>
Great work @aerubanov! |
Add default implementation of
moment
forCustomDist
with adist
function and close #6804📚 Documentation preview 📚: https://pymc--6873.org.readthedocs.build/en/6873/