Skip to content

Fix Assist "replace named generic type with impl trait" #14945

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 2, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
238 changes: 188 additions & 50 deletions crates/ide-assists/src/handlers/replace_named_generic_with_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,17 @@ use hir::Semantics;
use ide_db::{
base_db::{FileId, FileRange},
defs::Definition,
search::SearchScope,
search::{SearchScope, UsageSearchResult},
RootDatabase,
};
use syntax::{
ast::{self, make::impl_trait_type, HasGenericParams, HasName, HasTypeBounds},
ted, AstNode,
ast::{
self, make::impl_trait_type, HasGenericParams, HasName, HasTypeBounds, Name, NameLike,
PathType,
},
match_ast, ted, AstNode,
};
use text_edit::TextRange;

use crate::{AssistContext, AssistId, AssistKind, Assists};

Expand Down Expand Up @@ -36,87 +40,131 @@ pub(crate) fn replace_named_generic_with_impl(
let type_bound_list = type_param.type_bound_list()?;

let fn_ = type_param.syntax().ancestors().find_map(ast::Fn::cast)?;
let params = fn_
.param_list()?
.params()
.filter_map(|param| {
// function parameter type needs to match generic type name
if let ast::Type::PathType(path_type) = param.ty()? {
let left = path_type.path()?.segment()?.name_ref()?.ident_token()?.to_string();
let right = type_param_name.to_string();
if left == right {
Some(param)
} else {
None
}
} else {
None
}
})
.collect::<Vec<_>>();

if params.is_empty() {
return None;
}
let param_list_text_range = fn_.param_list()?.syntax().text_range();

let type_param_hir_def = ctx.sema.to_def(&type_param)?;
let type_param_def = Definition::GenericParam(hir::GenericParam::TypeParam(type_param_hir_def));

if is_referenced_outside(&ctx.sema, type_param_def, &fn_, ctx.file_id()) {
// get all usage references for the type param
let usage_refs = find_usages(&ctx.sema, &fn_, type_param_def, ctx.file_id());
if usage_refs.is_empty() {
return None;
}

// All usage references need to be valid (inside the function param list)
if !check_valid_usages(&usage_refs, param_list_text_range) {
return None;
}

let mut path_types_to_replace = Vec::new();
for (_a, refs) in usage_refs.iter() {
for usage_ref in refs {
let param_node = find_path_type(&ctx.sema, &type_param_name, &usage_ref.name)?;
path_types_to_replace.push(param_node);
}
}

let target = type_param.syntax().text_range();

acc.add(
AssistId("replace_named_generic_with_impl", AssistKind::RefactorRewrite),
"Replace named generic with impl",
"Replace named generic with impl trait",
target,
|edit| {
let type_param = edit.make_mut(type_param);
let fn_ = edit.make_mut(fn_);

// get all params
let param_types = params
.iter()
.filter_map(|param| match param.ty() {
Some(ast::Type::PathType(param_type)) => Some(edit.make_mut(param_type)),
_ => None,
})
let path_types_to_replace = path_types_to_replace
.into_iter()
.map(|param| edit.make_mut(param))
.collect::<Vec<_>>();

// remove trait from generic param list
if let Some(generic_params) = fn_.generic_param_list() {
generic_params.remove_generic_param(ast::GenericParam::TypeParam(type_param));
if generic_params.generic_params().count() == 0 {
ted::remove(generic_params.syntax());
}
}

// get type bounds in signature type: `P` -> `impl AsRef<Path>`
let new_bounds = impl_trait_type(type_bound_list);
for param_type in param_types.iter().rev() {
ted::replace(param_type.syntax(), new_bounds.clone_for_update().syntax());
for path_type in path_types_to_replace.iter().rev() {
ted::replace(path_type.syntax(), new_bounds.clone_for_update().syntax());
}
},
)
}

fn is_referenced_outside(
fn find_path_type(
sema: &Semantics<'_, RootDatabase>,
type_param_name: &Name,
param: &NameLike,
) -> Option<PathType> {
let path_type =
sema.ancestors_with_macros(param.syntax().clone()).find_map(ast::PathType::cast)?;

// Ignore any path types that look like `P::Assoc`
if path_type.path()?.as_single_name_ref()?.text() != type_param_name.text() {
return None;
}

let ancestors = sema.ancestors_with_macros(path_type.syntax().clone());

let mut in_generic_arg_list = false;
let mut is_associated_type = false;

// walking the ancestors checks them in a heuristic way until the `Fn` node is reached.
for ancestor in ancestors {
match_ast! {
match ancestor {
ast::PathSegment(ps) => {
match ps.kind()? {
ast::PathSegmentKind::Name(_name_ref) => (),
ast::PathSegmentKind::Type { .. } => return None,
_ => return None,
}
},
ast::GenericArgList(_) => {
in_generic_arg_list = true;
},
ast::AssocTypeArg(_) => {
is_associated_type = true;
},
ast::ImplTraitType(_) => {
if in_generic_arg_list && !is_associated_type {
return None;
}
},
ast::DynTraitType(_) => {
if !is_associated_type {
return None;
}
},
ast::Fn(_) => return Some(path_type),
_ => (),
}
}
}

None
}

/// Returns all usage references for the given type parameter definition.
fn find_usages(
sema: &Semantics<'_, RootDatabase>,
type_param_def: Definition,
fn_: &ast::Fn,
type_param_def: Definition,
file_id: FileId,
) -> bool {
// limit search scope to function body & return type
let search_ranges = vec![
fn_.body().map(|body| body.syntax().text_range()),
fn_.ret_type().map(|ret_type| ret_type.syntax().text_range()),
];

search_ranges.into_iter().flatten().any(|search_range| {
let file_range = FileRange { file_id, range: search_range };
!type_param_def.usages(sema).in_scope(SearchScope::file_range(file_range)).all().is_empty()
})
) -> UsageSearchResult {
let file_range = FileRange { file_id, range: fn_.syntax().text_range() };
type_param_def.usages(sema).in_scope(SearchScope::file_range(file_range)).all()
}

fn check_valid_usages(usages: &UsageSearchResult, param_list_range: TextRange) -> bool {
usages
.iter()
.flat_map(|(_, usage_refs)| usage_refs)
.all(|usage_ref| param_list_range.contains_range(usage_ref.range))
}

#[cfg(test)]
Expand Down Expand Up @@ -152,6 +200,96 @@ mod tests {
);
}

#[test]
fn replace_generic_trait_applies_to_generic_arguments_in_params() {
check_assist(
replace_named_generic_with_impl,
r#"
fn foo<P$0: Trait>(
_: P,
_: Option<P>,
_: Option<Option<P>>,
_: impl Iterator<Item = P>,
_: &dyn Iterator<Item = P>,
) {}
"#,
r#"
fn foo(
_: impl Trait,
_: Option<impl Trait>,
_: Option<Option<impl Trait>>,
_: impl Iterator<Item = impl Trait>,
_: &dyn Iterator<Item = impl Trait>,
) {}
"#,
);
}

#[test]
fn replace_generic_not_applicable_when_one_param_type_is_invalid() {
check_assist_not_applicable(
replace_named_generic_with_impl,
r#"
fn foo<P$0: Trait>(
_: i32,
_: Option<P>,
_: Option<Option<P>>,
_: impl Iterator<Item = P>,
_: &dyn Iterator<Item = P>,
_: <P as Trait>::Assoc,
) {}
"#,
);
}

#[test]
fn replace_generic_not_applicable_when_referenced_in_where_clause() {
check_assist_not_applicable(
replace_named_generic_with_impl,
r#"fn foo<P$0: Trait, I>() where I: FromRef<P> {}"#,
);
}

#[test]
fn replace_generic_not_applicable_when_used_with_type_alias() {
check_assist_not_applicable(
replace_named_generic_with_impl,
r#"fn foo<P$0: Trait>(p: <P as Trait>::Assoc) {}"#,
);
}

#[test]
fn replace_generic_not_applicable_when_used_as_argument_in_outer_trait_alias() {
check_assist_not_applicable(
replace_named_generic_with_impl,
r#"fn foo<P$0: Trait>(_: <() as OtherTrait<P>>::Assoc) {}"#,
);
}

#[test]
fn replace_generic_not_applicable_with_inner_associated_type() {
check_assist_not_applicable(
replace_named_generic_with_impl,
r#"fn foo<P$0: Trait>(_: P::Assoc) {}"#,
);
}

#[test]
fn replace_generic_not_applicable_when_passed_into_outer_impl_trait() {
check_assist_not_applicable(
replace_named_generic_with_impl,
r#"fn foo<P$0: Trait>(_: impl OtherTrait<P>) {}"#,
);
}

#[test]
fn replace_generic_not_applicable_when_used_in_passed_function_parameter() {
check_assist_not_applicable(
replace_named_generic_with_impl,
r#"fn foo<P$0: Trait>(_: &dyn Fn(P)) {}"#,
);
}

#[test]
fn replace_generic_with_multiple_generic_params() {
check_assist(
Expand Down