33
33
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
34
34
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
35
35
# SOFTWARE.
36
+ import re
36
37
37
38
import numpy as np
38
39
import pytensor
43
44
from scipy import stats
44
45
45
46
from pymc .distributions import Dirichlet
46
- from pymc .logprob .joint_logprob import factorized_joint_logprob
47
+ from pymc .logprob .basic import conditional_logp
47
48
from tests .distributions .test_multivariate import dirichlet_logpdf
48
49
49
50
@@ -58,7 +59,7 @@ def test_specify_shape_logprob():
58
59
59
60
# 2. Request logp
60
61
x_vv = x_rv .clone ()
61
- [x_logp ] = factorized_joint_logprob ({x_rv : x_vv }).values ()
62
+ [x_logp ] = conditional_logp ({x_rv : x_vv }).values ()
62
63
63
64
# 3. Test logp
64
65
x_logp_fn = pytensor .function ([last_dim , x_vv ], x_logp )
@@ -80,17 +81,19 @@ def test_assert_logprob():
80
81
rv = pt .random .normal ()
81
82
assert_op = Assert ("Test assert" )
82
83
# Example: Add assert that rv must be positive
83
- assert_rv = assert_op (rv > 0 , rv )
84
+ assert_rv = assert_op (rv , rv > 0 )
84
85
assert_rv .name = "assert_rv"
85
86
86
87
assert_vv = assert_rv .clone ()
87
- assert_logp = factorized_joint_logprob ({assert_rv : assert_vv })[assert_vv ]
88
+ assert_logp = conditional_logp ({assert_rv : assert_vv })[assert_vv ]
88
89
89
90
# Check valid value is correct and doesn't raise
90
91
# Since here the value to the rv satisfies the condition, no error is raised.
91
92
valid_value = 3.0
92
- with pytest .raises (AssertionError , match = "Test assert" ):
93
- assert_logp .eval ({assert_vv : valid_value })
93
+ np .testing .assert_allclose (
94
+ assert_logp .eval ({assert_vv : valid_value }),
95
+ stats .norm .logpdf (valid_value ),
96
+ )
94
97
95
98
# Check invalid value
96
99
# Since here the value to the rv is negative, an exception is raised as the condition is not met
0 commit comments