Skip to content

Disallow duplicated enum variant indices #628

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 16 commits into from
Closed
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ The derive implementation supports the following attributes:
- `codec(encoded_as = "OtherType")`: Needs to be placed above a field and makes the field being
encoded by using `OtherType`.
- `codec(index = 0)`: Needs to be placed above an enum variant to make the variant use the given
index when encoded. By default the index is determined by counting from `0` beginning wth the
index when encoded. By default the index is determined by counting from `0` beginning with the
first variant.
- `codec(encode_bound)`, `codec(decode_bound)` and `codec(mel_bound)`: All 3 attributes take
in a `where` clause for the `Encode`, `Decode` and `MaxEncodedLen` trait implementation for
Expand Down
23 changes: 16 additions & 7 deletions derive/src/decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,18 @@ pub fn quote(
Ok(variants) => variants,
Err(e) => return e.to_compile_error(),
};

let recurse = variants.iter().enumerate().map(|(i, v)| {
match utils::check_indexes(variants.iter()).map_err(|e| e.to_compile_error()) {
Ok(()) => (),
Err(e) => return e,
};
let mut items = vec![];
for (index, v) in variants.iter().enumerate() {
let name = &v.ident;
let index = utils::variant_index(v, i);
let index = match utils::variant_index(v, index).map_err(|e| e.into_compile_error())
{
Ok(i) => i,
Err(e) => return e,
};

let create = create_instance(
quote! { #type_name #type_generics :: #name },
Expand All @@ -57,7 +65,7 @@ pub fn quote(
crate_path,
);

quote_spanned! { v.span() =>
let item = quote_spanned! { v.span() =>
#[allow(clippy::unnecessary_cast)]
__codec_x_edqy if __codec_x_edqy == #index as ::core::primitive::u8 => {
// NOTE: This lambda is necessary to work around an upstream bug
Expand All @@ -68,8 +76,9 @@ pub fn quote(
#create
})();
},
}
});
};
items.push(item);
}

let read_byte_err_msg =
format!("Could not decode `{type_name}`, failed to read variant byte");
Expand All @@ -79,7 +88,7 @@ pub fn quote(
match #input.read_byte()
.map_err(|e| e.chain(#read_byte_err_msg))?
{
#( #recurse )*
#( #items )*
_ => {
#[allow(clippy::redundant_closure_call)]
return (move || {
Expand Down
27 changes: 17 additions & 10 deletions derive/src/encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -306,12 +306,19 @@ fn impl_encode(data: &Data, type_name: &Ident, crate_path: &syn::Path) -> TokenS
if variants.is_empty() {
return quote!();
}

let recurse = variants.iter().enumerate().map(|(i, f)| {
match utils::check_indexes(variants.iter()).map_err(|e| e.to_compile_error()) {
Ok(()) => (),
Err(e) => return e,
};
let mut items = vec![];
for (index, f) in variants.iter().enumerate() {
let name = &f.ident;
let index = utils::variant_index(f, i);

match f.fields {
let index = match utils::variant_index(f, index).map_err(|e| e.into_compile_error())
{
Ok(i) => i,
Err(e) => return e,
};
let item = match f.fields {
Fields::Named(ref fields) => {
let fields = &fields.named;
let field_name = |_, ident: &Option<Ident>| quote!(#ident);
Expand Down Expand Up @@ -389,12 +396,12 @@ fn impl_encode(data: &Data, type_name: &Ident, crate_path: &syn::Path) -> TokenS

[hinting, encoding]
},
}
});

let recurse_hinting = recurse.clone().map(|[hinting, _]| hinting);
let recurse_encoding = recurse.clone().map(|[_, encoding]| encoding);
};
items.push(item)
}

let recurse_hinting = items.iter().map(|[hinting, _]| hinting);
let recurse_encoding = items.iter().map(|[_, encoding]| encoding);
let hinting = quote! {
// The variant index uses 1 byte.
1_usize + match *#self_ {
Expand Down
4 changes: 2 additions & 2 deletions derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ fn wrap_with_dummy_const(
/// * if variant has attribute: `#[codec(index = "$n")]` then n
/// * else if variant has discriminant (like 3 in `enum T { A = 3 }`) then the discriminant.
/// * else its position in the variant set, excluding skipped variants, but including variant with
/// discriminant or attribute. Warning this position does collision with discriminant or attribute
/// index.
/// discriminant or attribute. Warning this position does collision with discriminant or attribute
/// index.
///
/// variant attributes:
/// * `#[codec(skip)]`: the variant is not encoded.
Expand Down
58 changes: 46 additions & 12 deletions derive/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
//! NOTE: attributes finder must be checked using check_attribute first,
//! otherwise the macro can panic.

use std::str::FromStr;
use std::{collections::HashMap, str::FromStr};

use proc_macro2::TokenStream;
use proc_macro2::{Span, TokenStream};
use quote::quote;
use syn::{
parse::Parse, punctuated::Punctuated, spanned::Spanned, token, Attribute, Data, DataEnum,
Expand All @@ -38,11 +38,29 @@ where
})
}

/// check usage of variant indexes with #[scale(index = $int)] attribute or
/// explicit discriminant on the variant
pub fn check_indexes<'a, I: Iterator<Item = &'a &'a Variant>>(values: I) -> syn::Result<()> {
let mut map: HashMap<u8, Span> = HashMap::new();
for (i, v) in values.enumerate() {
let index = variant_index(v, i)?;
if let Some(span) = map.insert(index, v.span()) {
let mut error = syn::Error::new(
v.span(),
"scale codec error: Invalid variant index, the variant index is duplicated.",
);
error.combine(syn::Error::new(span, "Variant index used here."));
return Err(error);
}
}
Ok(())
}

/// Look for a `#[scale(index = $int)]` attribute on a variant. If no attribute
/// is found, fall back to the discriminant or just the variant index.
pub fn variant_index(v: &Variant, i: usize) -> TokenStream {
pub fn variant_index(v: &Variant, index: usize) -> syn::Result<u8> {
// first look for an attribute
let index = find_meta_item(v.attrs.iter(), |meta| {
let codec_index = find_meta_item(v.attrs.iter(), |meta| {
if let Meta::NameValue(ref nv) = meta {
if nv.path.is_ident("index") {
if let Expr::Lit(ExprLit { lit: Lit::Int(ref v), .. }) = nv.value {
Expand All @@ -56,14 +74,30 @@ pub fn variant_index(v: &Variant, i: usize) -> TokenStream {

None
});

// then fallback to discriminant or just index
index.map(|i| quote! { #i }).unwrap_or_else(|| {
v.discriminant
.as_ref()
.map(|(_, expr)| quote! { #expr })
.unwrap_or_else(|| quote! { #i })
})
if let Some(index) = codec_index {
Ok(index)
} else {
match v.discriminant.as_ref() {
Some((_, syn::Expr::Lit(ExprLit { lit: syn::Lit::Int(v), .. }))) => {
let byte = v.base10_parse::<u8>().expect(
"scale codec error: Invalid variant index, discriminant doesn't fit u8.",
);
Ok(byte)
},
Some((_, expr)) => Err(syn::Error::new(
expr.span(),
"scale codec error: Invalid discriminant, only int literal are accepted, e.g. \
`= 32`.",
)),
None => index.try_into().map_err(|_| {
syn::Error::new(
v.span(),
"scale codec error: Variant index is too large, only 256 variants are \
supported.",
)
}),
}
}
}

/// Look for a `#[codec(encoded_as = "SomeType")]` outer attribute on the given
Expand Down
17 changes: 17 additions & 0 deletions tests/scale_codec_ui/codec_duplicate_index.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#[derive(::parity_scale_codec::Decode, ::parity_scale_codec::Encode)]
#[codec(crate = ::parity_scale_codec)]
enum T {
A = 3,
#[codec(index = 3)]
B,
}

#[derive(::parity_scale_codec::Decode, ::parity_scale_codec::Encode)]
#[codec(crate = ::parity_scale_codec)]
enum T1 {
A,
#[codec(index = 0)]
B,
}

fn main() {}
23 changes: 23 additions & 0 deletions tests/scale_codec_ui/codec_duplicate_index.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
error: scale codec error: Invalid variant index, the variant index is duplicated.
--> tests/scale_codec_ui/codec_duplicate_index.rs:5:2
|
5 | #[codec(index = 3)]
| ^

error: Variant index used here.
--> tests/scale_codec_ui/codec_duplicate_index.rs:4:2
|
4 | A = 3,
| ^

error: scale codec error: Invalid variant index, the variant index is duplicated.
--> tests/scale_codec_ui/codec_duplicate_index.rs:13:2
|
13 | #[codec(index = 0)]
| ^

error: Variant index used here.
--> tests/scale_codec_ui/codec_duplicate_index.rs:12:2
|
12 | A,
| ^
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#[derive(::parity_scale_codec::Decode, ::parity_scale_codec::Encode)]
#[codec(crate = ::parity_scale_codec)]
enum T {
A = 1,
B,
}

#[derive(::parity_scale_codec::Decode, ::parity_scale_codec::Encode)]
#[codec(crate = ::parity_scale_codec)]
enum T2 {
#[codec(index = 1)]
A,
B,
}

fn main() {}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
error: scale codec error: Invalid variant index, the variant index is duplicated.
--> tests/scale_codec_ui/discriminant_variant_counted_in_default_index.rs:5:2
|
5 | B,
| ^

error: Variant index used here.
--> tests/scale_codec_ui/discriminant_variant_counted_in_default_index.rs:4:2
|
4 | A = 1,
| ^

error: scale codec error: Invalid variant index, the variant index is duplicated.
--> tests/scale_codec_ui/discriminant_variant_counted_in_default_index.rs:13:2
|
13 | B,
| ^

error: Variant index used here.
--> tests/scale_codec_ui/discriminant_variant_counted_in_default_index.rs:11:2
|
11 | #[codec(index = 1)]
| ^
22 changes: 6 additions & 16 deletions tests/variant_number.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,6 @@
use parity_scale_codec::Encode;
use parity_scale_codec_derive::Encode as DeriveEncode;

#[test]
fn discriminant_variant_counted_in_default_index() {
#[derive(DeriveEncode)]
enum T {
A = 1,
B,
}

assert_eq!(T::A.encode(), vec![1]);
assert_eq!(T::B.encode(), vec![1]);
}

#[test]
fn skipped_variant_not_counted_in_default_index() {
#[derive(DeriveEncode)]
Expand All @@ -27,14 +15,16 @@ fn skipped_variant_not_counted_in_default_index() {
}

#[test]
fn index_attr_variant_counted_and_reused_in_default_index() {
fn index_attr_variant_duplicates_indices() {
// Tests codec index overriding and that variant indexes are without duplicates
#[derive(DeriveEncode)]
enum T {
#[codec(index = 0)]
A = 1,
#[codec(index = 1)]
A,
B,
B = 0,
}

assert_eq!(T::A.encode(), vec![1]);
assert_eq!(T::A.encode(), vec![0]);
assert_eq!(T::B.encode(), vec![1]);
}
Loading