Skip to content

Eliminate ObligationCauseData #91844

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions compiler/rustc_infer/src/infer/error_reporting/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -604,7 +604,7 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> {
exp_found: Option<ty::error::ExpectedFound<Ty<'tcx>>>,
terr: &TypeError<'tcx>,
) {
match cause.code {
match *cause.code() {
ObligationCauseCode::Pattern { origin_expr: true, span: Some(span), root_ty } => {
let ty = self.resolve_vars_if_possible(root_ty);
if ty.is_suggestable() {
Expand Down Expand Up @@ -781,7 +781,7 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> {
}
_ => {
if let ObligationCauseCode::BindingObligation(_, binding_span) =
cause.code.peel_derives()
cause.code().peel_derives()
{
if matches!(terr, TypeError::RegionsPlaceholderMismatch) {
err.span_note(*binding_span, "the lifetime requirement is introduced here");
Expand Down Expand Up @@ -1729,10 +1729,10 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> {
}
_ => exp_found,
};
debug!("exp_found {:?} terr {:?} cause.code {:?}", exp_found, terr, cause.code);
debug!("exp_found {:?} terr {:?} cause.code {:?}", exp_found, terr, cause.code());
if let Some(exp_found) = exp_found {
let should_suggest_fixes = if let ObligationCauseCode::Pattern { root_ty, .. } =
&cause.code
cause.code()
{
// Skip if the root_ty of the pattern is not the same as the expected_ty.
// If these types aren't equal then we've probably peeled off a layer of arrays.
Expand Down Expand Up @@ -1827,15 +1827,15 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> {
exp_span, exp_found.expected, exp_found.found,
);

if let ObligationCauseCode::CompareImplMethodObligation { .. } = &cause.code {
if let ObligationCauseCode::CompareImplMethodObligation { .. } = cause.code() {
return;
}

match (
self.get_impl_future_output_ty(exp_found.expected),
self.get_impl_future_output_ty(exp_found.found),
) {
(Some(exp), Some(found)) if same_type_modulo_infer(exp, found) => match &cause.code {
(Some(exp), Some(found)) if same_type_modulo_infer(exp, found) => match cause.code() {
ObligationCauseCode::IfExpression(box IfExpressionCause { then, .. }) => {
diag.multipart_suggestion(
"consider `await`ing on both `Future`s",
Expand Down Expand Up @@ -1875,7 +1875,7 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> {
Applicability::MaybeIncorrect,
);
}
(Some(ty), _) if same_type_modulo_infer(ty, exp_found.found) => match cause.code {
(Some(ty), _) if same_type_modulo_infer(ty, exp_found.found) => match cause.code() {
ObligationCauseCode::Pattern { span: Some(span), .. }
| ObligationCauseCode::IfExpression(box IfExpressionCause { then: span, .. }) => {
diag.span_suggestion_verbose(
Expand Down Expand Up @@ -1927,7 +1927,7 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> {
.map(|field| (field.ident.name, field.ty(self.tcx, expected_substs)))
.find(|(_, ty)| same_type_modulo_infer(ty, exp_found.found))
{
if let ObligationCauseCode::Pattern { span: Some(span), .. } = cause.code {
if let ObligationCauseCode::Pattern { span: Some(span), .. } = *cause.code() {
if let Ok(snippet) = self.tcx.sess.source_map().span_to_snippet(span) {
let suggestion = if expected_def.is_struct() {
format!("{}.{}", snippet, name)
Expand Down Expand Up @@ -2064,7 +2064,7 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> {
}
}
if let MatchExpressionArm(box MatchExpressionArmCause { source, .. }) =
trace.cause.code
*trace.cause.code()
{
if let hir::MatchSource::TryDesugar = source {
if let Some((expected_ty, found_ty)) = self.values_str(trace.values) {
Expand Down Expand Up @@ -2659,7 +2659,7 @@ impl<'tcx> ObligationCauseExt<'tcx> for ObligationCause<'tcx> {
fn as_failure_code(&self, terr: &TypeError<'tcx>) -> FailureCode {
use self::FailureCode::*;
use crate::traits::ObligationCauseCode::*;
match self.code {
match self.code() {
CompareImplMethodObligation { .. } => Error0308("method not compatible with trait"),
CompareImplTypeObligation { .. } => Error0308("type not compatible with trait"),
MatchExpressionArm(box MatchExpressionArmCause { source, .. }) => {
Expand Down Expand Up @@ -2694,7 +2694,7 @@ impl<'tcx> ObligationCauseExt<'tcx> for ObligationCause<'tcx> {

fn as_requirement_str(&self) -> &'static str {
use crate::traits::ObligationCauseCode::*;
match self.code {
match self.code() {
CompareImplMethodObligation { .. } => "method type is compatible with trait",
CompareImplTypeObligation { .. } => "associated type is compatible with trait",
ExprAssignable => "expression is assignable",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,15 @@ impl<'a, 'tcx> NiceRegionError<'a, 'tcx> {
};
// If we added a "points at argument expression" obligation, we remove it here, we care
// about the original obligation only.
let code = match &cause.code {
let code = match cause.code() {
ObligationCauseCode::FunctionArgumentObligation { parent_code, .. } => &*parent_code,
_ => &cause.code,
_ => cause.code(),
};
let (parent, impl_def_id) = match code {
ObligationCauseCode::MatchImpl(parent, impl_def_id) => (parent, impl_def_id),
_ => return None,
};
let binding_span = match parent.code {
let binding_span = match *parent.code() {
ObligationCauseCode::BindingObligation(_def_id, binding_span) => binding_span,
_ => return None,
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ impl<'tcx> NiceRegionError<'_, 'tcx> {
);
let mut err = self.tcx().sess.struct_span_err(span, &msg);

let leading_ellipsis = if let ObligationCauseCode::ItemObligation(def_id) = cause.code {
let leading_ellipsis = if let ObligationCauseCode::ItemObligation(def_id) = *cause.code() {
err.span_label(span, "doesn't satisfy where-clause");
err.span_label(
self.tcx().def_span(def_id),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ impl<'a, 'tcx> NiceRegionError<'a, 'tcx> {
sup_r,
) if **sub_r == RegionKind::ReStatic => {
// This is for an implicit `'static` requirement coming from `impl dyn Trait {}`.
if let ObligationCauseCode::UnifyReceiver(ctxt) = &cause.code {
if let ObligationCauseCode::UnifyReceiver(ctxt) = cause.code() {
// This may have a closure and it would cause ICE
// through `find_param_with_region` (#78262).
let anon_reg_sup = tcx.is_suitable_region(sup_r)?;
Expand Down Expand Up @@ -184,7 +184,7 @@ impl<'a, 'tcx> NiceRegionError<'a, 'tcx> {
}
if let SubregionOrigin::Subtype(box TypeTrace { cause, .. }) = sub_origin {
if let ObligationCauseCode::ReturnValue(hir_id)
| ObligationCauseCode::BlockTailExpression(hir_id) = &cause.code
| ObligationCauseCode::BlockTailExpression(hir_id) = cause.code()
{
let parent_id = tcx.hir().get_parent_item(*hir_id);
if let Some(fn_decl) = tcx.hir().fn_decl_by_hir_id(parent_id) {
Expand Down Expand Up @@ -226,7 +226,7 @@ impl<'a, 'tcx> NiceRegionError<'a, 'tcx> {

let mut override_error_code = None;
if let SubregionOrigin::Subtype(box TypeTrace { cause, .. }) = &sup_origin {
if let ObligationCauseCode::UnifyReceiver(ctxt) = &cause.code {
if let ObligationCauseCode::UnifyReceiver(ctxt) = cause.code() {
// Handle case of `impl Foo for dyn Bar { fn qux(&self) {} }` introducing a
// `'static` lifetime when called as a method on a binding: `bar.qux()`.
if self.find_impl_on_dyn_trait(&mut err, param.param_ty, &ctxt) {
Expand All @@ -235,9 +235,9 @@ impl<'a, 'tcx> NiceRegionError<'a, 'tcx> {
}
}
if let SubregionOrigin::Subtype(box TypeTrace { cause, .. }) = &sub_origin {
let code = match &cause.code {
ObligationCauseCode::MatchImpl(parent, ..) => &parent.code,
_ => &cause.code,
let code = match cause.code() {
ObligationCauseCode::MatchImpl(parent, ..) => parent.code(),
_ => cause.code(),
};
if let (ObligationCauseCode::ItemObligation(item_def_id), None) =
(code, override_error_code)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ impl<'a, 'tcx> NiceRegionError<'a, 'tcx> {
ValuePairs::Types(sub_expected_found),
ValuePairs::Types(sup_expected_found),
CompareImplMethodObligation { trait_item_def_id, .. },
) = (&sub_trace.values, &sup_trace.values, &sub_trace.cause.code)
) = (&sub_trace.values, &sup_trace.values, sub_trace.cause.code())
{
if sup_expected_found == sub_expected_found {
self.emit_err(
Expand Down
4 changes: 2 additions & 2 deletions compiler/rustc_infer/src/infer/error_reporting/note.rs
Original file line number Diff line number Diff line change
Expand Up @@ -359,13 +359,13 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> {
match placeholder_origin {
infer::Subtype(box ref trace)
if matches!(
&trace.cause.code.peel_derives(),
&trace.cause.code().peel_derives(),
ObligationCauseCode::BindingObligation(..)
) =>
{
// Hack to get around the borrow checker because trace.cause has an `Rc`.
if let ObligationCauseCode::BindingObligation(_, span) =
&trace.cause.code.peel_derives()
&trace.cause.code().peel_derives()
{
let span = *span;
let mut err = self.report_concrete_failure(placeholder_origin, sub, sup);
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_infer/src/infer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1824,7 +1824,7 @@ impl<'tcx> SubregionOrigin<'tcx> {
where
F: FnOnce() -> Self,
{
match cause.code {
match *cause.code() {
traits::ObligationCauseCode::ReferenceOutlivesReferent(ref_type) => {
SubregionOrigin::ReferenceOutlivesReferent(ref_type, cause.span)
}
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_infer/src/infer/outlives/obligations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ impl<'cx, 'tcx> InferCtxt<'cx, 'tcx> {
infer::RelateParamBound(
cause.span,
sup_type,
match cause.code.peel_derives() {
match cause.code().peel_derives() {
ObligationCauseCode::BindingObligation(_, span) => Some(*span),
_ => None,
},
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_infer/src/traits/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ impl TraitObligation<'_> {

// `PredicateObligation` is used a lot. Make sure it doesn't unintentionally get bigger.
#[cfg(all(target_arch = "x86_64", target_pointer_width = "64"))]
static_assert_size!(PredicateObligation<'_>, 32);
static_assert_size!(PredicateObligation<'_>, 48);

pub type PredicateObligations<'tcx> = Vec<PredicateObligation<'tcx>>;

Expand Down
82 changes: 40 additions & 42 deletions compiler/rustc_middle/src/traits/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@ use rustc_span::{Span, DUMMY_SP};
use smallvec::SmallVec;

use std::borrow::Cow;
use std::fmt;
use std::hash::{Hash, Hasher};
use std::ops::Deref;

pub use self::select::{EvaluationCache, EvaluationResult, OverflowError, SelectionCache};

Expand Down Expand Up @@ -80,38 +78,14 @@ pub enum Reveal {

/// The reason why we incurred this obligation; used for error reporting.
///
/// As the happy path does not care about this struct, storing this on the heap
/// ends up increasing performance.
/// Non-misc `ObligationCauseCode`s are stored on the heap. This gives the
/// best trade-off between keeping the type small (which makes copies cheaper)
/// while not doing too many heap allocations.
///
/// We do not want to intern this as there are a lot of obligation causes which
/// only live for a short period of time.
#[derive(Clone, PartialEq, Eq, Hash, Lift)]
pub struct ObligationCause<'tcx> {
/// `None` for `ObligationCause::dummy`, `Some` otherwise.
data: Option<Lrc<ObligationCauseData<'tcx>>>,
}

const DUMMY_OBLIGATION_CAUSE_DATA: ObligationCauseData<'static> =
ObligationCauseData { span: DUMMY_SP, body_id: hir::CRATE_HIR_ID, code: MiscObligation };

// Correctly format `ObligationCause::dummy`.
impl<'tcx> fmt::Debug for ObligationCause<'tcx> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
ObligationCauseData::fmt(self, f)
}
}

impl<'tcx> Deref for ObligationCause<'tcx> {
type Target = ObligationCauseData<'tcx>;

#[inline(always)]
fn deref(&self) -> &Self::Target {
self.data.as_deref().unwrap_or(&DUMMY_OBLIGATION_CAUSE_DATA)
}
}

#[derive(Clone, Debug, PartialEq, Eq, Lift)]
pub struct ObligationCauseData<'tcx> {
pub struct ObligationCause<'tcx> {
pub span: Span,

/// The ID of the fn body that triggered this obligation. This is
Expand All @@ -122,46 +96,58 @@ pub struct ObligationCauseData<'tcx> {
/// information.
pub body_id: hir::HirId,

pub code: ObligationCauseCode<'tcx>,
/// `None` for `MISC_OBLIGATION_CAUSE_CODE` (a common case, occurs ~60% of
/// the time). `Some` otherwise.
code: Option<Lrc<ObligationCauseCode<'tcx>>>,
}

impl Hash for ObligationCauseData<'_> {
// This custom hash function speeds up hashing for `Obligation` deduplication
// greatly by skipping the `code` field, which can be large and complex. That
// shouldn't affect hash quality much since there are several other fields in
// `Obligation` which should be unique enough, especially the predicate itself
// which is hashed as an interned pointer. See #90996.
impl Hash for ObligationCause<'_> {
fn hash<H: Hasher>(&self, state: &mut H) {
self.body_id.hash(state);
self.span.hash(state);
std::mem::discriminant(&self.code).hash(state);
}
}

const MISC_OBLIGATION_CAUSE_CODE: ObligationCauseCode<'static> = MiscObligation;

impl<'tcx> ObligationCause<'tcx> {
#[inline]
pub fn new(
span: Span,
body_id: hir::HirId,
code: ObligationCauseCode<'tcx>,
) -> ObligationCause<'tcx> {
ObligationCause { data: Some(Lrc::new(ObligationCauseData { span, body_id, code })) }
ObligationCause {
span,
body_id,
code: if code == MISC_OBLIGATION_CAUSE_CODE { None } else { Some(Lrc::new(code)) },
}
}

pub fn misc(span: Span, body_id: hir::HirId) -> ObligationCause<'tcx> {
ObligationCause::new(span, body_id, MiscObligation)
}

pub fn dummy_with_span(span: Span) -> ObligationCause<'tcx> {
ObligationCause::new(span, hir::CRATE_HIR_ID, MiscObligation)
}

#[inline(always)]
pub fn dummy() -> ObligationCause<'tcx> {
ObligationCause { data: None }
ObligationCause { span: DUMMY_SP, body_id: hir::CRATE_HIR_ID, code: None }
}

pub fn dummy_with_span(span: Span) -> ObligationCause<'tcx> {
ObligationCause { span, body_id: hir::CRATE_HIR_ID, code: None }
}

pub fn make_mut(&mut self) -> &mut ObligationCauseData<'tcx> {
Lrc::make_mut(self.data.get_or_insert_with(|| Lrc::new(DUMMY_OBLIGATION_CAUSE_DATA)))
pub fn make_mut_code(&mut self) -> &mut ObligationCauseCode<'tcx> {
Lrc::make_mut(self.code.get_or_insert_with(|| Lrc::new(MISC_OBLIGATION_CAUSE_CODE)))
}

pub fn span(&self, tcx: TyCtxt<'tcx>) -> Span {
match self.code {
match *self.code() {
ObligationCauseCode::CompareImplMethodObligation { .. }
| ObligationCauseCode::MainFunctionType
| ObligationCauseCode::StartFunctionType => {
Expand All @@ -174,6 +160,18 @@ impl<'tcx> ObligationCause<'tcx> {
_ => self.span,
}
}

#[inline]
pub fn code(&self) -> &ObligationCauseCode<'tcx> {
self.code.as_deref().unwrap_or(&MISC_OBLIGATION_CAUSE_CODE)
}

pub fn clone_code(&self) -> Lrc<ObligationCauseCode<'tcx>> {
match &self.code {
Some(code) => code.clone(),
None => Lrc::new(MISC_OBLIGATION_CAUSE_CODE),
}
}
}

#[derive(Clone, Debug, PartialEq, Eq, Hash, Lift)]
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_middle/src/ty/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@ impl<T> Trait<T> for X {
proj_ty,
values,
body_owner_def_id,
&cause.code,
cause.code(),
);
}
(_, ty::Projection(proj_ty)) => {
Expand Down
Loading