Skip to content

Commit ea186c0

Browse files
committed
add more tests
1 parent 4f1ee6d commit ea186c0

File tree

10 files changed

+165
-75
lines changed

10 files changed

+165
-75
lines changed

pandas/core/indexes/interval.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -1036,7 +1036,7 @@ def overlaps(self, other):
10361036
return self._data.overlaps(other)
10371037

10381038
def _setop(op_name):
1039-
def func(self, other):
1039+
def func(self, other, sort=True):
10401040
other = self._as_like_interval_index(other)
10411041

10421042
# GH 19016: ensure set op will not return a prohibited dtype
@@ -1047,7 +1047,11 @@ def func(self, other):
10471047
'objects that have compatible dtypes')
10481048
raise TypeError(msg.format(op=op_name))
10491049

1050-
result = getattr(self._multiindex, op_name)(other._multiindex)
1050+
if op_name == 'difference':
1051+
result = getattr(self._multiindex, op_name)(other._multiindex,
1052+
sort)
1053+
else:
1054+
result = getattr(self._multiindex, op_name)(other._multiindex)
10511055
result_name = self.name if self.name == other.name else None
10521056

10531057
# GH 19101: ensure empty results have correct dtype

pandas/core/indexes/multi.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -2796,8 +2796,14 @@ def difference(self, other, sort=True):
27962796
labels=[[]] * self.nlevels,
27972797
names=result_names, verify_integrity=False)
27982798

2799-
difference = set(self._ndarray_values) - set(other._ndarray_values)
2799+
this = self._get_unique_index()
28002800

2801+
indexer = this.get_indexer(other)
2802+
indexer = indexer.take((indexer != -1).nonzero()[0])
2803+
2804+
label_diff = np.setdiff1d(np.arange(this.size), indexer,
2805+
assume_unique=True)
2806+
difference = this.values.take(label_diff)
28012807
if sort:
28022808
difference = sorted(difference)
28032809

pandas/tests/indexes/common.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -668,12 +668,13 @@ def test_union_base(self):
668668
with tm.assert_raises_regex(TypeError, msg):
669669
result = first.union([1, 2, 3])
670670

671-
def test_difference_base(self):
671+
@pytest.mark.parametrize("sort", [True, False])
672+
def test_difference_base(self, sort):
672673
for name, idx in compat.iteritems(self.indices):
673674
first = idx[2:]
674675
second = idx[:4]
675676
answer = idx[4:]
676-
result = first.difference(second)
677+
result = first.difference(second, sort)
677678

678679
if isinstance(idx, CategoricalIndex):
679680
pass
@@ -687,21 +688,21 @@ def test_difference_base(self):
687688
if isinstance(idx, PeriodIndex):
688689
msg = "can only call with other PeriodIndex-ed objects"
689690
with tm.assert_raises_regex(ValueError, msg):
690-
result = first.difference(case)
691+
result = first.difference(case, sort)
691692
elif isinstance(idx, CategoricalIndex):
692693
pass
693694
elif isinstance(idx, (DatetimeIndex, TimedeltaIndex)):
694695
assert result.__class__ == answer.__class__
695696
tm.assert_numpy_array_equal(result.sort_values().asi8,
696697
answer.sort_values().asi8)
697698
else:
698-
result = first.difference(case)
699+
result = first.difference(case, sort)
699700
assert tm.equalContents(result, answer)
700701

701702
if isinstance(idx, MultiIndex):
702703
msg = "other must be a MultiIndex or a list of tuples"
703704
with tm.assert_raises_regex(TypeError, msg):
704-
result = first.difference([1, 2, 3])
705+
result = first.difference([1, 2, 3], sort)
705706

706707
def test_symmetric_difference(self):
707708
for name, idx in compat.iteritems(self.indices):

pandas/tests/indexes/datetimes/test_setops.py

+21-13
Original file line numberDiff line numberDiff line change
@@ -209,47 +209,55 @@ def test_intersection_bug_1708(self):
209209
assert len(result) == 0
210210

211211
@pytest.mark.parametrize("tz", tz)
212-
def test_difference(self, tz):
213-
rng1 = pd.date_range('1/1/2000', freq='D', periods=5, tz=tz)
212+
@pytest.mark.parametrize("sort", [True, False])
213+
def test_difference(self, tz, sort):
214+
rng_dates = ['1/2/2000', '1/3/2000', '1/1/2000', '1/4/2000',
215+
'1/5/2000']
216+
217+
rng1 = pd.DatetimeIndex(rng_dates, tz=tz)
214218
other1 = pd.date_range('1/6/2000', freq='D', periods=5, tz=tz)
215-
expected1 = pd.date_range('1/1/2000', freq='D', periods=5, tz=tz)
219+
expected1 = pd.DatetimeIndex(rng_dates, tz=tz)
216220

217-
rng2 = pd.date_range('1/1/2000', freq='D', periods=5, tz=tz)
221+
rng2 = pd.DatetimeIndex(rng_dates, tz=tz)
218222
other2 = pd.date_range('1/4/2000', freq='D', periods=5, tz=tz)
219-
expected2 = pd.date_range('1/1/2000', freq='D', periods=3, tz=tz)
223+
expected2 = pd.DatetimeIndex(rng_dates[:3], tz=tz)
220224

221-
rng3 = pd.date_range('1/1/2000', freq='D', periods=5, tz=tz)
225+
rng3 = pd.DatetimeIndex(rng_dates, tz=tz)
222226
other3 = pd.DatetimeIndex([], tz=tz)
223-
expected3 = pd.date_range('1/1/2000', freq='D', periods=5, tz=tz)
227+
expected3 = pd.DatetimeIndex(rng_dates, tz=tz)
224228

225229
for rng, other, expected in [(rng1, other1, expected1),
226230
(rng2, other2, expected2),
227231
(rng3, other3, expected3)]:
228-
result_diff = rng.difference(other)
232+
result_diff = rng.difference(other, sort)
233+
if sort:
234+
expected = expected.sort_values()
229235
tm.assert_index_equal(result_diff, expected)
230236

231-
def test_difference_freq(self):
237+
@pytest.mark.parametrize("sort", [True, False])
238+
def test_difference_freq(self, sort):
232239
# GH14323: difference of DatetimeIndex should not preserve frequency
233240

234241
index = date_range("20160920", "20160925", freq="D")
235242
other = date_range("20160921", "20160924", freq="D")
236243
expected = DatetimeIndex(["20160920", "20160925"], freq=None)
237-
idx_diff = index.difference(other)
244+
idx_diff = index.difference(other, sort)
238245
tm.assert_index_equal(idx_diff, expected)
239246
tm.assert_attr_equal('freq', idx_diff, expected)
240247

241248
other = date_range("20160922", "20160925", freq="D")
242-
idx_diff = index.difference(other)
249+
idx_diff = index.difference(other, sort)
243250
expected = DatetimeIndex(["20160920", "20160921"], freq=None)
244251
tm.assert_index_equal(idx_diff, expected)
245252
tm.assert_attr_equal('freq', idx_diff, expected)
246253

247-
def test_datetimeindex_diff(self):
254+
@pytest.mark.parametrize("sort", [True, False])
255+
def test_datetimeindex_diff(self, sort):
248256
dti1 = DatetimeIndex(freq='Q-JAN', start=datetime(1997, 12, 31),
249257
periods=100)
250258
dti2 = DatetimeIndex(freq='Q-JAN', start=datetime(1997, 12, 31),
251259
periods=98)
252-
assert len(dti1.difference(dti2)) == 2
260+
assert len(dti1.difference(dti2, sort)) == 2
253261

254262
def test_datetimeindex_union_join_empty(self):
255263
dti = DatetimeIndex(start='1/1/2001', end='2/1/2001', freq='D')

pandas/tests/indexes/interval/test_interval.py

+12-5
Original file line numberDiff line numberDiff line change
@@ -801,19 +801,26 @@ def test_intersection(self, closed):
801801
result = index.intersection(other)
802802
tm.assert_index_equal(result, expected)
803803

804-
def test_difference(self, closed):
805-
index = self.create_index(closed=closed)
806-
tm.assert_index_equal(index.difference(index[:1]), index[1:])
804+
@pytest.mark.parametrize("sort", [True, False])
805+
def test_difference(self, closed, sort):
806+
index = IntervalIndex.from_arrays([1, 0, 3, 2],
807+
[1, 2, 3, 4],
808+
closed=closed)
809+
result = index.difference(index[:1], sort)
810+
expected = index[1:]
811+
if sort:
812+
expected = expected.sort_values()
813+
tm.assert_index_equal(result, expected)
807814

808815
# GH 19101: empty result, same dtype
809-
result = index.difference(index)
816+
result = index.difference(index, sort)
810817
expected = IntervalIndex(np.array([], dtype='int64'), closed=closed)
811818
tm.assert_index_equal(result, expected)
812819

813820
# GH 19101: empty result, different dtypes
814821
other = IntervalIndex.from_arrays(index.left.astype('float64'),
815822
index.right, closed=closed)
816-
result = index.difference(other)
823+
result = index.difference(other, sort)
817824
tm.assert_index_equal(result, expected)
818825

819826
def test_symmetric_difference(self, closed):

pandas/tests/indexes/multi/test_set_ops.py

+23-15
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import pandas as pd
66
import pandas.util.testing as tm
77
from pandas import MultiIndex, Series
8+
import pytest
89

910

1011
def test_setops_errorcases(idx):
@@ -59,24 +60,25 @@ def test_union_base(idx):
5960
result = first.union([1, 2, 3])
6061

6162

62-
def test_difference_base(idx):
63+
@pytest.mark.parametrize("sort", [True, False])
64+
def test_difference_base(idx, sort):
6365
first = idx[2:]
6466
second = idx[:4]
6567
answer = idx[4:]
66-
result = first.difference(second)
68+
result = first.difference(second, sort)
6769

6870
assert tm.equalContents(result, answer)
6971

7072
# GH 10149
7173
cases = [klass(second.values)
7274
for klass in [np.array, Series, list]]
7375
for case in cases:
74-
result = first.difference(case)
76+
result = first.difference(case, sort)
7577
assert tm.equalContents(result, answer)
7678

7779
msg = "other must be a MultiIndex or a list of tuples"
7880
with tm.assert_raises_regex(TypeError, msg):
79-
result = first.difference([1, 2, 3])
81+
result = first.difference([1, 2, 3], sort)
8082

8183

8284
def test_symmetric_difference(idx):
@@ -104,11 +106,17 @@ def test_empty(idx):
104106
assert idx[:0].empty
105107

106108

107-
def test_difference(idx):
109+
@pytest.mark.parametrize("sort", [True, False])
110+
def test_difference(idx, sort):
108111

109112
first = idx
110-
result = first.difference(idx[-3:])
111-
expected = MultiIndex.from_tuples(sorted(idx[:-3].values),
113+
result = first.difference(idx[-3:], sort)
114+
vals = idx[:-3].values
115+
116+
if sort:
117+
vals = sorted(vals)
118+
119+
expected = MultiIndex.from_tuples(vals,
112120
sortorder=0,
113121
names=idx.names)
114122

@@ -117,44 +125,44 @@ def test_difference(idx):
117125
assert result.names == idx.names
118126

119127
# empty difference: reflexive
120-
result = idx.difference(idx)
128+
result = idx.difference(idx, sort)
121129
expected = idx[:0]
122130
assert result.equals(expected)
123131
assert result.names == idx.names
124132

125133
# empty difference: superset
126-
result = idx[-3:].difference(idx)
134+
result = idx[-3:].difference(idx, sort)
127135
expected = idx[:0]
128136
assert result.equals(expected)
129137
assert result.names == idx.names
130138

131139
# empty difference: degenerate
132-
result = idx[:0].difference(idx)
140+
result = idx[:0].difference(idx, sort)
133141
expected = idx[:0]
134142
assert result.equals(expected)
135143
assert result.names == idx.names
136144

137145
# names not the same
138146
chunklet = idx[-3:]
139147
chunklet.names = ['foo', 'baz']
140-
result = first.difference(chunklet)
148+
result = first.difference(chunklet, sort)
141149
assert result.names == (None, None)
142150

143151
# empty, but non-equal
144-
result = idx.difference(idx.sortlevel(1)[0])
152+
result = idx.difference(idx.sortlevel(1)[0], sort)
145153
assert len(result) == 0
146154

147155
# raise Exception called with non-MultiIndex
148-
result = first.difference(first.values)
156+
result = first.difference(first.values, sort)
149157
assert result.equals(first[:0])
150158

151159
# name from empty array
152-
result = first.difference([])
160+
result = first.difference([], sort)
153161
assert first.equals(result)
154162
assert first.names == result.names
155163

156164
# name from non-empty array
157-
result = first.difference([('foo', 'one')])
165+
result = first.difference([('foo', 'one')], sort)
158166
expected = pd.MultiIndex.from_tuples([('bar', 'one'), ('baz', 'two'), (
159167
'foo', 'two'), ('qux', 'one'), ('qux', 'two')])
160168
expected.names = first.names

pandas/tests/indexes/period/test_period.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -72,20 +72,21 @@ def test_no_millisecond_field(self):
7272
with pytest.raises(AttributeError):
7373
DatetimeIndex([]).millisecond
7474

75-
def test_difference_freq(self):
75+
@pytest.mark.parametrize("sort", [True, False])
76+
def test_difference_freq(self, sort):
7677
# GH14323: difference of Period MUST preserve frequency
7778
# but the ability to union results must be preserved
7879

7980
index = period_range("20160920", "20160925", freq="D")
8081

8182
other = period_range("20160921", "20160924", freq="D")
8283
expected = PeriodIndex(["20160920", "20160925"], freq='D')
83-
idx_diff = index.difference(other)
84+
idx_diff = index.difference(other, sort)
8485
tm.assert_index_equal(idx_diff, expected)
8586
tm.assert_attr_equal('freq', idx_diff, expected)
8687

8788
other = period_range("20160922", "20160925", freq="D")
88-
idx_diff = index.difference(other)
89+
idx_diff = index.difference(other, sort)
8990
expected = PeriodIndex(["20160920", "20160921"], freq='D')
9091
tm.assert_index_equal(idx_diff, expected)
9192
tm.assert_attr_equal('freq', idx_diff, expected)

pandas/tests/indexes/period/test_setops.py

+28-14
Original file line numberDiff line numberDiff line change
@@ -203,37 +203,49 @@ def test_intersection_cases(self):
203203
result = rng.intersection(rng[0:0])
204204
assert len(result) == 0
205205

206-
def test_difference(self):
206+
@pytest.mark.parametrize("sort", [True, False])
207+
def test_difference(self, sort):
207208
# diff
208-
rng1 = pd.period_range('1/1/2000', freq='D', periods=5)
209+
period_rng = ['1/3/2000', '1/2/2000', '1/1/2000', '1/5/2000',
210+
'1/4/2000']
211+
rng1 = pd.PeriodIndex(period_rng, freq='D')
209212
other1 = pd.period_range('1/6/2000', freq='D', periods=5)
210-
expected1 = pd.period_range('1/1/2000', freq='D', periods=5)
213+
expected1 = rng1
211214

212-
rng2 = pd.period_range('1/1/2000', freq='D', periods=5)
215+
rng2 = pd.PeriodIndex(period_rng, freq='D')
213216
other2 = pd.period_range('1/4/2000', freq='D', periods=5)
214-
expected2 = pd.period_range('1/1/2000', freq='D', periods=3)
217+
expected2 = pd.PeriodIndex(['1/3/2000', '1/2/2000', '1/1/2000'],
218+
freq='D')
215219

216-
rng3 = pd.period_range('1/1/2000', freq='D', periods=5)
220+
rng3 = pd.PeriodIndex(period_rng, freq='D')
217221
other3 = pd.PeriodIndex([], freq='D')
218-
expected3 = pd.period_range('1/1/2000', freq='D', periods=5)
222+
expected3 = rng3
219223

220-
rng4 = pd.period_range('2000-01-01 09:00', freq='H', periods=5)
224+
period_rng = ['2000-01-01 10:00', '2000-01-01 09:00',
225+
'2000-01-01 12:00', '2000-01-01 11:00',
226+
'2000-01-01 13:00']
227+
rng4 = pd.PeriodIndex(period_rng, freq='H')
221228
other4 = pd.period_range('2000-01-02 09:00', freq='H', periods=5)
222229
expected4 = rng4
223230

224-
rng5 = pd.PeriodIndex(['2000-01-01 09:01', '2000-01-01 09:03',
231+
rng5 = pd.PeriodIndex(['2000-01-01 09:03', '2000-01-01 09:01',
225232
'2000-01-01 09:05'], freq='T')
226233
other5 = pd.PeriodIndex(
227234
['2000-01-01 09:01', '2000-01-01 09:05'], freq='T')
228235
expected5 = pd.PeriodIndex(['2000-01-01 09:03'], freq='T')
229236

230-
rng6 = pd.period_range('2000-01-01', freq='M', periods=7)
237+
period_rng = ['2000-02-01', '2000-01-01', '2000-06-01',
238+
'2000-07-01', '2000-05-01', '2000-03-01',
239+
'2000-04-01']
240+
rng6 = pd.PeriodIndex(period_rng, freq='M')
231241
other6 = pd.period_range('2000-04-01', freq='M', periods=7)
232-
expected6 = pd.period_range('2000-01-01', freq='M', periods=3)
242+
expected6 = pd.PeriodIndex(['2000-02-01', '2000-01-01', '2000-03-01'],
243+
freq='M')
233244

234-
rng7 = pd.period_range('2003-01-01', freq='A', periods=5)
245+
period_rng = ['2003', '2007', '2006', '2005', '2004']
246+
rng7 = pd.PeriodIndex(period_rng, freq='A')
235247
other7 = pd.period_range('1998-01-01', freq='A', periods=8)
236-
expected7 = pd.period_range('2006-01-01', freq='A', periods=2)
248+
expected7 = pd.PeriodIndex(['2007', '2006'], freq='A')
237249

238250
for rng, other, expected in [(rng1, other1, expected1),
239251
(rng2, other2, expected2),
@@ -242,5 +254,7 @@ def test_difference(self):
242254
(rng5, other5, expected5),
243255
(rng6, other6, expected6),
244256
(rng7, other7, expected7), ]:
245-
result_union = rng.difference(other)
257+
result_union = rng.difference(other, sort)
258+
if sort:
259+
expected = expected.sort_values()
246260
tm.assert_index_equal(result_union, expected)

0 commit comments

Comments
 (0)