Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 38 additions & 38 deletions databricks/koalas/window.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#
from collections import OrderedDict
from functools import partial
from typing import Any
from typing import Any, Union

from pyspark.sql import Window
from pyspark.sql import functions as F
Expand Down Expand Up @@ -54,13 +54,13 @@ def _apply_as_series_or_frame(self, func):
"to handle the index and columns of output."
)

def count(self):
def count(self) -> Union["ks.Series", "ks.DataFrame"]:
def count(scol):
return F.count(scol).over(self._window)

return self._apply_as_series_or_frame(count).astype("float64")

def sum(self):
def sum(self) -> Union["ks.Series", "ks.DataFrame"]:
def sum(scol):
return F.when(
F.row_number().over(self._unbounded_window) >= self._min_periods,
Expand All @@ -69,7 +69,7 @@ def sum(scol):

return self._apply_as_series_or_frame(sum)

def min(self):
def min(self) -> Union["ks.Series", "ks.DataFrame"]:
def min(scol):
return F.when(
F.row_number().over(self._unbounded_window) >= self._min_periods,
Expand All @@ -78,7 +78,7 @@ def min(scol):

return self._apply_as_series_or_frame(min)

def max(self):
def max(self) -> Union["ks.Series", "ks.DataFrame"]:
def max(scol):
return F.when(
F.row_number().over(self._unbounded_window) >= self._min_periods,
Expand All @@ -87,7 +87,7 @@ def max(scol):

return self._apply_as_series_or_frame(max)

def mean(self):
def mean(self) -> Union["ks.Series", "ks.DataFrame"]:
def mean(scol):
return F.when(
F.row_number().over(self._unbounded_window) >= self._min_periods,
Expand All @@ -96,7 +96,7 @@ def mean(scol):

return self._apply_as_series_or_frame(mean)

def std(self):
def std(self) -> Union["ks.Series", "ks.DataFrame"]:
def std(scol):
return F.when(
F.row_number().over(self._unbounded_window) >= self._min_periods,
Expand All @@ -105,7 +105,7 @@ def std(scol):

return self._apply_as_series_or_frame(std)

def var(self):
def var(self) -> Union["ks.Series", "ks.DataFrame"]:
def var(scol):
return F.when(
F.row_number().over(self._unbounded_window) >= self._min_periods,
Expand Down Expand Up @@ -153,7 +153,7 @@ def _apply_as_series_or_frame(self, func):
lambda kser: kser._with_new_scol(func(kser.spark.column)), should_resolve=True
)

def count(self):
def count(self) -> Union["ks.Series", "ks.DataFrame"]:
"""
The rolling count of any non-NaN observations inside the window.

Expand Down Expand Up @@ -202,7 +202,7 @@ def count(self):
"""
return super().count()

def sum(self):
def sum(self) -> Union["ks.Series", "ks.DataFrame"]:
"""
Calculate rolling summation of given DataFrame or Series.

Expand Down Expand Up @@ -280,7 +280,7 @@ def sum(self):
"""
return super().sum()

def min(self):
def min(self) -> Union["ks.Series", "ks.DataFrame"]:
"""
Calculate the rolling minimum.

Expand Down Expand Up @@ -358,7 +358,7 @@ def min(self):
"""
return super().min()

def max(self):
def max(self) -> Union["ks.Series", "ks.DataFrame"]:
"""
Calculate the rolling maximum.

Expand Down Expand Up @@ -435,7 +435,7 @@ def max(self):
"""
return super().max()

def mean(self):
def mean(self) -> Union["ks.Series", "ks.DataFrame"]:
"""
Calculate the rolling mean of the values.

Expand Down Expand Up @@ -513,7 +513,7 @@ def mean(self):
"""
return super().mean()

def std(self):
def std(self) -> Union["ks.Series", "ks.DataFrame"]:
"""
Calculate rolling standard deviation.

Expand Down Expand Up @@ -563,7 +563,7 @@ def std(self):
"""
return super().std()

def var(self):
def var(self) -> Union["ks.Series", "ks.DataFrame"]:
"""
Calculate unbiased rolling variance.

Expand Down Expand Up @@ -713,7 +713,7 @@ def _apply_as_series_or_frame(self, func):
else:
return ret

def count(self):
def count(self) -> Union["ks.Series", "ks.DataFrame"]:
"""
The rolling count of any non-NaN observations inside the window.

Expand Down Expand Up @@ -767,7 +767,7 @@ def count(self):
"""
return super().count()

def sum(self):
def sum(self) -> Union["ks.Series", "ks.DataFrame"]:
"""
The rolling summation of any non-NaN observations inside the window.

Expand Down Expand Up @@ -821,7 +821,7 @@ def sum(self):
"""
return super().sum()

def min(self):
def min(self) -> Union["ks.Series", "ks.DataFrame"]:
"""
The rolling minimum of any non-NaN observations inside the window.

Expand Down Expand Up @@ -875,7 +875,7 @@ def min(self):
"""
return super().min()

def max(self):
def max(self) -> Union["ks.Series", "ks.DataFrame"]:
"""
The rolling maximum of any non-NaN observations inside the window.

Expand Down Expand Up @@ -929,7 +929,7 @@ def max(self):
"""
return super().max()

def mean(self):
def mean(self) -> Union["ks.Series", "ks.DataFrame"]:
"""
The rolling mean of any non-NaN observations inside the window.

Expand Down Expand Up @@ -983,7 +983,7 @@ def mean(self):
"""
return super().mean()

def std(self):
def std(self) -> Union["ks.Series", "ks.DataFrame"]:
"""
Calculate rolling standard deviation.

Expand All @@ -1002,7 +1002,7 @@ def std(self):
"""
return super().std()

def var(self):
def var(self) -> Union["ks.Series", "ks.DataFrame"]:
"""
Calculate unbiased rolling variance.

Expand Down Expand Up @@ -1053,9 +1053,9 @@ def __getattr__(self, item: str) -> Any:
def __repr__(self):
return "Expanding [min_periods={}]".format(self._min_periods)

_apply_as_series_or_frame = Rolling._apply_as_series_or_frame # type: ignore
_apply_as_series_or_frame = Rolling._apply_as_series_or_frame

def count(self):
def count(self) -> Union["ks.Series", "ks.DataFrame"]:
"""
The expanding count of any non-NaN observations inside the window.

Expand Down Expand Up @@ -1101,9 +1101,9 @@ def count(scol):
F.count(scol).over(self._window),
).otherwise(F.lit(None))

return self._apply_as_series_or_frame(count).astype("float64")
return self._apply_as_series_or_frame(count).astype("float64") # type: ignore

def sum(self):
def sum(self) -> Union["ks.Series", "ks.DataFrame"]:
"""
Calculate expanding summation of given DataFrame or Series.

Expand Down Expand Up @@ -1165,7 +1165,7 @@ def sum(self):
"""
return super().sum()

def min(self):
def min(self) -> Union["ks.Series", "ks.DataFrame"]:
"""
Calculate the expanding minimum.

Expand Down Expand Up @@ -1202,7 +1202,7 @@ def min(self):
"""
return super().min()

def max(self):
def max(self) -> Union["ks.Series", "ks.DataFrame"]:
"""
Calculate the expanding maximum.

Expand Down Expand Up @@ -1238,7 +1238,7 @@ def max(self):
"""
return super().max()

def mean(self):
def mean(self) -> Union["ks.Series", "ks.DataFrame"]:
"""
Calculate the expanding mean of the values.

Expand Down Expand Up @@ -1282,7 +1282,7 @@ def mean(self):
"""
return super().mean()

def std(self):
def std(self) -> Union["ks.Series", "ks.DataFrame"]:
"""
Calculate expanding standard deviation.

Expand Down Expand Up @@ -1332,7 +1332,7 @@ def std(self):
"""
return super().std()

def var(self):
def var(self) -> Union["ks.Series", "ks.DataFrame"]:
"""
Calculate unbiased expanding variance.

Expand Down Expand Up @@ -1417,7 +1417,7 @@ def __getattr__(self, item: str) -> Any:

_apply_as_series_or_frame = RollingGroupby._apply_as_series_or_frame # type: ignore

def count(self):
def count(self) -> Union["ks.Series", "ks.DataFrame"]:
"""
The expanding count of any non-NaN observations inside the window.

Expand Down Expand Up @@ -1471,7 +1471,7 @@ def count(self):
"""
return super().count()

def sum(self):
def sum(self) -> Union["ks.Series", "ks.DataFrame"]:
"""
Calculate expanding summation of given DataFrame or Series.

Expand Down Expand Up @@ -1525,7 +1525,7 @@ def sum(self):
"""
return super().sum()

def min(self):
def min(self) -> Union["ks.Series", "ks.DataFrame"]:
"""
Calculate the expanding minimum.

Expand Down Expand Up @@ -1579,7 +1579,7 @@ def min(self):
"""
return super().min()

def max(self):
def max(self) -> Union["ks.Series", "ks.DataFrame"]:
"""
Calculate the expanding maximum.

Expand Down Expand Up @@ -1632,7 +1632,7 @@ def max(self):
"""
return super().max()

def mean(self):
def mean(self) -> Union["ks.Series", "ks.DataFrame"]:
"""
Calculate the expanding mean of the values.

Expand Down Expand Up @@ -1686,7 +1686,7 @@ def mean(self):
"""
return super().mean()

def std(self):
def std(self) -> Union["ks.Series", "ks.DataFrame"]:
"""
Calculate expanding standard deviation.

Expand All @@ -1706,7 +1706,7 @@ def std(self):
"""
return super().std()

def var(self):
def var(self) -> Union["ks.Series", "ks.DataFrame"]:
"""
Calculate unbiased expanding variance.

Expand Down