Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions databricks/koalas/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -1182,6 +1182,7 @@ def pandas_apply(pdf, *a, **k):

if isinstance(return_type, DataFrameType):
return_schema = cast(DataFrameType, return_type).spark_type
data_dtypes = cast(DataFrameType, return_type).dtypes
else:
should_return_series = True
return_schema = cast(Union[SeriesType, ScalarType], return_type).spark_type
Expand All @@ -1191,6 +1192,7 @@ def pandas_apply(pdf, *a, **k):
return_schema = StructType(
[StructField(SPARK_DEFAULT_SERIES_NAME, return_schema)]
)
data_dtypes = [cast(Union[SeriesType, ScalarType], return_type).dtype]

def pandas_groupby_apply(pdf):

Expand Down Expand Up @@ -1237,7 +1239,9 @@ def wrapped_func(df, *a, **k):
internal = kdf_from_pandas._internal.with_new_sdf(sdf)
else:
# Otherwise, it loses index.
internal = InternalFrame(spark_frame=sdf, index_spark_columns=None)
internal = InternalFrame(
spark_frame=sdf, index_spark_columns=None, data_dtypes=data_dtypes
)

if should_return_series:
kser = first_series(DataFrame(internal))
Expand Down Expand Up @@ -2153,6 +2157,9 @@ def pandas_transform(pdf):
return_schema = StructType(
[StructField(c, return_schema) for c in data_columns if c not in groupkey_names]
)
data_dtypes = [
cast(SeriesType, return_type).dtype for c in data_columns if c not in groupkey_names
]

sdf = GroupBy._spark_group_map_apply(
kdf,
Expand All @@ -2162,7 +2169,9 @@ def pandas_transform(pdf):
retain_index=False,
)
# Otherwise, it loses index.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I may not understand # Otherwise, in this context. I was thinking transform would always lose the index.

internal = InternalFrame(spark_frame=sdf, index_spark_columns=None)
internal = InternalFrame(
spark_frame=sdf, index_spark_columns=None, data_dtypes=data_dtypes
)

return DataFrame(internal)

Expand Down
89 changes: 75 additions & 14 deletions databricks/koalas/tests/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,27 @@


class CategoricalTest(ReusedSQLTestCase, TestUtils):
def test_categorical_frame(self):
pdf = pd.DataFrame(
@property
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice refactoring!

def pdf(self):
return pd.DataFrame(
{
"a": pd.Categorical([1, 2, 3, 1, 2, 3]),
"b": pd.Categorical(["a", "b", "c", "a", "b", "c"], categories=["c", "b", "a"]),
"b": pd.Categorical(
["b", "a", "c", "c", "b", "a"], categories=["c", "b", "d", "a"]
),
},
index=pd.Categorical([10, 20, 30, 20, 30, 10], categories=[30, 10, 20], ordered=True),
)
kdf = ks.from_pandas(pdf)

@property
def kdf(self):
return ks.from_pandas(self.pdf)

@property
def df_pair(self):
return (self.pdf, self.kdf)

def test_categorical_frame(self):
pdf, kdf = self.df_pair

self.assert_eq(kdf, pdf)
self.assert_eq(kdf.a, pdf.a)
Expand Down Expand Up @@ -95,15 +107,7 @@ def test_factorize(self):
self.assert_eq(kuniques, puniques)

def test_groupby_apply(self):
pdf = pd.DataFrame(
{
"a": pd.Categorical([1, 2, 3, 1, 2, 3]),
"b": pd.Categorical(
["b", "a", "c", "c", "b", "a"], categories=["c", "b", "d", "a"]
),
},
)
kdf = ks.from_pandas(pdf)
pdf, kdf = self.df_pair

self.assert_eq(
kdf.groupby("a").apply(lambda df: df).sort_index(),
Expand Down Expand Up @@ -134,3 +138,60 @@ def test_groupby_apply(self):
def test_groupby_apply_without_shortcut(self):
with ks.option_context("compute.shortcut_limit", 0):
self.test_groupby_apply()

pdf, kdf = self.df_pair

def identity(df) -> ks.DataFrame[zip(kdf.columns, kdf.dtypes)]:
return df

self.assert_eq(
kdf.groupby("a").apply(identity).sort_values(["a", "b"]).reset_index(drop=True),
pdf.groupby("a").apply(identity).sort_values(["a", "b"]).reset_index(drop=True),
)

def test_groupby_transform(self):
pdf, kdf = self.df_pair

self.assert_eq(
kdf.groupby("a").transform(lambda x: x).sort_index(),
pdf.groupby("a").transform(lambda x: x).sort_index(),
)

dtype = CategoricalDtype(categories=["a", "b", "c", "d"])

self.assert_eq(
kdf.groupby("a").transform(lambda x: x.astype(dtype)).sort_index(),
pdf.groupby("a").transform(lambda x: x.astype(dtype)).sort_index(),
)

def test_groupby_transform_without_shortcut(self):
with ks.option_context("compute.shortcut_limit", 0):
self.test_groupby_transform()

pdf, kdf = self.df_pair

def identity(x) -> ks.Series[kdf.b.dtype]: # type: ignore
return x

self.assert_eq(
kdf.groupby("a").transform(identity).sort_values("b").reset_index(drop=True),
pdf.groupby("a").transform(identity).sort_values("b").reset_index(drop=True),
)

dtype = CategoricalDtype(categories=["a", "b", "c", "d"])

def astype(x) -> ks.Series[dtype]:
return x.astype(dtype)

if LooseVersion(pd.__version__) >= LooseVersion("1.2"):
self.assert_eq(
kdf.groupby("a").transform(astype).sort_values("b").reset_index(drop=True),
pdf.groupby("a").transform(astype).sort_values("b").reset_index(drop=True),
)
else:
expected = pdf.groupby("a").transform(astype)
expected["b"] = dtype.categories.take(expected["b"].cat.codes).astype(dtype)
self.assert_eq(
kdf.groupby("a").transform(astype).sort_values("b").reset_index(drop=True),
expected.sort_values("b").reset_index(drop=True),
)