Skip to content

Commit 688a54a

Browse files
committed
Add schema inference support at DataFrame.transform
1 parent 5a6bbb9 commit 688a54a

File tree

2 files changed

+57
-9
lines changed

2 files changed

+57
-9
lines changed

databricks/koalas/frame.py

Lines changed: 46 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
IntegerType, LongType, NumericType, ShortType, StructType)
4444
from pyspark.sql.utils import AnalysisException
4545
from pyspark.sql.window import Window
46-
from pyspark.sql.functions import pandas_udf
46+
from pyspark.sql.functions import pandas_udf, PandasUDFType
4747

4848
from databricks import koalas as ks # For running doctests and reference resolution in PyCharm.
4949
from databricks.koalas.utils import validate_arguments_and_invoke_function, align_diff_frames
@@ -1539,7 +1539,16 @@ def transform(self, func):
15391539
Call ``func`` on self producing a Series with transformed values
15401540
and that has the same length as its input.
15411541
1542-
.. note:: unlike pandas, it is required for ``func`` to specify return type hint.
1542+
.. note:: this API executes the function once to infer the type which is
1543+
potentially expensive, for instance, when the dataset is created after
1544+
aggregations or sorting.
1545+
1546+
To avoid this, specify return type in ``func``, for instance, as below:
1547+
1548+
>>> def square(x) -> ks.Series[np.int32]:
1549+
... return x ** 2
1550+
1551+
Koalas uses return type hint and does not try to infer the type.
15431552
15441553
.. note:: the series within ``func`` is actually a pandas series, and
15451554
the length of each series is not guaranteed.
@@ -1575,20 +1584,48 @@ def transform(self, func):
15751584
0 0 1
15761585
1 1 4
15771586
2 4 9
1587+
1588+
You can omit the type hint and let Koalas infer its type.
1589+
1590+
>>> df.transform(lambda x: x ** 2)
1591+
A B
1592+
0 0 1
1593+
1 1 4
1594+
2 4 9
1595+
15781596
"""
15791597
assert callable(func), "the first argument should be a callable function."
15801598
spec = inspect.getfullargspec(func)
15811599
return_sig = spec.annotations.get("return", None)
1582-
if return_sig is None:
1583-
raise ValueError("Given function must have return type hint; however, not found.")
1600+
should_infer_schema = return_sig is None
15841601

1585-
wrapped = ks.pandas_wraps(func)
1586-
applied = []
1587-
for column in self._internal.data_columns:
1588-
applied.append(wrapped(self[column]).rename(column))
1602+
if should_infer_schema:
1603+
# Here we execute with the first 1000 to get the return type.
1604+
# If the records were less than 1000, it uses pandas API directly for a shortcut.
1605+
limit = 1000
1606+
pdf = self.head(limit + 1)._to_internal_pandas()
1607+
transformed = pdf.transform(func)
1608+
kdf = DataFrame(transformed)
1609+
return_schema = kdf._sdf.schema
1610+
if len(pdf) <= limit:
1611+
return kdf
1612+
1613+
applied = []
1614+
for input_column, output_column in zip(
1615+
self._internal.data_columns, kdf._internal.data_columns):
1616+
pandas_func = pandas_udf(
1617+
func,
1618+
returnType=return_schema[output_column].dataType,
1619+
functionType=PandasUDFType.SCALAR)
1620+
applied.append(pandas_func(self[input_column]._scol).alias(output_column))
1621+
else:
1622+
wrapped = ks.pandas_wraps(func)
1623+
applied = []
1624+
for column in self._internal.data_columns:
1625+
applied.append(wrapped(self[column]).rename(column)._scol)
15891626

15901627
sdf = self._sdf.select(
1591-
self._internal.index_scols + [c._scol for c in applied])
1628+
self._internal.index_scols + [c for c in applied])
15921629
internal = self._internal.copy(sdf=sdf)
15931630

15941631
return DataFrame(internal)

databricks/koalas/tests/test_dataframe.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1554,3 +1554,14 @@ def test_pipe(self):
15541554
"arg is both the pipe target and a keyword argument",
15551555
lambda: kdf.pipe((lambda x: x, 'arg'), arg='1')
15561556
)
1557+
1558+
def test_transform(self):
1559+
# Data is intentionally big to test when schema inference is on.
1560+
pdf = pd.DataFrame({'a': [1, 2, 3, 4, 5, 6] * 300,
1561+
'b': [1., 1., 2., 3., 5., 8.] * 300,
1562+
'c': [1, 4, 9, 16, 25, 36] * 300}, columns=['a', 'b', 'c'])
1563+
kdf = ks.DataFrame(pdf)
1564+
self.assert_eq(kdf.transform(lambda x: x + 1).sort_index(),
1565+
pdf.transform(lambda x: x + 1).sort_index())
1566+
with self.assertRaisesRegex(AssertionError, "the first argument should be a callable"):
1567+
kdf.transform(1)

0 commit comments

Comments
 (0)