Skip to content

Commit de42f8b

Browse files
committed
Allow inlined Infinity / Nan constants in Composite
1 parent 8933712 commit de42f8b

File tree

2 files changed

+23
-5
lines changed

2 files changed

+23
-5
lines changed

pytensor/scalar/basic.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,12 @@ def c_literal(self, data):
432432
return None
433433
if self.dtype == "bool":
434434
return "1" if data else "0"
435+
if data == np.inf:
436+
return "INFINITY"
437+
if data == -np.inf:
438+
return "-INFINITY"
439+
if np.isnan(data):
440+
return "NAN"
435441
return str(data)
436442

437443
def c_declare(self, name, sub, check_input=True):

tests/scalar/test_basic.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -128,21 +128,33 @@ def test_flatten(self):
128128
# We don't flatten that case.
129129
assert isinstance(CC.outputs[0].owner.op, Composite)
130130

131-
def test_with_constants(self):
131+
@pytest.mark.parametrize("literal_value", (70.0, -np.inf, np.float32("nan")))
132+
def test_with_constants(self, literal_value):
132133
x, y, z = floats("xyz")
133-
e = mul(add(70.0, y), true_div(x, y))
134+
e = mul(add(literal_value, y), true_div(x, y))
134135
comp_op = Composite([x, y], [e])
135136
comp_node = comp_op.make_node(x, y)
136137

137138
c_code = comp_node.op.c_code(comp_node, "dummy", ["x", "y"], ["z"], dict(id=0))
138-
assert "70.0" in c_code
139+
assert constant(literal_value).type.c_literal(literal_value) in c_code
139140

140141
# Make sure caching of the c_code template works
141142
assert hasattr(comp_node.op, "_c_code")
142143

143144
g = FunctionGraph([x, y], [comp_node.out])
144-
fn = make_function(DualLinker().accept(g))
145-
assert fn(1.0, 2.0) == 36.0
145+
146+
# Default checker does not allow `nan`
147+
def checker(x, y):
148+
np.testing.assert_equal(x, y)
149+
150+
fn = make_function(DualLinker(checker=checker).accept(g))
151+
152+
test_x = 1.0
153+
test_y = 2.0
154+
np.testing.assert_equal(
155+
fn(test_x, test_y),
156+
(literal_value + test_y) * (test_x / test_y),
157+
)
146158

147159
def test_many_outputs(self):
148160
x, y, z = floats("xyz")

0 commit comments

Comments
 (0)