-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Prototype debugging helper method check_bounds
#4472
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Here is a list of examples: #%%
import numpy as np
import pymc3 as pm
from pymc3.model_debug import check_bounds
#%%
with pm.Model() as m1:
x = pm.Normal('x', -1, 2)
y = pm.HalfNormal('y', sd=x, observed=[-12, 21])
check_bounds()
# The following explicit bound(s) of y ~ HalfNormal were violated:
# -12.0 >= 0
# x ~ Normal = -1.0 > 0
#%%
with pm.Model() as m2:
x = pm.Normal('x', np.array([2, 1, -3]), -1, shape=(3,))
y = pm.HalfNormal('y', sd=x*-2, shape=(3,), transform=None)
check_bounds()
# The following explicit bound(s) of x ~ Normal were violated:
# -1.0 > 0
# The following explicit bound(s) of y ~ HalfNormal were violated:
# f(x ~ Normal = [1. 2.], -2, 1.0) > 0
#%%
with pm.Model() as m3:
a = pm.Normal('a', 0, 1, shape=3)
w = pm.Dirichlet('w', a=a, testval=np.ones(3))
check_bounds()
# The following explicit bound(s) of w_stickbreaking__ ~ TransformedDistribution were violated:
# a ~ Normal = 0.0 > 0
#%%
with pm.Model() as m4:
w = pm.Dirichlet('w', np.ones(3))
dists = [
pm.HalfNormal.dist(2),
pm.Exponential.dist(2),
pm.HalfCauchy.dist(2),
]
b = pm.Mixture('b', w, dists, observed=[-2, 0, 5])
check_bounds() # MESSY OUTPUT
# The following explicit bound(s) of b ~ Mixture were violated:
# -2.0 >= 0
# -2.0 >= 0
# -2.0 >= 0
# -2.0 >= 0
# -2.0 >= 0
# -2.0 >= 0
# -2.0 >= 0
# -2.0 >= 0
# -2.0 >= 0
# -2.0 >= 0
# -2.0 >= 0
# -2.0 >= 0
# -2.0 >= 0
# -2.0 >= 0
# -2.0 >= 0
#%%
# https://discourse.pymc.io/t/why-am-i-getting-bad-initial-energy-with-this-simple-model/6630/1
with pm.Model() as m5:
theta = pm.Uniform('theta', lower=0, upper=1, transform=None)
y = pm.Uniform('y', lower=0, upper=theta, observed=[0.4, 0.3, 0.7, 0.9])
check_bounds()
# The following explicit bound(s) of y ~ Uniform were violated:
# [0.7 0.9] <= theta ~ Uniform = 0.5
#%%
# https://discourse.pymc.io/t/fit-pareto-distribution-fails-bad-initial-energy/3494/3
np.random.seed(123)
a_true, m = 1.9, 3
test = np.round((np.random.pareto(a_true, 1000)+1)*m)
with pm.Model() as m6:
m = pm.Uniform('m', lower = 0, upper = 10, transform=None)
alpha = pm.Uniform('alpha', lower = 1, upper = 5, transform=None)
yhat = pm.Pareto('yhat', m = m, alpha = alpha, observed = test)
check_bounds()
# The following explicit bound(s) of yhat ~ Pareto were violated:
# [3. 4.] >= m ~ Uniform = 5.0
#%%
# https://discourse.pymc.io/t/why-am-i-getting-inf-or-nan-likelihood/6587/10
syn = np.array([0, 113, 42, 78, 125, 234, 393, 874, 407, 439, 1038])
with pm.Model() as m7:
alpha = pm.Exponential('alpha', lam=5)
beta = pm.Exponential('beta', lam=5)
g = pm.Gamma('g', alpha=alpha, beta=beta, observed=syn)
check_bounds() # Implicit -inf in the logp of `Gamma` introduced by `dist_math::logpow`
# No explicit bounds of g ~ Gamma were violated for the given inputs,
# An infinite logp could have arised from one of the following:
# 1. Undefined arithmetic operations (e.g., 1/0)
# 2. Numerical precision issues
# 3. Implicit bounds in the logp expression |
check_bounds
ricardoV94
commented
Feb 12, 2021
@@ -224,8 +224,9 @@ class Uniform(BoundedContinuous): | |||
""" | |||
|
|||
def __init__(self, lower=0, upper=1, *args, **kwargs): | |||
self.lower = lower = tt.as_tensor_variable(floatX(lower)) | |||
self.upper = upper = tt.as_tensor_variable(floatX(upper)) | |||
# TODO: This does not show up on logpt :( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the only thing I tried with the goal of having access to the name of the distribution parameters in the logpt graph
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This is a very early draft for a helper debugging method
check_bounds
that tries to parse out to the user any explicit variable bounds that seem to be violated at model.test_point. Here is a minimal example (more below):check_bounds
works by identifying the bounds introduced by thedist_math::bound
method and testing each of the nested logical conditions that may trigger the bound. One nested logical condition is tested at a time by disabling all the remaining ones in atheano.clone
of the switch bound variable.I would really appreciate tips in the following standing issues / limitations:
sd
ofHalfNormal
since it involvesx ~ Normal
, but when the input is not a variable, it can be less clear where it's coming from.Add custom logic to parse BinaryBitOps and more. Right now, the algorithm only attempts to test theano LogicalComparison operations and ignores any
BinaryBitOp
such astt.and
andtt.or
comparisons or aggregating operations such astt.all
andtt.any
.More robust masking of good values.
check_bounds
tries to mask values that are probably not responsible for the bound violations. The hack I implemented to do this looks very fragile to me.More clean way of disabling / enabling the logical operators? I am using
theano.clone
and disabling logical comparisons by replacing them withtt.eq(inputs[0], inputs[0])
so that it always evaluates to True and output shapes remain unaffected (or so I hope). I would love to hear about more idiomatic approaches to this.What kind of output would be most useful? Print statements, Pandas DataFrame or something Else?
Work with implicit bounds, which are sometimes used directly in the logp / logcdf expression or in helper functions such as
dist_math::logpow
. I first tried to implement a more general algorithm that looked for anytt.switch
leading to-inf
and not only those added explicitly by thebound
function. This proved much more challenging as often there are nested switches and the simplistic tests above could not handle this. Also aggregating functions such astt.all
make a general approach even more difficult. However, I see no principled reason why such information cannot be extracted from a theano graph, so if you have any ideas that would be really great!Test with more user examples. I tried to apply this to a few examples from the discourse and it seem to work fine. However, I have not really tried with more complex hierarchical / multivariate models. If you have a model that was particularly challenging to debug, it would be really great if you could share it/ test it in in my branch and see if the output is useful or correct.
Changes for V4.0.0. I think most of the logic will survive, or it may be even simplified by the transition to Aesara RandomVariables. If you see something that will definitely not work or may be easier to achieve later, that would be useful to know :)
Suggested by #4205