Skip to content

Commit 8bf0154

Browse files
logprob for maximum derived
1 parent aebe469 commit 8bf0154

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

tests/logprob/test_checks.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
3434
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
3535
# SOFTWARE.
36+
import re
3637

3738
import numpy as np
3839
import pytensor
@@ -43,7 +44,7 @@
4344
from scipy import stats
4445

4546
from pymc.distributions import Dirichlet
46-
from pymc.logprob.joint_logprob import factorized_joint_logprob
47+
from pymc.logprob.basic import conditional_logp
4748
from tests.distributions.test_multivariate import dirichlet_logpdf
4849

4950

@@ -58,7 +59,7 @@ def test_specify_shape_logprob():
5859

5960
# 2. Request logp
6061
x_vv = x_rv.clone()
61-
[x_logp] = factorized_joint_logprob({x_rv: x_vv}).values()
62+
[x_logp] = conditional_logp({x_rv: x_vv}).values()
6263

6364
# 3. Test logp
6465
x_logp_fn = pytensor.function([last_dim, x_vv], x_logp)
@@ -80,17 +81,19 @@ def test_assert_logprob():
8081
rv = pt.random.normal()
8182
assert_op = Assert("Test assert")
8283
# Example: Add assert that rv must be positive
83-
assert_rv = assert_op(rv > 0, rv)
84+
assert_rv = assert_op(rv, rv > 0)
8485
assert_rv.name = "assert_rv"
8586

8687
assert_vv = assert_rv.clone()
87-
assert_logp = factorized_joint_logprob({assert_rv: assert_vv})[assert_vv]
88+
assert_logp = conditional_logp({assert_rv: assert_vv})[assert_vv]
8889

8990
# Check valid value is correct and doesn't raise
9091
# Since here the value to the rv satisfies the condition, no error is raised.
9192
valid_value = 3.0
92-
with pytest.raises(AssertionError, match="Test assert"):
93-
assert_logp.eval({assert_vv: valid_value})
93+
np.testing.assert_allclose(
94+
assert_logp.eval({assert_vv: valid_value}),
95+
stats.norm.logpdf(valid_value),
96+
)
9497

9598
# Check invalid value
9699
# Since here the value to the rv is negative, an exception is raised as the condition is not met

0 commit comments

Comments
 (0)