diff --git a/compiler/rustc_middle/src/ty/fold.rs b/compiler/rustc_middle/src/ty/fold.rs index 4c7db4e803b8e..a6a1d1f73bb62 100644 --- a/compiler/rustc_middle/src/ty/fold.rs +++ b/compiler/rustc_middle/src/ty/fold.rs @@ -439,18 +439,18 @@ struct BoundVarReplacer<'a, 'tcx> { /// the ones we have visited. current_index: ty::DebruijnIndex, - fld_r: &'a mut (dyn FnMut(ty::BoundRegion) -> ty::Region<'tcx> + 'a), - fld_t: &'a mut (dyn FnMut(ty::BoundTy) -> Ty<'tcx> + 'a), - fld_c: &'a mut (dyn FnMut(ty::BoundVar, Ty<'tcx>) -> &'tcx ty::Const<'tcx> + 'a), + fld_r: Option<&'a mut (dyn FnMut(ty::BoundRegion) -> ty::Region<'tcx> + 'a)>, + fld_t: Option<&'a mut (dyn FnMut(ty::BoundTy) -> Ty<'tcx> + 'a)>, + fld_c: Option<&'a mut (dyn FnMut(ty::BoundVar, Ty<'tcx>) -> &'tcx ty::Const<'tcx> + 'a)>, } impl<'a, 'tcx> BoundVarReplacer<'a, 'tcx> { - fn new(tcx: TyCtxt<'tcx>, fld_r: &'a mut F, fld_t: &'a mut G, fld_c: &'a mut H) -> Self - where - F: FnMut(ty::BoundRegion) -> ty::Region<'tcx>, - G: FnMut(ty::BoundTy) -> Ty<'tcx>, - H: FnMut(ty::BoundVar, Ty<'tcx>) -> &'tcx ty::Const<'tcx>, - { + fn new( + tcx: TyCtxt<'tcx>, + fld_r: Option<&'a mut (dyn FnMut(ty::BoundRegion) -> ty::Region<'tcx> + 'a)>, + fld_t: Option<&'a mut (dyn FnMut(ty::BoundTy) -> Ty<'tcx> + 'a)>, + fld_c: Option<&'a mut (dyn FnMut(ty::BoundVar, Ty<'tcx>) -> &'tcx ty::Const<'tcx> + 'a)>, + ) -> Self { BoundVarReplacer { tcx, current_index: ty::INNERMOST, fld_r, fld_t, fld_c } } } @@ -469,63 +469,58 @@ impl<'a, 'tcx> TypeFolder<'tcx> for BoundVarReplacer<'a, 'tcx> { fn fold_ty(&mut self, t: Ty<'tcx>) -> Ty<'tcx> { match *t.kind() { - ty::Bound(debruijn, bound_ty) => { - if debruijn == self.current_index { - let fld_t = &mut self.fld_t; + ty::Bound(debruijn, bound_ty) if debruijn == self.current_index => { + if let Some(fld_t) = self.fld_t.as_mut() { let ty = fld_t(bound_ty); - ty::fold::shift_vars(self.tcx, &ty, self.current_index.as_u32()) - } else { - t + return ty::fold::shift_vars(self.tcx, &ty, self.current_index.as_u32()); } } - _ => { - if !t.has_vars_bound_at_or_above(self.current_index) { - // Nothing more to substitute. - t - } else { - t.super_fold_with(self) - } + _ if t.has_vars_bound_at_or_above(self.current_index) => { + return t.super_fold_with(self); } + _ => {} } + t } fn fold_region(&mut self, r: ty::Region<'tcx>) -> ty::Region<'tcx> { match *r { ty::ReLateBound(debruijn, br) if debruijn == self.current_index => { - let fld_r = &mut self.fld_r; - let region = fld_r(br); - if let ty::ReLateBound(debruijn1, br) = *region { - // If the callback returns a late-bound region, - // that region should always use the INNERMOST - // debruijn index. Then we adjust it to the - // correct depth. - assert_eq!(debruijn1, ty::INNERMOST); - self.tcx.mk_region(ty::ReLateBound(debruijn, br)) - } else { - region + if let Some(fld_r) = self.fld_r.as_mut() { + let region = fld_r(br); + return if let ty::ReLateBound(debruijn1, br) = *region { + // If the callback returns a late-bound region, + // that region should always use the INNERMOST + // debruijn index. Then we adjust it to the + // correct depth. + assert_eq!(debruijn1, ty::INNERMOST); + self.tcx.mk_region(ty::ReLateBound(debruijn, br)) + } else { + region + }; } } - _ => r, + _ => {} } + r } fn fold_const(&mut self, ct: &'tcx ty::Const<'tcx>) -> &'tcx ty::Const<'tcx> { - if let ty::Const { val: ty::ConstKind::Bound(debruijn, bound_const), ty } = *ct { - if debruijn == self.current_index { - let fld_c = &mut self.fld_c; - let ct = fld_c(bound_const, ty); - ty::fold::shift_vars(self.tcx, &ct, self.current_index.as_u32()) - } else { - ct + match *ct { + ty::Const { val: ty::ConstKind::Bound(debruijn, bound_const), ty } + if debruijn == self.current_index => + { + if let Some(fld_c) = self.fld_c.as_mut() { + let ct = fld_c(bound_const, ty); + return ty::fold::shift_vars(self.tcx, &ct, self.current_index.as_u32()); + } } - } else { - if !ct.has_vars_bound_at_or_above(self.current_index) { - // Nothing more to substitute. - ct - } else { - ct.super_fold_with(self) + _ if ct.has_vars_bound_at_or_above(self.current_index) => { + return ct.super_fold_with(self); } + _ => {} } + ct } } @@ -550,14 +545,16 @@ impl<'tcx> TyCtxt<'tcx> { F: FnMut(ty::BoundRegion) -> ty::Region<'tcx>, T: TypeFoldable<'tcx>, { - // identity for bound types and consts - let fld_t = |bound_ty| self.mk_ty(ty::Bound(ty::INNERMOST, bound_ty)); - let fld_c = |bound_ct, ty| { - self.mk_const(ty::Const { val: ty::ConstKind::Bound(ty::INNERMOST, bound_ct), ty }) - }; let mut region_map = BTreeMap::new(); - let real_fld_r = |br: ty::BoundRegion| *region_map.entry(br).or_insert_with(|| fld_r(br)); - let value = self.replace_escaping_bound_vars(value.skip_binder(), real_fld_r, fld_t, fld_c); + let mut real_fld_r = + |br: ty::BoundRegion| *region_map.entry(br).or_insert_with(|| fld_r(br)); + let value = value.skip_binder(); + let value = if !value.has_escaping_bound_vars() { + value + } else { + let mut replacer = BoundVarReplacer::new(self, Some(&mut real_fld_r), None, None); + value.fold_with(&mut replacer) + }; (value, region_map) } @@ -580,7 +577,8 @@ impl<'tcx> TyCtxt<'tcx> { if !value.has_escaping_bound_vars() { value } else { - let mut replacer = BoundVarReplacer::new(self, &mut fld_r, &mut fld_t, &mut fld_c); + let mut replacer = + BoundVarReplacer::new(self, Some(&mut fld_r), Some(&mut fld_t), Some(&mut fld_c)); value.fold_with(&mut replacer) } }