From 589d30af8a0c8e32a978f8bf6cdb0ff1b68789a8 Mon Sep 17 00:00:00 2001 From: sinhrks Date: Fri, 20 Jun 2014 05:07:03 +0900 Subject: [PATCH] BUG: DatetimeIndex comparison handles NaT incorrectly --- doc/source/v0.14.1.txt | 1 + pandas/tseries/index.py | 33 +++++++--- pandas/tseries/period.py | 15 ++--- pandas/tseries/tests/test_timeseries.py | 87 +++++++++++++++++++++++++ 4 files changed, 116 insertions(+), 20 deletions(-) diff --git a/doc/source/v0.14.1.txt b/doc/source/v0.14.1.txt index 03caf47dc7127..197bc9bae5c9d 100644 --- a/doc/source/v0.14.1.txt +++ b/doc/source/v0.14.1.txt @@ -237,6 +237,7 @@ Bug Fixes - Bug in when writing Stata files where the encoding was ignored (:issue:`7286`) +- Bug in ``DatetimeIndex`` comparison doesn't handle ``NaT`` properly (:issue:`7529`) - Bug in passing input with ``tzinfo`` to some offsets ``apply``, ``rollforward`` or ``rollback`` resets ``tzinfo`` or raises ``ValueError`` (:issue:`7465`) diff --git a/pandas/tseries/index.py b/pandas/tseries/index.py index 16468f24a0ee1..50296a417479e 100644 --- a/pandas/tseries/index.py +++ b/pandas/tseries/index.py @@ -74,22 +74,35 @@ def wrapper(left, right): return wrapper -def _dt_index_cmp(opname): +def _dt_index_cmp(opname, nat_result=False): """ Wrap comparison operations to convert datetime-like to datetime64 """ def wrapper(self, other): func = getattr(super(DatetimeIndex, self), opname) - if isinstance(other, datetime): + if isinstance(other, datetime) or isinstance(other, compat.string_types): other = _to_m8(other, tz=self.tz) - elif isinstance(other, list): - other = DatetimeIndex(other) - elif isinstance(other, compat.string_types): - other = _to_m8(other, tz=self.tz) - elif not isinstance(other, (np.ndarray, ABCSeries)): - other = _ensure_datetime64(other) - result = func(other) + result = func(other) + if com.isnull(other): + result.fill(nat_result) + else: + if isinstance(other, list): + other = DatetimeIndex(other) + elif not isinstance(other, (np.ndarray, ABCSeries)): + other = _ensure_datetime64(other) + result = func(other) + if isinstance(other, Index): + o_mask = other.values.view('i8') == tslib.iNaT + else: + o_mask = other.view('i8') == tslib.iNaT + + if o_mask.any(): + result[o_mask] = nat_result + + mask = self.asi8 == tslib.iNaT + if mask.any(): + result[mask] = nat_result return result.view(np.ndarray) return wrapper @@ -142,7 +155,7 @@ class DatetimeIndex(DatetimeIndexOpsMixin, Int64Index): _arrmap = None __eq__ = _dt_index_cmp('__eq__') - __ne__ = _dt_index_cmp('__ne__') + __ne__ = _dt_index_cmp('__ne__', nat_result=True) __lt__ = _dt_index_cmp('__lt__') __gt__ = _dt_index_cmp('__gt__') __le__ = _dt_index_cmp('__le__') diff --git a/pandas/tseries/period.py b/pandas/tseries/period.py index c44c3c9272f6a..b3a29ab4110d7 100644 --- a/pandas/tseries/period.py +++ b/pandas/tseries/period.py @@ -498,16 +498,11 @@ def dt64arr_to_periodarr(data, freq, tz): # --- Period index sketch -def _period_index_cmp(opname): +def _period_index_cmp(opname, nat_result=False): """ Wrap comparison operations to convert datetime-like to datetime64 """ def wrapper(self, other): - if opname == '__ne__': - fill_value = True - else: - fill_value = False - if isinstance(other, Period): func = getattr(self.values, opname) if other.freq != self.freq: @@ -523,7 +518,7 @@ def wrapper(self, other): mask = (com.mask_missing(self.values, tslib.iNaT) | com.mask_missing(other.values, tslib.iNaT)) if mask.any(): - result[mask] = fill_value + result[mask] = nat_result return result else: @@ -532,10 +527,10 @@ def wrapper(self, other): result = func(other.ordinal) if other.ordinal == tslib.iNaT: - result.fill(fill_value) + result.fill(nat_result) mask = self.values == tslib.iNaT if mask.any(): - result[mask] = fill_value + result[mask] = nat_result return result return wrapper @@ -595,7 +590,7 @@ class PeriodIndex(DatetimeIndexOpsMixin, Int64Index): _allow_period_index_ops = True __eq__ = _period_index_cmp('__eq__') - __ne__ = _period_index_cmp('__ne__') + __ne__ = _period_index_cmp('__ne__', nat_result=True) __lt__ = _period_index_cmp('__lt__') __gt__ = _period_index_cmp('__gt__') __le__ = _period_index_cmp('__le__') diff --git a/pandas/tseries/tests/test_timeseries.py b/pandas/tseries/tests/test_timeseries.py index d32efe0d777f7..d2cfdff2b003d 100644 --- a/pandas/tseries/tests/test_timeseries.py +++ b/pandas/tseries/tests/test_timeseries.py @@ -2179,6 +2179,93 @@ def test_comparisons_coverage(self): exp = rng == rng self.assert_numpy_array_equal(result, exp) + def test_comparisons_nat(self): + fidx1 = pd.Index([1.0, np.nan, 3.0, np.nan, 5.0, 7.0]) + fidx2 = pd.Index([2.0, 3.0, np.nan, np.nan, 6.0, 7.0]) + + didx1 = pd.DatetimeIndex(['2014-01-01', pd.NaT, '2014-03-01', pd.NaT, + '2014-05-01', '2014-07-01']) + didx2 = pd.DatetimeIndex(['2014-02-01', '2014-03-01', pd.NaT, pd.NaT, + '2014-06-01', '2014-07-01']) + darr = np.array([np.datetime64('2014-02-01 00:00Z'), + np.datetime64('2014-03-01 00:00Z'), + np.datetime64('nat'), np.datetime64('nat'), + np.datetime64('2014-06-01 00:00Z'), + np.datetime64('2014-07-01 00:00Z')]) + + if _np_version_under1p7: + # cannot test array because np.datetime('nat') returns today's date + cases = [(fidx1, fidx2), (didx1, didx2)] + else: + cases = [(fidx1, fidx2), (didx1, didx2), (didx1, darr)] + + # Check pd.NaT is handles as the same as np.nan + for idx1, idx2 in cases: + result = idx1 < idx2 + expected = np.array([True, False, False, False, True, False]) + self.assert_numpy_array_equal(result, expected) + result = idx2 > idx1 + expected = np.array([True, False, False, False, True, False]) + self.assert_numpy_array_equal(result, expected) + + result = idx1 <= idx2 + expected = np.array([True, False, False, False, True, True]) + self.assert_numpy_array_equal(result, expected) + result = idx2 >= idx1 + expected = np.array([True, False, False, False, True, True]) + self.assert_numpy_array_equal(result, expected) + + result = idx1 == idx2 + expected = np.array([False, False, False, False, False, True]) + self.assert_numpy_array_equal(result, expected) + + result = idx1 != idx2 + expected = np.array([True, True, True, True, True, False]) + self.assert_numpy_array_equal(result, expected) + + for idx1, val in [(fidx1, np.nan), (didx1, pd.NaT)]: + result = idx1 < val + expected = np.array([False, False, False, False, False, False]) + self.assert_numpy_array_equal(result, expected) + result = idx1 > val + self.assert_numpy_array_equal(result, expected) + + result = idx1 <= val + self.assert_numpy_array_equal(result, expected) + result = idx1 >= val + self.assert_numpy_array_equal(result, expected) + + result = idx1 == val + self.assert_numpy_array_equal(result, expected) + + result = idx1 != val + expected = np.array([True, True, True, True, True, True]) + self.assert_numpy_array_equal(result, expected) + + # Check pd.NaT is handles as the same as np.nan + for idx1, val in [(fidx1, 3), (didx1, datetime(2014, 3, 1))]: + result = idx1 < val + expected = np.array([True, False, False, False, False, False]) + self.assert_numpy_array_equal(result, expected) + result = idx1 > val + expected = np.array([False, False, False, False, True, True]) + self.assert_numpy_array_equal(result, expected) + + result = idx1 <= val + expected = np.array([True, False, True, False, False, False]) + self.assert_numpy_array_equal(result, expected) + result = idx1 >= val + expected = np.array([False, False, True, False, True, True]) + self.assert_numpy_array_equal(result, expected) + + result = idx1 == val + expected = np.array([False, False, True, False, False, False]) + self.assert_numpy_array_equal(result, expected) + + result = idx1 != val + expected = np.array([True, True, False, True, True, True]) + self.assert_numpy_array_equal(result, expected) + def test_map(self): rng = date_range('1/1/2000', periods=10)