Skip to content

Commit 379f9c0

Browse files
authored
BUG: DataFrame.corrwith raising for pyarrow-backed dtypes (#52314)
* BUG: DataFrame.corrwith raising for pyarrow-backed dtypes * whatsnew
1 parent 0428536 commit 379f9c0

File tree

3 files changed

+20
-2
lines changed

3 files changed

+20
-2
lines changed

doc/source/whatsnew/v2.1.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ Timezones
181181
Numeric
182182
^^^^^^^
183183
- Bug in :meth:`Series.corr` and :meth:`Series.cov` raising ``AttributeError`` for masked dtypes (:issue:`51422`)
184+
- Bug in :meth:`DataFrame.corrwith` raising ``NotImplementedError`` for pyarrow-backed dtypes (:issue:`52314`)
184185
-
185186

186187
Conversion

pandas/core/internals/ops.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
NamedTuple,
77
)
88

9+
from pandas.core.dtypes.common import is_1d_only_ea_dtype
10+
911
if TYPE_CHECKING:
1012
from pandas._libs.internals import BlockPlacement
1113
from pandas._typing import ArrayLike
@@ -60,7 +62,12 @@ def operate_blockwise(
6062
res_blks: list[Block] = []
6163
for lvals, rvals, locs, left_ea, right_ea, rblk in _iter_block_pairs(left, right):
6264
res_values = array_op(lvals, rvals)
63-
if left_ea and not right_ea and hasattr(res_values, "reshape"):
65+
if (
66+
left_ea
67+
and not right_ea
68+
and hasattr(res_values, "reshape")
69+
and not is_1d_only_ea_dtype(res_values.dtype)
70+
):
6471
res_values = res_values.reshape(1, -1)
6572
nbs = rblk._split_op_result(res_values)
6673

pandas/tests/frame/methods/test_cov_corr.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,17 @@ def test_corr_numeric_only(self, meth, numeric_only):
274274

275275

276276
class TestDataFrameCorrWith:
277-
def test_corrwith(self, datetime_frame):
277+
@pytest.mark.parametrize(
278+
"dtype",
279+
[
280+
"float64",
281+
"Float64",
282+
pytest.param("float64[pyarrow]", marks=td.skip_if_no("pyarrow")),
283+
],
284+
)
285+
def test_corrwith(self, datetime_frame, dtype):
286+
datetime_frame = datetime_frame.astype(dtype)
287+
278288
a = datetime_frame
279289
noise = Series(np.random.randn(len(a)), index=a.index)
280290

0 commit comments

Comments
 (0)