Skip to content

Commit b59e233

Browse files
committed
Complete NumPy ufunc compatibility
1 parent ad946fb commit b59e233

File tree

3 files changed

+259
-67
lines changed

3 files changed

+259
-67
lines changed

databricks/koalas/base.py

Lines changed: 9 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -30,77 +30,13 @@
3030
from pyspark.sql.functions import monotonically_increasing_id
3131

3232
from databricks import koalas as ks # For running doctests and reference resolution in PyCharm.
33+
from databricks.koalas import numpy_compat
3334
from databricks.koalas.internal import _InternalFrame, SPARK_INDEX_NAME_FORMAT
3435
from databricks.koalas.typedef import pandas_wraps, spark_type_to_pandas_dtype
3536
from databricks.koalas.utils import align_diff_series, scol_for, validate_axis
3637
from databricks.koalas.frame import DataFrame
3738

3839

39-
# Copied from pandas.
40-
def maybe_dispatch_ufunc_to_dunder_op(
41-
self, ufunc: Callable, method: str, *inputs, **kwargs: Any
42-
):
43-
special = {
44-
"add",
45-
"sub",
46-
"mul",
47-
"pow",
48-
"mod",
49-
"floordiv",
50-
"truediv",
51-
"divmod",
52-
"eq",
53-
"ne",
54-
"lt",
55-
"gt",
56-
"le",
57-
"ge",
58-
"remainder",
59-
"matmul",
60-
}
61-
aliases = {
62-
"subtract": "sub",
63-
"multiply": "mul",
64-
"floor_divide": "floordiv",
65-
"true_divide": "truediv",
66-
"power": "pow",
67-
"remainder": "mod",
68-
"divide": "div",
69-
"equal": "eq",
70-
"not_equal": "ne",
71-
"less": "lt",
72-
"less_equal": "le",
73-
"greater": "gt",
74-
"greater_equal": "ge",
75-
}
76-
77-
# For op(., Array) -> Array.__r{op}__
78-
flipped = {
79-
"lt": "__gt__",
80-
"le": "__ge__",
81-
"gt": "__lt__",
82-
"ge": "__le__",
83-
"eq": "__eq__",
84-
"ne": "__ne__",
85-
}
86-
87-
op_name = ufunc.__name__
88-
op_name = aliases.get(op_name, op_name)
89-
90-
def not_implemented(*args, **kwargs):
91-
return NotImplemented
92-
93-
if method == "__call__" and op_name in special and kwargs.get("out") is None:
94-
if isinstance(inputs[0], type(self)):
95-
name = "__{}__".format(op_name)
96-
return getattr(self, name, not_implemented)(inputs[1])
97-
else:
98-
name = flipped.get(op_name, "__r{}__".format(op_name))
99-
return getattr(self, name, not_implemented)(inputs[0])
100-
else:
101-
return NotImplemented
102-
103-
10440
def booleanize_null(left_scol, scol, f):
10541
"""
10642
Booleanize Null in Spark Column
@@ -291,8 +227,15 @@ def __rfloordiv__(self, other):
291227

292228
# NDArray Compat
293229
def __array_ufunc__(self, ufunc: Callable, method: str, *inputs: Any, **kwargs: Any):
294-
result = maybe_dispatch_ufunc_to_dunder_op(
230+
# Try dunder methods first.
231+
result = numpy_compat.maybe_dispatch_ufunc_to_dunder_op(
295232
self, ufunc, method, *inputs, **kwargs)
233+
234+
# After that, we try with PySpark APIs.
235+
if result is NotImplemented:
236+
result = numpy_compat.maybe_dispatch_ufunc_to_spark_func(
237+
self, ufunc, method, *inputs, **kwargs)
238+
296239
if result is not NotImplemented:
297240
return result
298241
else:

databricks/koalas/numpy_compat.py

Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
#
2+
# Copyright (C) 2019 Databricks, Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
from collections import OrderedDict
17+
from typing import Callable, Any
18+
19+
import numpy as np
20+
from pyspark.sql import functions as F
21+
from pyspark.sql.types import DoubleType, LongType, BooleanType
22+
23+
24+
unary_np_spark_mappings = OrderedDict({
25+
'abs': F.abs,
26+
'absolute': F.abs,
27+
'arccos': F.acos,
28+
'arccosh': F.pandas_udf(lambda s: np.arccosh(s), DoubleType()),
29+
'arcsin': F.asin,
30+
'arcsinh': F.pandas_udf(lambda s: np.arcsinh(s), DoubleType()),
31+
'arctan': F.atan,
32+
'arctanh': F.pandas_udf(lambda s: np.arctanh(s), DoubleType()),
33+
'bitwise_not': F.bitwiseNOT,
34+
'cbrt': F.cbrt,
35+
'ceil': F.ceil,
36+
'conj': lambda _: NotImplemented, # It requires complex type which Koalas does not support yet
37+
'conjugate': lambda _: NotImplemented, # It requires complex type
38+
'cos': F.cos,
39+
'cosh': F.pandas_udf(lambda s: np.cosh(s), DoubleType()),
40+
'deg2rad': F.pandas_udf(lambda s: np.deg2rad(s), DoubleType()),
41+
'degrees': F.degrees,
42+
'exp': F.exp,
43+
'exp2': F.pandas_udf(lambda s: np.exp2(s), DoubleType()),
44+
'expm1': F.expm1,
45+
'fabs': F.pandas_udf(lambda s: np.fabs(s), DoubleType()),
46+
'floor': F.floor,
47+
'frexp': lambda _: NotImplemented, # 'frexp' output lengths become different
48+
# and it cannot be supported via pandas UDF.
49+
'invert': F.pandas_udf(lambda s: np.invert(s), DoubleType()),
50+
'isfinite': lambda c: c != float("inf"),
51+
'isinf': lambda c: c == float("inf"),
52+
'isnan': F.isnan,
53+
'isnat': lambda c: NotImplemented, # Koalas and PySpark does not have Nat concept.
54+
'log': F.log,
55+
'log10': F.log10,
56+
'log1p': F.log1p,
57+
'log2': F.pandas_udf(lambda s: np.log2(s), DoubleType()),
58+
'logical_not': lambda c: ~(c.cast(BooleanType())),
59+
'matmul': lambda _: NotImplemented, # Can return a NumPy array in pandas.
60+
'negative': lambda c: c * -1,
61+
'positive': lambda c: c,
62+
'rad2deg': F.pandas_udf(lambda s: np.rad2deg(s), DoubleType()),
63+
'radians': F.radians,
64+
'reciprocal': F.pandas_udf(lambda s: np.reciprocal(s), DoubleType()),
65+
'rint': F.pandas_udf(lambda s: np.rint(s), DoubleType()),
66+
'sign': lambda c: F.when(c == 0, 0).when(c < 0, -1).otherwise(1),
67+
'signbit': lambda c: F.when(c < 0, True).otherwise(False),
68+
'sin': F.sin,
69+
'sinh': F.pandas_udf(lambda s: np.sinh(s), DoubleType()),
70+
'spacing': F.pandas_udf(lambda s: np.spacing(s), DoubleType()),
71+
'sqrt': F.sqrt,
72+
'square': F.pandas_udf(lambda s: np.square(s), DoubleType()),
73+
'tan': F.tan,
74+
'tanh': F.pandas_udf(lambda s: np.tanh(s), DoubleType()),
75+
'trunc': F.pandas_udf(lambda s: np.trunc(s), DoubleType()),
76+
})
77+
78+
binary_np_spark_mappings = OrderedDict({
79+
'arctan2': F.atan2,
80+
'bitwise_and': lambda c1, c2: c1.bitwiseAND(c2),
81+
'bitwise_or': lambda c1, c2: c1.bitwiseOR(c2),
82+
'bitwise_xor': lambda c1, c2: c1.bitwiseXOR(c2),
83+
'copysign': F.pandas_udf(lambda s1, s2: np.copysign(s1, s2), DoubleType()),
84+
'float_power': F.pandas_udf(lambda s1, s2: np.float_power(s1, s2), DoubleType()),
85+
'floor_divide': F.pandas_udf(lambda s1, s2: np.floor_divide(s1, s2), DoubleType()),
86+
'fmax': F.pandas_udf(lambda s1, s2: np.fmax(s1, s2), DoubleType()),
87+
'fmin': F.pandas_udf(lambda s1, s2: np.fmin(s1, s2), DoubleType()),
88+
'fmod': F.pandas_udf(lambda s1, s2: np.fmod(s1, s2), DoubleType()),
89+
'gcd': F.pandas_udf(lambda s1, s2: np.gcd(s1, s2), DoubleType()),
90+
'heaviside': F.pandas_udf(lambda s1, s2: np.heaviside(s1, s2), DoubleType()),
91+
'hypot': F.hypot,
92+
'lcm': F.pandas_udf(lambda s1, s2: np.lcm(s1, s2), DoubleType()),
93+
'ldexp': F.pandas_udf(lambda s1, s2: np.ldexp(s1, s2), DoubleType()),
94+
'left_shift': F.pandas_udf(lambda s1, s2: np.left_shift(s1, s2), LongType()),
95+
'logaddexp': F.pandas_udf(lambda s1, s2: np.logaddexp(s1, s2), DoubleType()),
96+
'logaddexp2': F.pandas_udf(lambda s1, s2: np.logaddexp2(s1, s2), DoubleType()),
97+
'logical_and': lambda c1, c2: c1.cast(BooleanType()) & c2.cast(BooleanType()),
98+
'logical_or': lambda c1, c2: c1.cast(BooleanType()) | c2.cast(BooleanType()),
99+
'logical_xor': lambda c1, c2: (
100+
# mimics xor by logical operators.
101+
(c1.cast(BooleanType()) | c2.cast(BooleanType()))
102+
& (~(c1.cast(BooleanType())) | ~(c2.cast(BooleanType())))
103+
),
104+
'maximum': F.greatest,
105+
'minimum': F.least,
106+
'modf': F.pandas_udf(lambda s1, s2: np.modf(s1, s2), DoubleType()),
107+
'nextafter': F.pandas_udf(lambda s1, s2: np.nextafter(s1, s2), DoubleType()),
108+
'right_shift': F.pandas_udf(lambda s1, s2: np.right_shift(s1, s2), LongType()),
109+
})
110+
111+
112+
# Copied from pandas.
113+
# See also https://docs.scipy.org/doc/numpy/reference/arrays.classes.html#standard-array-subclasses
114+
def maybe_dispatch_ufunc_to_dunder_op(
115+
ser_or_index, ufunc: Callable, method: str, *inputs, **kwargs: Any
116+
):
117+
special = {
118+
"add",
119+
"sub",
120+
"mul",
121+
"pow",
122+
"mod",
123+
"floordiv",
124+
"truediv",
125+
"divmod",
126+
"eq",
127+
"ne",
128+
"lt",
129+
"gt",
130+
"le",
131+
"ge",
132+
"remainder",
133+
"matmul",
134+
}
135+
aliases = {
136+
"absolute": "abs", # TODO: Koalas Series and Index should implement __abs__.
137+
"multiply": "mul",
138+
"floor_divide": "floordiv",
139+
"true_divide": "truediv",
140+
"power": "pow",
141+
"remainder": "mod",
142+
"divide": "div",
143+
"equal": "eq",
144+
"not_equal": "ne",
145+
"less": "lt",
146+
"less_equal": "le",
147+
"greater": "gt",
148+
"greater_equal": "ge",
149+
}
150+
151+
# For op(., Array) -> Array.__r{op}__
152+
flipped = {
153+
"lt": "__gt__",
154+
"le": "__ge__",
155+
"gt": "__lt__",
156+
"ge": "__le__",
157+
"eq": "__eq__",
158+
"ne": "__ne__",
159+
}
160+
161+
op_name = ufunc.__name__
162+
op_name = aliases.get(op_name, op_name)
163+
164+
def not_implemented(*args, **kwargs):
165+
return NotImplemented
166+
167+
if method == "__call__" and op_name in special and kwargs.get("out") is None:
168+
if isinstance(inputs[0], type(ser_or_index)):
169+
name = "__{}__".format(op_name)
170+
return getattr(ser_or_index, name, not_implemented)(inputs[1])
171+
else:
172+
name = flipped.get(op_name, "__r{}__".format(op_name))
173+
return getattr(ser_or_index, name, not_implemented)(inputs[0])
174+
else:
175+
return NotImplemented
176+
177+
178+
# See also https://docs.scipy.org/doc/numpy/reference/arrays.classes.html#standard-array-subclasses
179+
def maybe_dispatch_ufunc_to_spark_func(
180+
ser_or_index, ufunc: Callable, method: str, *inputs, **kwargs: Any
181+
):
182+
from databricks.koalas import Series
183+
184+
op_name = ufunc.__name__
185+
186+
if (method == "__call__"
187+
and (op_name in unary_np_spark_mappings or op_name in binary_np_spark_mappings)
188+
and kwargs.get("out") is None):
189+
inputs = [ # type: ignore
190+
inp._scol if isinstance(inp, Series) else F.lit(inp) for inp in inputs] # type: ignore
191+
192+
np_spark_map_func = (
193+
unary_np_spark_mappings.get(op_name)
194+
or binary_np_spark_mappings.get(op_name))
195+
196+
return ser_or_index._with_new_scol(np_spark_map_func(*inputs)) # type: ignore
197+
else:
198+
return NotImplemented

databricks/koalas/tests/test_numpy_compat.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
#
16-
1716
import numpy as np
1817
import pandas as pd
1918

2019
from databricks import koalas as ks
20+
from databricks.koalas.numpy_compat import unary_np_spark_mappings, binary_np_spark_mappings
2121
from databricks.koalas.testing.utils import ReusedSQLTestCase, SQLTestUtils
2222

2323

@@ -52,3 +52,54 @@ def test_np_unsupported(self):
5252
kdf = self.kdf
5353
with self.assertRaisesRegex(NotImplementedError, "Koalas.*not.*support.*sqrt.*"):
5454
np.sqrt(kdf.a, kdf.b)
55+
56+
def test_np_spark_compat(self):
57+
# Use randomly generated dataFrame
58+
pdf = pd.DataFrame(
59+
np.random.randint(-100, 100, size=(np.random.randint(100), 2)), columns=['a', 'b'])
60+
kdf = ks.from_pandas(pdf)
61+
62+
blacklist = [
63+
# Koalas does not currently support
64+
"conj",
65+
"conjugate",
66+
"isnat",
67+
"matmul",
68+
"frexp",
69+
70+
# Values are close enough but tests failed.
71+
"arccos",
72+
"exp",
73+
"expm1",
74+
"log", # flaky
75+
"log10", # flaky
76+
"log1p", # flaky
77+
"modf",
78+
"floor_divide", # flaky
79+
80+
# Results seem inconsistent in a different version of, I (Hyukjin) suspect, PyArrow.
81+
# From PyArrow 0.15, seems it returns the correct results via PySpark. Probably we
82+
# can enable it later when Koalas switches to PyArrow 0.15 completely.
83+
"left_shift",
84+
]
85+
86+
for np_name, spark_func in unary_np_spark_mappings.items():
87+
np_func = getattr(np, np_name)
88+
if np_name not in blacklist:
89+
try:
90+
# unary ufunc
91+
self.assert_eq(np_func(pdf.a), np_func(kdf.a), almost=True)
92+
except Exception as e:
93+
raise AssertionError("Test in '%s' function was failed." % np_name) from e
94+
95+
for np_name, spark_func in binary_np_spark_mappings.items():
96+
np_func = getattr(np, np_name)
97+
if np_name not in blacklist:
98+
try:
99+
# binary ufunc
100+
self.assert_eq(
101+
np_func(pdf.a, pdf.b), np_func(kdf.a, kdf.b), almost=True)
102+
self.assert_eq(
103+
np_func(pdf.a, 1), np_func(kdf.a, 1), almost=True)
104+
except Exception as e:
105+
raise AssertionError("Test in '%s' function was failed." % np_name) from e

0 commit comments

Comments
 (0)