diff --git a/CHANGELOG.md b/CHANGELOG.md index 24e1455..9e1ec24 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added +- Allow lifetime `for<'a, ...>` bounds in non-bounded generic parameters. + ### Changed - Use the `Copy` implementation for `Clone` and the `Ord` implementation for `PartialOrd` when custom bounds are present. diff --git a/src/attr/item.rs b/src/attr/item.rs index e0b6d03..c021be0 100644 --- a/src/attr/item.rs +++ b/src/attr/item.rs @@ -7,7 +7,7 @@ use syn::{ parse::{discouraged::Speculative, Parse, ParseStream}, punctuated::Punctuated, spanned::Spanned, - Attribute, Data, Ident, Meta, Path, PredicateType, Result, Token, TraitBound, + Attribute, BoundLifetimes, Data, Ident, Meta, Path, PredicateType, Result, Token, TraitBound, TraitBoundModifier, Type, TypeParamBound, TypePath, WhereClause, WherePredicate, }; @@ -241,7 +241,10 @@ impl DeriveWhere { /// Returns `true` if the given generic type parameter if present. pub fn has_type_param(&self, type_param: &Ident) -> bool { self.generics.iter().any(|generic| match generic { - Generic::NoBound(Type::Path(TypePath { qself: None, path })) => { + Generic::NoBound(GenericNoBound { + lifetimes: _, + ty: Type::Path(TypePath { qself: None, path }), + }) => { if let Some(ident) = path.get_ident() { ident == type_param } else { @@ -281,9 +284,12 @@ impl DeriveWhere { .predicates .push(WherePredicate::Type(match generic { Generic::CustomBound(type_bound) => type_bound.clone(), - Generic::NoBound(path) => PredicateType { - lifetimes: None, - bounded_ty: path.clone(), + Generic::NoBound(GenericNoBound { + lifetimes: bound_lifetimes, + ty, + }) => PredicateType { + lifetimes: bound_lifetimes.clone(), + bounded_ty: ty.clone(), colon_token: ::default(), bounds: trait_.where_bounds(item), }, @@ -293,13 +299,34 @@ impl DeriveWhere { } } -/// Holds a single generic [type](Type) or [type with bound](PredicateType). +/// Holds the first part of a [`PredicateType`] prior to the `:`. Optionally +/// contains lifetime `for` bindings. +#[derive(Eq, PartialEq)] +pub struct GenericNoBound { + /// Any `for<'a, 'b, 'etc>` bindings for the type. + lifetimes: Option, + /// The type bound to the [`DeriveTrait`]. + ty: Type, +} + +impl Parse for GenericNoBound { + fn parse(input: ParseStream) -> Result { + Ok(Self { + lifetimes: input.parse()?, + ty: input.parse()?, + }) + } +} + +/// Holds a single generic [type](GenericNoBound) with optional lifetime bounds +/// or [type with bound](PredicateType). #[derive(Eq, PartialEq)] pub enum Generic { /// Generic type with custom [specified bounds](PredicateType). CustomBound(PredicateType), - /// Generic [type](Type) which will be bound to the [`DeriveTrait`]. - NoBound(Type), + /// Generic [type](GenericNoBound) which will be bound to the + /// [`DeriveTrait`]. + NoBound(GenericNoBound), } impl Parse for Generic { @@ -307,8 +334,8 @@ impl Parse for Generic { let fork = input.fork(); // Try to parse input as a `WherePredicate`. The problem is, both expressions - // start with a Type, so starting with the `WherePredicate` is the easiest way - // of differentiating them. + // start with an optional lifetime for bound and then Type, so starting with the + // `WherePredicate` is the easiest way of differentiating them. if let Ok(where_predicate) = WherePredicate::parse(&fork) { input.advance_to(&fork); @@ -319,8 +346,8 @@ impl Parse for Generic { Err(Error::generic(where_predicate.span())) } } else { - match Type::parse(input) { - Ok(type_) => Ok(Generic::NoBound(type_)), + match GenericNoBound::parse(input) { + Ok(no_bound) => Ok(Generic::NoBound(no_bound)), Err(error) => Err(Error::generic_syntax(error.span(), error)), } } diff --git a/src/test/bound.rs b/src/test/bound.rs index 07ba600..a84d0c2 100644 --- a/src/test/bound.rs +++ b/src/test/bound.rs @@ -111,6 +111,31 @@ fn where_() -> Result<()> { ) } +#[test] +fn for_lifetime() -> Result<()> { + test_derive( + quote! { + #[derive_where(Clone; for<'a> T)] + struct Test(T, std::marker::PhantomData) where T: std::fmt::Debug; + }, + quote! { + #[automatically_derived] + impl ::core::clone::Clone for Test + where + T: std::fmt::Debug, + for<'a> T: ::core::clone::Clone + { + #[inline] + fn clone(&self) -> Self { + match self { + Test(ref __field_0, ref __field_1) => Test(::core::clone::Clone::clone(__field_0), ::core::clone::Clone::clone(__field_1)), + } + } + } + }, + ) +} + #[test] fn associated_type() -> Result<()> { test_derive(