Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 76 additions & 1 deletion databricks/koalas/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
"""

from functools import wraps
from typing import Union
from typing import Union, Callable, Any

import numpy as np
import pandas as pd
Expand All @@ -36,6 +36,71 @@
from databricks.koalas.frame import DataFrame


# Copied from pandas.
def maybe_dispatch_ufunc_to_dunder_op(
self, ufunc: Callable, method: str, *inputs, **kwargs: Any
):
special = {
"add",
"sub",
"mul",
"pow",
"mod",
"floordiv",
"truediv",
"divmod",
"eq",
"ne",
"lt",
"gt",
"le",
"ge",
"remainder",
"matmul",
}
aliases = {
"subtract": "sub",
"multiply": "mul",
"floor_divide": "floordiv",
"true_divide": "truediv",
"power": "pow",
"remainder": "mod",
"divide": "div",
"equal": "eq",
"not_equal": "ne",
"less": "lt",
"less_equal": "le",
"greater": "gt",
"greater_equal": "ge",
}

# For op(., Array) -> Array.__r{op}__
flipped = {
"lt": "__gt__",
"le": "__ge__",
"gt": "__lt__",
"ge": "__le__",
"eq": "__eq__",
"ne": "__ne__",
}

op_name = ufunc.__name__
op_name = aliases.get(op_name, op_name)

def not_implemented(*args, **kwargs):
return NotImplemented

if method == "__call__" and op_name in special and kwargs.get("out") is None:
if isinstance(inputs[0], type(self)):
name = "__{}__".format(op_name)
return getattr(self, name, not_implemented)(inputs[1])
else:
name = flipped.get(op_name, "__r{}__".format(op_name))
return getattr(self, name, not_implemented)(inputs[0])
else:
return NotImplemented
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will be able to add more functions supported in Spark natively.



def booleanize_null(left_scol, scol, f):
"""
Booleanize Null in Spark Column
Expand Down Expand Up @@ -224,6 +289,16 @@ def __rfloordiv__(self, other):
__rand__ = _column_op(spark.Column.__rand__)
__ror__ = _column_op(spark.Column.__ror__)

# NDArray Compat
def __array_ufunc__(self, ufunc: Callable, method: str, *inputs: Any, **kwargs: Any):
result = maybe_dispatch_ufunc_to_dunder_op(
self, ufunc, method, *inputs, **kwargs)
if result is not NotImplemented:
return result
else:
# TODO: support more APIs?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can delegate to pandas UDF?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh that's nice suggestion. Let me investigate a bit more about this. It will only work when the output is n to n but I'm sure there will be the case.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, if we can delegate to pandas UDF in a general way, we can take time to add more Spark native functions like np.sqrt or np.log.

raise NotImplementedError("Koalas objects currently do not support %s." % ufunc)

@property
def dtype(self):
"""Return the dtype object of the underlying data.
Expand Down
54 changes: 54 additions & 0 deletions databricks/koalas/tests/test_numpy_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#
# Copyright (C) 2019 Databricks, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import numpy as np
import pandas as pd

from databricks import koalas as ks
from databricks.koalas.testing.utils import ReusedSQLTestCase, SQLTestUtils


class NumPyCompatTest(ReusedSQLTestCase, SQLTestUtils):

@property
def pdf(self):
return pd.DataFrame({
'a': [1, 2, 3, 4, 5, 6, 7, 8, 9],
'b': [4, 5, 6, 3, 2, 1, 0, 0, 0],
}, index=[0, 1, 3, 5, 6, 8, 9, 9, 9])

@property
def kdf(self):
return ks.from_pandas(self.pdf)

def test_np_add_series(self):
kdf = self.kdf
pdf = self.pdf
self.assert_eq(np.add(kdf.a, kdf.b), np.add(pdf.a, pdf.b))

kdf = self.kdf
pdf = self.pdf
self.assert_eq(np.add(kdf.a, 1), np.add(pdf.a, 1))

def test_np_add_index(self):
k_index = self.kdf.index
p_index = self.pdf.index
self.assert_eq(np.add(k_index, k_index), np.add(p_index, p_index))

def test_np_unsupported(self):
kdf = self.kdf
with self.assertRaisesRegex(NotImplementedError, "Koalas.*not.*support.*sqrt.*"):
np.sqrt(kdf.a, kdf.b)