Skip to content

Commit bb22748

Browse files
authored
Implemented intersection for Index & MultiIndex (#1747)
This PR proposes the new API `Index.intersection()` and `MultiIndex.intersection()`. ```python >>> idx1 = ks.Index([1, 2, 3, 4]) >>> idx2 = ks.Index([3, 4, 5, 6]) >>> idx1.intersection(idx2) Int64Index([3, 4], dtype='int64') ```
1 parent 62fb01c commit bb22748

File tree

4 files changed

+226
-2
lines changed

4 files changed

+226
-2
lines changed

databricks/koalas/indexes.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2077,6 +2077,59 @@ def holds_integer(self):
20772077
"""
20782078
return isinstance(self.spark.data_type, IntegralType)
20792079

2080+
def intersection(self, other) -> "Index":
2081+
"""
2082+
Form the intersection of two Index objects.
2083+
2084+
This returns a new Index with elements common to the index and `other`.
2085+
2086+
Parameters
2087+
----------
2088+
other : Index or array-like
2089+
2090+
Returns
2091+
-------
2092+
intersection : Index
2093+
2094+
Examples
2095+
--------
2096+
>>> idx1 = ks.Index([1, 2, 3, 4])
2097+
>>> idx2 = ks.Index([3, 4, 5, 6])
2098+
>>> idx1.intersection(idx2).sort_values()
2099+
Int64Index([3, 4], dtype='int64')
2100+
"""
2101+
keep_name = True
2102+
2103+
if isinstance(other, DataFrame):
2104+
raise ValueError("Index data must be 1-dimensional")
2105+
elif isinstance(other, MultiIndex):
2106+
# Always returns an empty MultiIndex if `other` is MultiIndex.
2107+
return other.to_frame().head(0).index
2108+
elif isinstance(other, Index):
2109+
spark_frame_other = other.to_frame().to_spark()
2110+
keep_name = self.name == other.name
2111+
elif isinstance(other, Series):
2112+
spark_frame_other = other.to_frame().to_spark()
2113+
keep_name = self.name == other.name
2114+
elif is_list_like(other):
2115+
other = Index(other)
2116+
if isinstance(other, MultiIndex):
2117+
return other.to_frame().head(0).index
2118+
spark_frame_other = other.to_frame().to_spark()
2119+
keep_name = False
2120+
else:
2121+
raise TypeError("Input must be Index or array-like")
2122+
2123+
spark_frame_self = self.to_frame(name=SPARK_DEFAULT_INDEX_NAME).to_spark()
2124+
spark_frame_intersected = spark_frame_self.intersect(spark_frame_other)
2125+
if keep_name:
2126+
index_map = self._internal.index_map
2127+
else:
2128+
index_map = OrderedDict([(SPARK_DEFAULT_INDEX_NAME, None)])
2129+
internal = InternalFrame(spark_frame=spark_frame_intersected, index_map=index_map)
2130+
2131+
return DataFrame(internal).index
2132+
20802133
def item(self):
20812134
"""
20822135
Return the first element of the underlying data as a python scalar.
@@ -3118,6 +3171,59 @@ def item(self):
31183171
"""
31193172
return self._kdf.head(2)._to_internal_pandas().index.item()
31203173

3174+
def intersection(self, other):
3175+
"""
3176+
Form the intersection of two Index objects.
3177+
3178+
This returns a new Index with elements common to the index and `other`.
3179+
3180+
Parameters
3181+
----------
3182+
other : Index or array-like
3183+
3184+
Returns
3185+
-------
3186+
intersection : Index
3187+
3188+
Examples
3189+
--------
3190+
>>> midx1 = ks.MultiIndex.from_tuples([("a", "x"), ("b", "y"), ("c", "z")])
3191+
>>> midx2 = ks.MultiIndex.from_tuples([("c", "z"), ("d", "w")])
3192+
>>> midx1.intersection(midx2).sort_values() # doctest: +SKIP
3193+
MultiIndex([('c', 'z')],
3194+
)
3195+
"""
3196+
keep_name = True
3197+
3198+
if isinstance(other, Series) or not is_list_like(other):
3199+
raise TypeError("other must be a MultiIndex or a list of tuples")
3200+
elif isinstance(other, DataFrame):
3201+
raise ValueError("Index data must be 1-dimensional")
3202+
elif isinstance(other, MultiIndex):
3203+
spark_frame_other = other.to_frame().to_spark()
3204+
keep_name = self.names == other.names
3205+
elif isinstance(other, Index):
3206+
# Always returns an empty MultiIndex if `other` is Index.
3207+
return self.to_frame().head(0).index
3208+
elif not all(isinstance(item, tuple) for item in other):
3209+
raise TypeError("other must be a MultiIndex or a list of tuples")
3210+
else:
3211+
other = MultiIndex.from_tuples(list(other))
3212+
spark_frame_other = other.to_frame().to_spark()
3213+
keep_name = True
3214+
3215+
default_name = [SPARK_INDEX_NAME_FORMAT(i) for i in range(self.nlevels)]
3216+
spark_frame_self = self.to_frame(name=default_name).to_spark()
3217+
spark_frame_intersected = spark_frame_self.intersect(spark_frame_other)
3218+
if keep_name:
3219+
index_map = self._internal.index_map
3220+
else:
3221+
index_map = OrderedDict(
3222+
[(SPARK_INDEX_NAME_FORMAT(i), None) for i in range(self.nlevels)]
3223+
)
3224+
internal = InternalFrame(spark_frame=spark_frame_intersected, index_map=index_map)
3225+
return DataFrame(internal).index
3226+
31213227
@property
31223228
def inferred_type(self):
31233229
"""

databricks/koalas/missing/indexes.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ class MissingPandasLikeIndex(object):
4949
get_slice_bound = _unsupported_function("get_slice_bound")
5050
get_value = _unsupported_function("get_value")
5151
groupby = _unsupported_function("groupby")
52-
intersection = _unsupported_function("intersection")
5352
is_ = _unsupported_function("is_")
5453
is_lexsorted_for_tuple = _unsupported_function("is_lexsorted_for_tuple")
5554
join = _unsupported_function("join")
@@ -116,7 +115,6 @@ class MissingPandasLikeMultiIndex(object):
116115
get_slice_bound = _unsupported_function("get_slice_bound")
117116
get_value = _unsupported_function("get_value")
118117
groupby = _unsupported_function("groupby")
119-
intersection = _unsupported_function("intersection")
120118
is_ = _unsupported_function("is_")
121119
is_lexsorted = _unsupported_function("is_lexsorted")
122120
is_lexsorted_for_tuple = _unsupported_function("is_lexsorted_for_tuple")

databricks/koalas/tests/test_indexes.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1447,6 +1447,124 @@ def test_hasnans(self):
14471447
kser = ks.from_pandas(pser)
14481448
self.assert_eq(pser.hasnans, kser.hasnans)
14491449

1450+
def test_intersection(self):
1451+
pidx = pd.Index([1, 2, 3, 4], name="Koalas")
1452+
kidx = ks.from_pandas(pidx)
1453+
1454+
# other = Index
1455+
pidx_other = pd.Index([3, 4, 5, 6], name="Koalas")
1456+
kidx_other = ks.from_pandas(pidx_other)
1457+
self.assert_eq(pidx.intersection(pidx_other), kidx.intersection(kidx_other).sort_values())
1458+
self.assert_eq(
1459+
(pidx + 1).intersection(pidx_other), (kidx + 1).intersection(kidx_other).sort_values()
1460+
)
1461+
1462+
pidx_other_different_name = pd.Index([3, 4, 5, 6], name="Databricks")
1463+
kidx_other_different_name = ks.from_pandas(pidx_other_different_name)
1464+
self.assert_eq(
1465+
pidx.intersection(pidx_other_different_name),
1466+
kidx.intersection(kidx_other_different_name).sort_values(),
1467+
)
1468+
self.assert_eq(
1469+
(pidx + 1).intersection(pidx_other_different_name),
1470+
(kidx + 1).intersection(kidx_other_different_name).sort_values(),
1471+
)
1472+
1473+
pidx_other_from_frame = pd.DataFrame({"a": [3, 4, 5, 6]}).set_index("a").index
1474+
kidx_other_from_frame = ks.from_pandas(pidx_other_from_frame)
1475+
self.assert_eq(
1476+
pidx.intersection(pidx_other_from_frame),
1477+
kidx.intersection(kidx_other_from_frame).sort_values(),
1478+
)
1479+
self.assert_eq(
1480+
(pidx + 1).intersection(pidx_other_from_frame),
1481+
(kidx + 1).intersection(kidx_other_from_frame).sort_values(),
1482+
)
1483+
1484+
# other = MultiIndex
1485+
pmidx = pd.MultiIndex.from_tuples([("a", "x"), ("b", "y"), ("c", "z")])
1486+
kmidx = ks.from_pandas(pmidx)
1487+
self.assert_eq(
1488+
pidx.intersection(pmidx), kidx.intersection(kmidx).sort_values(), almost=True
1489+
)
1490+
self.assert_eq(
1491+
(pidx + 1).intersection(pmidx),
1492+
(kidx + 1).intersection(kmidx).sort_values(),
1493+
almost=True,
1494+
)
1495+
1496+
# other = Series
1497+
pser = pd.Series([3, 4, 5, 6])
1498+
kser = ks.from_pandas(pser)
1499+
self.assert_eq(pidx.intersection(pser), kidx.intersection(kser).sort_values())
1500+
self.assert_eq((pidx + 1).intersection(pser), (kidx + 1).intersection(kser).sort_values())
1501+
1502+
pser_different_name = pd.Series([3, 4, 5, 6], name="Databricks")
1503+
kser_different_name = ks.from_pandas(pser_different_name)
1504+
self.assert_eq(
1505+
pidx.intersection(pser_different_name),
1506+
kidx.intersection(kser_different_name).sort_values(),
1507+
)
1508+
self.assert_eq(
1509+
(pidx + 1).intersection(pser_different_name),
1510+
(kidx + 1).intersection(kser_different_name).sort_values(),
1511+
)
1512+
1513+
# other = list
1514+
other = [3, 4, 5, 6]
1515+
self.assert_eq(pidx.intersection(other), kidx.intersection(other).sort_values())
1516+
self.assert_eq((pidx + 1).intersection(other), (kidx + 1).intersection(other).sort_values())
1517+
1518+
# other = tuple
1519+
other = (3, 4, 5, 6)
1520+
self.assert_eq(pidx.intersection(other), kidx.intersection(other).sort_values())
1521+
self.assert_eq((pidx + 1).intersection(other), (kidx + 1).intersection(other).sort_values())
1522+
1523+
# other = dict
1524+
other = {3: None, 4: None, 5: None, 6: None}
1525+
self.assert_eq(pidx.intersection(other), kidx.intersection(other).sort_values())
1526+
self.assert_eq((pidx + 1).intersection(other), (kidx + 1).intersection(other).sort_values())
1527+
1528+
# MultiIndex / other = Index
1529+
self.assert_eq(
1530+
pmidx.intersection(pidx), kmidx.intersection(kidx).sort_values(), almost=True
1531+
)
1532+
self.assert_eq(
1533+
pmidx.intersection(pidx_other_from_frame),
1534+
kmidx.intersection(kidx_other_from_frame).sort_values(),
1535+
almost=True,
1536+
)
1537+
1538+
# MultiIndex / other = MultiIndex
1539+
pmidx_other = pd.MultiIndex.from_tuples([("c", "z"), ("d", "w")])
1540+
kmidx_other = ks.from_pandas(pmidx_other)
1541+
self.assert_eq(
1542+
pmidx.intersection(pmidx_other), kmidx.intersection(kmidx_other).sort_values()
1543+
)
1544+
1545+
# MultiIndex / other = list
1546+
other = [("c", "z"), ("d", "w")]
1547+
self.assert_eq(pmidx.intersection(other), kmidx.intersection(other).sort_values())
1548+
1549+
# MultiIndex / other = tuple
1550+
other = (("c", "z"), ("d", "w"))
1551+
self.assert_eq(pmidx.intersection(other), kmidx.intersection(other).sort_values())
1552+
1553+
# MultiIndex / other = dict
1554+
other = {("c", "z"): None, ("d", "w"): None}
1555+
self.assert_eq(pmidx.intersection(other), kmidx.intersection(other).sort_values())
1556+
1557+
with self.assertRaisesRegex(TypeError, "Input must be Index or array-like"):
1558+
kidx.intersection(4)
1559+
with self.assertRaisesRegex(TypeError, "other must be a MultiIndex or a list of tuples"):
1560+
kmidx.intersection(4)
1561+
with self.assertRaisesRegex(TypeError, "other must be a MultiIndex or a list of tuples"):
1562+
kmidx.intersection(ks.Series([3, 4, 5, 6]))
1563+
with self.assertRaisesRegex(ValueError, "Index data must be 1-dimensional"):
1564+
kidx.intersection(ks.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}))
1565+
with self.assertRaisesRegex(ValueError, "Index data must be 1-dimensional"):
1566+
kmidx.intersection(ks.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}))
1567+
14501568
def test_item(self):
14511569
pidx = pd.Index([10])
14521570
kidx = ks.from_pandas(pidx)

docs/source/reference/indexing.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ Combining / joining / set operations
139139
:toctree: api/
140140

141141
Index.append
142+
Index.intersection
142143
Index.union
143144
Index.difference
144145
Index.symmetric_difference
@@ -236,6 +237,7 @@ MultiIndex Combining / joining / set operations
236237
:toctree: api/
237238

238239
MultiIndex.append
240+
MultiIndex.intersection
239241
MultiIndex.union
240242
MultiIndex.difference
241243
MultiIndex.symmetric_difference

0 commit comments

Comments
 (0)