Skip to content

Commit d1c43dc

Browse files
authored
REF/TYP: avoid need for NDFrameTb (#51969)
1 parent a5a3300 commit d1c43dc

File tree

3 files changed

+35
-56
lines changed

3 files changed

+35
-56
lines changed

pandas/_typing.py

-3
Original file line numberDiff line numberDiff line change
@@ -131,9 +131,6 @@
131131
# Series is passed into a function, a Series is always returned and if a DataFrame is
132132
# passed in, a DataFrame is always returned.
133133
NDFrameT = TypeVar("NDFrameT", bound="NDFrame")
134-
# same as NDFrameT, needed when binding two pairs of parameters to potentially
135-
# separate NDFrame-subclasses (see NDFrame.align)
136-
NDFrameTb = TypeVar("NDFrameTb", bound="NDFrame")
137134

138135
NumpyIndexT = TypeVar("NumpyIndexT", np.ndarray, "Index")
139136

pandas/core/generic.py

+33-53
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@
6868
Manager,
6969
NaPosition,
7070
NDFrameT,
71-
NDFrameTb,
7271
RandomState,
7372
Renamer,
7473
Scalar,
@@ -9445,7 +9444,7 @@ def align(
94459444
{c: self for c in other.columns}, **other._construct_axes_dict()
94469445
)
94479446
# error: Incompatible return value type (got "Tuple[DataFrame,
9448-
# DataFrame]", expected "Tuple[NDFrameT, NDFrameTb]")
9447+
# DataFrame]", expected "Tuple[Self, NDFrameT]")
94499448
return df._align_frame( # type: ignore[return-value]
94509449
other, # type: ignore[arg-type]
94519450
join=join,
@@ -9456,7 +9455,7 @@ def align(
94569455
method=method,
94579456
limit=limit,
94589457
fill_axis=fill_axis,
9459-
)
9458+
)[:2]
94609459
elif isinstance(other, ABCSeries):
94619460
# this means self is a DataFrame, and we need to broadcast
94629461
# other
@@ -9465,7 +9464,7 @@ def align(
94659464
{c: other for c in self.columns}, **self._construct_axes_dict()
94669465
)
94679466
# error: Incompatible return value type (got "Tuple[NDFrameT,
9468-
# DataFrame]", expected "Tuple[NDFrameT, NDFrameTb]")
9467+
# DataFrame]", expected "Tuple[Self, NDFrameT]")
94699468
return self._align_frame( # type: ignore[return-value]
94709469
df,
94719470
join=join,
@@ -9476,14 +9475,13 @@ def align(
94769475
method=method,
94779476
limit=limit,
94789477
fill_axis=fill_axis,
9479-
)
9478+
)[:2]
94809479

9480+
_right: DataFrame | Series
94819481
if axis is not None:
94829482
axis = self._get_axis_number(axis)
94839483
if isinstance(other, ABCDataFrame):
9484-
# error: Incompatible return value type (got "Tuple[NDFrameT, DataFrame]",
9485-
# expected "Tuple[NDFrameT, NDFrameTb]")
9486-
return self._align_frame( # type: ignore[return-value]
9484+
left, _right, join_index = self._align_frame(
94879485
other,
94889486
join=join,
94899487
axis=axis,
@@ -9494,10 +9492,9 @@ def align(
94949492
limit=limit,
94959493
fill_axis=fill_axis,
94969494
)
9495+
94979496
elif isinstance(other, ABCSeries):
9498-
# error: Incompatible return value type (got "Tuple[NDFrameT, Series]",
9499-
# expected "Tuple[NDFrameT, NDFrameTb]")
9500-
return self._align_series( # type: ignore[return-value]
9497+
left, _right, join_index = self._align_series(
95019498
other,
95029499
join=join,
95039500
axis=axis,
@@ -9511,9 +9508,27 @@ def align(
95119508
else: # pragma: no cover
95129509
raise TypeError(f"unsupported type: {type(other)}")
95139510

9511+
right = cast(NDFrameT, _right)
9512+
if self.ndim == 1 or axis == 0:
9513+
# If we are aligning timezone-aware DatetimeIndexes and the timezones
9514+
# do not match, convert both to UTC.
9515+
if is_datetime64tz_dtype(left.index.dtype):
9516+
if left.index.tz != right.index.tz:
9517+
if join_index is not None:
9518+
# GH#33671 copy to ensure we don't change the index on
9519+
# our original Series
9520+
left = left.copy(deep=False)
9521+
right = right.copy(deep=False)
9522+
left.index = join_index
9523+
right.index = join_index
9524+
9525+
left = left.__finalize__(self)
9526+
right = right.__finalize__(other)
9527+
return left, right
9528+
95149529
@final
95159530
def _align_frame(
9516-
self: NDFrameT,
9531+
self,
95179532
other: DataFrame,
95189533
join: AlignJoin = "outer",
95199534
axis: Axis | None = None,
@@ -9523,7 +9538,7 @@ def _align_frame(
95239538
method=None,
95249539
limit=None,
95259540
fill_axis: Axis = 0,
9526-
) -> tuple[NDFrameT, DataFrame]:
9541+
) -> tuple[Self, DataFrame, Index | None]:
95279542
# defaults
95289543
join_index, join_columns = None, None
95299544
ilidx, iridx = None, None
@@ -9562,22 +9577,14 @@ def _align_frame(
95629577
)
95639578

95649579
if method is not None:
9565-
_left = left.fillna(method=method, axis=fill_axis, limit=limit)
9566-
assert _left is not None # needed for mypy
9567-
left = _left
9580+
left = left.fillna(method=method, axis=fill_axis, limit=limit)
95689581
right = right.fillna(method=method, axis=fill_axis, limit=limit)
95699582

9570-
# if DatetimeIndex have different tz, convert to UTC
9571-
left, right = _align_as_utc(left, right, join_index)
9572-
9573-
return (
9574-
left.__finalize__(self),
9575-
right.__finalize__(other),
9576-
)
9583+
return left, right, join_index
95779584

95789585
@final
95799586
def _align_series(
9580-
self: NDFrameT,
9587+
self,
95819588
other: Series,
95829589
join: AlignJoin = "outer",
95839590
axis: Axis | None = None,
@@ -9587,7 +9594,7 @@ def _align_series(
95879594
method=None,
95889595
limit=None,
95899596
fill_axis: Axis = 0,
9590-
) -> tuple[NDFrameT, Series]:
9597+
) -> tuple[Self, Series, Index | None]:
95919598
is_series = isinstance(self, ABCSeries)
95929599
if copy and using_copy_on_write():
95939600
copy = False
@@ -9649,14 +9656,7 @@ def _align_series(
96499656
left = left.fillna(fill_value, method=method, limit=limit, axis=fill_axis)
96509657
right = right.fillna(fill_value, method=method, limit=limit)
96519658

9652-
# if DatetimeIndex have different tz, convert to UTC
9653-
if is_series or (not is_series and axis == 0):
9654-
left, right = _align_as_utc(left, right, join_index)
9655-
9656-
return (
9657-
left.__finalize__(self),
9658-
right.__finalize__(other),
9659-
)
9659+
return left, right, join_index
96609660

96619661
@final
96629662
def _where(
@@ -12819,23 +12819,3 @@ def _doc_params(cls):
1281912819
The required number of valid values to perform the operation. If fewer than
1282012820
``min_count`` non-NA values are present the result will be NA.
1282112821
"""
12822-
12823-
12824-
def _align_as_utc(
12825-
left: NDFrameT, right: NDFrameTb, join_index: Index | None
12826-
) -> tuple[NDFrameT, NDFrameTb]:
12827-
"""
12828-
If we are aligning timezone-aware DatetimeIndexes and the timezones
12829-
do not match, convert both to UTC.
12830-
"""
12831-
if is_datetime64tz_dtype(left.index.dtype):
12832-
if left.index.tz != right.index.tz:
12833-
if join_index is not None:
12834-
# GH#33671 ensure we don't change the index on
12835-
# our original Series (NB: by default deep=False)
12836-
left = left.copy()
12837-
right = right.copy()
12838-
left.index = join_index
12839-
right.index = join_index
12840-
12841-
return left, right

pandas/tests/frame/methods/test_align.py

+2
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ def test_align_float(self, float_frame, using_copy_on_write):
101101
with pytest.raises(ValueError, match=msg):
102102
float_frame.align(af.iloc[0, :3], join="inner", axis=2)
103103

104+
def test_align_frame_with_series(self, float_frame):
104105
# align dataframe to series with broadcast or not
105106
idx = float_frame.index
106107
s = Series(range(len(idx)), index=idx)
@@ -118,6 +119,7 @@ def test_align_float(self, float_frame, using_copy_on_write):
118119
)
119120
tm.assert_frame_equal(right, expected)
120121

122+
def test_align_series_condition(self):
121123
# see gh-9558
122124
df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
123125
result = df[df["a"] == 2]

0 commit comments

Comments
 (0)