33
33
((pt .gt , pt .ge ), "logcdf" , "logsf" , (0.5 , pt .random .normal (0 , 1 ))),
34
34
],
35
35
)
36
- def test_continuous_rv_comparison (comparison_op , exp_logp_true , exp_logp_false , inputs ):
36
+ def test_continuous_rv_comparison_bitwise (comparison_op , exp_logp_true , exp_logp_false , inputs ):
37
37
for op in comparison_op :
38
38
comp_x_rv = op (* inputs )
39
39
@@ -48,6 +48,17 @@ def test_continuous_rv_comparison(comparison_op, exp_logp_true, exp_logp_false,
48
48
assert np .isclose (logp_fn (0 ), getattr (ref_scipy , exp_logp_false )(0.5 ))
49
49
assert np .isclose (logp_fn (1 ), getattr (ref_scipy , exp_logp_true )(0.5 ))
50
50
51
+ bitwise_rv = pt .bitwise_not (op (* inputs ))
52
+ bitwise_vv = bitwise_rv .clone ()
53
+
54
+ logprob_not = logp (bitwise_rv , bitwise_vv )
55
+ assert_no_rvs (logprob_not )
56
+
57
+ logp_fn_not = pytensor .function ([bitwise_vv ], logprob_not )
58
+
59
+ assert np .isclose (logp_fn_not (0 ), getattr (ref_scipy , exp_logp_true )(0.5 ))
60
+ assert np .isclose (logp_fn_not (1 ), getattr (ref_scipy , exp_logp_false )(0.5 ))
61
+
51
62
52
63
@pytest .mark .parametrize (
53
64
"comparison_op, exp_logp_true, exp_logp_false, inputs" ,
@@ -87,7 +98,7 @@ def test_continuous_rv_comparison(comparison_op, exp_logp_true, exp_logp_false,
87
98
),
88
99
],
89
100
)
90
- def test_discrete_rv_comparison (inputs , comparison_op , exp_logp_true , exp_logp_false ):
101
+ def test_discrete_rv_comparison_bitwise (inputs , comparison_op , exp_logp_true , exp_logp_false ):
91
102
cens_x_rv = comparison_op (* inputs )
92
103
93
104
cens_x_vv = cens_x_rv .clone ()
@@ -100,6 +111,17 @@ def test_discrete_rv_comparison(inputs, comparison_op, exp_logp_true, exp_logp_f
100
111
assert np .isclose (logp_fn (1 ), exp_logp_true (3 ))
101
112
assert np .isclose (logp_fn (0 ), exp_logp_false (3 ))
102
113
114
+ bitwise_rv = pt .bitwise_not (comparison_op (* inputs ))
115
+ bitwise_vv = bitwise_rv .clone ()
116
+
117
+ logprob_not = logp (bitwise_rv , bitwise_vv )
118
+ assert_no_rvs (logprob_not )
119
+
120
+ logp_fn_not = pytensor .function ([bitwise_vv ], logprob_not )
121
+
122
+ assert np .isclose (logp_fn_not (1 ), exp_logp_false (3 ))
123
+ assert np .isclose (logp_fn_not (0 ), exp_logp_true (3 ))
124
+
103
125
104
126
def test_potentially_measurable_operand ():
105
127
x_rv = pt .random .normal (2 )
0 commit comments