Skip to content

Commit e6b96db

Browse files
committed
Auto merge of rust-lang#16590 - davidsemakula:unnecessary-else-diagnostic-fix, r=Veykril
fix: Fix false positives for "unnecessary else" diagnostic Completes rust-lang/rust-analyzer#16567 by `@ShoyuVanilla` (see rust-lang/rust-analyzer#16567 (comment)) Fixes rust-lang#16556
2 parents ac1029f + f2218e7 commit e6b96db

File tree

2 files changed

+151
-8
lines changed

2 files changed

+151
-8
lines changed

crates/hir-ty/src/diagnostics/expr.rs

+34-3
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use hir_expand::name;
1212
use itertools::Itertools;
1313
use rustc_hash::FxHashSet;
1414
use rustc_pattern_analysis::usefulness::{compute_match_usefulness, ValidityConstraint};
15+
use syntax::{ast, AstNode};
1516
use tracing::debug;
1617
use triomphe::Arc;
1718
use typed_arena::Arena;
@@ -108,7 +109,7 @@ impl ExprValidator {
108109
self.check_for_trailing_return(*body_expr, &body);
109110
}
110111
Expr::If { .. } => {
111-
self.check_for_unnecessary_else(id, expr, &body);
112+
self.check_for_unnecessary_else(id, expr, db);
112113
}
113114
Expr::Block { .. } => {
114115
self.validate_block(db, expr);
@@ -336,19 +337,49 @@ impl ExprValidator {
336337
}
337338
}
338339

339-
fn check_for_unnecessary_else(&mut self, id: ExprId, expr: &Expr, body: &Body) {
340+
fn check_for_unnecessary_else(&mut self, id: ExprId, expr: &Expr, db: &dyn HirDatabase) {
340341
if let Expr::If { condition: _, then_branch, else_branch } = expr {
341342
if else_branch.is_none() {
342343
return;
343344
}
344-
if let Expr::Block { statements, tail, .. } = &body.exprs[*then_branch] {
345+
if let Expr::Block { statements, tail, .. } = &self.body.exprs[*then_branch] {
345346
let last_then_expr = tail.or_else(|| match statements.last()? {
346347
Statement::Expr { expr, .. } => Some(*expr),
347348
_ => None,
348349
});
349350
if let Some(last_then_expr) = last_then_expr {
350351
let last_then_expr_ty = &self.infer[last_then_expr];
351352
if last_then_expr_ty.is_never() {
353+
// Only look at sources if the then branch diverges and we have an else branch.
354+
let (_, source_map) = db.body_with_source_map(self.owner);
355+
let Ok(source_ptr) = source_map.expr_syntax(id) else {
356+
return;
357+
};
358+
let root = source_ptr.file_syntax(db.upcast());
359+
let ast::Expr::IfExpr(if_expr) = source_ptr.value.to_node(&root) else {
360+
return;
361+
};
362+
let mut top_if_expr = if_expr;
363+
loop {
364+
let parent = top_if_expr.syntax().parent();
365+
let has_parent_expr_stmt_or_stmt_list =
366+
parent.as_ref().map_or(false, |node| {
367+
ast::ExprStmt::can_cast(node.kind())
368+
| ast::StmtList::can_cast(node.kind())
369+
});
370+
if has_parent_expr_stmt_or_stmt_list {
371+
// Only emit diagnostic if parent or direct ancestor is either
372+
// an expr stmt or a stmt list.
373+
break;
374+
}
375+
let Some(parent_if_expr) = parent.and_then(ast::IfExpr::cast) else {
376+
// Bail if parent is neither an if expr, an expr stmt nor a stmt list.
377+
return;
378+
};
379+
// Check parent if expr.
380+
top_if_expr = parent_if_expr;
381+
}
382+
352383
self.diagnostics
353384
.push(BodyValidationDiagnostic::RemoveUnnecessaryElse { if_expr: id })
354385
}

crates/ide-diagnostics/src/handlers/remove_unnecessary_else.rs

+117-5
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@ use hir::{db::ExpandDatabase, diagnostics::RemoveUnnecessaryElse, HirFileIdExt};
22
use ide_db::{assists::Assist, source_change::SourceChange};
33
use itertools::Itertools;
44
use syntax::{
5-
ast::{self, edit::IndentLevel},
5+
ast::{
6+
self,
7+
edit::{AstNodeEdit, IndentLevel},
8+
},
69
AstNode, SyntaxToken, TextRange,
710
};
811
use text_edit::TextEdit;
@@ -41,10 +44,15 @@ fn fixes(ctx: &DiagnosticsContext<'_>, d: &RemoveUnnecessaryElse) -> Option<Vec<
4144
indent = indent + 1;
4245
}
4346
let else_replacement = match if_expr.else_branch()? {
44-
ast::ElseBranch::Block(ref block) => {
45-
block.statements().map(|stmt| format!("\n{indent}{stmt}")).join("")
46-
}
47-
ast::ElseBranch::IfExpr(ref nested_if_expr) => {
47+
ast::ElseBranch::Block(block) => block
48+
.statements()
49+
.map(|stmt| format!("\n{indent}{stmt}"))
50+
.chain(block.tail_expr().map(|tail| format!("\n{indent}{tail}")))
51+
.join(""),
52+
ast::ElseBranch::IfExpr(mut nested_if_expr) => {
53+
if has_parent_if_expr {
54+
nested_if_expr = nested_if_expr.indent(IndentLevel(1))
55+
}
4856
format!("\n{indent}{nested_if_expr}")
4957
}
5058
};
@@ -171,6 +179,41 @@ fn test() {
171179
);
172180
}
173181

182+
#[test]
183+
fn remove_unnecessary_else_for_return3() {
184+
check_diagnostics_with_needless_return_disabled(
185+
r#"
186+
fn test(a: bool) -> i32 {
187+
if a {
188+
return 1;
189+
} else {
190+
//^^^^ 💡 weak: remove unnecessary else block
191+
0
192+
}
193+
}
194+
"#,
195+
);
196+
check_fix(
197+
r#"
198+
fn test(a: bool) -> i32 {
199+
if a {
200+
return 1;
201+
} else$0 {
202+
0
203+
}
204+
}
205+
"#,
206+
r#"
207+
fn test(a: bool) -> i32 {
208+
if a {
209+
return 1;
210+
}
211+
0
212+
}
213+
"#,
214+
);
215+
}
216+
174217
#[test]
175218
fn remove_unnecessary_else_for_return_in_child_if_expr() {
176219
check_diagnostics_with_needless_return_disabled(
@@ -214,6 +257,41 @@ fn test() {
214257
);
215258
}
216259

260+
#[test]
261+
fn remove_unnecessary_else_for_return_in_child_if_expr2() {
262+
check_fix(
263+
r#"
264+
fn test() {
265+
if foo {
266+
do_something();
267+
} else if qux {
268+
return bar;
269+
} else$0 if quux {
270+
do_something_else();
271+
} else {
272+
do_something_else2();
273+
}
274+
}
275+
"#,
276+
r#"
277+
fn test() {
278+
if foo {
279+
do_something();
280+
} else {
281+
if qux {
282+
return bar;
283+
}
284+
if quux {
285+
do_something_else();
286+
} else {
287+
do_something_else2();
288+
}
289+
}
290+
}
291+
"#,
292+
);
293+
}
294+
217295
#[test]
218296
fn remove_unnecessary_else_for_break() {
219297
check_diagnostics(
@@ -384,6 +462,40 @@ fn test() {
384462
return bar;
385463
}
386464
}
465+
"#,
466+
);
467+
}
468+
469+
#[test]
470+
fn no_diagnostic_if_not_expr_stmt() {
471+
check_diagnostics_with_needless_return_disabled(
472+
r#"
473+
fn test1() {
474+
let _x = if a {
475+
return;
476+
} else {
477+
1
478+
};
479+
}
480+
481+
fn test2() {
482+
let _x = if a {
483+
return;
484+
} else if b {
485+
return;
486+
} else if c {
487+
1
488+
} else {
489+
return;
490+
};
491+
}
492+
"#,
493+
);
494+
check_diagnostics(
495+
r#"
496+
fn test3() -> u8 {
497+
foo(if a { return 1 } else { 0 })
498+
}
387499
"#,
388500
);
389501
}

0 commit comments

Comments
 (0)