Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,7 @@ def func(arg1: int, arg2: int = 4):
_ = t"{set(map(lambda x: x % 2 == 0, nums))}"
_ = t"{dict(map(lambda v: (v, v**2), nums))}"


# See https://github.com/astral-sh/ruff/issues/20198
# No error: lambda contains `yield`, so map() should not be rewritten
map(lambda x: (yield x), [1, 2, 3])
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,13 @@ pub(crate) fn unnecessary_map(checker: &Checker, call: &ast::ExprCall) {
}
};

// If the lambda body contains a `yield` or `yield from`, rewriting `map(lambda ...)` to a
// generator expression or any comprehension is invalid Python syntax
// (e.g., `yield` is not allowed inside generator or comprehension expressions). In such cases, skip.
if lambda_contains_yield(&lambda.body) {
return;
}

for iterable in iterables {
// For example, (x+1 for x in (c:=a)) is invalid syntax
// so we can't suggest it.
Expand Down Expand Up @@ -183,6 +190,13 @@ fn map_lambda_and_iterables<'a>(
Some((lambda, iterables))
}

/// Returns true if the expression tree contains a `yield` or `yield from` expression.
fn lambda_contains_yield(expr: &Expr) -> bool {
any_over_expr(expr, &|expr| {
matches!(expr, Expr::Yield(_) | Expr::YieldFrom(_))
})
}

/// A lambda as the first argument to `map()` has the "expected" arity when:
///
/// * It has exactly one parameter
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,7 @@ help: Replace `map()` with a set comprehension
75 + _ = t"{ {x % 2 == 0 for x in nums} }"
76 | _ = t"{dict(map(lambda v: (v, v**2), nums))}"
77 |
78 |
note: This is an unsafe fix and may change runtime behavior

C417 [*] Unnecessary `map()` usage (rewrite using a dict comprehension)
Expand All @@ -359,4 +360,6 @@ help: Replace `map()` with a dict comprehension
- _ = t"{dict(map(lambda v: (v, v**2), nums))}"
76 + _ = t"{ {v: v**2 for v in nums} }"
77 |
78 |
79 | # See https://github.com/astral-sh/ruff/issues/20198
note: This is an unsafe fix and may change runtime behavior
Loading