Skip to content

Fix pd.merge to preserve ExtensionArrays dtypes #20745

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 22, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pandas/core/dtypes/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1807,7 +1807,7 @@ def _get_dtype(arr_or_dtype):
return arr_or_dtype
elif isinstance(arr_or_dtype, type):
return np.dtype(arr_or_dtype)
elif isinstance(arr_or_dtype, CategoricalDtype):
elif isinstance(arr_or_dtype, ExtensionDtype):
return arr_or_dtype
elif isinstance(arr_or_dtype, DatetimeTZDtype):
return arr_or_dtype
Expand Down
12 changes: 9 additions & 3 deletions pandas/core/internals.py
Original file line number Diff line number Diff line change
Expand Up @@ -5541,8 +5541,14 @@ def concatenate_join_units(join_units, concat_axis, copy):
if len(to_concat) == 1:
# Only one block, nothing to concatenate.
concat_values = to_concat[0]
if copy and concat_values.base is not None:
concat_values = concat_values.copy()
if copy:
if isinstance(concat_values, np.ndarray):
# non-reindexed (=not yet copied) arrays are made into a view
# in JoinUnit.get_reindexed_values
if concat_values.base is not None:
concat_values = concat_values.copy()
else:
concat_values = concat_values.copy()
else:
concat_values = _concat._concat_compat(to_concat, axis=concat_axis)

Expand Down Expand Up @@ -5823,7 +5829,7 @@ def get_reindexed_values(self, empty_dtype, upcasted_na):
# External code requested filling/upcasting, bool values must
# be upcasted to object to avoid being upcasted to numeric.
values = self.block.astype(np.object_).values
elif self.block.is_categorical:
elif self.block.is_extension:
values = self.block.values
else:
# No dtype upcasting is done here, it will be performed during
Expand Down
21 changes: 21 additions & 0 deletions pandas/tests/extension/base/reshaping.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,24 @@ def test_set_frame_overwrite_object(self, data):
df = pd.DataFrame({"A": [1] * len(data)}, dtype=object)
df['A'] = data
assert df.dtypes['A'] == data.dtype

def test_merge(self, data, na_value):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you prob should test with with the how=join_type fixture.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They both need a different expected result, I don't think it is really worth it here in this case?
(the test is also not meant to be a full cover of the merge function (for that we already have other tests), just to test that basic use cases of concatting works with extension arrays)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, its worth doing way more tests than a single usecase. but ok here I guess.

# GH-20743
df1 = pd.DataFrame({'ext': data[:3], 'int1': [1, 2, 3],
'key': [0, 1, 2]})
df2 = pd.DataFrame({'int2': [1, 2, 3, 4], 'key': [0, 0, 1, 3]})

res = pd.merge(df1, df2)
exp = pd.DataFrame(
{'int1': [1, 1, 2], 'int2': [1, 2, 3], 'key': [0, 0, 1],
'ext': data._constructor_from_sequence(
[data[0], data[0], data[1]])})
self.assert_frame_equal(res, exp[['ext', 'int1', 'key', 'int2']])

res = pd.merge(df1, df2, how='outer')
exp = pd.DataFrame(
{'int1': [1, 1, 2, 3, np.nan], 'int2': [1, 2, 3, np.nan, 4],
'key': [0, 0, 1, 2, 3],
'ext': data._constructor_from_sequence(
[data[0], data[0], data[1], data[2], na_value])})
self.assert_frame_equal(res, exp[['ext', 'int1', 'key', 'int2']])
4 changes: 4 additions & 0 deletions pandas/tests/extension/category/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ def test_align(self, data, na_value):
def test_align_frame(self, data, na_value):
pass

@pytest.mark.skip(reason="Unobserved categories preseved in concat.")
def test_merge(self, data, na_value):
pass


class TestGetitem(base.BaseGetitemTests):
@pytest.mark.skip(reason="Backwards compatibility")
Expand Down
8 changes: 8 additions & 0 deletions pandas/tests/extension/decimal/test_decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,14 @@ def assert_series_equal(self, left, right, *args, **kwargs):

def assert_frame_equal(self, left, right, *args, **kwargs):
# TODO(EA): select_dtypes
tm.assert_index_equal(
left.columns, right.columns,
exact=kwargs.get('check_column_type', 'equiv'),
check_names=kwargs.get('check_names', True),
check_exact=kwargs.get('check_exact', False),
check_categorical=kwargs.get('check_categorical', True),
obj='{obj}.columns'.format(obj=kwargs.get('obj', 'DataFrame')))

decimals = (left.dtypes == 'decimal').index

for col in decimals:
Expand Down