Skip to content

Split fallible infallible folding #772

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Oct 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
288 changes: 222 additions & 66 deletions chalk-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,17 @@ extern crate proc_macro;
use proc_macro2::{Span, TokenStream};
use quote::quote;
use quote::ToTokens;
use syn::{parse_quote, DeriveInput, GenericParam, Ident, TypeParamBound};
use syn::{parse_quote, DeriveInput, Ident, TypeParam, TypeParamBound};

use synstructure::decl_derive;

/// Checks whether a generic parameter has a `: HasInterner` bound
fn has_interner(param: &GenericParam) -> Option<&Ident> {
fn has_interner(param: &TypeParam) -> Option<&Ident> {
bounded_by_trait(param, "HasInterner")
}

/// Checks whether a generic parameter has a `: Interner` bound
fn is_interner(param: &GenericParam) -> Option<&Ident> {
fn is_interner(param: &TypeParam) -> Option<&Ident> {
bounded_by_trait(param, "Interner")
}

Expand All @@ -28,48 +28,44 @@ fn has_interner_attr(input: &DeriveInput) -> Option<TokenStream> {
)
}

fn bounded_by_trait<'p>(param: &'p GenericParam, name: &str) -> Option<&'p Ident> {
fn bounded_by_trait<'p>(param: &'p TypeParam, name: &str) -> Option<&'p Ident> {
let name = Some(String::from(name));
match param {
GenericParam::Type(ref t) => t.bounds.iter().find_map(|b| {
if let TypeParamBound::Trait(trait_bound) = b {
if trait_bound
.path
.segments
.last()
.map(|s| s.ident.to_string())
== name
{
return Some(&t.ident);
}
param.bounds.iter().find_map(|b| {
if let TypeParamBound::Trait(trait_bound) = b {
if trait_bound
.path
.segments
.last()
.map(|s| s.ident.to_string())
== name
{
return Some(&param.ident);
}
None
}),
_ => None,
}
}
None
})
}

fn get_generic_param(input: &DeriveInput) -> &GenericParam {
match input.generics.params.len() {
1 => {}
fn get_intern_param(input: &DeriveInput) -> Option<(DeriveKind, &Ident)> {
let mut params = input.generics.type_params().filter_map(|param| {
has_interner(param)
.map(|ident| (DeriveKind::FromHasInterner, ident))
.or_else(|| is_interner(param).map(|ident| (DeriveKind::FromInterner, ident)))
});

0 => panic!(
"deriving this trait requires a single type parameter or a `#[has_interner]` attr"
),
let param = params.next();
assert!(params.next().is_none(), "deriving this trait only works with at most one type parameter that implements HasInterner or Interner");

_ => panic!("deriving this trait only works with a single type parameter"),
};
&input.generics.params[0]
param
}

fn get_generic_param_name(input: &DeriveInput) -> Option<&Ident> {
match get_generic_param(input) {
GenericParam::Type(t) => Some(&t.ident),
_ => None,
}
fn get_intern_param_name(input: &DeriveInput) -> &Ident {
get_intern_param(input)
.expect("deriving this trait requires a parameter that implements HasInterner or Interner")
.1
}

fn find_interner(s: &mut synstructure::Structure) -> (TokenStream, DeriveKind) {
fn try_find_interner(s: &mut synstructure::Structure) -> Option<(TokenStream, DeriveKind)> {
let input = s.ast();

if let Some(arg) = has_interner_attr(input) {
Expand All @@ -79,35 +75,40 @@ fn find_interner(s: &mut synstructure::Structure) -> (TokenStream, DeriveKind) {
// struct S {
//
// }
return (arg, DeriveKind::FromHasInternerAttr);
return Some((arg, DeriveKind::FromHasInternerAttr));
}

let generic_param0 = get_generic_param(input);

if let Some(param) = has_interner(generic_param0) {
// HasInterner bound:
//
// Example:
//
// struct Binders<T: HasInterner> { }
s.add_impl_generic(parse_quote! { _I });

s.add_where_predicate(parse_quote! { _I: ::chalk_ir::interner::Interner });
s.add_where_predicate(
parse_quote! { #param: ::chalk_ir::interner::HasInterner<Interner = _I> },
);
get_intern_param(input).map(|generic_param0| match generic_param0 {
(DeriveKind::FromHasInterner, param) => {
// HasInterner bound:
//
// Example:
//
// struct Binders<T: HasInterner> { }
s.add_impl_generic(parse_quote! { _I });

s.add_where_predicate(parse_quote! { _I: ::chalk_ir::interner::Interner });
s.add_where_predicate(
parse_quote! { #param: ::chalk_ir::interner::HasInterner<Interner = _I> },
);

(quote! { _I }, DeriveKind::FromHasInterner)
}
(DeriveKind::FromInterner, i) => {
// Interner bound:
//
// Example:
//
// struct Foo<I: Interner> { }
(quote! { #i }, DeriveKind::FromInterner)
}
_ => unreachable!(),
})
}

(quote! { _I }, DeriveKind::FromHasInterner)
} else if let Some(i) = is_interner(generic_param0) {
// Interner bound:
//
// Example:
//
// struct Foo<I: Interner> { }
(quote! { #i }, DeriveKind::FromInterner)
} else {
panic!("deriving this trait requires a parameter that implements HasInterner or Interner",);
}
fn find_interner(s: &mut synstructure::Structure) -> (TokenStream, DeriveKind) {
try_find_interner(s)
.expect("deriving this trait requires a `#[has_interner]` attr or a parameter that implements HasInterner or Interner")
}

#[derive(Copy, Clone, PartialEq)]
Expand All @@ -117,6 +118,7 @@ enum DeriveKind {
FromInterner,
}

decl_derive!([FallibleTypeFolder, attributes(has_interner)] => derive_fallible_type_folder);
decl_derive!([HasInterner, attributes(has_interner)] => derive_has_interner);
decl_derive!([TypeVisitable, attributes(has_interner)] => derive_type_visitable);
decl_derive!([TypeSuperVisitable, attributes(has_interner)] => derive_type_super_visitable);
Expand Down Expand Up @@ -173,7 +175,7 @@ fn derive_any_type_visitable(
});

if kind == DeriveKind::FromHasInterner {
let param = get_generic_param_name(input).unwrap();
let param = get_intern_param_name(input);
s.add_where_predicate(parse_quote! { #param: ::chalk_ir::visit::TypeVisitable<#interner> });
}

Expand Down Expand Up @@ -269,29 +271,183 @@ fn derive_type_foldable(mut s: synstructure::Structure) -> TokenStream {
vi.construct(|_, index| {
let bind = &bindings[index];
quote! {
::chalk_ir::fold::TypeFoldable::fold_with(#bind, folder, outer_binder)?
::chalk_ir::fold::TypeFoldable::try_fold_with(#bind, folder, outer_binder)?
}
})
});

let input = s.ast();

if kind == DeriveKind::FromHasInterner {
let param = get_generic_param_name(input).unwrap();
let param = get_intern_param_name(input);
s.add_where_predicate(parse_quote! { #param: ::chalk_ir::fold::TypeFoldable<#interner> });
};

s.add_bounds(synstructure::AddBounds::None);
s.bound_impl(
quote!(::chalk_ir::fold::TypeFoldable<#interner>),
quote! {
fn fold_with<E>(
fn try_fold_with<E>(
self,
folder: &mut dyn ::chalk_ir::fold::TypeFolder < #interner, Error = E >,
folder: &mut dyn ::chalk_ir::fold::FallibleTypeFolder < #interner, Error = E >,
outer_binder: ::chalk_ir::DebruijnIndex,
) -> ::std::result::Result<Self, E> {
Ok(match self { #body })
}
},
)
}

fn derive_fallible_type_folder(mut s: synstructure::Structure) -> TokenStream {
let interner = try_find_interner(&mut s).map_or_else(
|| {
s.add_impl_generic(parse_quote! { _I });
s.add_where_predicate(parse_quote! { _I: ::chalk_ir::interner::Interner });
quote! { _I }
},
|(interner, _)| interner,
);
s.underscore_const(true);
s.unbound_impl(
quote!(::chalk_ir::fold::FallibleTypeFolder<#interner>),
quote! {
type Error = ::core::convert::Infallible;

fn as_dyn(&mut self) -> &mut dyn ::chalk_ir::fold::FallibleTypeFolder<I, Error = Self::Error> {
self
}

fn try_fold_ty(
&mut self,
ty: ::chalk_ir::Ty<#interner>,
outer_binder: ::chalk_ir::DebruijnIndex,
) -> ::core::result::Result<::chalk_ir::Ty<#interner>, Self::Error> {
::core::result::Result::Ok(::chalk_ir::fold::TypeFolder::fold_ty(self, ty, outer_binder))
}

fn try_fold_lifetime(
&mut self,
lifetime: ::chalk_ir::Lifetime<#interner>,
outer_binder: ::chalk_ir::DebruijnIndex,
) -> ::core::result::Result<::chalk_ir::Lifetime<#interner>, Self::Error> {
::core::result::Result::Ok(::chalk_ir::fold::TypeFolder::fold_lifetime(self, lifetime, outer_binder))
}

fn try_fold_const(
&mut self,
constant: ::chalk_ir::Const<#interner>,
outer_binder: ::chalk_ir::DebruijnIndex,
) -> ::core::result::Result<::chalk_ir::Const<#interner>, Self::Error> {
::core::result::Result::Ok(::chalk_ir::fold::TypeFolder::fold_const(self, constant, outer_binder))
}

fn try_fold_program_clause(
&mut self,
clause: ::chalk_ir::ProgramClause<#interner>,
outer_binder: ::chalk_ir::DebruijnIndex,
) -> ::core::result::Result<::chalk_ir::ProgramClause<#interner>, Self::Error> {
::core::result::Result::Ok(::chalk_ir::fold::TypeFolder::fold_program_clause(self, clause, outer_binder))
}

fn try_fold_goal(
&mut self,
goal: ::chalk_ir::Goal<#interner>,
outer_binder: ::chalk_ir::DebruijnIndex,
) -> ::core::result::Result<::chalk_ir::Goal<#interner>, Self::Error> {
::core::result::Result::Ok(::chalk_ir::fold::TypeFolder::fold_goal(self, goal, outer_binder))
}

fn forbid_free_vars(&self) -> bool {
::chalk_ir::fold::TypeFolder::forbid_free_vars(self)
}

fn try_fold_free_var_ty(
&mut self,
bound_var: ::chalk_ir::BoundVar,
outer_binder: ::chalk_ir::DebruijnIndex,
) -> ::core::result::Result<::chalk_ir::Ty<#interner>, Self::Error> {
::core::result::Result::Ok(::chalk_ir::fold::TypeFolder::fold_free_var_ty(self, bound_var, outer_binder))
}

fn try_fold_free_var_lifetime(
&mut self,
bound_var: ::chalk_ir::BoundVar,
outer_binder: ::chalk_ir::DebruijnIndex,
) -> ::core::result::Result<::chalk_ir::Lifetime<#interner>, Self::Error> {
::core::result::Result::Ok(::chalk_ir::fold::TypeFolder::fold_free_var_lifetime(self, bound_var, outer_binder))
}

fn try_fold_free_var_const(
&mut self,
ty: ::chalk_ir::Ty<#interner>,
bound_var: ::chalk_ir::BoundVar,
outer_binder: ::chalk_ir::DebruijnIndex,
) -> ::core::result::Result<::chalk_ir::Const<#interner>, Self::Error> {
::core::result::Result::Ok(::chalk_ir::fold::TypeFolder::fold_free_var_const(self, ty, bound_var, outer_binder))
}

fn forbid_free_placeholders(&self) -> bool {
::chalk_ir::fold::TypeFolder::forbid_free_placeholders(self)
}

fn try_fold_free_placeholder_ty(
&mut self,
universe: ::chalk_ir::PlaceholderIndex,
outer_binder: ::chalk_ir::DebruijnIndex,
) -> ::core::result::Result<::chalk_ir::Ty<#interner>, Self::Error> {
::core::result::Result::Ok(::chalk_ir::fold::TypeFolder::fold_free_placeholder_ty(self, universe, outer_binder))
}

fn try_fold_free_placeholder_lifetime(
&mut self,
universe: ::chalk_ir::PlaceholderIndex,
outer_binder: ::chalk_ir::DebruijnIndex,
) -> ::core::result::Result<::chalk_ir::Lifetime<#interner>, Self::Error> {
::core::result::Result::Ok(::chalk_ir::fold::TypeFolder::fold_free_placeholder_lifetime(self, universe, outer_binder))
}

fn try_fold_free_placeholder_const(
&mut self,
ty: ::chalk_ir::Ty<#interner>,
universe: ::chalk_ir::PlaceholderIndex,
outer_binder: ::chalk_ir::DebruijnIndex,
) -> ::core::result::Result<::chalk_ir::Const<#interner>, Self::Error> {
::core::result::Result::Ok(::chalk_ir::fold::TypeFolder::fold_free_placeholder_const(self, ty, universe, outer_binder))
}

fn forbid_inference_vars(&self) -> bool {
::chalk_ir::fold::TypeFolder::forbid_inference_vars(self)
}

fn try_fold_inference_ty(
&mut self,
var: ::chalk_ir::InferenceVar,
kind: ::chalk_ir::TyVariableKind,
outer_binder: ::chalk_ir::DebruijnIndex,
) -> ::core::result::Result<::chalk_ir::Ty<#interner>, Self::Error> {
::core::result::Result::Ok(::chalk_ir::fold::TypeFolder::fold_inference_ty(self, var, kind, outer_binder))
}

fn try_fold_inference_lifetime(
&mut self,
var: ::chalk_ir::InferenceVar,
outer_binder: ::chalk_ir::DebruijnIndex,
) -> ::core::result::Result<::chalk_ir::Lifetime<#interner>, Self::Error> {
::core::result::Result::Ok(::chalk_ir::fold::TypeFolder::fold_inference_lifetime(self, var, outer_binder))
}

fn try_fold_inference_const(
&mut self,
ty: ::chalk_ir::Ty<#interner>,
var: ::chalk_ir::InferenceVar,
outer_binder: ::chalk_ir::DebruijnIndex,
) -> ::core::result::Result<::chalk_ir::Const<#interner>, Self::Error> {
::core::result::Result::Ok(::chalk_ir::fold::TypeFolder::fold_inference_const(self, ty, var, outer_binder))
}

fn interner(&self) -> #interner {
::chalk_ir::fold::TypeFolder::interner(self)
}
},
)
}
Loading