Skip to content

Commit 04928d9

Browse files
committed
API: better warnings for df.set_index
1 parent 3745576 commit 04928d9

File tree

4 files changed

+674
-423
lines changed

4 files changed

+674
-423
lines changed

doc/source/whatsnew/v0.24.0.txt

+1
Original file line numberDiff line numberDiff line change
@@ -545,6 +545,7 @@ Other API Changes
545545
- :class:`pandas.io.formats.style.Styler` supports a ``number-format`` property when using :meth:`~pandas.io.formats.style.Styler.to_excel` (:issue:`22015`)
546546
- :meth:`DataFrame.corr` and :meth:`Series.corr` now raise a ``ValueError`` along with a helpful error message instead of a ``KeyError`` when supplied with an invalid method (:issue:`22298`)
547547
- :meth:`shift` will now always return a copy, instead of the previous behaviour of returning self when shifting by 0 (:issue:`22397`)
548+
- :meth:`DataFrame.set_index` now raises a ``TypeError`` for incorrect types, has an improved ``KeyError`` message, and will not fail on duplicate column names with ``drop=True``. (:issue:`22484`)
548549

549550
.. _whatsnew_0240.deprecations:
550551

pandas/core/frame.py

+31-16
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
is_sequence,
6161
is_named_tuple)
6262
from pandas.core.dtypes.concat import _get_sliced_frame_result_type
63+
from pandas.core.dtypes.generic import ABCSeries, ABCIndexClass, ABCMultiIndex
6364
from pandas.core.dtypes.missing import isna, notna
6465

6566

@@ -3892,6 +3893,22 @@ def set_index(self, keys, drop=True, append=False, inplace=False,
38923893
if not isinstance(keys, list):
38933894
keys = [keys]
38943895

3896+
missing = []
3897+
for x in keys:
3898+
if not (is_scalar(x) or isinstance(x, tuple)):
3899+
if not isinstance(x, (ABCSeries, ABCIndexClass, ABCMultiIndex,
3900+
list, np.ndarray)):
3901+
raise TypeError('keys may only contain a combination of '
3902+
'the following: valid column keys, '
3903+
'Series, Index, MultiIndex, list or '
3904+
'np.ndarray')
3905+
else:
3906+
if x not in self:
3907+
missing.append(x)
3908+
3909+
if missing:
3910+
raise KeyError('{}'.format(missing))
3911+
38953912
if inplace:
38963913
frame = self
38973914
else:
@@ -3901,37 +3918,34 @@ def set_index(self, keys, drop=True, append=False, inplace=False,
39013918
names = []
39023919
if append:
39033920
names = [x for x in self.index.names]
3904-
if isinstance(self.index, MultiIndex):
3921+
if isinstance(self.index, ABCMultiIndex):
39053922
for i in range(self.index.nlevels):
39063923
arrays.append(self.index._get_level_values(i))
39073924
else:
39083925
arrays.append(self.index)
39093926

39103927
to_remove = []
39113928
for col in keys:
3912-
if isinstance(col, MultiIndex):
3913-
# append all but the last column so we don't have to modify
3914-
# the end of this loop
3915-
for n in range(col.nlevels - 1):
3929+
if isinstance(col, ABCMultiIndex):
3930+
for n in range(col.nlevels):
39163931
arrays.append(col._get_level_values(n))
3917-
3918-
level = col._get_level_values(col.nlevels - 1)
39193932
names.extend(col.names)
3920-
elif isinstance(col, Series):
3921-
level = col._values
3933+
elif isinstance(col, ABCIndexClass):
3934+
# Index but not MultiIndex (treated above)
3935+
arrays.append(col)
39223936
names.append(col.name)
3923-
elif isinstance(col, Index):
3924-
level = col
3937+
elif isinstance(col, ABCSeries):
3938+
arrays.append(col._values)
39253939
names.append(col.name)
3926-
elif isinstance(col, (list, np.ndarray, Index)):
3927-
level = col
3940+
elif isinstance(col, (list, np.ndarray)):
3941+
arrays.append(col)
39283942
names.append(None)
3943+
# from here, col can only be a column label
39293944
else:
3930-
level = frame[col]._values
3945+
arrays.append(frame[col]._values)
39313946
names.append(col)
39323947
if drop:
39333948
to_remove.append(col)
3934-
arrays.append(level)
39353949

39363950
index = ensure_index_from_sequences(arrays, names)
39373951

@@ -3940,7 +3954,8 @@ def set_index(self, keys, drop=True, append=False, inplace=False,
39403954
raise ValueError('Index has duplicate keys: {dup}'.format(
39413955
dup=duplicates))
39423956

3943-
for c in to_remove:
3957+
# use set to handle duplicate column names gracefully in case of drop
3958+
for c in set(to_remove):
39443959
del frame[c]
39453960

39463961
# clear up memory usage

pandas/tests/frame/conftest.py

+191
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
import pytest
2+
3+
import numpy as np
4+
5+
from pandas import compat
6+
import pandas.util.testing as tm
7+
from pandas import DataFrame, date_range, NaT
8+
9+
10+
@pytest.fixture
11+
def float_frame():
12+
"""
13+
Fixture for DataFrame of floats with index of unique strings
14+
15+
Columns are ['A', 'B', 'C', 'D'].
16+
"""
17+
return DataFrame(tm.getSeriesData())
18+
19+
20+
@pytest.fixture
21+
def float_frame2():
22+
"""
23+
Fixture for DataFrame of floats with index of unique strings
24+
25+
Columns are ['D', 'C', 'B', 'A']
26+
"""
27+
return DataFrame(tm.getSeriesData(), columns=['D', 'C', 'B', 'A'])
28+
29+
30+
@pytest.fixture
31+
def int_frame():
32+
"""
33+
Fixture for DataFrame of ints with index of unique strings
34+
35+
Columns are ['A', 'B', 'C', 'D']
36+
"""
37+
df = DataFrame({k: v.astype(int)
38+
for k, v in compat.iteritems(tm.getSeriesData())})
39+
# force these all to int64 to avoid platform testing issues
40+
return DataFrame({c: s for c, s in compat.iteritems(df)}, dtype=np.int64)
41+
42+
43+
@pytest.fixture
44+
def datetime_frame():
45+
"""
46+
Fixture for DataFrame of floats with DatetimeIndex
47+
48+
Columns are ['A', 'B', 'C', 'D']
49+
"""
50+
return DataFrame(tm.getTimeSeriesData())
51+
52+
53+
@pytest.fixture
54+
def float_string_frame():
55+
"""
56+
Fixture for DataFrame of floats and strings with index of unique strings
57+
58+
Columns are ['A', 'B', 'C', 'D', 'foo'].
59+
"""
60+
df = DataFrame(tm.getSeriesData())
61+
df['foo'] = 'bar'
62+
return df
63+
64+
65+
@pytest.fixture
66+
def mixed_float_frame():
67+
"""
68+
Fixture for DataFrame of different float types with index of unique strings
69+
70+
Columns are ['A', 'B', 'C', 'D'].
71+
"""
72+
df = DataFrame(tm.getSeriesData())
73+
df.A = df.A.astype('float16')
74+
df.B = df.B.astype('float32')
75+
df.C = df.C.astype('float64')
76+
return df
77+
78+
79+
@pytest.fixture
80+
def mixed_float_frame2():
81+
"""
82+
Fixture for DataFrame of different float types with index of unique strings
83+
84+
Columns are ['A', 'B', 'C', 'D'].
85+
"""
86+
df = DataFrame(tm.getSeriesData())
87+
df.D = df.D.astype('float16')
88+
df.C = df.C.astype('float32')
89+
df.B = df.B.astype('float64')
90+
return df
91+
92+
93+
@pytest.fixture
94+
def mixed_int_frame():
95+
"""
96+
Fixture for DataFrame of different int types with index of unique strings
97+
98+
Columns are ['A', 'B', 'C', 'D'].
99+
"""
100+
df = DataFrame({k: v.astype(int)
101+
for k, v in compat.iteritems(tm.getSeriesData())})
102+
df.A = df.A.astype('uint8')
103+
df.B = df.B.astype('int32')
104+
df.C = df.C.astype('int64')
105+
df.D = np.ones(len(df.D), dtype='uint64')
106+
return df
107+
108+
109+
@pytest.fixture
110+
def mixed_type_frame():
111+
"""
112+
Fixture for DataFrame of float/int/string columns with RangeIndex
113+
114+
Columns are ['a', 'b', 'c', 'float32', 'int32'].
115+
"""
116+
return DataFrame({'a': 1., 'b': 2, 'c': 'foo',
117+
'float32': np.array([1.] * 10, dtype='float32'),
118+
'int32': np.array([1] * 10, dtype='int32')},
119+
index=np.arange(10))
120+
121+
122+
@pytest.fixture
123+
def timezone_frame():
124+
"""
125+
Fixture for DataFrame of date_range Series with different time zones
126+
127+
Columns are ['A', 'B', 'C']; some entries are missing
128+
"""
129+
df = DataFrame({'A': date_range('20130101', periods=3),
130+
'B': date_range('20130101', periods=3,
131+
tz='US/Eastern'),
132+
'C': date_range('20130101', periods=3,
133+
tz='CET')})
134+
df.iloc[1, 1] = NaT
135+
df.iloc[1, 2] = NaT
136+
return df
137+
138+
139+
@pytest.fixture
140+
def empty_frame():
141+
"""
142+
Fixture for empty DataFrame
143+
"""
144+
return DataFrame({})
145+
146+
147+
@pytest.fixture
148+
def datetime_series():
149+
"""
150+
Fixture for Series of floats with DatetimeIndex
151+
"""
152+
return tm.makeTimeSeries(nper=30)
153+
154+
155+
@pytest.fixture
156+
def datetime_series_short():
157+
"""
158+
Fixture for Series of floats with DatetimeIndex
159+
"""
160+
return tm.makeTimeSeries(nper=30)[5:]
161+
162+
163+
@pytest.fixture
164+
def simple_frame():
165+
"""
166+
Fixture for simple 3x3 DataFrame
167+
168+
Columns are ['one', 'two', 'three'], index is ['a', 'b', 'c'].
169+
"""
170+
arr = np.array([[1., 2., 3.],
171+
[4., 5., 6.],
172+
[7., 8., 9.]])
173+
174+
return DataFrame(arr, columns=['one', 'two', 'three'],
175+
index=['a', 'b', 'c'])
176+
177+
178+
@pytest.fixture
179+
def frame_of_index_cols():
180+
"""
181+
Fixture for DataFrame of columns that can be used for indexing
182+
183+
Columns are ['A', 'B', 'C', 'D', 'E']; 'A' & 'B' contain duplicates (but
184+
are jointly unique), the rest are unique.
185+
"""
186+
df = DataFrame({'A': ['foo', 'foo', 'foo', 'bar', 'bar'],
187+
'B': ['one', 'two', 'three', 'one', 'two'],
188+
'C': ['a', 'b', 'c', 'd', 'e'],
189+
'D': np.random.randn(5),
190+
'E': np.random.randn(5)})
191+
return df

0 commit comments

Comments
 (0)