Skip to content

Commit a65f1b6

Browse files
committed
Implement dot() for Series
1 parent 27de39e commit a65f1b6

File tree

4 files changed

+103
-1
lines changed

4 files changed

+103
-1
lines changed

databricks/koalas/missing/series.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ class _MissingPandasLikeSeries(object):
6060
combine_first = unsupported_function('combine_first')
6161
cov = unsupported_function('cov')
6262
divmod = unsupported_function('divmod')
63-
dot = unsupported_function('dot')
6463
droplevel = unsupported_function('droplevel')
6564
duplicated = unsupported_function('duplicated')
6665
ewm = unsupported_function('ewm')

databricks/koalas/series.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4194,6 +4194,70 @@ def pct_change(self, periods=1):
41944194

41954195
return self._with_new_scol((scol - prev_row) / prev_row)
41964196

4197+
def dot(self, other):
4198+
"""
4199+
Compute the dot product between the Series and the columns of other.
4200+
4201+
This method computes the dot product between the Series and another
4202+
one, or the Series and each columns of a DataFrame.
4203+
4204+
It can also be called using `self @ other` in Python >= 3.5.
4205+
4206+
Parameters
4207+
----------
4208+
other : Series, DataFrame.
4209+
The other object to compute the dot product with its columns.
4210+
4211+
Returns
4212+
-------
4213+
scalar, Series
4214+
Return the dot product of the Series and other if other is a
4215+
Series, the Series of the dot product of Series and each rows of
4216+
other if other is a DataFrame.
4217+
4218+
Notes
4219+
-----
4220+
The Series and other has to share the same index if other is a Series
4221+
or a DataFrame.
4222+
4223+
Examples
4224+
--------
4225+
>>> from databricks.koalas.config import set_option, reset_option
4226+
>>> set_option("compute.ops_on_diff_frames", True)
4227+
>>> s = ks.Series([0, 1, 2, 3])
4228+
>>> other = ks.Series([-1, 2, -3, 4])
4229+
4230+
>>> s.dot(other)
4231+
8
4232+
4233+
>>> s @ other
4234+
8
4235+
4236+
>>> df = ks.DataFrame([[0, 1], [-2, 3], [4, -5], [6, 7]])
4237+
>>> s.dot(df)
4238+
0 24
4239+
1 14
4240+
Name: 0, dtype: int64
4241+
4242+
>>> reset_option("compute.ops_on_diff_frames")
4243+
"""
4244+
if repr(self.index) != repr(other.index):
4245+
raise ValueError("matrices are not aligned")
4246+
4247+
if isinstance(other, DataFrame):
4248+
idx_val_dict = {col_name: (self * other[col_name]).sum() for col_name in other}
4249+
result = Series(idx_val_dict)
4250+
elif isinstance(other, Series):
4251+
result = (self * other).sum()
4252+
4253+
return result
4254+
4255+
def __matmul__(self, other):
4256+
"""
4257+
Matrix multiplication using binary `@` operator in Python>=3.5.
4258+
"""
4259+
return self.dot(other)
4260+
41974261
def _cum(self, func, skipna, part_cols=()):
41984262
# This is used to cummin, cummax, cumsum, etc.
41994263
index_columns = self._internal.index_columns

databricks/koalas/tests/test_ops_on_diff_frames.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -574,6 +574,44 @@ def test_multi_index_column_assignment_frame(self):
574574
pdf["c"] = 1
575575
self.assert_eq(repr(kdf), repr(pdf))
576576

577+
def test_dot(self):
578+
kser = ks.Series([90, 91, 85], index=[2, 4, 1])
579+
pser = kser.to_pandas()
580+
kser_other = ks.Series([90, 91, 85], index=[2, 4, 1])
581+
pser_other = kser_other.to_pandas()
582+
583+
self.assert_eq(kser.dot(kser_other), pser.dot(pser_other))
584+
585+
kser_other = ks.Series([90, 91, 85], index=[1, 2, 4])
586+
with self.assertRaisesRegex(ValueError, "matrices are not aligned"):
587+
kser.dot(kser_other)
588+
589+
kser_other = ks.Series([90, 91, 85, 100], index=[2, 4, 1, 0])
590+
with self.assertRaisesRegex(ValueError, "matrices are not aligned"):
591+
kser.dot(kser_other)
592+
593+
# with DataFrame
594+
kdf = ks.DataFrame([[0, 1], [-2, 3], [4, -5]], index=[2, 4, 1])
595+
pdf = kdf.to_pandas()
596+
597+
self.assert_eq(kser.dot(kdf), pser.dot(pdf))
598+
599+
kdf = ks.DataFrame([[0, 1], [-2, 3], [4, -5]], index=[1, 2, 4])
600+
with self.assertRaisesRegex(ValueError, "matrices are not aligned"):
601+
kser.dot(kdf)
602+
603+
# for MultiIndex
604+
midx = pd.MultiIndex([['lama', 'cow', 'falcon'],
605+
['speed', 'weight', 'length']],
606+
[[0, 0, 0, 1, 1, 1, 2, 2, 2],
607+
[0, 1, 2, 0, 1, 2, 0, 1, 2]])
608+
kser = ks.Series([45, 200, 1.2, 30, 250, 1.5, 320, 1, 0.3], index=midx)
609+
pser = kser.to_pandas()
610+
kser_other = ks.Series([-450, 20, 12, -30, -250, 15, -320, 100, 3], index=midx)
611+
pser_other = kser_other.to_pandas()
612+
613+
self.assert_eq(kser.dot(kser_other), pser.dot(pser_other))
614+
577615

578616
class OpsOnDiffFramesDisabledTest(ReusedSQLTestCase, SQLTestUtils):
579617

docs/source/reference/series.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ Binary operator functions
8282
Series.ge
8383
Series.ne
8484
Series.eq
85+
Series.dot
8586

8687
Function application, GroupBy & Window
8788
--------------------------------------

0 commit comments

Comments
 (0)