Skip to content

Commit 15e2e51

Browse files
authored
Rollup merge of #100473 - compiler-errors:normalize-the-fn-def-sig-plz, r=lcnr
Attempt to normalize `FnDef` signature in `InferCtxt::cmp` Stashes a normalization callback in `InferCtxt` so that the signature we get from `tcx.fn_sig(..).subst(..)` in `InferCtxt::cmp` can be properly normalized, since we cannot expect for it to have normalized types since it comes straight from astconv. This is kind of a hack, but I will say that `@jyn514` found the fact that we present unnormalized types to be very confusing in real life code, and I agree with that feeling. Though altogether I am still a bit unsure about whether this PR is worth the effort, so I'm open to alternatives and/or just closing it outright. On the other hand, this isn't a ridiculously heavy implementation anyways -- it's less than a hundred lines of changes, and half of that is just miscellaneous cleanup. This is stacked onto #100471 which is basically unrelated, and it can be rebased off of that when that lands or if needed. --- The code: ```rust trait Foo { type Bar; } impl<T> Foo for T { type Bar = i32; } fn foo<T>(_: <T as Foo>::Bar) {} fn needs_i32_ref_fn(f: fn(&'static i32)) {} fn main() { needs_i32_ref_fn(foo::<()>); } ``` Before: ``` = note: expected fn pointer `fn(&'static i32)` found fn item `fn(<() as Foo>::Bar) {foo::<()>}` ``` After: ``` = note: expected fn pointer `fn(&'static i32)` found fn item `fn(i32) {foo::<()>}` ```
2 parents 9cfd161 + e5602cb commit 15e2e51

File tree

8 files changed

+125
-14
lines changed

8 files changed

+125
-14
lines changed

compiler/rustc_infer/src/infer/at.rs

+4
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,10 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> {
7878
err_count_on_creation: self.err_count_on_creation,
7979
in_snapshot: self.in_snapshot.clone(),
8080
universe: self.universe.clone(),
81+
normalize_fn_sig_for_diagnostic: self
82+
.normalize_fn_sig_for_diagnostic
83+
.as_ref()
84+
.map(|f| f.clone()),
8185
}
8286
}
8387
}

compiler/rustc_infer/src/infer/error_reporting/mod.rs

+11
Original file line numberDiff line numberDiff line change
@@ -961,12 +961,23 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> {
961961
}
962962
}
963963

964+
fn normalize_fn_sig_for_diagnostic(&self, sig: ty::PolyFnSig<'tcx>) -> ty::PolyFnSig<'tcx> {
965+
if let Some(normalize) = &self.normalize_fn_sig_for_diagnostic {
966+
normalize(self, sig)
967+
} else {
968+
sig
969+
}
970+
}
971+
964972
/// Given two `fn` signatures highlight only sub-parts that are different.
965973
fn cmp_fn_sig(
966974
&self,
967975
sig1: &ty::PolyFnSig<'tcx>,
968976
sig2: &ty::PolyFnSig<'tcx>,
969977
) -> (DiagnosticStyledString, DiagnosticStyledString) {
978+
let sig1 = &self.normalize_fn_sig_for_diagnostic(*sig1);
979+
let sig2 = &self.normalize_fn_sig_for_diagnostic(*sig2);
980+
970981
let get_lifetimes = |sig| {
971982
use rustc_hir::def::Namespace;
972983
let (_, sig, reg) = ty::print::FmtPrinter::new(self.tcx, Namespace::TypeNS)

compiler/rustc_infer/src/infer/mod.rs

+18
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,9 @@ pub struct InferCtxt<'a, 'tcx> {
337337
/// when we enter into a higher-ranked (`for<..>`) type or trait
338338
/// bound.
339339
universe: Cell<ty::UniverseIndex>,
340+
341+
normalize_fn_sig_for_diagnostic:
342+
Option<Lrc<dyn Fn(&InferCtxt<'_, 'tcx>, ty::PolyFnSig<'tcx>) -> ty::PolyFnSig<'tcx>>>,
340343
}
341344

342345
/// See the `error_reporting` module for more details.
@@ -540,6 +543,8 @@ pub struct InferCtxtBuilder<'tcx> {
540543
defining_use_anchor: DefiningAnchor,
541544
considering_regions: bool,
542545
fresh_typeck_results: Option<RefCell<ty::TypeckResults<'tcx>>>,
546+
normalize_fn_sig_for_diagnostic:
547+
Option<Lrc<dyn Fn(&InferCtxt<'_, 'tcx>, ty::PolyFnSig<'tcx>) -> ty::PolyFnSig<'tcx>>>,
543548
}
544549

545550
pub trait TyCtxtInferExt<'tcx> {
@@ -553,6 +558,7 @@ impl<'tcx> TyCtxtInferExt<'tcx> for TyCtxt<'tcx> {
553558
defining_use_anchor: DefiningAnchor::Error,
554559
considering_regions: true,
555560
fresh_typeck_results: None,
561+
normalize_fn_sig_for_diagnostic: None,
556562
}
557563
}
558564
}
@@ -582,6 +588,14 @@ impl<'tcx> InferCtxtBuilder<'tcx> {
582588
self
583589
}
584590

591+
pub fn with_normalize_fn_sig_for_diagnostic(
592+
mut self,
593+
fun: Lrc<dyn Fn(&InferCtxt<'_, 'tcx>, ty::PolyFnSig<'tcx>) -> ty::PolyFnSig<'tcx>>,
594+
) -> Self {
595+
self.normalize_fn_sig_for_diagnostic = Some(fun);
596+
self
597+
}
598+
585599
/// Given a canonical value `C` as a starting point, create an
586600
/// inference context that contains each of the bound values
587601
/// within instantiated as a fresh variable. The `f` closure is
@@ -611,6 +625,7 @@ impl<'tcx> InferCtxtBuilder<'tcx> {
611625
defining_use_anchor,
612626
considering_regions,
613627
ref fresh_typeck_results,
628+
ref normalize_fn_sig_for_diagnostic,
614629
} = *self;
615630
let in_progress_typeck_results = fresh_typeck_results.as_ref();
616631
f(InferCtxt {
@@ -629,6 +644,9 @@ impl<'tcx> InferCtxtBuilder<'tcx> {
629644
in_snapshot: Cell::new(false),
630645
skip_leak_check: Cell::new(false),
631646
universe: Cell::new(ty::UniverseIndex::ROOT),
647+
normalize_fn_sig_for_diagnostic: normalize_fn_sig_for_diagnostic
648+
.as_ref()
649+
.map(|f| f.clone()),
632650
})
633651
}
634652
}

compiler/rustc_trait_selection/src/traits/engine.rs

+13
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ use rustc_span::Span;
1717

1818
pub trait TraitEngineExt<'tcx> {
1919
fn new(tcx: TyCtxt<'tcx>) -> Box<Self>;
20+
fn new_in_snapshot(tcx: TyCtxt<'tcx>) -> Box<Self>;
2021
}
2122

2223
impl<'tcx> TraitEngineExt<'tcx> for dyn TraitEngine<'tcx> {
@@ -27,6 +28,14 @@ impl<'tcx> TraitEngineExt<'tcx> for dyn TraitEngine<'tcx> {
2728
Box::new(FulfillmentContext::new())
2829
}
2930
}
31+
32+
fn new_in_snapshot(tcx: TyCtxt<'tcx>) -> Box<Self> {
33+
if tcx.sess.opts.unstable_opts.chalk {
34+
Box::new(ChalkFulfillmentContext::new())
35+
} else {
36+
Box::new(FulfillmentContext::new_in_snapshot())
37+
}
38+
}
3039
}
3140

3241
/// Used if you want to have pleasant experience when dealing
@@ -41,6 +50,10 @@ impl<'a, 'tcx> ObligationCtxt<'a, 'tcx> {
4150
Self { infcx, engine: RefCell::new(<dyn TraitEngine<'_>>::new(infcx.tcx)) }
4251
}
4352

53+
pub fn new_in_snapshot(infcx: &'a InferCtxt<'a, 'tcx>) -> Self {
54+
Self { infcx, engine: RefCell::new(<dyn TraitEngine<'_>>::new_in_snapshot(infcx.tcx)) }
55+
}
56+
4457
pub fn register_obligation(&self, obligation: PredicateObligation<'tcx>) {
4558
self.engine.borrow_mut().register_predicate_obligation(self.infcx, obligation);
4659
}

compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs

+17-12
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ use rustc_hir::def_id::DefId;
2020
use rustc_hir::intravisit::Visitor;
2121
use rustc_hir::lang_items::LangItem;
2222
use rustc_hir::{AsyncGeneratorKind, GeneratorKind, Node};
23-
use rustc_infer::infer::TyCtxtInferExt;
23+
use rustc_infer::infer::type_variable::{TypeVariableOrigin, TypeVariableOriginKind};
2424
use rustc_middle::hir::map;
2525
use rustc_middle::ty::{
2626
self, suggest_arbitrary_trait_bound, suggest_constraining_type_param, AdtKind, DefIdTree,
@@ -1589,32 +1589,38 @@ impl<'a, 'tcx> InferCtxtExt<'tcx> for InferCtxt<'a, 'tcx> {
15891589
expected: ty::PolyTraitRef<'tcx>,
15901590
) -> DiagnosticBuilder<'tcx, ErrorGuaranteed> {
15911591
pub(crate) fn build_fn_sig_ty<'tcx>(
1592-
tcx: TyCtxt<'tcx>,
1592+
infcx: &InferCtxt<'_, 'tcx>,
15931593
trait_ref: ty::PolyTraitRef<'tcx>,
15941594
) -> Ty<'tcx> {
15951595
let inputs = trait_ref.skip_binder().substs.type_at(1);
15961596
let sig = match inputs.kind() {
15971597
ty::Tuple(inputs)
1598-
if tcx.fn_trait_kind_from_lang_item(trait_ref.def_id()).is_some() =>
1598+
if infcx.tcx.fn_trait_kind_from_lang_item(trait_ref.def_id()).is_some() =>
15991599
{
1600-
tcx.mk_fn_sig(
1600+
infcx.tcx.mk_fn_sig(
16011601
inputs.iter(),
1602-
tcx.mk_ty_infer(ty::TyVar(ty::TyVid::from_u32(0))),
1602+
infcx.next_ty_var(TypeVariableOrigin {
1603+
span: DUMMY_SP,
1604+
kind: TypeVariableOriginKind::MiscVariable,
1605+
}),
16031606
false,
16041607
hir::Unsafety::Normal,
16051608
abi::Abi::Rust,
16061609
)
16071610
}
1608-
_ => tcx.mk_fn_sig(
1611+
_ => infcx.tcx.mk_fn_sig(
16091612
std::iter::once(inputs),
1610-
tcx.mk_ty_infer(ty::TyVar(ty::TyVid::from_u32(0))),
1613+
infcx.next_ty_var(TypeVariableOrigin {
1614+
span: DUMMY_SP,
1615+
kind: TypeVariableOriginKind::MiscVariable,
1616+
}),
16111617
false,
16121618
hir::Unsafety::Normal,
16131619
abi::Abi::Rust,
16141620
),
16151621
};
16161622

1617-
tcx.mk_fn_ptr(trait_ref.rebind(sig))
1623+
infcx.tcx.mk_fn_ptr(trait_ref.rebind(sig))
16181624
}
16191625

16201626
let argument_kind = match expected.skip_binder().self_ty().kind() {
@@ -1634,11 +1640,10 @@ impl<'a, 'tcx> InferCtxtExt<'tcx> for InferCtxt<'a, 'tcx> {
16341640
let found_span = found_span.unwrap_or(span);
16351641
err.span_label(found_span, "found signature defined here");
16361642

1637-
let expected = build_fn_sig_ty(self.tcx, expected);
1638-
let found = build_fn_sig_ty(self.tcx, found);
1643+
let expected = build_fn_sig_ty(self, expected);
1644+
let found = build_fn_sig_ty(self, found);
16391645

1640-
let (expected_str, found_str) =
1641-
self.tcx.infer_ctxt().enter(|infcx| infcx.cmp(expected, found));
1646+
let (expected_str, found_str) = self.cmp(expected, found);
16421647

16431648
let signature_kind = format!("{argument_kind} signature");
16441649
err.note_expected_found(&signature_kind, expected_str, &signature_kind, found_str);

compiler/rustc_typeck/src/check/inherited.rs

+27-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use super::callee::DeferredCallResolution;
22

33
use rustc_data_structures::fx::FxHashSet;
4+
use rustc_data_structures::sync::Lrc;
45
use rustc_hir as hir;
56
use rustc_hir::def_id::LocalDefId;
67
use rustc_hir::HirIdMap;
@@ -12,7 +13,9 @@ use rustc_middle::ty::{self, Ty, TyCtxt};
1213
use rustc_span::def_id::LocalDefIdMap;
1314
use rustc_span::{self, Span};
1415
use rustc_trait_selection::infer::InferCtxtExt as _;
15-
use rustc_trait_selection::traits::{self, ObligationCause, TraitEngine, TraitEngineExt};
16+
use rustc_trait_selection::traits::{
17+
self, ObligationCause, ObligationCtxt, TraitEngine, TraitEngineExt as _,
18+
};
1619

1720
use std::cell::RefCell;
1821
use std::ops::Deref;
@@ -84,7 +87,29 @@ impl<'tcx> Inherited<'_, 'tcx> {
8487
infcx: tcx
8588
.infer_ctxt()
8689
.ignoring_regions()
87-
.with_fresh_in_progress_typeck_results(hir_owner),
90+
.with_fresh_in_progress_typeck_results(hir_owner)
91+
.with_normalize_fn_sig_for_diagnostic(Lrc::new(move |infcx, fn_sig| {
92+
if fn_sig.has_escaping_bound_vars() {
93+
return fn_sig;
94+
}
95+
infcx.probe(|_| {
96+
let ocx = ObligationCtxt::new_in_snapshot(infcx);
97+
let normalized_fn_sig = ocx.normalize(
98+
ObligationCause::dummy(),
99+
// FIXME(compiler-errors): This is probably not the right param-env...
100+
infcx.tcx.param_env(def_id),
101+
fn_sig,
102+
);
103+
if ocx.select_all_or_error().is_empty() {
104+
let normalized_fn_sig =
105+
infcx.resolve_vars_if_possible(normalized_fn_sig);
106+
if !normalized_fn_sig.needs_infer() {
107+
return normalized_fn_sig;
108+
}
109+
}
110+
fn_sig
111+
})
112+
})),
88113
def_id,
89114
}
90115
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
trait Foo {
2+
type Bar;
3+
}
4+
5+
impl<T> Foo for T {
6+
type Bar = i32;
7+
}
8+
9+
fn foo<T>(_: <T as Foo>::Bar, _: &'static <T as Foo>::Bar) {}
10+
11+
fn needs_i32_ref_fn(_: fn(&'static i32, i32)) {}
12+
13+
fn main() {
14+
needs_i32_ref_fn(foo::<()>);
15+
//~^ ERROR mismatched types
16+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
error[E0308]: mismatched types
2+
--> $DIR/normalize-fn-sig.rs:14:22
3+
|
4+
LL | needs_i32_ref_fn(foo::<()>);
5+
| ---------------- ^^^^^^^^^ expected `&i32`, found `i32`
6+
| |
7+
| arguments to this function are incorrect
8+
|
9+
= note: expected fn pointer `fn(&'static i32, i32)`
10+
found fn item `fn(i32, &'static i32) {foo::<()>}`
11+
note: function defined here
12+
--> $DIR/normalize-fn-sig.rs:11:4
13+
|
14+
LL | fn needs_i32_ref_fn(_: fn(&'static i32, i32)) {}
15+
| ^^^^^^^^^^^^^^^^ ------------------------
16+
17+
error: aborting due to previous error
18+
19+
For more information about this error, try `rustc --explain E0308`.

0 commit comments

Comments
 (0)