Skip to content
12 changes: 10 additions & 2 deletions databricks/koalas/tests/test_expanding.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from distutils.version import LooseVersion

import numpy as np
import pandas as pd

Expand Down Expand Up @@ -60,7 +62,10 @@ def test_expanding_repr(self):
self.assertEqual(repr(ks.range(10).expanding(5)), "Expanding [min_periods=5]")

def test_expanding_count(self):
self._test_expanding_func("count")
# The behaviour of Expanding.count are different between pandas>=1.0.0 and lower,
# and we're following the behaviour of latest version of pandas.
if LooseVersion(pd.__version__) >= LooseVersion('1.0.0'):
self._test_expanding_func("count")

def test_expanding_min(self):
self._test_expanding_func("min")
Expand Down Expand Up @@ -115,7 +120,10 @@ def _test_groupby_expanding_func(self, f):
repr(getattr(pdf.groupby([("a", "x"), ("a", "y")]).expanding(2), f)().sort_index()))

def test_groupby_expanding_count(self):
self._test_groupby_expanding_func("count")
# The behaviour of ExpandingGroupby.count are different between pandas>=1.0.0 and lower,
# and we're following the behaviour of latest version of pandas.
if LooseVersion(pd.__version__) >= LooseVersion('1.0.0'):
self._test_groupby_expanding_func("count")

def test_groupby_expanding_min(self):
self._test_groupby_expanding_func("min")
Expand Down
16 changes: 6 additions & 10 deletions databricks/koalas/window.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,6 @@ def count(self):
def count(scol):
return F.count(scol).over(self._window)

if LooseVersion(pd.__version__) >= LooseVersion('1.0.0'):
if isinstance(self, (Expanding, ExpandingGroupby)):
def count_expanding(scol):
return F.when(
F.row_number().over(self._unbounded_window) >= self._min_periods,
F.count(scol).over(self._window)
).otherwise(F.lit(None))
return self._apply_as_series_or_frame(count_expanding).astype('float64')

return self._apply_as_series_or_frame(count).astype('float64')

def sum(self):
Expand Down Expand Up @@ -1104,7 +1095,12 @@ def count(self):
2 2.0
3 3.0
"""
return super(Expanding, self).count()
def count(scol):
return F.when(
F.row_number().over(self._unbounded_window) >= self._min_periods,
F.count(scol).over(self._window)
).otherwise(F.lit(None))
return self._apply_as_series_or_frame(count).astype('float64')

def sum(self):
"""
Expand Down