Skip to content

Commit 4b78141

Browse files
Generate one fix per statement for flake8-type-checking rules (#4915)
1 parent 5235977 commit 4b78141

11 files changed

Lines changed: 731 additions & 304 deletions

crates/ruff/src/checkers/ast/mod.rs

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5069,22 +5069,19 @@ impl<'a> Checker<'a> {
50695069
.copied()
50705070
.collect()
50715071
};
5072-
for binding_id in scope.binding_ids() {
5073-
let binding = &self.semantic_model.bindings[binding_id];
50745072

5075-
flake8_type_checking::rules::runtime_import_in_type_checking_block(
5076-
self,
5077-
binding,
5078-
&mut diagnostics,
5079-
);
5073+
flake8_type_checking::rules::runtime_import_in_type_checking_block(
5074+
self,
5075+
scope,
5076+
&mut diagnostics,
5077+
);
50805078

5081-
flake8_type_checking::rules::typing_only_runtime_import(
5082-
self,
5083-
binding,
5084-
&runtime_imports,
5085-
&mut diagnostics,
5086-
);
5087-
}
5079+
flake8_type_checking::rules::typing_only_runtime_import(
5080+
self,
5081+
scope,
5082+
&runtime_imports,
5083+
&mut diagnostics,
5084+
);
50885085
}
50895086

50905087
if self.enabled(Rule::UnusedImport) {

crates/ruff/src/importer/mod.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -83,12 +83,12 @@ impl<'a> Importer<'a> {
8383
/// import statement.
8484
pub(crate) fn runtime_import_edit(
8585
&self,
86-
import: &StmtImport,
86+
import: &StmtImports,
8787
at: TextSize,
8888
) -> Result<RuntimeImportEdit> {
8989
// Generate the modified import statement.
9090
let content = autofix::codemods::retain_imports(
91-
&[import.qualified_name],
91+
&import.qualified_names,
9292
import.stmt,
9393
self.locator,
9494
self.stylist,
@@ -114,13 +114,13 @@ impl<'a> Importer<'a> {
114114
/// `TYPE_CHECKING` block.
115115
pub(crate) fn typing_import_edit(
116116
&self,
117-
import: &StmtImport,
117+
import: &StmtImports,
118118
at: TextSize,
119119
semantic_model: &SemanticModel,
120120
) -> Result<TypingImportEdit> {
121121
// Generate the modified import statement.
122122
let content = autofix::codemods::retain_imports(
123-
&[import.qualified_name],
123+
&import.qualified_names,
124124
import.stmt,
125125
self.locator,
126126
self.stylist,
@@ -442,12 +442,12 @@ impl<'a> ImportRequest<'a> {
442442
}
443443
}
444444

445-
/// An existing module or member import, located within an import statement.
446-
pub(crate) struct StmtImport<'a> {
445+
/// An existing list of module or member imports, located within an import statement.
446+
pub(crate) struct StmtImports<'a> {
447447
/// The import statement.
448448
pub(crate) stmt: &'a Stmt,
449-
/// The "full name" of the imported module or member.
450-
pub(crate) qualified_name: &'a str,
449+
/// The "qualified names" of the imported modules or members.
450+
pub(crate) qualified_names: Vec<&'a str>,
451451
}
452452

453453
/// The result of an [`Importer::get_or_import_symbol`] call.

crates/ruff/src/rules/flake8_type_checking/mod.rs

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,48 @@ mod tests {
282282
"#,
283283
"import_from_type_checking_block"
284284
)]
285+
#[test_case(
286+
r#"
287+
from __future__ import annotations
288+
289+
from typing import TYPE_CHECKING
290+
291+
from pandas import (
292+
DataFrame, # DataFrame
293+
Series, # Series
294+
)
295+
296+
def f(x: DataFrame, y: Series):
297+
pass
298+
"#,
299+
"multiple_members"
300+
)]
301+
#[test_case(
302+
r#"
303+
from __future__ import annotations
304+
305+
from typing import TYPE_CHECKING
306+
307+
import os, sys
308+
309+
def f(x: os, y: sys):
310+
pass
311+
"#,
312+
"multiple_modules_same_type"
313+
)]
314+
#[test_case(
315+
r#"
316+
from __future__ import annotations
317+
318+
from typing import TYPE_CHECKING
319+
320+
import os, pandas
321+
322+
def f(x: os, y: pandas):
323+
pass
324+
"#,
325+
"multiple_modules_different_types"
326+
)]
285327
fn contents(contents: &str, snapshot: &str) {
286328
let diagnostics = test_snippet(
287329
contents,
Lines changed: 163 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
1+
use anyhow::Result;
2+
use ruff_text_size::TextRange;
3+
use rustc_hash::FxHashMap;
4+
15
use ruff_diagnostics::{AutofixKind, Diagnostic, Fix, Violation};
26
use ruff_macros::{derive_message_formats, violation};
3-
use ruff_python_semantic::binding::Binding;
7+
use ruff_python_semantic::node::NodeId;
8+
use ruff_python_semantic::reference::ReferenceId;
9+
use ruff_python_semantic::scope::Scope;
410

511
use crate::autofix;
612
use crate::checkers::ast::Checker;
7-
use crate::importer::StmtImport;
8-
use crate::registry::AsRule;
13+
use crate::codes::Rule;
14+
use crate::importer::StmtImports;
915

1016
/// ## What it does
1117
/// Checks for runtime imports defined in a type-checking block.
@@ -61,72 +67,172 @@ impl Violation for RuntimeImportInTypeCheckingBlock {
6167
/// TCH004
6268
pub(crate) fn runtime_import_in_type_checking_block(
6369
checker: &Checker,
64-
binding: &Binding,
70+
scope: &Scope,
6571
diagnostics: &mut Vec<Diagnostic>,
6672
) {
67-
let Some(qualified_name) = binding.qualified_name() else {
68-
return;
69-
};
73+
// Collect all runtime imports by statement.
74+
let mut errors_by_statement: FxHashMap<NodeId, Vec<Import>> = FxHashMap::default();
75+
let mut ignores_by_statement: FxHashMap<NodeId, Vec<Import>> = FxHashMap::default();
7076

71-
let Some(reference_id) = binding.references.first() else {
72-
return;
73-
};
77+
for binding_id in scope.binding_ids() {
78+
let binding = &checker.semantic_model().bindings[binding_id];
7479

75-
if binding.context.is_typing()
76-
&& binding.references().any(|reference_id| {
77-
checker
78-
.semantic_model()
79-
.references
80-
.resolve(reference_id)
81-
.context()
82-
.is_runtime()
83-
})
80+
let Some(qualified_name) = binding.qualified_name() else {
81+
continue;
82+
};
83+
84+
let Some(reference_id) = binding.references.first().copied() else {
85+
continue;
86+
};
87+
88+
if binding.context.is_typing()
89+
&& binding.references().any(|reference_id| {
90+
checker
91+
.semantic_model()
92+
.references
93+
.resolve(reference_id)
94+
.context()
95+
.is_runtime()
96+
})
97+
{
98+
let Some(stmt_id) = binding.source else {
99+
continue;
100+
};
101+
102+
let import = Import {
103+
qualified_name,
104+
reference_id,
105+
trimmed_range: binding.trimmed_range(checker.semantic_model(), checker.locator),
106+
parent_range: binding.parent_range(checker.semantic_model()),
107+
};
108+
109+
if checker.rule_is_ignored(
110+
Rule::RuntimeImportInTypeCheckingBlock,
111+
import.trimmed_range.start(),
112+
) || import.parent_range.map_or(false, |parent_range| {
113+
checker
114+
.rule_is_ignored(Rule::RuntimeImportInTypeCheckingBlock, parent_range.start())
115+
}) {
116+
ignores_by_statement
117+
.entry(stmt_id)
118+
.or_default()
119+
.push(import);
120+
} else {
121+
errors_by_statement.entry(stmt_id).or_default().push(import);
122+
}
123+
}
124+
}
125+
126+
// Generate a diagnostic for every import, but share a fix across all imports within the same
127+
// statement (excluding those that are ignored).
128+
for (stmt_id, imports) in errors_by_statement {
129+
let fix = if checker.patch(Rule::RuntimeImportInTypeCheckingBlock) {
130+
fix_imports(checker, stmt_id, &imports).ok()
131+
} else {
132+
None
133+
};
134+
135+
for Import {
136+
qualified_name,
137+
trimmed_range,
138+
parent_range,
139+
..
140+
} in imports
141+
{
142+
let mut diagnostic = Diagnostic::new(
143+
RuntimeImportInTypeCheckingBlock {
144+
qualified_name: qualified_name.to_string(),
145+
},
146+
trimmed_range,
147+
);
148+
if let Some(range) = parent_range {
149+
diagnostic.set_parent(range.start());
150+
}
151+
if let Some(fix) = fix.as_ref() {
152+
diagnostic.set_fix(fix.clone());
153+
}
154+
diagnostics.push(diagnostic);
155+
}
156+
}
157+
158+
// Separately, generate a diagnostic for every _ignored_ import, to ensure that the
159+
// suppression comments aren't marked as unused.
160+
for Import {
161+
qualified_name,
162+
trimmed_range,
163+
parent_range,
164+
..
165+
} in ignores_by_statement.into_values().flatten()
84166
{
85167
let mut diagnostic = Diagnostic::new(
86168
RuntimeImportInTypeCheckingBlock {
87169
qualified_name: qualified_name.to_string(),
88170
},
89-
binding.trimmed_range(checker.semantic_model(), checker.locator),
171+
trimmed_range,
90172
);
91-
if let Some(range) = binding.parent_range(checker.semantic_model()) {
173+
if let Some(range) = parent_range {
92174
diagnostic.set_parent(range.start());
93175
}
176+
diagnostics.push(diagnostic);
177+
}
178+
}
94179

95-
if checker.patch(diagnostic.kind.rule()) {
96-
diagnostic.try_set_fix(|| {
97-
// Step 1) Remove the import.
98-
// SAFETY: All non-builtin bindings have a source.
99-
let source = binding.source.unwrap();
100-
let stmt = checker.semantic_model().stmts[source];
101-
let parent = checker.semantic_model().stmts.parent(stmt);
102-
let remove_import_edit = autofix::edits::remove_unused_imports(
103-
std::iter::once(qualified_name),
104-
stmt,
105-
parent,
106-
checker.locator,
107-
checker.indexer,
108-
checker.stylist,
109-
)?;
110-
111-
// Step 2) Add the import to the top-level.
112-
let reference = checker.semantic_model().references.resolve(*reference_id);
113-
let add_import_edit = checker.importer.runtime_import_edit(
114-
&StmtImport {
115-
stmt,
116-
qualified_name,
117-
},
118-
reference.range().start(),
119-
)?;
120-
121-
Ok(
122-
Fix::suggested_edits(remove_import_edit, add_import_edit.into_edits())
123-
.isolate(checker.isolation(parent)),
124-
)
125-
});
126-
}
180+
/// A runtime-required import with its surrounding context.
181+
struct Import<'a> {
182+
/// The qualified name of the import (e.g., `typing.List` for `from typing import List`).
183+
qualified_name: &'a str,
184+
/// The first reference to the imported symbol.
185+
reference_id: ReferenceId,
186+
/// The trimmed range of the import (e.g., `List` in `from typing import List`).
187+
trimmed_range: TextRange,
188+
/// The range of the import's parent statement.
189+
parent_range: Option<TextRange>,
190+
}
127191

128-
if checker.enabled(diagnostic.kind.rule()) {
129-
diagnostics.push(diagnostic);
130-
}
131-
}
192+
/// Generate a [`Fix`] to remove runtime imports from a type-checking block.
193+
fn fix_imports(checker: &Checker, stmt_id: NodeId, imports: &[Import]) -> Result<Fix> {
194+
let stmt = checker.semantic_model().stmts[stmt_id];
195+
let parent = checker.semantic_model().stmts.parent(stmt);
196+
let qualified_names: Vec<&str> = imports
197+
.iter()
198+
.map(|Import { qualified_name, .. }| *qualified_name)
199+
.collect();
200+
201+
// Find the first reference across all imports.
202+
let at = imports
203+
.iter()
204+
.map(|Import { reference_id, .. }| {
205+
checker
206+
.semantic_model()
207+
.references
208+
.resolve(*reference_id)
209+
.range()
210+
.start()
211+
})
212+
.min()
213+
.expect("Expected at least one import");
214+
215+
// Step 1) Remove the import.
216+
let remove_import_edit = autofix::edits::remove_unused_imports(
217+
qualified_names.iter().copied(),
218+
stmt,
219+
parent,
220+
checker.locator,
221+
checker.indexer,
222+
checker.stylist,
223+
)?;
224+
225+
// Step 2) Add the import to the top-level.
226+
let add_import_edit = checker.importer.runtime_import_edit(
227+
&StmtImports {
228+
stmt,
229+
qualified_names,
230+
},
231+
at,
232+
)?;
233+
234+
Ok(
235+
Fix::suggested_edits(remove_import_edit, add_import_edit.into_edits())
236+
.isolate(checker.isolation(parent)),
237+
)
132238
}

0 commit comments

Comments
 (0)