Skip to content

BUG: .replace coerces incorrect dtype #15741

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/whatsnew/v0.20.0.txt
Original file line number Diff line number Diff line change
Expand Up @@ -823,6 +823,7 @@ Bug Fixes


- Bug in the display of ``.info()`` where a qualifier (+) would always be displayed with a ``MultiIndex`` that contains only non-strings (:issue:`15245`)
- Bug in ``.replace()`` may result in incorrect dtypes. (:issue:`12747`)

- Bug in ``.asfreq()``, where frequency was not set for empty ``Series`` (:issue:`14320`)

Expand Down
20 changes: 17 additions & 3 deletions pandas/core/internals.py
Original file line number Diff line number Diff line change
Expand Up @@ -1894,8 +1894,11 @@ def convert(self, *args, **kwargs):
blocks.append(newb)

else:
values = fn(
self.values.ravel(), **fn_kwargs).reshape(self.values.shape)
values = fn(self.values.ravel(), **fn_kwargs)
try:
values = values.reshape(self.values.shape)
except NotImplementedError:
pass
blocks.append(make_block(values, ndim=self.ndim,
placement=self.mgr_locs))

Expand Down Expand Up @@ -3238,6 +3241,16 @@ def comp(s):
return _possibly_compare(values, getattr(s, 'asm8', s),
operator.eq)

def _cast_scalar(block, scalar):
dtype, val = _infer_dtype_from_scalar(scalar, pandas_dtype=True)
if not is_dtype_equal(block.dtype, dtype):
dtype = _find_common_type([block.dtype, dtype])
block = block.astype(dtype)
# use original value
val = scalar

return block, val

masks = [comp(s) for i, s in enumerate(src_list)]

result_blocks = []
Expand All @@ -3260,7 +3273,8 @@ def comp(s):
# particular block
m = masks[i][b.mgr_locs.indexer]
if m.any():
new_rb.extend(b.putmask(m, d, inplace=True))
b, val = _cast_scalar(b, d)
new_rb.extend(b.putmask(m, val, inplace=True))
else:
new_rb.append(b)
rb = new_rb
Expand Down
48 changes: 38 additions & 10 deletions pandas/tests/indexing/test_coercion.py
Original file line number Diff line number Diff line change
Expand Up @@ -1153,12 +1153,27 @@ def setUp(self):
self.rep['float64'] = [1.1, 2.2]
self.rep['complex128'] = [1 + 1j, 2 + 2j]
self.rep['bool'] = [True, False]
self.rep['datetime64[ns]'] = [pd.Timestamp('2011-01-01'),
pd.Timestamp('2011-01-03')]

for tz in ['UTC', 'US/Eastern']:
# to test tz => different tz replacement
key = 'datetime64[ns, {0}]'.format(tz)
self.rep[key] = [pd.Timestamp('2011-01-01', tz=tz),
pd.Timestamp('2011-01-03', tz=tz)]

self.rep['timedelta64[ns]'] = [pd.Timedelta('1 day'),
pd.Timedelta('2 day')]

def _assert_replace_conversion(self, from_key, to_key, how):
index = pd.Index([3, 4], name='xxx')
obj = pd.Series(self.rep[from_key], index=index, name='yyy')
self.assertEqual(obj.dtype, from_key)

if (from_key.startswith('datetime') and to_key.startswith('datetime')):
# different tz, currently mask_missing raises SystemError
return

if how == 'dict':
replacer = dict(zip(self.rep[from_key], self.rep[to_key]))
elif how == 'series':
Expand All @@ -1175,17 +1190,12 @@ def _assert_replace_conversion(self, from_key, to_key, how):
pytest.skip("windows platform buggy: {0} -> {1}".format
(from_key, to_key))

if ((from_key == 'float64' and
to_key in ('bool', 'int64')) or

if ((from_key == 'float64' and to_key in ('bool', 'int64')) or
(from_key == 'complex128' and
to_key in ('bool', 'int64', 'float64')) or

(from_key == 'int64' and
to_key in ('bool')) or

# TODO_GH12747 The result must be int?
(from_key == 'bool' and to_key == 'int64')):
(from_key == 'int64' and to_key in ('bool'))):

# buggy on 32-bit
if tm.is_platform_32bit():
Expand Down Expand Up @@ -1248,13 +1258,31 @@ def test_replace_series_bool(self):
self._assert_replace_conversion(from_key, to_key, how='series')

def test_replace_series_datetime64(self):
pass
from_key = 'datetime64[ns]'
for to_key in self.rep:
self._assert_replace_conversion(from_key, to_key, how='dict')

from_key = 'datetime64[ns]'
for to_key in self.rep:
self._assert_replace_conversion(from_key, to_key, how='series')

def test_replace_series_datetime64tz(self):
pass
from_key = 'datetime64[ns, US/Eastern]'
for to_key in self.rep:
self._assert_replace_conversion(from_key, to_key, how='dict')

from_key = 'datetime64[ns, US/Eastern]'
for to_key in self.rep:
self._assert_replace_conversion(from_key, to_key, how='series')

def test_replace_series_timedelta64(self):
pass
from_key = 'timedelta64[ns]'
for to_key in self.rep:
self._assert_replace_conversion(from_key, to_key, how='dict')

from_key = 'timedelta64[ns]'
for to_key in self.rep:
self._assert_replace_conversion(from_key, to_key, how='series')

def test_replace_series_period(self):
pass
4 changes: 2 additions & 2 deletions pandas/tests/series/test_replace.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,8 @@ def check_replace(to_rep, val, expected):
tm.assert_series_equal(expected, r)
tm.assert_series_equal(expected, sc)

# should NOT upcast to float
e = pd.Series([0, 1, 2, 3, 4])
# MUST upcast to float
e = pd.Series([0., 1., 2., 3., 4.])
tr, v = [3], [3.0]
check_replace(tr, v, e)

Expand Down
37 changes: 29 additions & 8 deletions pandas/types/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
_ensure_int32, _ensure_int64,
_NS_DTYPE, _TD_DTYPE, _INT64_DTYPE,
_POSSIBLY_CAST_DTYPES)
from .dtypes import ExtensionDtype
from .dtypes import ExtensionDtype, DatetimeTZDtype, PeriodDtype
from .generic import ABCDatetimeIndex, ABCPeriodIndex, ABCSeries
from .missing import isnull, notnull
from .inference import is_list_like
Expand Down Expand Up @@ -312,8 +312,17 @@ def _maybe_promote(dtype, fill_value=np.nan):
return dtype, fill_value


def _infer_dtype_from_scalar(val):
""" interpret the dtype from a scalar """
def _infer_dtype_from_scalar(val, pandas_dtype=False):
"""
interpret the dtype from a scalar

Parameters
----------
pandas_dtype : bool, default False
whether to infer dtype including pandas extension types.
If False, scalar belongs to pandas extension types is inferred as
object
"""

dtype = np.object_

Expand All @@ -336,13 +345,20 @@ def _infer_dtype_from_scalar(val):

dtype = np.object_

elif isinstance(val, (np.datetime64,
datetime)) and getattr(val, 'tzinfo', None) is None:
val = lib.Timestamp(val).value
dtype = np.dtype('M8[ns]')
elif isinstance(val, (np.datetime64, datetime)):
val = tslib.Timestamp(val)
if val is tslib.NaT or val.tz is None:
dtype = np.dtype('M8[ns]')
else:
if pandas_dtype:
dtype = DatetimeTZDtype(unit='ns', tz=val.tz)
else:
# return datetimetz as object
return np.object_, val
val = val.value

elif isinstance(val, (np.timedelta64, timedelta)):
val = lib.Timedelta(val).value
val = tslib.Timedelta(val).value
dtype = np.dtype('m8[ns]')

elif is_bool(val):
Expand All @@ -363,6 +379,11 @@ def _infer_dtype_from_scalar(val):
elif is_complex(val):
dtype = np.complex_

elif pandas_dtype:
if lib.is_period(val):
dtype = PeriodDtype(freq=val.freq)
val = val.ordinal

return dtype, val


Expand Down