From 2b4a2b95dd1a6b85678537332045d122afda682f Mon Sep 17 00:00:00 2001
From: Michael Goulet <michael@errs.io>
Date: Tue, 30 Jan 2024 14:22:05 +0000
Subject: [PATCH] Check normalized call signature for WF in mir typeck

---
 compiler/rustc_borrowck/src/type_check/mod.rs | 24 ++++++++--
 tests/ui/nll/check-normalized-sig-for-wf.rs   | 27 +++++++++++
 .../ui/nll/check-normalized-sig-for-wf.stderr | 47 +++++++++++++++++++
 3 files changed, 94 insertions(+), 4 deletions(-)
 create mode 100644 tests/ui/nll/check-normalized-sig-for-wf.rs
 create mode 100644 tests/ui/nll/check-normalized-sig-for-wf.stderr

diff --git a/compiler/rustc_borrowck/src/type_check/mod.rs b/compiler/rustc_borrowck/src/type_check/mod.rs
index cfb46f3ac8a96..ae4000f02a7d6 100644
--- a/compiler/rustc_borrowck/src/type_check/mod.rs
+++ b/compiler/rustc_borrowck/src/type_check/mod.rs
@@ -1432,7 +1432,7 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
                         return;
                     }
                 };
-                let (sig, map) = tcx.instantiate_bound_regions(sig, |br| {
+                let (unnormalized_sig, map) = tcx.instantiate_bound_regions(sig, |br| {
                     use crate::renumber::RegionCtxt;
 
                     let region_ctxt_fn = || {
@@ -1454,7 +1454,7 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
                         region_ctxt_fn,
                     )
                 });
-                debug!(?sig);
+                debug!(?unnormalized_sig);
                 // IMPORTANT: We have to prove well formed for the function signature before
                 // we normalize it, as otherwise types like `<&'a &'b () as Trait>::Assoc`
                 // get normalized away, causing us to ignore the `'b: 'a` bound used by the function.
@@ -1464,7 +1464,7 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
                 //
                 // See #91068 for an example.
                 self.prove_predicates(
-                    sig.inputs_and_output.iter().map(|ty| {
+                    unnormalized_sig.inputs_and_output.iter().map(|ty| {
                         ty::Binder::dummy(ty::PredicateKind::Clause(ty::ClauseKind::WellFormed(
                             ty.into(),
                         )))
@@ -1472,7 +1472,23 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
                     term_location.to_locations(),
                     ConstraintCategory::Boring,
                 );
-                let sig = self.normalize(sig, term_location);
+
+                let sig = self.normalize(unnormalized_sig, term_location);
+                // HACK(#114936): `WF(sig)` does not imply `WF(normalized(sig))`
+                // with built-in `Fn` implementations, since the impl may not be
+                // well-formed itself.
+                if sig != unnormalized_sig {
+                    self.prove_predicates(
+                        sig.inputs_and_output.iter().map(|ty| {
+                            ty::Binder::dummy(ty::PredicateKind::Clause(
+                                ty::ClauseKind::WellFormed(ty.into()),
+                            ))
+                        }),
+                        term_location.to_locations(),
+                        ConstraintCategory::Boring,
+                    );
+                }
+
                 self.check_call_dest(body, term, &sig, *destination, *target, term_location);
 
                 // The ordinary liveness rules will ensure that all
diff --git a/tests/ui/nll/check-normalized-sig-for-wf.rs b/tests/ui/nll/check-normalized-sig-for-wf.rs
new file mode 100644
index 0000000000000..cb0f34ce02f7a
--- /dev/null
+++ b/tests/ui/nll/check-normalized-sig-for-wf.rs
@@ -0,0 +1,27 @@
+// <https://github.com/rust-lang/rust/issues/114936>
+fn whoops(
+    s: String,
+    f: impl for<'s> FnOnce(&'s str) -> (&'static str, [&'static &'s (); 0]),
+) -> &'static str
+{
+    f(&s).0
+    //~^ ERROR `s` does not live long enough
+}
+
+// <https://github.com/rust-lang/rust/issues/118876>
+fn extend<T>(input: &T) -> &'static T {
+    struct Bounded<'a, 'b: 'static, T>(&'a T, [&'b (); 0]);
+    let n: Box<dyn FnOnce(&T) -> Bounded<'static, '_, T>> = Box::new(|x| Bounded(x, []));
+    n(input).0
+    //~^ ERROR borrowed data escapes outside of function
+}
+
+// <https://github.com/rust-lang/rust/issues/118876>
+fn extend_mut<'a, T>(input: &'a mut T) -> &'static mut T {
+    struct Bounded<'a, 'b: 'static, T>(&'a mut T, [&'b (); 0]);
+    let mut n: Box<dyn FnMut(&mut T) -> Bounded<'static, '_, T>> = Box::new(|x| Bounded(x, []));
+    n(input).0
+    //~^ ERROR borrowed data escapes outside of function
+}
+
+fn main() {}
diff --git a/tests/ui/nll/check-normalized-sig-for-wf.stderr b/tests/ui/nll/check-normalized-sig-for-wf.stderr
new file mode 100644
index 0000000000000..5c96b0c6561a1
--- /dev/null
+++ b/tests/ui/nll/check-normalized-sig-for-wf.stderr
@@ -0,0 +1,47 @@
+error[E0597]: `s` does not live long enough
+  --> $DIR/check-normalized-sig-for-wf.rs:7:7
+   |
+LL |     s: String,
+   |     - binding `s` declared here
+...
+LL |     f(&s).0
+   |     --^^-
+   |     | |
+   |     | borrowed value does not live long enough
+   |     argument requires that `s` is borrowed for `'static`
+LL |
+LL | }
+   | - `s` dropped here while still borrowed
+
+error[E0521]: borrowed data escapes outside of function
+  --> $DIR/check-normalized-sig-for-wf.rs:15:5
+   |
+LL | fn extend<T>(input: &T) -> &'static T {
+   |              -----  - let's call the lifetime of this reference `'1`
+   |              |
+   |              `input` is a reference that is only valid in the function body
+...
+LL |     n(input).0
+   |     ^^^^^^^^
+   |     |
+   |     `input` escapes the function body here
+   |     argument requires that `'1` must outlive `'static`
+
+error[E0521]: borrowed data escapes outside of function
+  --> $DIR/check-normalized-sig-for-wf.rs:23:5
+   |
+LL | fn extend_mut<'a, T>(input: &'a mut T) -> &'static mut T {
+   |               --     ----- `input` is a reference that is only valid in the function body
+   |               |
+   |               lifetime `'a` defined here
+...
+LL |     n(input).0
+   |     ^^^^^^^^
+   |     |
+   |     `input` escapes the function body here
+   |     argument requires that `'a` must outlive `'static`
+
+error: aborting due to 3 previous errors
+
+Some errors have detailed explanations: E0521, E0597.
+For more information about an error, try `rustc --explain E0521`.