Skip to content

Commit 97a7718

Browse files
committed
Add column axis in ks.concat
1 parent da3740d commit 97a7718

File tree

3 files changed

+218
-43
lines changed

3 files changed

+218
-43
lines changed

databricks/koalas/frame.py

Lines changed: 51 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -9655,34 +9655,8 @@ def __setitem__(self, key, value):
96559655
isinstance(value, DataFrame) and value is not self
96569656
):
96579657
# Different Series or DataFrames
9658-
if isinstance(value, Series):
9659-
value = value.to_frame()
9660-
else:
9661-
assert isinstance(value, DataFrame), type(value)
9662-
value = value.copy()
9663-
level = self._internal.column_labels_level
9664-
9665-
value.columns = pd.MultiIndex.from_tuples(
9666-
[
9667-
tuple([name_like_string(label)] + ([""] * (level - 1)))
9668-
for label in value._internal.column_labels
9669-
]
9670-
)
9671-
9672-
if isinstance(key, str):
9673-
key = [(key,)]
9674-
elif isinstance(key, tuple):
9675-
key = [key]
9676-
else:
9677-
key = [k if isinstance(k, tuple) else (k,) for k in key]
9678-
9679-
if any(len(label) > level for label in key):
9680-
raise KeyError(
9681-
"Key length ({}) exceeds index depth ({})".format(
9682-
max(len(label) for label in key), level
9683-
)
9684-
)
9685-
key = [tuple(list(label) + ([""] * (level - len(label)))) for label in key]
9658+
key = self._index_normalized_label(key)
9659+
value = self._index_normalized_frame(value)
96869660

96879661
def assign_columns(kdf, this_column_labels, that_column_labels):
96889662
assert len(key) == len(that_column_labels)
@@ -9707,6 +9681,55 @@ def assign_columns(kdf, this_column_labels, that_column_labels):
97079681

97089682
self._internal = kdf._internal
97099683

9684+
def _index_normalized_label(self, labels):
9685+
"""
9686+
Returns a label that is normalized against the current column index level.
9687+
For example, the key "abc" can be ("abc", "", "") if the current Frame has
9688+
a multi-index for its column
9689+
"""
9690+
level = self._internal.column_labels_level
9691+
9692+
if isinstance(labels, str):
9693+
labels = [(labels,)]
9694+
elif isinstance(labels, tuple):
9695+
labels = [labels]
9696+
else:
9697+
labels = [k if isinstance(k, tuple) else (k,) for k in labels]
9698+
9699+
if any(len(label) > level for label in labels):
9700+
raise KeyError(
9701+
"Key length ({}) exceeds index depth ({})".format(
9702+
max(len(label) for label in labels), level
9703+
)
9704+
)
9705+
return [tuple(list(label) + ([""] * (level - len(label)))) for label in labels]
9706+
9707+
def _index_normalized_frame(self, kser_or_kdf):
9708+
"""
9709+
Returns a frame that is normalized against the current column index level.
9710+
For example, the name in `pd.Series([...], name="abc")` can be can be
9711+
("abc", "", "") if the current DataFrame has a multi-index for its column
9712+
"""
9713+
9714+
from databricks.koalas.series import Series
9715+
9716+
level = self._internal.column_labels_level
9717+
if isinstance(kser_or_kdf, Series):
9718+
kdf = kser_or_kdf.to_frame()
9719+
else:
9720+
assert isinstance(kser_or_kdf, DataFrame), type(kser_or_kdf)
9721+
kdf = kser_or_kdf.copy()
9722+
9723+
if level != kdf._internal.column_labels_level:
9724+
kdf.columns = pd.MultiIndex.from_tuples(
9725+
[
9726+
tuple([name_like_string(label)] + ([""] * (level - 1)))
9727+
for label in kdf._internal.column_labels
9728+
]
9729+
)
9730+
9731+
return kdf
9732+
97109733
def __getattr__(self, key: str) -> Any:
97119734
if key.startswith("__"):
97129735
raise AttributeError(key)

databricks/koalas/namespace.py

Lines changed: 103 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,13 @@
4545

4646
from databricks import koalas as ks # For running doctests and reference resolution in PyCharm.
4747
from databricks.koalas.base import IndexOpsMixin
48-
from databricks.koalas.utils import default_session, name_like_string, scol_for, validate_axis
48+
from databricks.koalas.utils import (
49+
default_session,
50+
name_like_string,
51+
scol_for,
52+
validate_axis,
53+
align_diff_frames,
54+
)
4955
from databricks.koalas.frame import DataFrame, _reduce_spark_multi
5056
from databricks.koalas.internal import _InternalFrame
5157
from databricks.koalas.typedef import pandas_wraps
@@ -107,6 +113,9 @@ def from_pandas(pobj: Union["pd.DataFrame", "pd.Series"]) -> Union["Series", "Da
107113
raise ValueError("Unknown data type: {}".format(type(pobj)))
108114

109115

116+
_range = range # built-in range
117+
118+
110119
def range(
111120
start: int, end: Optional[int] = None, step: int = 1, num_partitions: Optional[int] = None
112121
) -> DataFrame:
@@ -1539,11 +1548,11 @@ def concat(objs, axis=0, join="outer", ignore_index=False):
15391548
objs : a sequence of Series or DataFrame
15401549
Any None objects will be dropped silently unless
15411550
they are all None in which case a ValueError will be raised
1542-
axis : {0/'index'}, default 0
1551+
axis : {0/'index', 1/'columns'}, default 0
15431552
The axis to concatenate along.
15441553
join : {'inner', 'outer'}, default 'outer'
1545-
How to handle indexes on other axis(es)
1546-
ignore_index : boolean, default False
1554+
How to handle indexes on other axis (or axes).
1555+
ignore_index : bool, default False
15471556
If True, do not use the index values along the concatenation axis. The
15481557
resulting axis will be labeled 0, ..., n - 1. This is useful if you are
15491558
concatenating objects where the concatenation axis does not have
@@ -1552,14 +1561,17 @@ def concat(objs, axis=0, join="outer", ignore_index=False):
15521561
15531562
Returns
15541563
-------
1555-
concatenated : object, type of objs
1564+
object, type of objs
15561565
When concatenating all ``Series`` along the index (axis=0), a
15571566
``Series`` is returned. When ``objs`` contains at least one
1558-
``DataFrame``, a ``DataFrame`` is returned.
1567+
``DataFrame``, a ``DataFrame`` is returned. When concatenating along
1568+
the columns (axis=1), a ``DataFrame`` is returned.
15591569
15601570
See Also
15611571
--------
1562-
DataFrame.merge
1572+
Series.append : Concatenate Series.
1573+
DataFrame.join : Join DataFrames using indexes.
1574+
DataFrame.merge : Merge DataFrames by indexes or columns.
15631575
15641576
Examples
15651577
--------
@@ -1645,6 +1657,17 @@ def concat(objs, axis=0, join="outer", ignore_index=False):
16451657
1 b 2
16461658
0 c 3
16471659
1 d 4
1660+
1661+
>>> df4 = ks.DataFrame([['bird', 'polly'], ['monkey', 'george']],
1662+
... columns=['animal', 'name'])
1663+
1664+
Combine with column axis.
1665+
1666+
>>> ks.concat([df1, df4], axis=1)
1667+
letter number animal name
1668+
0 a 1 bird polly
1669+
1 b 2 monkey george
1670+
16481671
"""
16491672
if isinstance(objs, (DataFrame, IndexOpsMixin)) or not isinstance(
16501673
objs, Iterable
@@ -1655,10 +1678,6 @@ def concat(objs, axis=0, join="outer", ignore_index=False):
16551678
'"{name}"'.format(name=type(objs).__name__)
16561679
)
16571680

1658-
axis = validate_axis(axis)
1659-
if axis != 0:
1660-
raise NotImplementedError('axis should be either 0 or "index" currently.')
1661-
16621681
if len(objs) == 0:
16631682
raise ValueError("No objects to concatenate")
16641683
objs = list(filter(lambda obj: obj is not None, objs))
@@ -1674,6 +1693,79 @@ def concat(objs, axis=0, join="outer", ignore_index=False):
16741693
"and ks.DataFrame are valid".format(name=type(objs).__name__)
16751694
)
16761695

1696+
axis = validate_axis(axis)
1697+
if axis == 1:
1698+
if isinstance(objs[0], ks.Series):
1699+
concat_kdf = objs[0].to_frame()
1700+
else:
1701+
concat_kdf = objs[0]
1702+
1703+
with ks.option_context("compute.ops_on_diff_frames", True):
1704+
1705+
def assign_columns(kdf, this_column_labels, that_column_labels):
1706+
# Note that here intentionally uses `zip_longest` that combine
1707+
# all columns.
1708+
for this_label, that_label in itertools.zip_longest(
1709+
this_column_labels, that_column_labels
1710+
):
1711+
yield (kdf._kser_for(this_label), this_label)
1712+
yield (kdf._kser_for(that_label), that_label)
1713+
1714+
for kser_or_kdf in objs[1:]:
1715+
if isinstance(kser_or_kdf, Series):
1716+
# TODO: there is a corner case to optimize - when the series are from
1717+
# the same DataFrame.
1718+
kser = kser_or_kdf
1719+
# Series in different frames.
1720+
if join == "inner":
1721+
concat_kdf = align_diff_frames(
1722+
assign_columns,
1723+
concat_kdf,
1724+
concat_kdf._index_normalized_frame(kser),
1725+
fillna=False,
1726+
how="inner",
1727+
)
1728+
elif join == "outer":
1729+
concat_kdf = align_diff_frames(
1730+
assign_columns,
1731+
concat_kdf,
1732+
concat_kdf._index_normalized_frame(kser),
1733+
fillna=False,
1734+
how="full",
1735+
)
1736+
else:
1737+
raise ValueError(
1738+
"Only can inner (intersect) or outer (union) join the other axis."
1739+
)
1740+
else:
1741+
kdf = kser_or_kdf
1742+
1743+
if join == "inner":
1744+
concat_kdf = align_diff_frames(
1745+
assign_columns,
1746+
concat_kdf,
1747+
concat_kdf._index_normalized_frame(kdf),
1748+
fillna=False,
1749+
how="inner",
1750+
)
1751+
elif join == "outer":
1752+
concat_kdf = align_diff_frames(
1753+
assign_columns,
1754+
concat_kdf,
1755+
concat_kdf._index_normalized_frame(kdf),
1756+
fillna=False,
1757+
how="full",
1758+
)
1759+
else:
1760+
raise ValueError(
1761+
"Only can inner (intersect) or outer (union) join the other axis."
1762+
)
1763+
1764+
if ignore_index:
1765+
concat_kdf.columns = list(map(str, _range(len(concat_kdf.columns))))
1766+
1767+
return concat_kdf
1768+
16771769
# Series, Series ...
16781770
# We should return Series if objects are all Series.
16791771
should_return_series = all(map(lambda obj: isinstance(obj, Series), objs))

databricks/koalas/tests/test_namespace.py

Lines changed: 64 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
#
16+
import itertools
1617

1718
import pandas as pd
1819

@@ -92,10 +93,6 @@ def test_concat(self):
9293

9394
self.assertRaisesRegex(ValueError, "All objects passed", lambda: ks.concat([None, None]))
9495

95-
self.assertRaisesRegex(
96-
NotImplementedError, "axis should be either 0 or", lambda: ks.concat([kdf, kdf], axis=1)
97-
)
98-
9996
pdf3 = pdf.copy()
10097
kdf3 = kdf.copy()
10198

@@ -128,3 +125,66 @@ def test_concat(self):
128125
r"Only can inner \(intersect\) or outer \(union\) join the other axis.",
129126
lambda: ks.concat([kdf, kdf4], join=""),
130127
)
128+
129+
self.assertRaisesRegex(
130+
ValueError,
131+
r"Only can inner \(intersect\) or outer \(union\) join the other axis.",
132+
lambda: ks.concat([kdf, kdf4], join="", axis=1),
133+
)
134+
135+
self.assertRaisesRegex(
136+
ValueError,
137+
r"Only can inner \(intersect\) or outer \(union\) join the other axis.",
138+
lambda: ks.concat([kdf.A, kdf4.B], join="", axis=1),
139+
)
140+
141+
def test_concat_column_axis(self):
142+
pdf1 = pd.DataFrame({"A": [0, 2, 4], "B": [1, 3, 5]}, index=[1, 2, 3])
143+
pdf2 = pd.DataFrame({"C": [1, 2, 3], "D": [4, 5, 6]}, index=[1, 3, 5])
144+
kdf1 = ks.from_pandas(pdf1)
145+
kdf2 = ks.from_pandas(pdf2)
146+
147+
kdf3 = kdf1.copy()
148+
kdf4 = kdf2.copy()
149+
pdf3 = pdf1.copy()
150+
pdf4 = pdf2.copy()
151+
152+
columns = pd.MultiIndex.from_tuples([("X", "A"), ("X", "B")])
153+
pdf3.columns = columns
154+
kdf3.columns = columns
155+
156+
columns = pd.MultiIndex.from_tuples([("X", "B"), ("X", "C")])
157+
pdf4.columns = columns
158+
kdf4.columns = columns
159+
160+
ignore_indexes = [False]
161+
joins = ["inner", "outer"]
162+
163+
objs = [
164+
([kdf1.A, kdf2.C], [pdf1.A, pdf2.C]),
165+
([kdf1, kdf2.C], [pdf1, pdf2.C]),
166+
([kdf1.A, kdf2], [pdf1.A, pdf2]),
167+
([kdf1.A, kdf2.C], [pdf1.A, pdf2.C]),
168+
([kdf1.A, kdf1.A.rename("B")], [pdf1.A, pdf1.A.rename("B")]),
169+
([kdf3[("X", "A")], kdf4[("X", "B")]], [pdf3[("X", "A")], pdf4[("X", "B")]]),
170+
([kdf3, kdf4[("X", "B")]], [pdf3, pdf4[("X", "B")]]),
171+
([kdf3[("X", "A")], kdf4], [pdf3[("X", "A")], pdf4]),
172+
([kdf3, kdf4], [pdf3, pdf4]),
173+
(
174+
[kdf3[("X", "A")], kdf3[("X", "B")].rename("B")],
175+
[pdf3[("X", "A")], pdf3[("X", "B")].rename("B")],
176+
),
177+
]
178+
179+
for ignore_index, join in itertools.product(ignore_indexes, joins):
180+
for obj in objs:
181+
kdfs, pdfs = obj
182+
with self.subTest(ignore_index=ignore_index, join=join, objs=obj):
183+
actual = ks.concat(kdfs, axis=1, ignore_index=ignore_index, join=join)
184+
expected = pd.concat(pdfs, axis=1, ignore_index=ignore_index, join=join)
185+
if ignore_index:
186+
expected.columns = list(map(str, actual.columns))
187+
self.assert_eq(
188+
actual.sort_values(list(actual.columns)).reset_index(drop=True),
189+
expected.sort_values(list(expected.columns)).reset_index(drop=True),
190+
)

0 commit comments

Comments
 (0)