Skip to content

Special-case deriving PartialOrd for enums with dataless variants #103659

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 73 additions & 9 deletions compiler/rustc_builtin_macros/src/deriving/cmp/partial_ord.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::deriving::generic::ty::*;
use crate::deriving::generic::*;
use crate::deriving::{path_std, pathvec_std};
use rustc_ast::MetaItem;
use rustc_ast::{ExprKind, ItemKind, MetaItem, PatKind};
use rustc_expand::base::{Annotatable, ExtCtxt};
use rustc_span::symbol::{sym, Ident};
use rustc_span::Span;
Expand All @@ -21,6 +21,27 @@ pub fn expand_deriving_partial_ord(

let attrs = thin_vec![cx.attr_word(sym::inline, span)];

// Order in which to perform matching
let tag_then_data = if let Annotatable::Item(item) = item
&& let ItemKind::Enum(def, _) = &item.kind {
let dataful: Vec<bool> = def.variants.iter().map(|v| !v.data.fields().is_empty()).collect();
match dataful.iter().filter(|&&b| b).count() {
// No data, placing the tag check first makes codegen simpler
0 => true,
1..=2 => false,
_ => {
(0..dataful.len()-1).any(|i| {
if dataful[i] && let Some(idx) = dataful[i+1..].iter().position(|v| *v) {
idx >= 2
} else {
false
}
})
}
}
} else {
true
};
let partial_cmp_def = MethodDef {
name: sym::partial_cmp,
generics: Bounds::empty(),
Expand All @@ -30,7 +51,7 @@ pub fn expand_deriving_partial_ord(
attributes: attrs,
unify_fieldless_variants: true,
combine_substructure: combine_substructure(Box::new(|cx, span, substr| {
cs_partial_cmp(cx, span, substr)
cs_partial_cmp(cx, span, substr, tag_then_data)
})),
};

Expand All @@ -47,7 +68,12 @@ pub fn expand_deriving_partial_ord(
trait_def.expand(cx, mitem, item, push)
}

pub fn cs_partial_cmp(cx: &mut ExtCtxt<'_>, span: Span, substr: &Substructure<'_>) -> BlockOrExpr {
fn cs_partial_cmp(
cx: &mut ExtCtxt<'_>,
span: Span,
substr: &Substructure<'_>,
tag_then_data: bool,
) -> BlockOrExpr {
let test_id = Ident::new(sym::cmp, span);
let equal_path = cx.path_global(span, cx.std_path(&[sym::cmp, sym::Ordering, sym::Equal]));
let partial_cmp_path = cx.std_path(&[sym::cmp, sym::PartialOrd, sym::partial_cmp]);
Expand All @@ -74,12 +100,50 @@ pub fn cs_partial_cmp(cx: &mut ExtCtxt<'_>, span: Span, substr: &Substructure<'_
let args = vec![field.self_expr.clone(), other_expr.clone()];
cx.expr_call_global(field.span, partial_cmp_path.clone(), args)
}
CsFold::Combine(span, expr1, expr2) => {
let eq_arm =
cx.arm(span, cx.pat_some(span, cx.pat_path(span, equal_path.clone())), expr1);
let neq_arm =
cx.arm(span, cx.pat_ident(span, test_id), cx.expr_ident(span, test_id));
cx.expr_match(span, expr2, vec![eq_arm, neq_arm])
CsFold::Combine(span, mut expr1, expr2) => {
// When the item is an enum, this expands to
// ```
// match (expr2) {
// Some(Ordering::Equal) => expr1,
// cmp => cmp
// }
// ```
// where `expr2` is `partial_cmp(self_tag, other_tag)`, and `expr1` is a `match`
// against the enum variants. This means that we begin by comparing the enum tags,
// before either inspecting their contents (if they match), or returning
// the `cmp::Ordering` of comparing the enum tags.
// ```
// match partial_cmp(self_tag, other_tag) {
// Some(Ordering::Equal) => match (self, other) {
// (Self::A(self_0), Self::A(other_0)) => partial_cmp(self_0, other_0),
// (Self::B(self_0), Self::B(other_0)) => partial_cmp(self_0, other_0),
// _ => Some(Ordering::Equal)
// }
// cmp => cmp
// }
// ```
// If we have any certain enum layouts, flipping this results in better codegen
// ```
// match (self, other) {
// (Self::A(self_0), Self::A(other_0)) => partial_cmp(self_0, other_0),
// _ => partial_cmp(self_tag, other_tag)
// }
// ```
// Reference: https://github.com/rust-lang/rust/pull/103659#issuecomment-1328126354

if !tag_then_data
&& let ExprKind::Match(_, arms) = &mut expr1.kind
&& let Some(last) = arms.last_mut()
&& let PatKind::Wild = last.pat.kind {
last.body = expr2;
expr1
} else {
let eq_arm =
cx.arm(span, cx.pat_some(span, cx.pat_path(span, equal_path.clone())), expr1);
let neq_arm =
cx.arm(span, cx.pat_ident(span, test_id), cx.expr_ident(span, test_id));
cx.expr_match(span, expr2, vec![eq_arm, neq_arm])
}
}
CsFold::Fieldless => cx.expr_some(span, cx.expr_path(equal_path.clone())),
},
Expand Down
51 changes: 23 additions & 28 deletions tests/ui/deriving/deriving-all-codegen.stdout
Original file line number Diff line number Diff line change
Expand Up @@ -888,23 +888,20 @@ impl ::core::cmp::PartialOrd for Mixed {
-> ::core::option::Option<::core::cmp::Ordering> {
let __self_tag = ::core::intrinsics::discriminant_value(self);
let __arg1_tag = ::core::intrinsics::discriminant_value(other);
match ::core::cmp::PartialOrd::partial_cmp(&__self_tag, &__arg1_tag) {
::core::option::Option::Some(::core::cmp::Ordering::Equal) =>
match (self, other) {
(Mixed::R(__self_0), Mixed::R(__arg1_0)) =>
::core::cmp::PartialOrd::partial_cmp(__self_0, __arg1_0),
(Mixed::S { d1: __self_0, d2: __self_1 }, Mixed::S {
d1: __arg1_0, d2: __arg1_1 }) =>
match ::core::cmp::PartialOrd::partial_cmp(__self_0,
__arg1_0) {
::core::option::Option::Some(::core::cmp::Ordering::Equal)
=> ::core::cmp::PartialOrd::partial_cmp(__self_1, __arg1_1),
cmp => cmp,
},
_ =>
::core::option::Option::Some(::core::cmp::Ordering::Equal),
match (self, other) {
(Mixed::R(__self_0), Mixed::R(__arg1_0)) =>
::core::cmp::PartialOrd::partial_cmp(__self_0, __arg1_0),
(Mixed::S { d1: __self_0, d2: __self_1 }, Mixed::S {
d1: __arg1_0, d2: __arg1_1 }) =>
match ::core::cmp::PartialOrd::partial_cmp(__self_0, __arg1_0)
{
::core::option::Option::Some(::core::cmp::Ordering::Equal)
=> ::core::cmp::PartialOrd::partial_cmp(__self_1, __arg1_1),
cmp => cmp,
},
cmp => cmp,
_ =>
::core::cmp::PartialOrd::partial_cmp(&__self_tag,
&__arg1_tag),
}
}
}
Expand Down Expand Up @@ -1018,18 +1015,16 @@ impl ::core::cmp::PartialOrd for Fielded {
-> ::core::option::Option<::core::cmp::Ordering> {
let __self_tag = ::core::intrinsics::discriminant_value(self);
let __arg1_tag = ::core::intrinsics::discriminant_value(other);
match ::core::cmp::PartialOrd::partial_cmp(&__self_tag, &__arg1_tag) {
::core::option::Option::Some(::core::cmp::Ordering::Equal) =>
match (self, other) {
(Fielded::X(__self_0), Fielded::X(__arg1_0)) =>
::core::cmp::PartialOrd::partial_cmp(__self_0, __arg1_0),
(Fielded::Y(__self_0), Fielded::Y(__arg1_0)) =>
::core::cmp::PartialOrd::partial_cmp(__self_0, __arg1_0),
(Fielded::Z(__self_0), Fielded::Z(__arg1_0)) =>
::core::cmp::PartialOrd::partial_cmp(__self_0, __arg1_0),
_ => unsafe { ::core::intrinsics::unreachable() }
},
cmp => cmp,
match (self, other) {
(Fielded::X(__self_0), Fielded::X(__arg1_0)) =>
::core::cmp::PartialOrd::partial_cmp(__self_0, __arg1_0),
(Fielded::Y(__self_0), Fielded::Y(__arg1_0)) =>
::core::cmp::PartialOrd::partial_cmp(__self_0, __arg1_0),
(Fielded::Z(__self_0), Fielded::Z(__arg1_0)) =>
::core::cmp::PartialOrd::partial_cmp(__self_0, __arg1_0),
_ =>
::core::cmp::PartialOrd::partial_cmp(&__self_tag,
&__arg1_tag),
}
}
}
Expand Down