Skip to content

Commit 5081464

Browse files
Add nan_to_num conversion
1 parent b7b309d commit 5081464

File tree

4 files changed

+178
-0
lines changed

4 files changed

+178
-0
lines changed

pytensor/scalar/basic.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1533,6 +1533,56 @@ def c_code_cache_version(self):
15331533
isinf = IsInf()
15341534

15351535

1536+
class IsPosInf(FixedLogicalComparison):
1537+
nfunc_spec = ("isposinf", 1, 1)
1538+
1539+
def impl(self, x):
1540+
return np.isposinf(x)
1541+
1542+
def c_code(self, node, name, inputs, outputs, sub):
1543+
(x,) = inputs
1544+
(z,) = outputs
1545+
if node.inputs[0].type in complex_types:
1546+
raise NotImplementedError()
1547+
# Discrete type can never be posinf
1548+
if node.inputs[0].type in discrete_types:
1549+
return f"{z} = false;"
1550+
1551+
return f"{z} = isinf({x}) && !signbit({x});"
1552+
1553+
def c_code_cache_version(self):
1554+
scalarop_version = super().c_code_cache_version()
1555+
return (*scalarop_version, 4)
1556+
1557+
1558+
isposinf = IsPosInf()
1559+
1560+
1561+
class IsNegInf(FixedLogicalComparison):
1562+
nfunc_spec = ("isneginf", 1, 1)
1563+
1564+
def impl(self, x):
1565+
return np.isneginf(x)
1566+
1567+
def c_code(self, node, name, inputs, outputs, sub):
1568+
(x,) = inputs
1569+
(z,) = outputs
1570+
if node.inputs[0].type in complex_types:
1571+
raise NotImplementedError()
1572+
# Discrete type can never be neginf
1573+
if node.inputs[0].type in discrete_types:
1574+
return f"{z} = false;"
1575+
1576+
return f"{z} = isinf({x}) && signbit({x});"
1577+
1578+
def c_code_cache_version(self):
1579+
scalarop_version = super().c_code_cache_version()
1580+
return (*scalarop_version, 4)
1581+
1582+
1583+
isneginf = IsNegInf()
1584+
1585+
15361586
class InRange(LogicalComparison):
15371587
nin = 3
15381588

pytensor/tensor/math.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -881,6 +881,46 @@ def isinf(a):
881881
return isinf_(a)
882882

883883

884+
@scalar_elemwise
885+
def isposinf(a):
886+
"""isposinf(a)"""
887+
888+
889+
# Rename isposnan to isposnan_ to allow to bypass it when not needed.
890+
# glibc 2.23 don't allow isposnan on int, so we remove it from the graph.
891+
isposinf_ = isposinf
892+
893+
894+
def isposinf(a):
895+
"""isposinf(a)"""
896+
a = as_tensor_variable(a)
897+
if a.dtype in discrete_dtypes:
898+
return alloc(
899+
np.asarray(False, dtype="bool"), *[a.shape[i] for i in range(a.ndim)]
900+
)
901+
return isposinf_(a)
902+
903+
904+
@scalar_elemwise
905+
def isneginf(a):
906+
"""isneginf(a)"""
907+
908+
909+
# Rename isnegnan to isnegnan_ to allow to bypass it when not needed.
910+
# glibc 2.23 don't allow isnegnan on int, so we remove it from the graph.
911+
isneginf_ = isneginf
912+
913+
914+
def isneginf(a):
915+
"""isneginf(a)"""
916+
a = as_tensor_variable(a)
917+
if a.dtype in discrete_dtypes:
918+
return alloc(
919+
np.asarray(False, dtype="bool"), *[a.shape[i] for i in range(a.ndim)]
920+
)
921+
return isneginf_(a)
922+
923+
884924
def allclose(a, b, rtol=1.0e-5, atol=1.0e-8, equal_nan=False):
885925
"""
886926
Implement Numpy's ``allclose`` on tensors.
@@ -3043,6 +3083,65 @@ def vectorize_node_dot_to_matmul(op, node, batched_x, batched_y):
30433083
return vectorize_node_fallback(op, node, batched_x, batched_y)
30443084

30453085

3086+
def nan_to_num(x, nan=0.0, posinf=None, neginf=None):
3087+
"""
3088+
Replace NaN with zero and infinity with large finite numbers (default
3089+
behaviour) or with the numbers defined by the user using the `nan`,
3090+
`posinf` and/or `neginf` keywords.
3091+
3092+
NaN is replaced by zero or by the user defined value in
3093+
`nan` keyword, infinity is replaced by the largest finite floating point
3094+
values representable by ``x.dtype`` or by the user defined value in
3095+
`posinf` keyword and -infinity is replaced by the most negative finite
3096+
floating point values representable by ``x.dtype`` or by the user defined
3097+
value in `neginf` keyword.
3098+
3099+
Parameters
3100+
----------
3101+
x : symbolic tensor
3102+
Input array.
3103+
nan
3104+
The value to replace NaN's with in the tensor (default = 0).
3105+
posinf
3106+
The value to replace +INF with in the tensor (default max
3107+
in range representable by ``x.dtype``).
3108+
neginf
3109+
The value to replace -INF with in the tensor (default min
3110+
in range representable by ``x.dtype``).
3111+
3112+
Returns
3113+
-------
3114+
out
3115+
The tensor with NaN's, +INF, and -INF replaced with the
3116+
specified and/or default substitutions.
3117+
"""
3118+
# Replace NaN's with nan keyword
3119+
is_nan = isnan(x)
3120+
is_pos_inf = isposinf(x)
3121+
is_neg_inf = isneginf(x)
3122+
3123+
if not any(is_nan) and not any(is_pos_inf) and not any(is_neg_inf):
3124+
return
3125+
3126+
x = switch(is_nan, nan, x)
3127+
3128+
# Get max and min values representable by x.dtype
3129+
maxf = posinf
3130+
minf = neginf
3131+
3132+
# Specify the value to replace +INF and -INF with
3133+
if maxf is None:
3134+
maxf = np.finfo(x.real.dtype).max
3135+
if minf is None:
3136+
minf = np.finfo(x.real.dtype).min
3137+
3138+
# Replace +INF and -INF values
3139+
x = switch(is_pos_inf, maxf, x)
3140+
x = switch(is_neg_inf, minf, x)
3141+
3142+
return x
3143+
3144+
30463145
# NumPy logical aliases
30473146
square = sqr
30483147

@@ -3199,4 +3298,5 @@ def vectorize_node_dot_to_matmul(op, node, batched_x, batched_y):
31993298
"logaddexp",
32003299
"logsumexp",
32013300
"hyp2f1",
3301+
"nan_to_num",
32023302
]

tests/tensor/random/test_basic.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@
5151
pareto,
5252
permutation,
5353
poisson,
54+
<<<<<<< HEAD
55+
=======
56+
randint,
57+
>>>>>>> f7b9916e2 (Add nan_to_num conversion)
5458
rayleigh,
5559
standard_normal,
5660
t,

tests/tensor/test_math.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@
9595
minimum,
9696
mod,
9797
mul,
98+
nan_to_num,
9899
neg,
99100
neq,
100101
outer,
@@ -3641,3 +3642,26 @@ def test_grad_n_undefined(self):
36413642
n = scalar(dtype="int64")
36423643
with pytest.raises(NullTypeGradError):
36433644
grad(polygamma(n, 0.5), wrt=n)
3645+
3646+
3647+
@pytest.mark.parametrize(
3648+
["nan", "posinf", "neginf"],
3649+
[(0, None, None), (0, 0, 0), (0, None, 1000), (3, 1, -1)],
3650+
)
3651+
def test_nan_to_num(nan, posinf, neginf):
3652+
x = tensor(shape=(7,))
3653+
3654+
out = nan_to_num(x, nan, posinf, neginf)
3655+
3656+
f = function([x], nan_to_num(x, nan, posinf, neginf), on_unused_input="warn")
3657+
3658+
y = np.array([1, 2, np.nan, np.inf, -np.inf, 3, 4])
3659+
out = f(y)
3660+
3661+
posinf = np.finfo(x.real.dtype).max if posinf is None else posinf
3662+
neginf = np.finfo(x.real.dtype).min if neginf is None else neginf
3663+
3664+
np.testing.assert_allclose(
3665+
out,
3666+
np.nan_to_num(y, nan=nan, posinf=posinf, neginf=neginf),
3667+
)

0 commit comments

Comments
 (0)