|
43 | 43 | IntegerType, LongType, NumericType, ShortType, StructType) |
44 | 44 | from pyspark.sql.utils import AnalysisException |
45 | 45 | from pyspark.sql.window import Window |
46 | | -from pyspark.sql.functions import pandas_udf |
| 46 | +from pyspark.sql.functions import pandas_udf, PandasUDFType |
47 | 47 |
|
48 | 48 | from databricks import koalas as ks # For running doctests and reference resolution in PyCharm. |
49 | 49 | from databricks.koalas.utils import validate_arguments_and_invoke_function, align_diff_frames |
@@ -1539,7 +1539,16 @@ def transform(self, func): |
1539 | 1539 | Call ``func`` on self producing a Series with transformed values |
1540 | 1540 | and that has the same length as its input. |
1541 | 1541 |
|
1542 | | - .. note:: unlike pandas, it is required for ``func`` to specify return type hint. |
| 1542 | + .. note:: this API executes the function once to infer the type which is |
| 1543 | + potentially expensive, for instance, when the dataset is created after |
| 1544 | + aggregations or sorting. |
| 1545 | +
|
| 1546 | + To avoid this, specify return type in ``func``, for instance, as below: |
| 1547 | +
|
| 1548 | + >>> def square(x) -> ks.Series[np.int32]: |
| 1549 | + ... return x ** 2 |
| 1550 | +
|
| 1551 | + Koalas uses return type hint and does not try to infer the type. |
1543 | 1552 |
|
1544 | 1553 | .. note:: the series within ``func`` is actually a pandas series, and |
1545 | 1554 | the length of each series is not guaranteed. |
@@ -1575,20 +1584,48 @@ def transform(self, func): |
1575 | 1584 | 0 0 1 |
1576 | 1585 | 1 1 4 |
1577 | 1586 | 2 4 9 |
| 1587 | +
|
| 1588 | + You can omit the type hint and let Koalas infer its type. |
| 1589 | +
|
| 1590 | + >>> df.transform(lambda x: x ** 2) |
| 1591 | + A B |
| 1592 | + 0 0 1 |
| 1593 | + 1 1 4 |
| 1594 | + 2 4 9 |
| 1595 | +
|
1578 | 1596 | """ |
1579 | 1597 | assert callable(func), "the first argument should be a callable function." |
1580 | 1598 | spec = inspect.getfullargspec(func) |
1581 | 1599 | return_sig = spec.annotations.get("return", None) |
1582 | | - if return_sig is None: |
1583 | | - raise ValueError("Given function must have return type hint; however, not found.") |
| 1600 | + should_infer_schema = return_sig is None |
1584 | 1601 |
|
1585 | | - wrapped = ks.pandas_wraps(func) |
1586 | | - applied = [] |
1587 | | - for column in self._internal.data_columns: |
1588 | | - applied.append(wrapped(self[column]).rename(column)) |
| 1602 | + if should_infer_schema: |
| 1603 | + # Here we execute with the first 1000 to get the return type. |
| 1604 | + # If the records were less than 1000, it uses pandas API directly for a shortcut. |
| 1605 | + limit = 1000 |
| 1606 | + pdf = self.head(limit + 1)._to_internal_pandas() |
| 1607 | + transformed = pdf.transform(func) |
| 1608 | + kdf = DataFrame(transformed) |
| 1609 | + return_schema = kdf._sdf.schema |
| 1610 | + if len(pdf) <= limit: |
| 1611 | + return kdf |
| 1612 | + |
| 1613 | + applied = [] |
| 1614 | + for input_column, output_column in zip( |
| 1615 | + self._internal.data_columns, kdf._internal.data_columns): |
| 1616 | + pandas_func = pandas_udf( |
| 1617 | + func, |
| 1618 | + returnType=return_schema[output_column].dataType, |
| 1619 | + functionType=PandasUDFType.SCALAR) |
| 1620 | + applied.append(pandas_func(self[input_column]._scol).alias(output_column)) |
| 1621 | + else: |
| 1622 | + wrapped = ks.pandas_wraps(func) |
| 1623 | + applied = [] |
| 1624 | + for column in self._internal.data_columns: |
| 1625 | + applied.append(wrapped(self[column]).rename(column)._scol) |
1589 | 1626 |
|
1590 | 1627 | sdf = self._sdf.select( |
1591 | | - self._internal.index_scols + [c._scol for c in applied]) |
| 1628 | + self._internal.index_scols + [c for c in applied]) |
1592 | 1629 | internal = self._internal.copy(sdf=sdf) |
1593 | 1630 |
|
1594 | 1631 | return DataFrame(internal) |
|
0 commit comments