33
33
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
34
34
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
35
35
# SOFTWARE.
36
- import re
37
36
38
37
import numpy as np
39
38
import pytensor
44
43
from scipy import stats
45
44
46
45
from pymc .distributions import Dirichlet
47
- from pymc .logprob .basic import conditional_logp
46
+ from pymc .logprob .joint_logprob import factorized_joint_logprob
48
47
from tests .distributions .test_multivariate import dirichlet_logpdf
49
48
50
49
@@ -59,7 +58,7 @@ def test_specify_shape_logprob():
59
58
60
59
# 2. Request logp
61
60
x_vv = x_rv .clone ()
62
- [x_logp ] = conditional_logp ({x_rv : x_vv }).values ()
61
+ [x_logp ] = factorized_joint_logprob ({x_rv : x_vv }).values ()
63
62
64
63
# 3. Test logp
65
64
x_logp_fn = pytensor .function ([last_dim , x_vv ], x_logp )
@@ -81,19 +80,17 @@ def test_assert_logprob():
81
80
rv = pt .random .normal ()
82
81
assert_op = Assert ("Test assert" )
83
82
# Example: Add assert that rv must be positive
84
- assert_rv = assert_op (rv , rv > 0 )
83
+ assert_rv = assert_op (rv > 0 , rv )
85
84
assert_rv .name = "assert_rv"
86
85
87
86
assert_vv = assert_rv .clone ()
88
- assert_logp = conditional_logp ({assert_rv : assert_vv })[assert_vv ]
87
+ assert_logp = factorized_joint_logprob ({assert_rv : assert_vv })[assert_vv ]
89
88
90
89
# Check valid value is correct and doesn't raise
91
90
# Since here the value to the rv satisfies the condition, no error is raised.
92
91
valid_value = 3.0
93
- np .testing .assert_allclose (
94
- assert_logp .eval ({assert_vv : valid_value }),
95
- stats .norm .logpdf (valid_value ),
96
- )
92
+ with pytest .raises (AssertionError , match = "Test assert" ):
93
+ assert_logp .eval ({assert_vv : valid_value })
97
94
98
95
# Check invalid value
99
96
# Since here the value to the rv is negative, an exception is raised as the condition is not met
0 commit comments