Skip to content

Commit f36ea01

Browse files
committed
Fix async iterator body stripping (#15491)
Fixes #15489
1 parent ba7887b commit f36ea01

File tree

2 files changed

+43
-28
lines changed

2 files changed

+43
-28
lines changed

mypy/fastparse.py

+33-25
Original file line numberDiff line numberDiff line change
@@ -521,40 +521,48 @@ def translate_stmt_list(
521521
return [block]
522522

523523
stack = self.class_and_function_stack
524-
if self.strip_function_bodies and len(stack) == 1 and stack[0] == "F":
524+
# Fast case for stripping function bodies
525+
if (
526+
can_strip
527+
and self.strip_function_bodies
528+
and len(stack) == 1
529+
and stack[0] == "F"
530+
and not is_coroutine
531+
):
525532
return []
526533

527534
res: list[Statement] = []
528535
for stmt in stmts:
529536
node = self.visit(stmt)
530537
res.append(node)
531538

532-
if (
533-
self.strip_function_bodies
534-
and can_strip
535-
and stack[-2:] == ["C", "F"]
536-
and not is_possible_trivial_body(res)
537-
):
538-
# We only strip method bodies if they don't assign to an attribute, as
539-
# this may define an attribute which has an externally visible effect.
540-
visitor = FindAttributeAssign()
541-
for s in res:
542-
s.accept(visitor)
543-
if visitor.found:
544-
break
545-
else:
546-
if is_coroutine:
547-
# Yields inside an async function affect the return type and should not
548-
# be stripped.
549-
yield_visitor = FindYield()
539+
# Slow case for stripping function bodies
540+
if can_strip and self.strip_function_bodies:
541+
if stack[-2:] == ["C", "F"]:
542+
if is_possible_trivial_body(res):
543+
can_strip = False
544+
else:
545+
# We only strip method bodies if they don't assign to an attribute, as
546+
# this may define an attribute which has an externally visible effect.
547+
visitor = FindAttributeAssign()
550548
for s in res:
551-
s.accept(yield_visitor)
552-
if yield_visitor.found:
549+
s.accept(visitor)
550+
if visitor.found:
551+
can_strip = False
553552
break
554-
else:
555-
return []
556-
else:
557-
return []
553+
554+
if can_strip and stack[-1] == "F" and is_coroutine:
555+
# Yields inside an async function affect the return type and should not
556+
# be stripped.
557+
yield_visitor = FindYield()
558+
for s in res:
559+
s.accept(yield_visitor)
560+
if yield_visitor.found:
561+
can_strip = False
562+
break
563+
564+
if can_strip:
565+
return []
558566
return res
559567

560568
def translate_type_comment(

test-data/unit/check-async-await.test

+10-3
Original file line numberDiff line numberDiff line change
@@ -945,17 +945,21 @@ async def bar(x: Union[A, B]) -> None:
945945
[typing fixtures/typing-async.pyi]
946946

947947
[case testAsyncIteratorWithIgnoredErrors]
948-
from m import L
948+
import m
949949

950-
async def func(l: L) -> None:
950+
async def func(l: m.L) -> None:
951951
reveal_type(l.get_iterator) # N: Revealed type is "def () -> typing.AsyncIterator[builtins.str]"
952952
reveal_type(l.get_iterator2) # N: Revealed type is "def () -> typing.AsyncIterator[builtins.str]"
953953
async for i in l.get_iterator():
954954
reveal_type(i) # N: Revealed type is "builtins.str"
955955

956+
reveal_type(m.get_generator) # N: Revealed type is "def () -> typing.AsyncGenerator[builtins.int, None]"
957+
async for i2 in m.get_generator():
958+
reveal_type(i2) # N: Revealed type is "builtins.int"
959+
956960
[file m.py]
957961
# mypy: ignore-errors=True
958-
from typing import AsyncIterator
962+
from typing import AsyncIterator, AsyncGenerator
959963

960964
class L:
961965
async def some_func(self, i: int) -> str:
@@ -968,6 +972,9 @@ class L:
968972
if self:
969973
a = (yield 'x')
970974

975+
async def get_generator() -> AsyncGenerator[int, None]:
976+
yield 1
977+
971978
[builtins fixtures/async_await.pyi]
972979
[typing fixtures/typing-async.pyi]
973980

0 commit comments

Comments
 (0)