From f719ebede59873b938d1107efa6900e0d8d56998 Mon Sep 17 00:00:00 2001 From: mendess Date: Wed, 27 Dec 2023 15:20:11 +0000 Subject: [PATCH 1/2] Add support for defaulted methods --- Cargo.lock | 32 ++++----- trait-variant/examples/variant.rs | 16 +++++ trait-variant/src/variant.rs | 116 +++++++++++++++++++++++++----- 3 files changed, 132 insertions(+), 32 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index d5b5105..190a817 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -55,30 +55,30 @@ checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" [[package]] name = "libc" -version = "0.2.150" +version = "0.2.153" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89d92a4743f9a61002fae18374ed11e7973f530cb3a3255fb354818118b2203c" +checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd" [[package]] name = "memchr" -version = "2.6.4" +version = "2.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f665ee40bc4a3c5590afb1e9677db74a508659dfd71e126420da8274909a0167" +checksum = "523dc4f511e55ab87b694dc30d0f820d60906ef06413f93d4d7a1385599cc149" [[package]] name = "miniz_oxide" -version = "0.7.1" +version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7810e0be55b428ada41041c41f32c9f1a42817901b4ccf45fa3d4b6561e74c7" +checksum = "9d811f3e15f28568be3407c8e7fdb6514c1cda3cb30683f15b6a1a1dc4ea14a7" dependencies = [ "adler", ] [[package]] name = "object" -version = "0.32.1" +version = "0.32.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9cf5f9dd3933bd50a9e1f149ec995f39ae2c496d31fd772c1fd45ebc27e902b0" +checksum = "a6a622008b6e321afc04970976f62ee297fdbaa6f95318ca343e3eebb9648441" dependencies = [ "memchr", ] @@ -91,18 +91,18 @@ checksum = "8afb450f006bf6385ca15ef45d71d2288452bc3683ce2e2cacc0d18e4be60b58" [[package]] name = "proc-macro2" -version = "1.0.70" +version = "1.0.78" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "39278fbbf5fb4f646ce651690877f89d1c5811a3d4acb27700c1cb3cdb78fd3b" +checksum = "e2422ad645d89c99f8f3e6b88a9fdeca7fabeac836b1002371c4367c8f984aae" dependencies = [ "unicode-ident", ] [[package]] name = "quote" -version = "1.0.33" +version = "1.0.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5267fca4496028628a95160fc423a33e8b2e6af8a5302579e322e4b520293cae" +checksum = "291ec9ab5efd934aaf503a6466c5d5251535d108ee747472c3977cc5acc868ef" dependencies = [ "proc-macro2", ] @@ -115,9 +115,9 @@ checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76" [[package]] name = "syn" -version = "2.0.39" +version = "2.0.48" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23e78b90f2fcf45d3e842032ce32e3f2d1545ba6636271dcbf24fa306d87be7a" +checksum = "0f3531638e407dfc0814761abb7c00a5b54992b849452a0646b7f65c9f770f3f" dependencies = [ "proc-macro2", "quote", @@ -126,9 +126,9 @@ dependencies = [ [[package]] name = "tokio" -version = "1.35.0" +version = "1.36.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "841d45b238a16291a4e1584e61820b8ae57d696cc5015c459c229ccc6990cc1c" +checksum = "61285f6515fa018fb2d1e46eb21223fff441ee8db5d0f1435e8ab4f5cdb80931" dependencies = [ "backtrace", "pin-project-lite", diff --git a/trait-variant/examples/variant.rs b/trait-variant/examples/variant.rs index dcd7189..c82a9e5 100644 --- a/trait-variant/examples/variant.rs +++ b/trait-variant/examples/variant.rs @@ -17,9 +17,25 @@ pub trait LocalIntFactory { Self: 'a; async fn make(&self, x: u32, y: &str) -> i32; + async fn make_mut(&mut self); fn stream(&self) -> impl Iterator; fn call(&self) -> u32; fn another_async(&self, input: Result<(), &str>) -> Self::MyFut<'_>; + async fn defaulted(&self) -> i32 { + self.make(10, "10").await + } + async fn defaulted_mut(&mut self) -> i32 { + self.make(10, "10").await + } + async fn defaulted_mut_2(&mut self) { + self.make_mut().await + } + async fn defaulted_move(self) -> i32 + where + Self: Sized, + { + self.make(10, "10").await + } } #[allow(dead_code)] diff --git a/trait-variant/src/variant.rs b/trait-variant/src/variant.rs index 0a61f33..0c600c7 100644 --- a/trait-variant/src/variant.rs +++ b/trait-variant/src/variant.rs @@ -8,17 +8,18 @@ use std::iter; -use proc_macro2::TokenStream; +use proc_macro2::{Span, TokenStream}; use quote::quote; use syn::{ parse::{Parse, ParseStream}, - parse_macro_input, parse_quote, + parse_macro_input, punctuated::Punctuated, token::Plus, - Error, FnArg, GenericParam, Ident, ItemTrait, Pat, PatType, Result, ReturnType, Signature, - Token, TraitBound, TraitItem, TraitItemConst, TraitItemFn, TraitItemType, Type, TypeGenerics, - TypeImplTrait, TypeParam, TypeParamBound, + Error, FnArg, GenericParam, Ident, ItemTrait, Pat, PatType, Receiver, Result, ReturnType, + Signature, Token, TraitBound, TraitItem, TraitItemConst, TraitItemFn, TraitItemType, Type, + TypeGenerics, TypeImplTrait, TypeParam, TypeParamBound, WhereClause, }; +use syn::{parse_quote, TypeReference}; struct Attrs { variant: MakeVariant, @@ -127,10 +128,10 @@ fn mk_variant( // Transforms one item declaration within the definition if it has `async fn` and/or `-> impl Trait` return types by adding new bounds. fn transform_item(item: &TraitItem, bounds: &Vec) -> TraitItem { - let TraitItem::Fn(fn_item @ TraitItemFn { sig, .. }) = item else { + let TraitItem::Fn(fn_item @ TraitItemFn { sig, default, .. }) = item else { return item.clone(); }; - let (arrow, output) = if sig.asyncness.is_some() { + let (sig, default) = if sig.asyncness.is_some() { let orig = match &sig.output { ReturnType::Default => quote! { () }, ReturnType::Type(_, ty) => quote! { #ty }, @@ -142,7 +143,22 @@ fn transform_item(item: &TraitItem, bounds: &Vec) -> TraitItem { .chain(bounds.iter().cloned()) .collect(), }); - (syn::parse2(quote! { -> }).unwrap(), ty) + let mut sig = sig.clone(); + if default.is_some() { + add_receiver_bounds(&mut sig); + } + + ( + Signature { + asyncness: None, + output: ReturnType::Type(syn::parse2(quote! { -> }).unwrap(), Box::new(ty)), + ..sig.clone() + }, + fn_item + .default + .as_ref() + .map(|b| syn::parse2(quote! { { async move #b } }).unwrap()), + ) } else { match &sig.output { ReturnType::Type(arrow, ty) => match &**ty { @@ -151,7 +167,13 @@ fn transform_item(item: &TraitItem, bounds: &Vec) -> TraitItem { impl_token: it.impl_token, bounds: it.bounds.iter().chain(bounds).cloned().collect(), }); - (*arrow, ty) + ( + Signature { + output: ReturnType::Type(*arrow, Box::new(ty)), + ..sig.clone() + }, + fn_item.default.clone(), + ) } _ => return item.clone(), }, @@ -159,11 +181,8 @@ fn transform_item(item: &TraitItem, bounds: &Vec) -> TraitItem { } }; TraitItem::Fn(TraitItemFn { - sig: Signature { - asyncness: None, - output: ReturnType::Type(arrow, Box::new(output)), - ..sig.clone() - }, + sig, + default, ..fn_item.clone() }) } @@ -182,9 +201,29 @@ fn mk_blanket_impl(variant: &Ident, tr: &ItemTrait) -> TokenStream { blanket_generics .params .push(GenericParam::Type(blanket_bound)); - let (blanket_impl_generics, _ty, blanket_where_clause) = &blanket_generics.split_for_impl(); + let (blanket_impl_generics, _ty, blanket_where_clause) = &mut blanket_generics.split_for_impl(); + let self_is_sync = tr.items.iter().any(|item| { + matches!( + item, + TraitItem::Fn(TraitItemFn { + default: Some(_), + .. + }) + ) + }); + + let mut blanket_where_clause = blanket_where_clause + .map(|w| w.predicates.clone()) + .unwrap_or_default(); + + if self_is_sync { + blanket_where_clause.push(parse_quote! { for<'s> &'s Self: Send }); + } + quote! { - impl #blanket_impl_generics #orig #orig_ty_generics for #blanket #blanket_where_clause + impl #blanket_impl_generics #orig #orig_ty_generics for #blanket + where + #blanket_where_clause { #(#items)* } @@ -229,6 +268,7 @@ fn blanket_impl_item( } else { quote! {} }; + quote! { #sig { ::#ident(#(#args),*)#maybe_await @@ -246,3 +286,47 @@ fn blanket_impl_item( _ => Error::new_spanned(item, "unsupported item type").into_compile_error(), } } + +fn add_receiver_bounds(sig: &mut Signature) { + let Some(FnArg::Receiver(Receiver { ty, reference, .. })) = sig.inputs.first_mut() else { + return; + }; + let Type::Reference( + recv_ty @ TypeReference { + mutability: None, .. + }, + ) = &mut **ty + else { + return; + }; + let Some((_and, lt)) = reference else { + return; + }; + + let lifetime = syn::Lifetime { + apostrophe: Span::mixed_site(), + ident: Ident::new("the_self_lt", Span::mixed_site()), + }; + sig.generics.params.insert( + 0, + syn::GenericParam::Lifetime(syn::LifetimeParam { + lifetime: lifetime.clone(), + colon_token: None, + bounds: Default::default(), + attrs: Default::default(), + }), + ); + recv_ty.lifetime = Some(lifetime.clone()); + *lt = Some(lifetime); + let predicate = parse_quote! { #recv_ty: Send }; + + if let Some(wh) = &mut sig.generics.where_clause { + wh.predicates.push(predicate); + } else { + let where_clause = WhereClause { + where_token: Token![where](Span::mixed_site()), + predicates: Punctuated::from_iter([predicate]), + }; + sig.generics.where_clause = Some(where_clause); + } +} From 5fb6904c5ba7ef70331aabb04ae82f74deb5db7d Mon Sep 17 00:00:00 2001 From: mendess Date: Wed, 3 Jan 2024 16:38:18 +0000 Subject: [PATCH 2/2] Explicitly move all arguments into the async block --- trait-variant/Cargo.toml | 2 +- trait-variant/examples/variant.rs | 4 ++-- trait-variant/src/variant.rs | 39 +++++++++++++++++++++---------- 3 files changed, 30 insertions(+), 15 deletions(-) diff --git a/trait-variant/Cargo.toml b/trait-variant/Cargo.toml index f47aea9..db040f3 100644 --- a/trait-variant/Cargo.toml +++ b/trait-variant/Cargo.toml @@ -25,7 +25,7 @@ proc-macro = true [dependencies] proc-macro2 = "1.0" quote = "1.0" -syn = { version = "2.0", features = ["full"] } +syn = { version = "2.0", features = ["full", "visit-mut"] } [dev-dependencies] tokio = { version = "1", features = ["rt"] } diff --git a/trait-variant/examples/variant.rs b/trait-variant/examples/variant.rs index c82a9e5..3a440ae 100644 --- a/trait-variant/examples/variant.rs +++ b/trait-variant/examples/variant.rs @@ -21,8 +21,8 @@ pub trait LocalIntFactory { fn stream(&self) -> impl Iterator; fn call(&self) -> u32; fn another_async(&self, input: Result<(), &str>) -> Self::MyFut<'_>; - async fn defaulted(&self) -> i32 { - self.make(10, "10").await + async fn defaulted(&self, x: u32) -> i32 { + self.make(x, "10").await } async fn defaulted_mut(&mut self) -> i32 { self.make(10, "10").await diff --git a/trait-variant/src/variant.rs b/trait-variant/src/variant.rs index 0c600c7..66f4236 100644 --- a/trait-variant/src/variant.rs +++ b/trait-variant/src/variant.rs @@ -11,15 +11,8 @@ use std::iter; use proc_macro2::{Span, TokenStream}; use quote::quote; use syn::{ - parse::{Parse, ParseStream}, - parse_macro_input, - punctuated::Punctuated, - token::Plus, - Error, FnArg, GenericParam, Ident, ItemTrait, Pat, PatType, Receiver, Result, ReturnType, - Signature, Token, TraitBound, TraitItem, TraitItemConst, TraitItemFn, TraitItemType, Type, - TypeGenerics, TypeImplTrait, TypeParam, TypeParamBound, WhereClause, + parse::{Parse, ParseStream}, parse_macro_input, parse_quote, punctuated::Punctuated, token::Plus, Error, FnArg, GenericParam, Ident, ItemTrait, Pat, PatIdent, PatType, Receiver, Result, ReturnType, Signature, Token, TraitBound, TraitItem, TraitItemConst, TraitItemFn, TraitItemType, Type, TypeGenerics, TypeImplTrait, TypeParam, TypeParamBound, TypeReference, WhereClause }; -use syn::{parse_quote, TypeReference}; struct Attrs { variant: MakeVariant, @@ -154,10 +147,32 @@ fn transform_item(item: &TraitItem, bounds: &Vec) -> TraitItem { output: ReturnType::Type(syn::parse2(quote! { -> }).unwrap(), Box::new(ty)), ..sig.clone() }, - fn_item - .default - .as_ref() - .map(|b| syn::parse2(quote! { { async move #b } }).unwrap()), + fn_item.default.as_ref().map(|b| { + let items = sig.inputs.iter().map(|i| match i { + FnArg::Receiver(Receiver { self_token, .. }) => { + quote! { let __self = #self_token; } + } + FnArg::Typed(PatType { pat, .. }) => match pat.as_ref() { + Pat::Ident(PatIdent { ident, .. }) => quote! { let #ident = #ident; }, + _ => todo!(), + }, + }); + + struct ReplaceSelfVisitor; + impl syn::visit_mut::VisitMut for ReplaceSelfVisitor { + fn visit_ident_mut(&mut self, ident: &mut syn::Ident) { + if ident == "self" { + *ident = syn::Ident::new("__self", ident.span()); + } + syn::visit_mut::visit_ident_mut(self, ident); + } + } + + let mut block = b.clone(); + syn::visit_mut::visit_block_mut(&mut ReplaceSelfVisitor, &mut block); + + parse_quote! { { async move { #(#items)* #block} } } + }), ) } else { match &sig.output {