Skip to content

Commit 200349c

Browse files
authored
[flake8-comprehensions] Skip C417 when lambda contains yield/yield from (#20201)
## Summary Fixes #20198
1 parent 0d4f7dd commit 200349c

File tree

3 files changed

+21
-0
lines changed

3 files changed

+21
-0
lines changed

crates/ruff_linter/resources/test/fixtures/flake8_comprehensions/C417.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,3 +75,7 @@ def func(arg1: int, arg2: int = 4):
7575
_ = t"{set(map(lambda x: x % 2 == 0, nums))}"
7676
_ = t"{dict(map(lambda v: (v, v**2), nums))}"
7777

78+
79+
# See https://github.com/astral-sh/ruff/issues/20198
80+
# No error: lambda contains `yield`, so map() should not be rewritten
81+
map(lambda x: (yield x), [1, 2, 3])

crates/ruff_linter/src/rules/flake8_comprehensions/rules/unnecessary_map.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,13 @@ pub(crate) fn unnecessary_map(checker: &Checker, call: &ast::ExprCall) {
122122
}
123123
};
124124

125+
// If the lambda body contains a `yield` or `yield from`, rewriting `map(lambda ...)` to a
126+
// generator expression or any comprehension is invalid Python syntax
127+
// (e.g., `yield` is not allowed inside generator or comprehension expressions). In such cases, skip.
128+
if lambda_contains_yield(&lambda.body) {
129+
return;
130+
}
131+
125132
for iterable in iterables {
126133
// For example, (x+1 for x in (c:=a)) is invalid syntax
127134
// so we can't suggest it.
@@ -183,6 +190,13 @@ fn map_lambda_and_iterables<'a>(
183190
Some((lambda, iterables))
184191
}
185192

193+
/// Returns true if the expression tree contains a `yield` or `yield from` expression.
194+
fn lambda_contains_yield(expr: &Expr) -> bool {
195+
any_over_expr(expr, &|expr| {
196+
matches!(expr, Expr::Yield(_) | Expr::YieldFrom(_))
197+
})
198+
}
199+
186200
/// A lambda as the first argument to `map()` has the "expected" arity when:
187201
///
188202
/// * It has exactly one parameter

crates/ruff_linter/src/rules/flake8_comprehensions/snapshots/ruff_linter__rules__flake8_comprehensions__tests__C417_C417.py.snap

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,7 @@ help: Replace `map()` with a set comprehension
342342
75 + _ = t"{ {x % 2 == 0 for x in nums} }"
343343
76 | _ = t"{dict(map(lambda v: (v, v**2), nums))}"
344344
77 |
345+
78 |
345346
note: This is an unsafe fix and may change runtime behavior
346347

347348
C417 [*] Unnecessary `map()` usage (rewrite using a dict comprehension)
@@ -359,4 +360,6 @@ help: Replace `map()` with a dict comprehension
359360
- _ = t"{dict(map(lambda v: (v, v**2), nums))}"
360361
76 + _ = t"{ {v: v**2 for v in nums} }"
361362
77 |
363+
78 |
364+
79 | # See https://github.com/astral-sh/ruff/issues/20198
362365
note: This is an unsafe fix and may change runtime behavior

0 commit comments

Comments
 (0)