Skip to content

Commit aebe469

Browse files
Guidelines
1 parent 816fe4a commit aebe469

File tree

1 file changed

+6
-9
lines changed

1 file changed

+6
-9
lines changed

tests/logprob/test_checks.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
3434
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
3535
# SOFTWARE.
36-
import re
3736

3837
import numpy as np
3938
import pytensor
@@ -44,7 +43,7 @@
4443
from scipy import stats
4544

4645
from pymc.distributions import Dirichlet
47-
from pymc.logprob.basic import conditional_logp
46+
from pymc.logprob.joint_logprob import factorized_joint_logprob
4847
from tests.distributions.test_multivariate import dirichlet_logpdf
4948

5049

@@ -59,7 +58,7 @@ def test_specify_shape_logprob():
5958

6059
# 2. Request logp
6160
x_vv = x_rv.clone()
62-
[x_logp] = conditional_logp({x_rv: x_vv}).values()
61+
[x_logp] = factorized_joint_logprob({x_rv: x_vv}).values()
6362

6463
# 3. Test logp
6564
x_logp_fn = pytensor.function([last_dim, x_vv], x_logp)
@@ -81,19 +80,17 @@ def test_assert_logprob():
8180
rv = pt.random.normal()
8281
assert_op = Assert("Test assert")
8382
# Example: Add assert that rv must be positive
84-
assert_rv = assert_op(rv, rv > 0)
83+
assert_rv = assert_op(rv > 0, rv)
8584
assert_rv.name = "assert_rv"
8685

8786
assert_vv = assert_rv.clone()
88-
assert_logp = conditional_logp({assert_rv: assert_vv})[assert_vv]
87+
assert_logp = factorized_joint_logprob({assert_rv: assert_vv})[assert_vv]
8988

9089
# Check valid value is correct and doesn't raise
9190
# Since here the value to the rv satisfies the condition, no error is raised.
9291
valid_value = 3.0
93-
np.testing.assert_allclose(
94-
assert_logp.eval({assert_vv: valid_value}),
95-
stats.norm.logpdf(valid_value),
96-
)
92+
with pytest.raises(AssertionError, match="Test assert"):
93+
assert_logp.eval({assert_vv: valid_value})
9794

9895
# Check invalid value
9996
# Since here the value to the rv is negative, an exception is raised as the condition is not met

0 commit comments

Comments
 (0)