Skip to content

Commit 3d76910

Browse files
committed
feat: trait_variant::make supports rewriting of the original trait.
1 parent f1e171e commit 3d76910

File tree

3 files changed

+87
-41
lines changed

3 files changed

+87
-41
lines changed

trait-variant/examples/variant.rs

+14
Original file line numberDiff line numberDiff line change
@@ -43,4 +43,18 @@ where
4343
fn build<T: Display>(&self, items: impl Iterator<Item = T>) -> Self::B<T>;
4444
}
4545

46+
#[trait_variant::make(Send + Sync)]
47+
pub trait GenericTraitWithBounds<'x, S: Sync, Y, const X: usize>
48+
where
49+
Y: Sync,
50+
{
51+
const CONST: usize = 3;
52+
type F;
53+
type A<const ANOTHER_CONST: u8>;
54+
type B<T: Display>: FromIterator<T>;
55+
56+
async fn take(&self, s: S);
57+
fn build<T: Display>(&self, items: impl Iterator<Item = T>) -> Self::B<T>;
58+
}
59+
4660
fn main() {}

trait-variant/src/lib.rs

+16-3
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,15 @@ mod variant;
1414
/// fn` and/or `-> impl Trait` return types.
1515
///
1616
/// ```
17-
/// #[trait_variant::make(IntFactory: Send)]
18-
/// trait LocalIntFactory {
17+
/// #[trait_variant::make(Send)]
18+
/// trait IntFactory {
1919
/// async fn make(&self) -> i32;
2020
/// fn stream(&self) -> impl Iterator<Item = i32>;
2121
/// fn call(&self) -> u32;
2222
/// }
2323
/// ```
2424
///
25-
/// The above example causes a second trait called `IntFactory` to be created:
25+
/// The above example causes the trait to be rewritten as:
2626
///
2727
/// ```
2828
/// # use core::future::Future;
@@ -35,6 +35,19 @@ mod variant;
3535
///
3636
/// Note that ordinary methods such as `call` are not affected.
3737
///
38+
/// If you want to preserve an original trait untouched, `make` can be used to create a new trait with bounds on `async
39+
/// fn` and/or `-> impl Trait` return types.
40+
///
41+
/// ```
42+
/// #[trait_variant::make(IntFactory: Send)]
43+
/// trait LocalIntFactory {
44+
/// async fn make(&self) -> i32;
45+
/// fn stream(&self) -> impl Iterator<Item = i32>;
46+
/// fn call(&self) -> u32;
47+
/// }
48+
/// ```
49+
///
50+
/// The example causes a second trait called `IntFactory` to be created.
3851
/// Implementers of the trait can choose to implement the variant instead of the
3952
/// original trait. The macro creates a blanket impl which ensures that any type
4053
/// which implements the variant also implements the original trait.

trait-variant/src/variant.rs

+57-38
Original file line numberDiff line numberDiff line change
@@ -32,20 +32,32 @@ impl Parse for Attrs {
3232
}
3333
}
3434

35-
struct MakeVariant {
36-
name: Ident,
37-
#[allow(unused)]
38-
colon: Token![:],
39-
bounds: Punctuated<TraitBound, Plus>,
35+
enum MakeVariant {
36+
Create {
37+
name: Ident,
38+
#[allow(dead_code)]
39+
colon: Token![:],
40+
bounds: Punctuated<TraitBound, Plus>,
41+
},
42+
Rewrite {
43+
bounds: Punctuated<TraitBound, Plus>,
44+
},
4045
}
4146

4247
impl Parse for MakeVariant {
4348
fn parse(input: ParseStream) -> Result<Self> {
44-
Ok(Self {
45-
name: input.parse()?,
46-
colon: input.parse()?,
47-
bounds: input.parse_terminated(TraitBound::parse, Token![+])?,
48-
})
49+
let variant = if input.peek(Ident) && input.peek2(Token![:]) {
50+
MakeVariant::Create {
51+
name: input.parse()?,
52+
colon: input.parse()?,
53+
bounds: input.parse_terminated(TraitBound::parse, Token![+])?,
54+
}
55+
} else {
56+
MakeVariant::Rewrite {
57+
bounds: input.parse_terminated(TraitBound::parse, Token![+])?,
58+
}
59+
};
60+
Ok(variant)
4961
}
5062
}
5163

@@ -56,43 +68,51 @@ pub fn make(
5668
let attrs = parse_macro_input!(attr as Attrs);
5769
let item = parse_macro_input!(item as ItemTrait);
5870

59-
let maybe_allow_async_lint = if attrs
60-
.variant
61-
.bounds
62-
.iter()
63-
.any(|b| b.path.segments.last().unwrap().ident == "Send")
64-
{
65-
quote! { #[allow(async_fn_in_trait)] }
66-
} else {
67-
quote! {}
68-
};
71+
match attrs.variant {
72+
MakeVariant::Create { name, bounds, .. } => {
73+
let maybe_allow_async_lint = if bounds
74+
.iter()
75+
.any(|b| b.path.segments.last().unwrap().ident == "Send")
76+
{
77+
quote! { #[allow(async_fn_in_trait)] }
78+
} else {
79+
quote! {}
80+
};
6981

70-
let variant = mk_variant(&attrs, &item);
71-
let blanket_impl = mk_blanket_impl(&attrs, &item);
82+
let variant = mk_variant(&name, bounds, &item);
83+
let blanket_impl = mk_blanket_impl(&name, &item);
7284

73-
quote! {
74-
#maybe_allow_async_lint
75-
#item
85+
quote! {
86+
#maybe_allow_async_lint
87+
#item
7688

77-
#variant
89+
#variant
7890

79-
#blanket_impl
91+
#blanket_impl
92+
}
93+
.into()
94+
}
95+
MakeVariant::Rewrite { bounds, .. } => {
96+
let variant = mk_variant(&item.ident, bounds, &item);
97+
quote! {
98+
#variant
99+
}
100+
.into()
101+
}
80102
}
81-
.into()
82103
}
83104

84-
fn mk_variant(attrs: &Attrs, tr: &ItemTrait) -> TokenStream {
85-
let MakeVariant {
86-
ref name,
87-
colon: _,
88-
ref bounds,
89-
} = attrs.variant;
90-
let bounds: Vec<_> = bounds
105+
fn mk_variant(
106+
variant: &Ident,
107+
with_bounds: Punctuated<TraitBound, Plus>,
108+
tr: &ItemTrait,
109+
) -> TokenStream {
110+
let bounds: Vec<_> = with_bounds
91111
.into_iter()
92112
.map(|b| TypeParamBound::Trait(b.clone()))
93113
.collect();
94114
let variant = ItemTrait {
95-
ident: name.clone(),
115+
ident: variant.clone(),
96116
supertraits: tr.supertraits.iter().chain(&bounds).cloned().collect(),
97117
items: tr
98118
.items
@@ -160,9 +180,8 @@ fn transform_item(item: &TraitItem, bounds: &Vec<TypeParamBound>) -> TraitItem {
160180
})
161181
}
162182

163-
fn mk_blanket_impl(attrs: &Attrs, tr: &ItemTrait) -> TokenStream {
183+
fn mk_blanket_impl(variant: &Ident, tr: &ItemTrait) -> TokenStream {
164184
let orig = &tr.ident;
165-
let variant = &attrs.variant.name;
166185
let (_impl, orig_ty_generics, _where) = &tr.generics.split_for_impl();
167186
let items = tr
168187
.items

0 commit comments

Comments
 (0)