Skip to content

Commit d4ec2d9

Browse files
committed
add more tests
1 parent 7aa0ccf commit d4ec2d9

File tree

10 files changed

+164
-75
lines changed

10 files changed

+164
-75
lines changed

pandas/core/indexes/interval.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -1036,7 +1036,7 @@ def overlaps(self, other):
10361036
return self._data.overlaps(other)
10371037

10381038
def _setop(op_name):
1039-
def func(self, other):
1039+
def func(self, other, sort=True):
10401040
other = self._as_like_interval_index(other)
10411041

10421042
# GH 19016: ensure set op will not return a prohibited dtype
@@ -1047,7 +1047,11 @@ def func(self, other):
10471047
'objects that have compatible dtypes')
10481048
raise TypeError(msg.format(op=op_name))
10491049

1050-
result = getattr(self._multiindex, op_name)(other._multiindex)
1050+
if op_name == 'difference':
1051+
result = getattr(self._multiindex, op_name)(other._multiindex,
1052+
sort)
1053+
else:
1054+
result = getattr(self._multiindex, op_name)(other._multiindex)
10511055
result_name = self.name if self.name == other.name else None
10521056

10531057
# GH 19101: ensure empty results have correct dtype

pandas/core/indexes/multi.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -2794,8 +2794,14 @@ def difference(self, other, sort=True):
27942794
labels=[[]] * self.nlevels,
27952795
names=result_names, verify_integrity=False)
27962796

2797-
difference = set(self._ndarray_values) - set(other._ndarray_values)
2797+
this = self._get_unique_index()
27982798

2799+
indexer = this.get_indexer(other)
2800+
indexer = indexer.take((indexer != -1).nonzero()[0])
2801+
2802+
label_diff = np.setdiff1d(np.arange(this.size), indexer,
2803+
assume_unique=True)
2804+
difference = this.values.take(label_diff)
27992805
if sort:
28002806
difference = sorted(difference)
28012807

pandas/tests/indexes/common.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -666,12 +666,13 @@ def test_union_base(self):
666666
with tm.assert_raises_regex(TypeError, msg):
667667
result = first.union([1, 2, 3])
668668

669-
def test_difference_base(self):
669+
@pytest.mark.parametrize("sort", [True, False])
670+
def test_difference_base(self, sort):
670671
for name, idx in compat.iteritems(self.indices):
671672
first = idx[2:]
672673
second = idx[:4]
673674
answer = idx[4:]
674-
result = first.difference(second)
675+
result = first.difference(second, sort)
675676

676677
if isinstance(idx, CategoricalIndex):
677678
pass
@@ -685,21 +686,21 @@ def test_difference_base(self):
685686
if isinstance(idx, PeriodIndex):
686687
msg = "can only call with other PeriodIndex-ed objects"
687688
with tm.assert_raises_regex(ValueError, msg):
688-
result = first.difference(case)
689+
result = first.difference(case, sort)
689690
elif isinstance(idx, CategoricalIndex):
690691
pass
691692
elif isinstance(idx, (DatetimeIndex, TimedeltaIndex)):
692693
assert result.__class__ == answer.__class__
693694
tm.assert_numpy_array_equal(result.sort_values().asi8,
694695
answer.sort_values().asi8)
695696
else:
696-
result = first.difference(case)
697+
result = first.difference(case, sort)
697698
assert tm.equalContents(result, answer)
698699

699700
if isinstance(idx, MultiIndex):
700701
msg = "other must be a MultiIndex or a list of tuples"
701702
with tm.assert_raises_regex(TypeError, msg):
702-
result = first.difference([1, 2, 3])
703+
result = first.difference([1, 2, 3], sort)
703704

704705
def test_symmetric_difference(self):
705706
for name, idx in compat.iteritems(self.indices):

pandas/tests/indexes/datetimes/test_setops.py

+21-13
Original file line numberDiff line numberDiff line change
@@ -208,47 +208,55 @@ def test_intersection_bug_1708(self):
208208
assert len(result) == 0
209209

210210
@pytest.mark.parametrize("tz", tz)
211-
def test_difference(self, tz):
212-
rng1 = pd.date_range('1/1/2000', freq='D', periods=5, tz=tz)
211+
@pytest.mark.parametrize("sort", [True, False])
212+
def test_difference(self, tz, sort):
213+
rng_dates = ['1/2/2000', '1/3/2000', '1/1/2000', '1/4/2000',
214+
'1/5/2000']
215+
216+
rng1 = pd.DatetimeIndex(rng_dates, tz=tz)
213217
other1 = pd.date_range('1/6/2000', freq='D', periods=5, tz=tz)
214-
expected1 = pd.date_range('1/1/2000', freq='D', periods=5, tz=tz)
218+
expected1 = pd.DatetimeIndex(rng_dates, tz=tz)
215219

216-
rng2 = pd.date_range('1/1/2000', freq='D', periods=5, tz=tz)
220+
rng2 = pd.DatetimeIndex(rng_dates, tz=tz)
217221
other2 = pd.date_range('1/4/2000', freq='D', periods=5, tz=tz)
218-
expected2 = pd.date_range('1/1/2000', freq='D', periods=3, tz=tz)
222+
expected2 = pd.DatetimeIndex(rng_dates[:3], tz=tz)
219223

220-
rng3 = pd.date_range('1/1/2000', freq='D', periods=5, tz=tz)
224+
rng3 = pd.DatetimeIndex(rng_dates, tz=tz)
221225
other3 = pd.DatetimeIndex([], tz=tz)
222-
expected3 = pd.date_range('1/1/2000', freq='D', periods=5, tz=tz)
226+
expected3 = pd.DatetimeIndex(rng_dates, tz=tz)
223227

224228
for rng, other, expected in [(rng1, other1, expected1),
225229
(rng2, other2, expected2),
226230
(rng3, other3, expected3)]:
227-
result_diff = rng.difference(other)
231+
result_diff = rng.difference(other, sort)
232+
if sort:
233+
expected = expected.sort_values()
228234
tm.assert_index_equal(result_diff, expected)
229235

230-
def test_difference_freq(self):
236+
@pytest.mark.parametrize("sort", [True, False])
237+
def test_difference_freq(self, sort):
231238
# GH14323: difference of DatetimeIndex should not preserve frequency
232239

233240
index = date_range("20160920", "20160925", freq="D")
234241
other = date_range("20160921", "20160924", freq="D")
235242
expected = DatetimeIndex(["20160920", "20160925"], freq=None)
236-
idx_diff = index.difference(other)
243+
idx_diff = index.difference(other, sort)
237244
tm.assert_index_equal(idx_diff, expected)
238245
tm.assert_attr_equal('freq', idx_diff, expected)
239246

240247
other = date_range("20160922", "20160925", freq="D")
241-
idx_diff = index.difference(other)
248+
idx_diff = index.difference(other, sort)
242249
expected = DatetimeIndex(["20160920", "20160921"], freq=None)
243250
tm.assert_index_equal(idx_diff, expected)
244251
tm.assert_attr_equal('freq', idx_diff, expected)
245252

246-
def test_datetimeindex_diff(self):
253+
@pytest.mark.parametrize("sort", [True, False])
254+
def test_datetimeindex_diff(self, sort):
247255
dti1 = DatetimeIndex(freq='Q-JAN', start=datetime(1997, 12, 31),
248256
periods=100)
249257
dti2 = DatetimeIndex(freq='Q-JAN', start=datetime(1997, 12, 31),
250258
periods=98)
251-
assert len(dti1.difference(dti2)) == 2
259+
assert len(dti1.difference(dti2, sort)) == 2
252260

253261
def test_datetimeindex_union_join_empty(self):
254262
dti = DatetimeIndex(start='1/1/2001', end='2/1/2001', freq='D')

pandas/tests/indexes/interval/test_interval.py

+12-5
Original file line numberDiff line numberDiff line change
@@ -801,19 +801,26 @@ def test_intersection(self, closed):
801801
result = index.intersection(other)
802802
tm.assert_index_equal(result, expected)
803803

804-
def test_difference(self, closed):
805-
index = self.create_index(closed=closed)
806-
tm.assert_index_equal(index.difference(index[:1]), index[1:])
804+
@pytest.mark.parametrize("sort", [True, False])
805+
def test_difference(self, closed, sort):
806+
index = IntervalIndex.from_arrays([1, 0, 3, 2],
807+
[1, 2, 3, 4],
808+
closed=closed)
809+
result = index.difference(index[:1], sort)
810+
expected = index[1:]
811+
if sort:
812+
expected = expected.sort_values()
813+
tm.assert_index_equal(result, expected)
807814

808815
# GH 19101: empty result, same dtype
809-
result = index.difference(index)
816+
result = index.difference(index, sort)
810817
expected = IntervalIndex(np.array([], dtype='int64'), closed=closed)
811818
tm.assert_index_equal(result, expected)
812819

813820
# GH 19101: empty result, different dtypes
814821
other = IntervalIndex.from_arrays(index.left.astype('float64'),
815822
index.right, closed=closed)
816-
result = index.difference(other)
823+
result = index.difference(other, sort)
817824
tm.assert_index_equal(result, expected)
818825

819826
def test_symmetric_difference(self, closed):

pandas/tests/indexes/multi/test_set_ops.py

+23-15
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import pandas as pd
66
import pandas.util.testing as tm
77
from pandas import MultiIndex, Series
8+
import pytest
89

910

1011
def test_setops_errorcases(idx):
@@ -59,24 +60,25 @@ def test_union_base(idx):
5960
result = first.union([1, 2, 3])
6061

6162

62-
def test_difference_base(idx):
63+
@pytest.mark.parametrize("sort", [True, False])
64+
def test_difference_base(idx, sort):
6365
first = idx[2:]
6466
second = idx[:4]
6567
answer = idx[4:]
66-
result = first.difference(second)
68+
result = first.difference(second, sort)
6769

6870
assert tm.equalContents(result, answer)
6971

7072
# GH 10149
7173
cases = [klass(second.values)
7274
for klass in [np.array, Series, list]]
7375
for case in cases:
74-
result = first.difference(case)
76+
result = first.difference(case, sort)
7577
assert tm.equalContents(result, answer)
7678

7779
msg = "other must be a MultiIndex or a list of tuples"
7880
with tm.assert_raises_regex(TypeError, msg):
79-
result = first.difference([1, 2, 3])
81+
result = first.difference([1, 2, 3], sort)
8082

8183

8284
def test_symmetric_difference(idx):
@@ -104,11 +106,17 @@ def test_empty(idx):
104106
assert idx[:0].empty
105107

106108

107-
def test_difference(idx):
109+
@pytest.mark.parametrize("sort", [True, False])
110+
def test_difference(idx, sort):
108111

109112
first = idx
110-
result = first.difference(idx[-3:])
111-
expected = MultiIndex.from_tuples(sorted(idx[:-3].values),
113+
result = first.difference(idx[-3:], sort)
114+
vals = idx[:-3].values
115+
116+
if sort:
117+
vals = sorted(vals)
118+
119+
expected = MultiIndex.from_tuples(vals,
112120
sortorder=0,
113121
names=idx.names)
114122

@@ -117,44 +125,44 @@ def test_difference(idx):
117125
assert result.names == idx.names
118126

119127
# empty difference: reflexive
120-
result = idx.difference(idx)
128+
result = idx.difference(idx, sort)
121129
expected = idx[:0]
122130
assert result.equals(expected)
123131
assert result.names == idx.names
124132

125133
# empty difference: superset
126-
result = idx[-3:].difference(idx)
134+
result = idx[-3:].difference(idx, sort)
127135
expected = idx[:0]
128136
assert result.equals(expected)
129137
assert result.names == idx.names
130138

131139
# empty difference: degenerate
132-
result = idx[:0].difference(idx)
140+
result = idx[:0].difference(idx, sort)
133141
expected = idx[:0]
134142
assert result.equals(expected)
135143
assert result.names == idx.names
136144

137145
# names not the same
138146
chunklet = idx[-3:]
139147
chunklet.names = ['foo', 'baz']
140-
result = first.difference(chunklet)
148+
result = first.difference(chunklet, sort)
141149
assert result.names == (None, None)
142150

143151
# empty, but non-equal
144-
result = idx.difference(idx.sortlevel(1)[0])
152+
result = idx.difference(idx.sortlevel(1)[0], sort)
145153
assert len(result) == 0
146154

147155
# raise Exception called with non-MultiIndex
148-
result = first.difference(first.values)
156+
result = first.difference(first.values, sort)
149157
assert result.equals(first[:0])
150158

151159
# name from empty array
152-
result = first.difference([])
160+
result = first.difference([], sort)
153161
assert first.equals(result)
154162
assert first.names == result.names
155163

156164
# name from non-empty array
157-
result = first.difference([('foo', 'one')])
165+
result = first.difference([('foo', 'one')], sort)
158166
expected = pd.MultiIndex.from_tuples([('bar', 'one'), ('baz', 'two'), (
159167
'foo', 'two'), ('qux', 'one'), ('qux', 'two')])
160168
expected.names = first.names

pandas/tests/indexes/period/test_period.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -72,20 +72,21 @@ def test_no_millisecond_field(self):
7272
with pytest.raises(AttributeError):
7373
DatetimeIndex([]).millisecond
7474

75-
def test_difference_freq(self):
75+
@pytest.mark.parametrize("sort", [True, False])
76+
def test_difference_freq(self, sort):
7677
# GH14323: difference of Period MUST preserve frequency
7778
# but the ability to union results must be preserved
7879

7980
index = period_range("20160920", "20160925", freq="D")
8081

8182
other = period_range("20160921", "20160924", freq="D")
8283
expected = PeriodIndex(["20160920", "20160925"], freq='D')
83-
idx_diff = index.difference(other)
84+
idx_diff = index.difference(other, sort)
8485
tm.assert_index_equal(idx_diff, expected)
8586
tm.assert_attr_equal('freq', idx_diff, expected)
8687

8788
other = period_range("20160922", "20160925", freq="D")
88-
idx_diff = index.difference(other)
89+
idx_diff = index.difference(other, sort)
8990
expected = PeriodIndex(["20160920", "20160921"], freq='D')
9091
tm.assert_index_equal(idx_diff, expected)
9192
tm.assert_attr_equal('freq', idx_diff, expected)

pandas/tests/indexes/period/test_setops.py

+28-14
Original file line numberDiff line numberDiff line change
@@ -203,37 +203,49 @@ def test_intersection_cases(self):
203203
result = rng.intersection(rng[0:0])
204204
assert len(result) == 0
205205

206-
def test_difference(self):
206+
@pytest.mark.parametrize("sort", [True, False])
207+
def test_difference(self, sort):
207208
# diff
208-
rng1 = pd.period_range('1/1/2000', freq='D', periods=5)
209+
period_rng = ['1/3/2000', '1/2/2000', '1/1/2000', '1/5/2000',
210+
'1/4/2000']
211+
rng1 = pd.PeriodIndex(period_rng, freq='D')
209212
other1 = pd.period_range('1/6/2000', freq='D', periods=5)
210-
expected1 = pd.period_range('1/1/2000', freq='D', periods=5)
213+
expected1 = rng1
211214

212-
rng2 = pd.period_range('1/1/2000', freq='D', periods=5)
215+
rng2 = pd.PeriodIndex(period_rng, freq='D')
213216
other2 = pd.period_range('1/4/2000', freq='D', periods=5)
214-
expected2 = pd.period_range('1/1/2000', freq='D', periods=3)
217+
expected2 = pd.PeriodIndex(['1/3/2000', '1/2/2000', '1/1/2000'],
218+
freq='D')
215219

216-
rng3 = pd.period_range('1/1/2000', freq='D', periods=5)
220+
rng3 = pd.PeriodIndex(period_rng, freq='D')
217221
other3 = pd.PeriodIndex([], freq='D')
218-
expected3 = pd.period_range('1/1/2000', freq='D', periods=5)
222+
expected3 = rng3
219223

220-
rng4 = pd.period_range('2000-01-01 09:00', freq='H', periods=5)
224+
period_rng = ['2000-01-01 10:00', '2000-01-01 09:00',
225+
'2000-01-01 12:00', '2000-01-01 11:00',
226+
'2000-01-01 13:00']
227+
rng4 = pd.PeriodIndex(period_rng, freq='H')
221228
other4 = pd.period_range('2000-01-02 09:00', freq='H', periods=5)
222229
expected4 = rng4
223230

224-
rng5 = pd.PeriodIndex(['2000-01-01 09:01', '2000-01-01 09:03',
231+
rng5 = pd.PeriodIndex(['2000-01-01 09:03', '2000-01-01 09:01',
225232
'2000-01-01 09:05'], freq='T')
226233
other5 = pd.PeriodIndex(
227234
['2000-01-01 09:01', '2000-01-01 09:05'], freq='T')
228235
expected5 = pd.PeriodIndex(['2000-01-01 09:03'], freq='T')
229236

230-
rng6 = pd.period_range('2000-01-01', freq='M', periods=7)
237+
period_rng = ['2000-02-01', '2000-01-01', '2000-06-01',
238+
'2000-07-01', '2000-05-01', '2000-03-01',
239+
'2000-04-01']
240+
rng6 = pd.PeriodIndex(period_rng, freq='M')
231241
other6 = pd.period_range('2000-04-01', freq='M', periods=7)
232-
expected6 = pd.period_range('2000-01-01', freq='M', periods=3)
242+
expected6 = pd.PeriodIndex(['2000-02-01', '2000-01-01', '2000-03-01'],
243+
freq='M')
233244

234-
rng7 = pd.period_range('2003-01-01', freq='A', periods=5)
245+
period_rng = ['2003', '2007', '2006', '2005', '2004']
246+
rng7 = pd.PeriodIndex(period_rng, freq='A')
235247
other7 = pd.period_range('1998-01-01', freq='A', periods=8)
236-
expected7 = pd.period_range('2006-01-01', freq='A', periods=2)
248+
expected7 = pd.PeriodIndex(['2007', '2006'], freq='A')
237249

238250
for rng, other, expected in [(rng1, other1, expected1),
239251
(rng2, other2, expected2),
@@ -242,5 +254,7 @@ def test_difference(self):
242254
(rng5, other5, expected5),
243255
(rng6, other6, expected6),
244256
(rng7, other7, expected7), ]:
245-
result_union = rng.difference(other)
257+
result_union = rng.difference(other, sort)
258+
if sort:
259+
expected = expected.sort_values()
246260
tm.assert_index_equal(result_union, expected)

0 commit comments

Comments
 (0)