@@ -128,21 +128,33 @@ def test_flatten(self):
128
128
# We don't flatten that case.
129
129
assert isinstance (CC .outputs [0 ].owner .op , Composite )
130
130
131
- def test_with_constants (self ):
131
+ @pytest .mark .parametrize ("literal_value" , (70.0 , - np .inf , np .float32 ("nan" )))
132
+ def test_with_constants (self , literal_value ):
132
133
x , y , z = floats ("xyz" )
133
- e = mul (add (70.0 , y ), true_div (x , y ))
134
+ e = mul (add (literal_value , y ), true_div (x , y ))
134
135
comp_op = Composite ([x , y ], [e ])
135
136
comp_node = comp_op .make_node (x , y )
136
137
137
138
c_code = comp_node .op .c_code (comp_node , "dummy" , ["x" , "y" ], ["z" ], dict (id = 0 ))
138
- assert "70.0" in c_code
139
+ assert constant ( literal_value ). type . c_literal ( literal_value ) in c_code
139
140
140
141
# Make sure caching of the c_code template works
141
142
assert hasattr (comp_node .op , "_c_code" )
142
143
143
144
g = FunctionGraph ([x , y ], [comp_node .out ])
144
- fn = make_function (DualLinker ().accept (g ))
145
- assert fn (1.0 , 2.0 ) == 36.0
145
+
146
+ # Default checker does not allow `nan`
147
+ def checker (x , y ):
148
+ np .testing .assert_equal (x , y )
149
+
150
+ fn = make_function (DualLinker (checker = checker ).accept (g ))
151
+
152
+ test_x = 1.0
153
+ test_y = 2.0
154
+ np .testing .assert_equal (
155
+ fn (test_x , test_y ),
156
+ (literal_value + test_y ) * (test_x / test_y ),
157
+ )
146
158
147
159
def test_many_outputs (self ):
148
160
x , y , z = floats ("xyz" )
0 commit comments