Skip to content
1 change: 0 additions & 1 deletion databricks/koalas/missing/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ class _MissingPandasLikeSeries(object):
combine_first = unsupported_function('combine_first')
cov = unsupported_function('cov')
divmod = unsupported_function('divmod')
dot = unsupported_function('dot')
droplevel = unsupported_function('droplevel')
duplicated = unsupported_function('duplicated')
ewm = unsupported_function('ewm')
Expand Down
56 changes: 56 additions & 0 deletions databricks/koalas/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -4194,6 +4194,62 @@ def pct_change(self, periods=1):

return self._with_new_scol((scol - prev_row) / prev_row)

def dot(self, other):
"""
Compute the dot product between the Series and the columns of other.

This method computes the dot product between the Series and another
one, or the Series and each columns of a DataFrame.

It can also be called using `self @ other` in Python >= 3.5.

Parameters
----------
other : Series, DataFrame.
The other object to compute the dot product with its columns.

Returns
-------
scalar, Series
Return the dot product of the Series and other if other is a
Series, the Series of the dot product of Series and each rows of
other if other is a DataFrame.

Notes
-----
The Series and other has to share the same index if other is a Series
or a DataFrame.

Examples
--------
>>> s = ks.Series([0, 1, 2, 3])

>>> s.dot(s)
14

>>> s @ s
14
"""
if repr(self.index) != repr(other.index):
raise ValueError("matrices are not aligned")

if isinstance(other, DataFrame):
raise ValueError(
"Series.dot() is currently not supported with DataFrame since "
"it will cause expansive calculation as many as the number "
"of columns of DataFrame")

elif isinstance(other, Series):
result = (self * other).sum()

return result

def __matmul__(self, other):
"""
Matrix multiplication using binary `@` operator in Python>=3.5.
"""
return self.dot(other)

def _cum(self, func, skipna, part_cols=()):
# This is used to cummin, cummax, cumsum, etc.
index_columns = self._internal.index_columns
Expand Down
35 changes: 35 additions & 0 deletions databricks/koalas/tests/test_ops_on_diff_frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,41 @@ def test_multi_index_column_assignment_frame(self):
pdf["c"] = 1
self.assert_eq(repr(kdf), repr(pdf))

def test_dot(self):
kser = ks.Series([90, 91, 85], index=[2, 4, 1])
pser = kser.to_pandas()
kser_other = ks.Series([90, 91, 85], index=[2, 4, 1])
pser_other = kser_other.to_pandas()

self.assert_eq(kser.dot(kser_other), pser.dot(pser_other))

kser_other = ks.Series([90, 91, 85], index=[1, 2, 4])
with self.assertRaisesRegex(ValueError, "matrices are not aligned"):
kser.dot(kser_other)

kser_other = ks.Series([90, 91, 85, 100], index=[2, 4, 1, 0])
with self.assertRaisesRegex(ValueError, "matrices are not aligned"):
kser.dot(kser_other)

# with DataFram is not supported for now since performance issue,
# now we raise ValueError with proper message instead.
kdf = ks.DataFrame([[0, 1], [-2, 3], [4, -5]], index=[2, 4, 1])

with self.assertRaisesRegex(ValueError, r"Series\.dot\(\) is currently not supported*"):
kser.dot(kdf)

# for MultiIndex
midx = pd.MultiIndex([['lama', 'cow', 'falcon'],
['speed', 'weight', 'length']],
[[0, 0, 0, 1, 1, 1, 2, 2, 2],
[0, 1, 2, 0, 1, 2, 0, 1, 2]])
kser = ks.Series([45, 200, 1.2, 30, 250, 1.5, 320, 1, 0.3], index=midx)
pser = kser.to_pandas()
kser_other = ks.Series([-450, 20, 12, -30, -250, 15, -320, 100, 3], index=midx)
pser_other = kser_other.to_pandas()

self.assert_eq(kser.dot(kser_other), pser.dot(pser_other))


class OpsOnDiffFramesDisabledTest(ReusedSQLTestCase, SQLTestUtils):

Expand Down
1 change: 1 addition & 0 deletions docs/source/reference/series.rst
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ Binary operator functions
Series.ge
Series.ne
Series.eq
Series.dot

Function application, GroupBy & Window
--------------------------------------
Expand Down