diff --git a/pandas/_libs/index.pyx b/pandas/_libs/index.pyx index 4185cc2084469..6141e2b78e9f4 100644 --- a/pandas/_libs/index.pyx +++ b/pandas/_libs/index.pyx @@ -12,6 +12,7 @@ cnp.import_array() cimport pandas._libs.util as util +from pandas._libs.tslibs import Period from pandas._libs.tslibs.nattype cimport c_NaT as NaT from pandas._libs.tslibs.c_timestamp cimport _Timestamp @@ -466,6 +467,28 @@ cdef class TimedeltaEngine(DatetimeEngine): cdef class PeriodEngine(Int64Engine): + cdef int64_t _unbox_scalar(self, scalar) except? -1: + if scalar is NaT: + return scalar.value + if isinstance(scalar, Period): + # NB: we assume that we have the correct freq here. + # TODO: potential optimize by checking for _Period? + return scalar.ordinal + raise TypeError(scalar) + + cpdef get_loc(self, object val): + # NB: the caller is responsible for ensuring that we are called + # with either a Period or NaT + cdef: + int64_t conv + + try: + conv = self._unbox_scalar(val) + except TypeError: + raise KeyError(val) + + return Int64Engine.get_loc(self, conv) + cdef _get_index_values(self): return super(PeriodEngine, self).vgetter().view("i8") diff --git a/pandas/core/indexes/period.py b/pandas/core/indexes/period.py index 0e0eb249562d7..987725bb4b70b 100644 --- a/pandas/core/indexes/period.py +++ b/pandas/core/indexes/period.py @@ -468,6 +468,10 @@ def get_indexer(self, target, method=None, limit=None, tolerance=None): if tolerance is not None: tolerance = self._convert_tolerance(tolerance, target) + if self_index is not self: + # convert tolerance to i8 + tolerance = self._maybe_convert_timedelta(tolerance) + return Index.get_indexer(self_index, target, method, limit, tolerance) @Appender(_index_shared_docs["get_indexer_non_unique"] % _index_doc_kwargs) @@ -504,6 +508,7 @@ def get_loc(self, key, method=None, tolerance=None): TypeError If key is listlike or otherwise not hashable. """ + orig_key = key if not is_scalar(key): raise InvalidIndexError(key) @@ -545,20 +550,12 @@ def get_loc(self, key, method=None, tolerance=None): key = Period(key, freq=self.freq) except ValueError: # we cannot construct the Period - raise KeyError(key) + raise KeyError(orig_key) - ordinal = self._data._unbox_scalar(key) try: - return self._engine.get_loc(ordinal) + return Index.get_loc(self, key, method, tolerance) except KeyError: - - try: - if tolerance is not None: - tolerance = self._convert_tolerance(tolerance, np.asarray(key)) - return self._int64index.get_loc(ordinal, method, tolerance) - - except KeyError: - raise KeyError(key) + raise KeyError(orig_key) def _maybe_cast_slice_bound(self, label, side: str, kind: str): """ @@ -625,12 +622,6 @@ def _get_string_slice(self, key: str, use_lhs: bool = True, use_rhs: bool = True except KeyError: raise KeyError(key) - def _convert_tolerance(self, tolerance, target): - tolerance = DatetimeIndexOpsMixin._convert_tolerance(self, tolerance, target) - if target.size != tolerance.size and tolerance.size > 1: - raise ValueError("list-like tolerance size must match target index size") - return self._maybe_convert_timedelta(tolerance) - def insert(self, loc, item): if not isinstance(item, Period) or self.freq != item.freq: return self.astype(object).insert(loc, item)