-
Notifications
You must be signed in to change notification settings - Fork 367
Add a basic NumPy ufunc compatibility #1096
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,7 +19,7 @@ | |
| """ | ||
|
|
||
| from functools import wraps | ||
| from typing import Union | ||
| from typing import Union, Callable, Any | ||
|
|
||
| import numpy as np | ||
| import pandas as pd | ||
|
|
@@ -36,6 +36,71 @@ | |
| from databricks.koalas.frame import DataFrame | ||
|
|
||
|
|
||
| # Copied from pandas. | ||
| def maybe_dispatch_ufunc_to_dunder_op( | ||
| self, ufunc: Callable, method: str, *inputs, **kwargs: Any | ||
| ): | ||
| special = { | ||
| "add", | ||
| "sub", | ||
| "mul", | ||
| "pow", | ||
| "mod", | ||
| "floordiv", | ||
| "truediv", | ||
| "divmod", | ||
| "eq", | ||
| "ne", | ||
| "lt", | ||
| "gt", | ||
| "le", | ||
| "ge", | ||
| "remainder", | ||
| "matmul", | ||
| } | ||
| aliases = { | ||
| "subtract": "sub", | ||
| "multiply": "mul", | ||
| "floor_divide": "floordiv", | ||
| "true_divide": "truediv", | ||
| "power": "pow", | ||
| "remainder": "mod", | ||
| "divide": "div", | ||
| "equal": "eq", | ||
| "not_equal": "ne", | ||
| "less": "lt", | ||
| "less_equal": "le", | ||
| "greater": "gt", | ||
| "greater_equal": "ge", | ||
| } | ||
|
|
||
| # For op(., Array) -> Array.__r{op}__ | ||
| flipped = { | ||
| "lt": "__gt__", | ||
| "le": "__ge__", | ||
| "gt": "__lt__", | ||
| "ge": "__le__", | ||
| "eq": "__eq__", | ||
| "ne": "__ne__", | ||
| } | ||
|
|
||
| op_name = ufunc.__name__ | ||
| op_name = aliases.get(op_name, op_name) | ||
|
|
||
| def not_implemented(*args, **kwargs): | ||
| return NotImplemented | ||
|
|
||
| if method == "__call__" and op_name in special and kwargs.get("out") is None: | ||
| if isinstance(inputs[0], type(self)): | ||
| name = "__{}__".format(op_name) | ||
| return getattr(self, name, not_implemented)(inputs[1]) | ||
| else: | ||
| name = flipped.get(op_name, "__r{}__".format(op_name)) | ||
| return getattr(self, name, not_implemented)(inputs[0]) | ||
| else: | ||
| return NotImplemented | ||
|
|
||
|
|
||
| def booleanize_null(left_scol, scol, f): | ||
| """ | ||
| Booleanize Null in Spark Column | ||
|
|
@@ -224,6 +289,16 @@ def __rfloordiv__(self, other): | |
| __rand__ = _column_op(spark.Column.__rand__) | ||
| __ror__ = _column_op(spark.Column.__ror__) | ||
|
|
||
| # NDArray Compat | ||
| def __array_ufunc__(self, ufunc: Callable, method: str, *inputs: Any, **kwargs: Any): | ||
| result = maybe_dispatch_ufunc_to_dunder_op( | ||
| self, ufunc, method, *inputs, **kwargs) | ||
| if result is not NotImplemented: | ||
| return result | ||
| else: | ||
| # TODO: support more APIs? | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we can delegate to pandas UDF?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh that's nice suggestion. Let me investigate a bit more about this. It will only work when the output is n to n but I'm sure there will be the case.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, if we can delegate to pandas UDF in a general way, we can take time to add more Spark native functions like |
||
| raise NotImplementedError("Koalas objects currently do not support %s." % ufunc) | ||
|
|
||
| @property | ||
| def dtype(self): | ||
| """Return the dtype object of the underlying data. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,54 @@ | ||
| # | ||
| # Copyright (C) 2019 Databricks, Inc. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # | ||
|
|
||
| import numpy as np | ||
| import pandas as pd | ||
|
|
||
| from databricks import koalas as ks | ||
| from databricks.koalas.testing.utils import ReusedSQLTestCase, SQLTestUtils | ||
|
|
||
|
|
||
| class NumPyCompatTest(ReusedSQLTestCase, SQLTestUtils): | ||
|
|
||
| @property | ||
| def pdf(self): | ||
| return pd.DataFrame({ | ||
| 'a': [1, 2, 3, 4, 5, 6, 7, 8, 9], | ||
| 'b': [4, 5, 6, 3, 2, 1, 0, 0, 0], | ||
| }, index=[0, 1, 3, 5, 6, 8, 9, 9, 9]) | ||
|
|
||
| @property | ||
| def kdf(self): | ||
| return ks.from_pandas(self.pdf) | ||
|
|
||
| def test_np_add_series(self): | ||
| kdf = self.kdf | ||
| pdf = self.pdf | ||
| self.assert_eq(np.add(kdf.a, kdf.b), np.add(pdf.a, pdf.b)) | ||
|
|
||
| kdf = self.kdf | ||
| pdf = self.pdf | ||
| self.assert_eq(np.add(kdf.a, 1), np.add(pdf.a, 1)) | ||
|
|
||
| def test_np_add_index(self): | ||
| k_index = self.kdf.index | ||
| p_index = self.pdf.index | ||
| self.assert_eq(np.add(k_index, k_index), np.add(p_index, p_index)) | ||
|
|
||
| def test_np_unsupported(self): | ||
| kdf = self.kdf | ||
| with self.assertRaisesRegex(NotImplementedError, "Koalas.*not.*support.*sqrt.*"): | ||
| np.sqrt(kdf.a, kdf.b) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We will be able to add more functions supported in Spark natively.