Skip to content

Commit fc55ac7

Browse files
authored
Merge pull request #2567 from Mingun/fix-2565
Correctly process flatten fields in enum variants
2 parents 9b868ef + 2afe5b4 commit fc55ac7

File tree

7 files changed

+269
-31
lines changed

7 files changed

+269
-31
lines changed

serde_derive/src/de.rs

Lines changed: 59 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -281,11 +281,21 @@ fn deserialize_body(cont: &Container, params: &Parameters) -> Fragment {
281281
} else if let attr::Identifier::No = cont.attrs.identifier() {
282282
match &cont.data {
283283
Data::Enum(variants) => deserialize_enum(params, variants, &cont.attrs),
284-
Data::Struct(Style::Struct, fields) => {
285-
deserialize_struct(params, fields, &cont.attrs, StructForm::Struct)
286-
}
284+
Data::Struct(Style::Struct, fields) => deserialize_struct(
285+
params,
286+
fields,
287+
&cont.attrs,
288+
cont.attrs.has_flatten(),
289+
StructForm::Struct,
290+
),
287291
Data::Struct(Style::Tuple, fields) | Data::Struct(Style::Newtype, fields) => {
288-
deserialize_tuple(params, fields, &cont.attrs, TupleForm::Tuple)
292+
deserialize_tuple(
293+
params,
294+
fields,
295+
&cont.attrs,
296+
cont.attrs.has_flatten(),
297+
TupleForm::Tuple,
298+
)
289299
}
290300
Data::Struct(Style::Unit, _) => deserialize_unit_struct(params, &cont.attrs),
291301
}
@@ -459,9 +469,13 @@ fn deserialize_tuple(
459469
params: &Parameters,
460470
fields: &[Field],
461471
cattrs: &attr::Container,
472+
has_flatten: bool,
462473
form: TupleForm,
463474
) -> Fragment {
464-
assert!(!cattrs.has_flatten());
475+
assert!(
476+
!has_flatten,
477+
"tuples and tuple variants cannot have flatten fields"
478+
);
465479

466480
let field_count = fields
467481
.iter()
@@ -579,7 +593,10 @@ fn deserialize_tuple_in_place(
579593
fields: &[Field],
580594
cattrs: &attr::Container,
581595
) -> Fragment {
582-
assert!(!cattrs.has_flatten());
596+
assert!(
597+
!cattrs.has_flatten(),
598+
"tuples and tuple variants cannot have flatten fields"
599+
);
583600

584601
let field_count = fields
585602
.iter()
@@ -910,6 +927,7 @@ fn deserialize_struct(
910927
params: &Parameters,
911928
fields: &[Field],
912929
cattrs: &attr::Container,
930+
has_flatten: bool,
913931
form: StructForm,
914932
) -> Fragment {
915933
let this_type = &params.this_type;
@@ -958,13 +976,13 @@ fn deserialize_struct(
958976
)
959977
})
960978
.collect();
961-
let field_visitor = deserialize_field_identifier(&field_names_idents, cattrs);
979+
let field_visitor = deserialize_field_identifier(&field_names_idents, cattrs, has_flatten);
962980

963981
// untagged struct variants do not get a visit_seq method. The same applies to
964982
// structs that only have a map representation.
965983
let visit_seq = match form {
966984
StructForm::Untagged(..) => None,
967-
_ if cattrs.has_flatten() => None,
985+
_ if has_flatten => None,
968986
_ => {
969987
let mut_seq = if field_names_idents.is_empty() {
970988
quote!(_)
@@ -987,10 +1005,16 @@ fn deserialize_struct(
9871005
})
9881006
}
9891007
};
990-
let visit_map = Stmts(deserialize_map(&type_path, params, fields, cattrs));
1008+
let visit_map = Stmts(deserialize_map(
1009+
&type_path,
1010+
params,
1011+
fields,
1012+
cattrs,
1013+
has_flatten,
1014+
));
9911015

9921016
let visitor_seed = match form {
993-
StructForm::ExternallyTagged(..) if cattrs.has_flatten() => Some(quote! {
1017+
StructForm::ExternallyTagged(..) if has_flatten => Some(quote! {
9941018
impl #de_impl_generics _serde::de::DeserializeSeed<#delife> for __Visitor #de_ty_generics #where_clause {
9951019
type Value = #this_type #ty_generics;
9961020

@@ -1005,7 +1029,7 @@ fn deserialize_struct(
10051029
_ => None,
10061030
};
10071031

1008-
let fields_stmt = if cattrs.has_flatten() {
1032+
let fields_stmt = if has_flatten {
10091033
None
10101034
} else {
10111035
let field_names = field_names_idents
@@ -1025,7 +1049,7 @@ fn deserialize_struct(
10251049
}
10261050
};
10271051
let dispatch = match form {
1028-
StructForm::Struct if cattrs.has_flatten() => quote! {
1052+
StructForm::Struct if has_flatten => quote! {
10291053
_serde::Deserializer::deserialize_map(__deserializer, #visitor_expr)
10301054
},
10311055
StructForm::Struct => {
@@ -1034,7 +1058,7 @@ fn deserialize_struct(
10341058
_serde::Deserializer::deserialize_struct(__deserializer, #type_name, FIELDS, #visitor_expr)
10351059
}
10361060
}
1037-
StructForm::ExternallyTagged(_) if cattrs.has_flatten() => quote! {
1061+
StructForm::ExternallyTagged(_) if has_flatten => quote! {
10381062
_serde::de::VariantAccess::newtype_variant_seed(__variant, #visitor_expr)
10391063
},
10401064
StructForm::ExternallyTagged(_) => quote! {
@@ -1116,7 +1140,7 @@ fn deserialize_struct_in_place(
11161140
})
11171141
.collect();
11181142

1119-
let field_visitor = deserialize_field_identifier(&field_names_idents, cattrs);
1143+
let field_visitor = deserialize_field_identifier(&field_names_idents, cattrs, false);
11201144

11211145
let mut_seq = if field_names_idents.is_empty() {
11221146
quote!(_)
@@ -1210,10 +1234,7 @@ fn deserialize_homogeneous_enum(
12101234
}
12111235
}
12121236

1213-
fn prepare_enum_variant_enum(
1214-
variants: &[Variant],
1215-
cattrs: &attr::Container,
1216-
) -> (TokenStream, Stmts) {
1237+
fn prepare_enum_variant_enum(variants: &[Variant]) -> (TokenStream, Stmts) {
12171238
let mut deserialized_variants = variants
12181239
.iter()
12191240
.enumerate()
@@ -1247,7 +1268,7 @@ fn prepare_enum_variant_enum(
12471268

12481269
let variant_visitor = Stmts(deserialize_generated_identifier(
12491270
&variant_names_idents,
1250-
cattrs,
1271+
false, // variant identifiers does not depend on the presence of flatten fields
12511272
true,
12521273
None,
12531274
fallthrough,
@@ -1270,7 +1291,7 @@ fn deserialize_externally_tagged_enum(
12701291
let expecting = format!("enum {}", params.type_name());
12711292
let expecting = cattrs.expecting().unwrap_or(&expecting);
12721293

1273-
let (variants_stmt, variant_visitor) = prepare_enum_variant_enum(variants, cattrs);
1294+
let (variants_stmt, variant_visitor) = prepare_enum_variant_enum(variants);
12741295

12751296
// Match arms to extract a variant from a string
12761297
let variant_arms = variants
@@ -1355,7 +1376,7 @@ fn deserialize_internally_tagged_enum(
13551376
cattrs: &attr::Container,
13561377
tag: &str,
13571378
) -> Fragment {
1358-
let (variants_stmt, variant_visitor) = prepare_enum_variant_enum(variants, cattrs);
1379+
let (variants_stmt, variant_visitor) = prepare_enum_variant_enum(variants);
13591380

13601381
// Match arms to extract a variant from a string
13611382
let variant_arms = variants
@@ -1409,7 +1430,7 @@ fn deserialize_adjacently_tagged_enum(
14091430
split_with_de_lifetime(params);
14101431
let delife = params.borrowed.de_lifetime();
14111432

1412-
let (variants_stmt, variant_visitor) = prepare_enum_variant_enum(variants, cattrs);
1433+
let (variants_stmt, variant_visitor) = prepare_enum_variant_enum(variants);
14131434

14141435
let variant_arms: &Vec<_> = &variants
14151436
.iter()
@@ -1810,12 +1831,14 @@ fn deserialize_externally_tagged_variant(
18101831
params,
18111832
&variant.fields,
18121833
cattrs,
1834+
variant.attrs.has_flatten(),
18131835
TupleForm::ExternallyTagged(variant_ident),
18141836
),
18151837
Style::Struct => deserialize_struct(
18161838
params,
18171839
&variant.fields,
18181840
cattrs,
1841+
variant.attrs.has_flatten(),
18191842
StructForm::ExternallyTagged(variant_ident),
18201843
),
18211844
}
@@ -1859,6 +1882,7 @@ fn deserialize_internally_tagged_variant(
18591882
params,
18601883
&variant.fields,
18611884
cattrs,
1885+
variant.attrs.has_flatten(),
18621886
StructForm::InternallyTagged(variant_ident, deserializer),
18631887
),
18641888
Style::Tuple => unreachable!("checked in serde_derive_internals"),
@@ -1909,12 +1933,14 @@ fn deserialize_untagged_variant(
19091933
params,
19101934
&variant.fields,
19111935
cattrs,
1936+
variant.attrs.has_flatten(),
19121937
TupleForm::Untagged(variant_ident, deserializer),
19131938
),
19141939
Style::Struct => deserialize_struct(
19151940
params,
19161941
&variant.fields,
19171942
cattrs,
1943+
variant.attrs.has_flatten(),
19181944
StructForm::Untagged(variant_ident, deserializer),
19191945
),
19201946
}
@@ -1985,7 +2011,7 @@ fn deserialize_untagged_newtype_variant(
19852011

19862012
fn deserialize_generated_identifier(
19872013
fields: &[(&str, Ident, &BTreeSet<String>)],
1988-
cattrs: &attr::Container,
2014+
has_flatten: bool,
19892015
is_variant: bool,
19902016
ignore_variant: Option<TokenStream>,
19912017
fallthrough: Option<TokenStream>,
@@ -1999,11 +2025,11 @@ fn deserialize_generated_identifier(
19992025
is_variant,
20002026
fallthrough,
20012027
None,
2002-
!is_variant && cattrs.has_flatten(),
2028+
!is_variant && has_flatten,
20032029
None,
20042030
));
20052031

2006-
let lifetime = if !is_variant && cattrs.has_flatten() {
2032+
let lifetime = if !is_variant && has_flatten {
20072033
Some(quote!(<'de>))
20082034
} else {
20092035
None
@@ -2043,8 +2069,9 @@ fn deserialize_generated_identifier(
20432069
fn deserialize_field_identifier(
20442070
fields: &[(&str, Ident, &BTreeSet<String>)],
20452071
cattrs: &attr::Container,
2072+
has_flatten: bool,
20462073
) -> Stmts {
2047-
let (ignore_variant, fallthrough) = if cattrs.has_flatten() {
2074+
let (ignore_variant, fallthrough) = if has_flatten {
20482075
let ignore_variant = quote!(__other(_serde::__private::de::Content<'de>),);
20492076
let fallthrough = quote!(_serde::__private::Ok(__Field::__other(__value)));
20502077
(Some(ignore_variant), Some(fallthrough))
@@ -2058,7 +2085,7 @@ fn deserialize_field_identifier(
20582085

20592086
Stmts(deserialize_generated_identifier(
20602087
fields,
2061-
cattrs,
2088+
has_flatten,
20622089
false,
20632090
ignore_variant,
20642091
fallthrough,
@@ -2460,6 +2487,7 @@ fn deserialize_map(
24602487
params: &Parameters,
24612488
fields: &[Field],
24622489
cattrs: &attr::Container,
2490+
has_flatten: bool,
24632491
) -> Fragment {
24642492
// Create the field names for the fields.
24652493
let fields_names: Vec<_> = fields
@@ -2480,9 +2508,6 @@ fn deserialize_map(
24802508
});
24812509

24822510
// Collect contents for flatten fields into a buffer
2483-
let has_flatten = fields
2484-
.iter()
2485-
.any(|field| field.attrs.flatten() && !field.attrs.skip_deserializing());
24862511
let let_collect = if has_flatten {
24872512
Some(quote! {
24882513
let mut __collect = _serde::__private::Vec::<_serde::__private::Option<(
@@ -2681,7 +2706,10 @@ fn deserialize_map_in_place(
26812706
fields: &[Field],
26822707
cattrs: &attr::Container,
26832708
) -> Fragment {
2684-
assert!(!cattrs.has_flatten());
2709+
assert!(
2710+
!cattrs.has_flatten(),
2711+
"inplace deserialization of maps doesn't support flatten fields"
2712+
);
26852713

26862714
// Create the field names for the fields.
26872715
let fields_names: Vec<_> = fields

serde_derive/src/internals/ast.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ impl<'a> Container<'a> {
8585
for field in &mut variant.fields {
8686
if field.attrs.flatten() {
8787
has_flatten = true;
88+
variant.attrs.mark_has_flatten();
8889
}
8990
field.attrs.rename_by_rules(
9091
variant

serde_derive/src/internals/attr.rs

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,22 @@ pub struct Container {
216216
type_into: Option<syn::Type>,
217217
remote: Option<syn::Path>,
218218
identifier: Identifier,
219+
/// `true` if container is a `struct` and it has a field with `#[serde(flatten)]`
220+
/// attribute or it is an `enum` with a struct variant which has a field with
221+
/// `#[serde(flatten)]` attribute. Examples:
222+
///
223+
/// ```ignore
224+
/// struct Container {
225+
/// #[serde(flatten)]
226+
/// some_field: (),
227+
/// }
228+
/// enum Container {
229+
/// Variant {
230+
/// #[serde(flatten)]
231+
/// some_field: (),
232+
/// },
233+
/// }
234+
/// ```
219235
has_flatten: bool,
220236
serde_path: Option<syn::Path>,
221237
is_packed: bool,
@@ -794,6 +810,18 @@ pub struct Variant {
794810
rename_all_rules: RenameAllRules,
795811
ser_bound: Option<Vec<syn::WherePredicate>>,
796812
de_bound: Option<Vec<syn::WherePredicate>>,
813+
/// `true` if variant is a struct variant which contains a field with `#[serde(flatten)]`
814+
/// attribute. Examples:
815+
///
816+
/// ```ignore
817+
/// enum Enum {
818+
/// Variant {
819+
/// #[serde(flatten)]
820+
/// some_field: (),
821+
/// },
822+
/// }
823+
/// ```
824+
has_flatten: bool,
797825
skip_deserializing: bool,
798826
skip_serializing: bool,
799827
other: bool,
@@ -963,6 +991,7 @@ impl Variant {
963991
},
964992
ser_bound: ser_bound.get(),
965993
de_bound: de_bound.get(),
994+
has_flatten: false,
966995
skip_deserializing: skip_deserializing.get(),
967996
skip_serializing: skip_serializing.get(),
968997
other: other.get(),
@@ -1005,6 +1034,14 @@ impl Variant {
10051034
self.de_bound.as_ref().map(|vec| &vec[..])
10061035
}
10071036

1037+
pub fn has_flatten(&self) -> bool {
1038+
self.has_flatten
1039+
}
1040+
1041+
pub fn mark_has_flatten(&mut self) {
1042+
self.has_flatten = true;
1043+
}
1044+
10081045
pub fn skip_deserializing(&self) -> bool {
10091046
self.skip_deserializing
10101047
}

0 commit comments

Comments
 (0)