Skip to content

Commit 901a6f0

Browse files
authored
Support DataFrame parameter in Series.dot (#1931)
1 parent 347ce57 commit 901a6f0

File tree

3 files changed

+87
-23
lines changed

3 files changed

+87
-23
lines changed

databricks/koalas/series.py

Lines changed: 43 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4721,9 +4721,9 @@ def combine_first(self, other) -> "Series":
47214721
*index_scols, cond.alias(self._internal.data_spark_column_names[0])
47224722
).distinct()
47234723
internal = self._internal.with_new_sdf(sdf)
4724-
return first_series(ks.DataFrame(internal))
4724+
return first_series(DataFrame(internal))
47254725

4726-
def dot(self, other) -> Union[Scalar, "Series"]:
4726+
def dot(self, other: Union["Series", DataFrame]) -> Union[Scalar, "Series"]:
47274727
"""
47284728
Compute the dot product between the Series and the columns of other.
47294729
@@ -4732,7 +4732,7 @@ def dot(self, other) -> Union[Scalar, "Series"]:
47324732
47334733
It can also be called using `self @ other` in Python >= 3.5.
47344734
4735-
.. note:: This API is slightly different from pandas when indexes from both
4735+
.. note:: This API is slightly different from pandas when indexes from both Series
47364736
are not aligned. To match with pandas', it requires to read the whole data for,
47374737
for example, counting. pandas raises an exception; however, Koalas just proceeds
47384738
and performs by ignoring mismatches with NaN permissively.
@@ -4774,20 +4774,48 @@ def dot(self, other) -> Union[Scalar, "Series"]:
47744774
47754775
>>> s @ s
47764776
14
4777+
4778+
>>> kdf = ks.DataFrame({'x': [0, 1, 2, 3], 'y': [0, -1, -2, -3]})
4779+
>>> kdf
4780+
x y
4781+
0 0 0
4782+
1 1 -1
4783+
2 2 -2
4784+
3 3 -3
4785+
4786+
>>> with ks.option_context("compute.ops_on_diff_frames", True):
4787+
... s.dot(kdf)
4788+
...
4789+
x 14
4790+
y -14
4791+
dtype: int64
47774792
"""
47784793
if isinstance(other, DataFrame):
4779-
raise ValueError(
4780-
"Series.dot() is currently not supported with DataFrame since "
4781-
"it will cause expansive calculation as many as the number "
4782-
"of columns of DataFrame"
4783-
)
4784-
if self._kdf is not other._kdf:
4785-
if len(self.index) != len(other.index):
4786-
raise ValueError("matrices are not aligned")
4787-
if isinstance(other, Series):
4788-
result = (self * other).sum()
4794+
if not same_anchor(self, other):
4795+
if not self.index.sort_values().equals(other.index.sort_values()):
4796+
raise ValueError("matrices are not aligned")
47894797

4790-
return result
4798+
other = other.copy()
4799+
column_labels = other._internal.column_labels
4800+
4801+
self_column_label = verify_temp_column_name(other, "__self_column__")
4802+
other[self_column_label] = self
4803+
self_kser = other._kser_for(self_column_label)
4804+
4805+
product_ksers = [other._kser_for(label) * self_kser for label in column_labels]
4806+
4807+
dot_product_kser = DataFrame(
4808+
other._internal.with_new_columns(product_ksers, column_labels)
4809+
).sum()
4810+
4811+
return cast(Series, dot_product_kser).rename(self.name)
4812+
4813+
else:
4814+
assert isinstance(other, Series)
4815+
if not same_anchor(self, other):
4816+
if len(self.index) != len(other.index):
4817+
raise ValueError("matrices are not aligned")
4818+
return (self * other).sum()
47914819

47924820
def __matmul__(self, other):
47934821
"""
@@ -4945,7 +4973,7 @@ def asof(self, where) -> Union[Scalar, "Series"]:
49454973
should_return_series = True
49464974
if isinstance(self.index, ks.MultiIndex):
49474975
raise ValueError("asof is not supported for a MultiIndex")
4948-
if isinstance(where, (ks.Index, ks.Series, ks.DataFrame)):
4976+
if isinstance(where, (ks.Index, ks.Series, DataFrame)):
49494977
raise ValueError("where cannot be an Index, Series or a DataFrame")
49504978
if not self.index.is_monotonic_increasing:
49514979
raise ValueError("asof requires a sorted index")

databricks/koalas/tests/test_ops_on_diff_frames.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -870,13 +870,6 @@ def test_dot(self):
870870
with self.assertRaisesRegex(ValueError, "matrices are not aligned"):
871871
kser.dot(kser_other)
872872

873-
# with DataFram is not supported for now since performance issue,
874-
# now we raise ValueError with proper message instead.
875-
kdf = ks.DataFrame([[0, 1], [-2, 3], [4, -5]], index=[2, 4, 1])
876-
877-
with self.assertRaisesRegex(ValueError, r"Series\.dot\(\) is currently not supported*"):
878-
kser.dot(kdf)
879-
880873
# for MultiIndex
881874
midx = pd.MultiIndex(
882875
[["lama", "cow", "falcon"], ["speed", "weight", "length"]],
@@ -886,9 +879,44 @@ def test_dot(self):
886879
kser = ks.from_pandas(pser)
887880
pser_other = pd.Series([-450, 20, 12, -30, -250, 15, -320, 100, 3], index=midx)
888881
kser_other = ks.from_pandas(pser_other)
889-
890882
self.assert_eq(kser.dot(kser_other), pser.dot(pser_other))
891883

884+
pser = pd.Series([0, 1, 2, 3])
885+
kser = ks.from_pandas(pser)
886+
887+
# DataFrame "other" without Index/MultiIndex as columns
888+
pdf = pd.DataFrame([[0, 1], [-2, 3], [4, -5], [6, 7]])
889+
kdf = ks.from_pandas(pdf)
890+
self.assert_eq(kser.dot(kdf), pser.dot(pdf))
891+
892+
# DataFrame "other" with Index as columns
893+
pdf.columns = pd.Index(["x", "y"])
894+
kdf = ks.from_pandas(pdf)
895+
self.assert_eq(kser.dot(kdf), pser.dot(pdf))
896+
pdf.columns = pd.Index(["x", "y"], name="cols_name")
897+
kdf = ks.from_pandas(pdf)
898+
self.assert_eq(kser.dot(kdf), pser.dot(pdf))
899+
900+
pdf = pdf.reindex([1, 0, 2, 3])
901+
kdf = ks.from_pandas(pdf)
902+
self.assert_eq(kser.dot(kdf), pser.dot(pdf))
903+
904+
# DataFrame "other" with MultiIndex as columns
905+
pdf.columns = pd.MultiIndex.from_tuples([("a", "x"), ("b", "y")])
906+
kdf = ks.from_pandas(pdf)
907+
self.assert_eq(kser.dot(kdf), pser.dot(pdf))
908+
pdf.columns = pd.MultiIndex.from_tuples(
909+
[("a", "x"), ("b", "y")], names=["cols_name1", "cols_name2"]
910+
)
911+
kdf = ks.from_pandas(pdf)
912+
self.assert_eq(kser.dot(kdf), pser.dot(pdf))
913+
914+
kser = ks.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}).b
915+
pser = kser.to_pandas()
916+
kdf = ks.DataFrame({"c": [7, 8, 9]})
917+
pdf = kdf.to_pandas()
918+
self.assert_eq(kser.dot(kdf), pser.dot(pdf))
919+
892920
def test_to_series_comparison(self):
893921
kidx1 = ks.Index([1, 2, 3, 4, 5])
894922
kidx2 = ks.Index([1, 2, 3, 4, 5])

databricks/koalas/tests/test_series.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2123,6 +2123,14 @@ def test_droplevel(self):
21232123
pser.droplevel([("a", "1"), ("c", "3")]), kser.droplevel([("a", "1"), ("c", "3")])
21242124
)
21252125

2126+
def test_dot(self):
2127+
pdf = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
2128+
kdf = ks.from_pandas(pdf)
2129+
2130+
self.assert_eq((kdf["b"] * 10).dot(kdf["a"]), (pdf["b"] * 10).dot(pdf["a"]))
2131+
self.assert_eq((kdf["b"] * 10).dot(kdf), (pdf["b"] * 10).dot(pdf))
2132+
self.assert_eq((kdf["b"] * 10).dot(kdf + 1), (pdf["b"] * 10).dot(pdf + 1))
2133+
21262134
@unittest.skipIf(
21272135
LooseVersion(pyspark.__version__) < LooseVersion("3.0"),
21282136
"tail won't work properly with PySpark<3.0",

0 commit comments

Comments
 (0)