Skip to content

Commit 693dd34

Browse files
committed
Allow IntervalTransform to handle dynamic infinite bounds
Fixes bug in partially observed truncated variable
1 parent 6b051f9 commit 693dd34

File tree

5 files changed

+150
-46
lines changed

5 files changed

+150
-46
lines changed

pymc/distributions/continuous.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -664,7 +664,7 @@ class TruncatedNormal(BoundedContinuous):
664664
@classmethod
665665
def dist(
666666
cls,
667-
mu: Optional[DIST_PARAMETER_TYPES] = None,
667+
mu: Optional[DIST_PARAMETER_TYPES] = 0,
668668
sigma: Optional[DIST_PARAMETER_TYPES] = None,
669669
*,
670670
tau: Optional[DIST_PARAMETER_TYPES] = None,
@@ -674,7 +674,6 @@ def dist(
674674
) -> RandomVariable:
675675
tau, sigma = get_tau_sigma(tau=tau, sigma=sigma)
676676
sigma = pt.as_tensor_variable(sigma)
677-
tau = pt.as_tensor_variable(tau)
678677
mu = pt.as_tensor_variable(floatX(mu))
679678

680679
lower = pt.as_tensor_variable(floatX(lower)) if lower is not None else pt.constant(-np.inf)

pymc/logprob/transforms.py

Lines changed: 80 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -816,41 +816,97 @@ def __init__(self, args_fn: Callable[..., Tuple[Optional[Variable], Optional[Var
816816
"""
817817
self.args_fn = args_fn
818818

819-
def forward(self, value, *inputs):
819+
def get_a_and_b(self, inputs):
820+
"""Return interval bound values.
821+
822+
Also returns two boolean variables indicating whether the transform is known to be statically bounded.
823+
This is used to generate smaller graphs in the transform methods.
824+
"""
820825
a, b = self.args_fn(*inputs)
826+
lower_bounded, upper_bounded = True, True
827+
if a is None:
828+
a = -pt.inf
829+
lower_bounded = False
830+
if b is None:
831+
b = pt.inf
832+
upper_bounded = False
833+
return a, b, lower_bounded, upper_bounded
821834

822-
if a is not None and b is not None:
823-
return pt.log(value - a) - pt.log(b - value)
824-
elif a is not None:
825-
return pt.log(value - a)
826-
elif b is not None:
827-
return pt.log(b - value)
835+
def forward(self, value, *inputs):
836+
a, b, lower_bounded, upper_bounded = self.get_a_and_b(inputs)
837+
838+
log_lower_distance = pt.log(value - a)
839+
log_upper_distance = pt.log(b - value)
840+
841+
if lower_bounded and upper_bounded:
842+
return pt.where(
843+
pt.and_(pt.neq(a, -pt.inf), pt.neq(b, pt.inf)),
844+
log_lower_distance - log_upper_distance,
845+
pt.where(
846+
pt.neq(a, -pt.inf),
847+
log_lower_distance,
848+
pt.where(
849+
pt.neq(b, pt.inf),
850+
log_upper_distance,
851+
value,
852+
),
853+
),
854+
)
855+
elif lower_bounded:
856+
return log_lower_distance
857+
elif upper_bounded:
858+
return log_upper_distance
828859
else:
829-
raise ValueError("Both edges of IntervalTransform cannot be None")
860+
return value
830861

831862
def backward(self, value, *inputs):
832-
a, b = self.args_fn(*inputs)
833-
834-
if a is not None and b is not None:
835-
sigmoid_x = pt.sigmoid(value)
836-
return sigmoid_x * b + (1 - sigmoid_x) * a
837-
elif a is not None:
838-
return pt.exp(value) + a
839-
elif b is not None:
840-
return b - pt.exp(value)
863+
a, b, lower_bounded, upper_bounded = self.get_a_and_b(inputs)
864+
865+
exp_value = pt.exp(value)
866+
sigmoid_x = pt.sigmoid(value)
867+
lower_distance = exp_value + a
868+
upper_distance = b - exp_value
869+
870+
if lower_bounded and upper_bounded:
871+
return pt.where(
872+
pt.and_(pt.neq(a, -pt.inf), pt.neq(b, pt.inf)),
873+
sigmoid_x * b + (1 - sigmoid_x) * a,
874+
pt.where(
875+
pt.neq(a, -pt.inf),
876+
lower_distance,
877+
pt.where(
878+
pt.neq(b, pt.inf),
879+
upper_distance,
880+
value,
881+
),
882+
),
883+
)
884+
elif lower_bounded:
885+
return lower_distance
886+
elif upper_bounded:
887+
return upper_distance
841888
else:
842-
raise ValueError("Both edges of IntervalTransform cannot be None")
889+
return value
843890

844891
def log_jac_det(self, value, *inputs):
845-
a, b = self.args_fn(*inputs)
892+
a, b, lower_bounded, upper_bounded = self.get_a_and_b(inputs)
846893

847-
if a is not None and b is not None:
894+
if lower_bounded and upper_bounded:
848895
s = pt.softplus(-value)
849-
return pt.log(b - a) - 2 * s - value
850-
elif a is None and b is None:
851-
raise ValueError("Both edges of IntervalTransform cannot be None")
852-
else:
896+
897+
return pt.where(
898+
pt.and_(pt.neq(a, -pt.inf), pt.neq(b, pt.inf)),
899+
pt.log(b - a) - 2 * s - value,
900+
pt.where(
901+
pt.or_(pt.neq(a, -pt.inf), pt.neq(b, pt.inf)),
902+
value,
903+
pt.zeros_like(value),
904+
),
905+
)
906+
elif lower_bounded or upper_bounded:
853907
return value
908+
else:
909+
return pt.zeros_like(value)
854910

855911

856912
class LogOddsTransform(RVTransform):

tests/distributions/test_truncated.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,12 @@
2020
from pytensor.tensor.random.basic import GeometricRV, NormalRV
2121

2222
from pymc import Censored, Model, draw, find_MAP
23-
from pymc.distributions.continuous import Exponential, Gamma, TruncatedNormalRV
23+
from pymc.distributions.continuous import (
24+
Exponential,
25+
Gamma,
26+
TruncatedNormal,
27+
TruncatedNormalRV,
28+
)
2429
from pymc.distributions.shape_utils import change_dist_size
2530
from pymc.distributions.transforms import _default_transform
2631
from pymc.distributions.truncated import Truncated, TruncatedRV, _truncated
@@ -423,3 +428,53 @@ def test_truncated_gamma():
423428
logp_resized_pymc,
424429
logp_scipy,
425430
)
431+
432+
433+
def test_vectorized_bounds():
434+
with Model() as m:
435+
x1 = TruncatedNormal("x1", lower=None, upper=0, initval=-1)
436+
x2 = TruncatedNormal("x2", lower=0, upper=None, initval=1)
437+
x3 = TruncatedNormal("x3", lower=-np.pi, upper=np.e, initval=-1)
438+
x4 = TruncatedNormal("x4", lower=None, upper=None, initval=1)
439+
440+
xs = TruncatedNormal(
441+
"xs",
442+
lower=[-np.inf, 0, -np.pi, -np.inf],
443+
upper=[0, np.inf, np.e, np.inf],
444+
initval=[-1, 1, -1, 1],
445+
)
446+
xs_sym = Truncated(
447+
"xs_sym",
448+
dist=rejection_normal(),
449+
lower=[-np.inf, 0, -np.pi, -np.inf],
450+
upper=[0, np.inf, np.e, np.inf],
451+
initval=[-1, 1, -1, 1],
452+
)
453+
454+
ip = m.initial_point()
455+
np.testing.assert_allclose(
456+
np.stack([ip[f"x{i + 1}_interval__"] for i in range(4)]),
457+
ip["xs_interval__"],
458+
)
459+
np.testing.assert_allclose(
460+
ip["xs_interval__"],
461+
ip["xs_sym_interval__"],
462+
)
463+
np.testing.assert_allclose(
464+
m.rvs_to_transforms[xs].backward(ip["xs_interval__"], *xs.owner.inputs).eval(),
465+
[-1, 1, -1, 1],
466+
)
467+
np.testing.assert_allclose(
468+
m.rvs_to_transforms[xs_sym].backward(ip["xs_sym_interval__"], *xs_sym.owner.inputs).eval(),
469+
[-1, 1, -1, 1],
470+
)
471+
*x_logp, xs_logp, xs_sym_logp = m.compile_logp(sum=False)(ip)
472+
assert np.all(np.isfinite(xs_logp))
473+
np.testing.assert_allclose(
474+
np.stack(x_logp),
475+
xs_logp,
476+
)
477+
np.testing.assert_allclose(
478+
xs_logp,
479+
xs_sym_logp,
480+
)

tests/logprob/test_transforms.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@
5555
ErfcxTransform,
5656
ErfTransform,
5757
ExpTransform,
58-
IntervalTransform,
5958
LocTransform,
6059
LogTransform,
6160
RVTransform,
@@ -139,23 +138,6 @@ def backward(self, value, *inputs):
139138
):
140139
SquareTransform().log_jac_det(0)
141140

142-
def test_invalid_interval_transform(self):
143-
x_rv = pt.random.normal(0, 1)
144-
x_vv = x_rv.clone()
145-
146-
msg = "Both edges of IntervalTransform cannot be None"
147-
tr = IntervalTransform(lambda *inputs: (None, None))
148-
with pytest.raises(ValueError, match=msg):
149-
tr.forward(x_vv, *x_rv.owner.inputs)
150-
151-
tr = IntervalTransform(lambda *inputs: (None, None))
152-
with pytest.raises(ValueError, match=msg):
153-
tr.backward(x_vv, *x_rv.owner.inputs)
154-
155-
tr = IntervalTransform(lambda *inputs: (None, None))
156-
with pytest.raises(ValueError, match=msg):
157-
tr.log_jac_det(x_vv, *x_rv.owner.inputs)
158-
159141
def test_chained_transform(self):
160142
loc = 5
161143
scale = 0.1

tests/model/test_core.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040

4141
import pymc as pm
4242

43-
from pymc import Deterministic, Potential
43+
from pymc import Deterministic, Model, Potential
4444
from pymc.blocking import DictToArrayBijection, RaveledVars
4545
from pymc.distributions import Normal, transforms
4646
from pymc.distributions.distribution import PartialObservedRV
@@ -1524,6 +1524,18 @@ def test_symbolic_random_variable(self):
15241524
st.norm.logcdf(0) * 10,
15251525
)
15261526

1527+
def test_truncated_normal(self):
1528+
"""Test transform of unobserved TruncatedNormal leads to finite logp.
1529+
1530+
Regression test for #6999
1531+
"""
1532+
with Model() as m:
1533+
mu = pm.TruncatedNormal("mu", mu=1, sigma=2, lower=0)
1534+
x = pm.TruncatedNormal(
1535+
"x", mu=mu, sigma=0.5, lower=0, observed=np.array([0.1, 0.2, 0.5, np.nan, np.nan])
1536+
)
1537+
m.check_start_vals(m.initial_point())
1538+
15271539

15281540
class TestShared:
15291541
def test_deterministic(self):

0 commit comments

Comments
 (0)