Skip to content

Support alternative forms of censoring logprob via set_subtensor #6354

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
ricardoV94 opened this issue Nov 29, 2022 · 1 comment
Open

Support alternative forms of censoring logprob via set_subtensor #6354

ricardoV94 opened this issue Nov 29, 2022 · 1 comment

Comments

@ricardoV94
Copy link
Member

Description

As the TODO comment describes, those types of graphs can be equivalent to clipping for the purpose of logprobs

# TODO: Canonicalize x[x>ub] = ub -> clip(x, x, ub)

More generally, we could also allow arbitrary censoring binning encoding of the form x[(x > lower) & (x < upper)] = encoding, but that would require a different logprob definition.

@ricardoV94
Copy link
Member Author

ricardoV94 commented Apr 4, 2023

Worth noting such graphs can also be defined with Switch statements. It helps to see there are (at least) two types of Switch graphs that correspond to different graphs:

x = pm.math.switch(cond(constant), var1, var2)
y = pm.math.switch(cond(var), constant1, constant2)

Where constants need not be "constants". They can be other variables that are conditioned on already (i.e., they have a value)

The first is a "Mixture" graph where the logp is that of var1 or var2, depending on the value of constant, and can always be implemented as long as var1 and var2 are measurable.

In pseudo code logp(x, value) = switch(cond(constant), logp(var1, value), logp(var2, value))

Right now we have a restriction that var1 and var2 must be pure RandomVariables (as we generate an optimized code) but we should lift this restriction so it works in general.

The second type of graph is an "Encoding" graph where the logp depends on being able to measure cond(var) which is possible whenever cond translates to a pmf or cdf (but not really if it's a pdf).

The logp of that graph is a reversed logp(y, value) = switch(value == constant1, logp(cond(var), True), switch(value == constant2, logp(cond(var), False), -np.inf))

This gets more tricky when instead of an encoding we return an expression that depends on var. In that case that branch has a logp of logp(var, value). To start easy, we can allow only those branch to be var exactly (and not exp(var) for example)


What still needs to be figured out is whether for nested switches (or set_subtensor) we can generate logp expressions locally or we need to reason about them together. I think the later. Specially for the encoding cases as they might not be mutually exclusive:

y = switch(
  x < 0.5,
  encoding1,
  switch(
    x < 0.7,
    encoding2,
    encoding3,
  ),
)

In that case the second condition is actually 0.5 <= x < 0.7, as the first branch matches anything x < 0.5, already. So the logp of encoding 2 needs to take that into consideration.

This suggests we first need to implement #6633 so that we can chain cdf expressions, like logp(encoding2) == logp((x < 0.7) & ~(x < 0.5))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant