Skip to content

Conversation

@juliancoffee
Copy link

@juliancoffee juliancoffee commented Mar 25, 2025

Fixes #424

As I said, it's possible if slightly complex. I'm not an expert in writing iterators, though, so maybe it's possible to cut some rough edges; I just tried to make it correct.
I tried to produce a slim diff, but DoubleEndedIterator implementation went into pieces.

Also, you can see in tests, Color::simple_iter() gives much simpler implementation, but maybe a bit slower to run and/or compile? I didn't bench it.

UPD: I think I know how to simplify this a little (without going through implementation I pointed out above), so if you're interested I'll try to refactor it a bit

@juliancoffee juliancoffee force-pushed the juliancoffee/enum-iter-flatten branch from ff6eab2 to 7b81796 Compare March 25, 2025 22:46
@juliancoffee
Copy link
Author

juliancoffee commented Mar 25, 2025

This is the manual implementation of what these macros generate. It has some dbg!() here and there, which I used while developing the algorithm. Of course none are present of them in MR code :P

#[derive(Debug, Eq, PartialEq)]
enum Vibe {
    Weak,
    Average,
    Strong,
}

impl Vibe {
    fn iter() -> <Self as IntoIterator>::IntoIter {
        let vibe = Vibe::Weak;
        vibe.into_iter()
    }
}

impl IntoIterator for Vibe {
    type Item = Vibe;
    type IntoIter = std::vec::IntoIter<Vibe>;
    fn into_iter(self) -> Self::IntoIter {
        vec![Vibe::Weak, Vibe::Average, Vibe::Strong].into_iter()
    }
}

const SHADE_NUM: usize = 5;
#[derive(Debug, Eq, PartialEq)]
enum Shade {
    Light,
    Med1(Vibe),
    Med2(Vibe),
    Med3(Vibe),
    Dark,
}

impl Shade {
    fn iter() -> ShadeIter {
        ShadeIter {
            idx: 0,
            med1_iter: Some(Vibe::iter()),
            med2_iter: Some(Vibe::iter()),
            med3_iter: Some(Vibe::iter()),
            back_idx: 0,
        }
    }
}

impl Shade {
    fn simple_iter() -> impl DoubleEndedIterator<Item = Shade> {
        vec![Shade::Light]
            .into_iter()
            .chain(Vibe::iter().map(Shade::Med1))
            .chain(Vibe::iter().map(Shade::Med2))
            .chain(Vibe::iter().map(Shade::Med3))
            .chain(vec![Shade::Dark])
    }
}

struct ShadeIter {
    idx: usize,
    med1_iter: Option<<Vibe as IntoIterator>::IntoIter>,
    med2_iter: Option<<Vibe as IntoIterator>::IntoIter>,
    med3_iter: Option<<Vibe as IntoIterator>::IntoIter>,
    back_idx: usize,
}

#[derive(Debug)]
enum Res {
    Done(Shade),
    DoneStep(Shade),
    EndStep,
    End,
}

impl ShadeIter {
    fn nested_get(
        nested_iter: &mut Option<<Vibe as IntoIterator>::IntoIter>,
        wrap: fn(<Vibe as IntoIterator>::Item) -> Shade,
        forward: bool,
    ) -> Res {
        let next_inner = if forward {
            nested_iter.as_mut().and_then(|t| t.next())
        } else {
            nested_iter.as_mut().and_then(|t| t.next_back())
        };
        if let Some(it) = next_inner {
            Res::DoneStep(wrap(it))
        } else {
            nested_iter.take();
            Res::EndStep
        }
    }

    fn get(&mut self, idx: usize, forward: bool) -> Res {
        let res = match dbg!(idx) {
            0 => Res::Done(Shade::Light),
            1 => Self::nested_get(&mut self.med1_iter, Shade::Med1, forward),
            2 => Self::nested_get(&mut self.med2_iter, Shade::Med2, forward),
            3 => Self::nested_get(&mut self.med3_iter, Shade::Med3, forward),
            4 => Res::Done(Shade::Dark),
            _ => Res::End,
        };
        dbg!(res)
    }
}

impl Iterator for ShadeIter {
    type Item = Shade;

    fn next(&mut self) -> Option<Self::Item> {
        self.nth(0)
    }

    fn nth(&mut self, n: usize) -> Option<Self::Item> {
        if self.back_idx + self.idx >= SHADE_NUM {
            return None;
        }
        match ShadeIter::get(self, dbg!(self.idx) + dbg!(n), true) {
            Res::Done(x) => {
                // move to requested, and past it
                self.idx += n + 1;
                Some(x)
            }
            Res::DoneStep(x) => {
                // move to requested, but not past it
                self.idx += n;
                Some(x)
            }
            Res::EndStep => {
                // ok, this one failed, move past it and request again
                self.idx += 1;
                let res = self.nth(0);
                res
            }
            Res::End => None,
        }
    }

    fn size_hint(&self) -> (usize, Option<usize>) {
        /*
        let min = if self.idx + self.back_idx >= SHADE_NUM {
            0
        } else {
            SHADE_NUM - self.idx - self.back_idx
        };
        */

        let med1_size = self.med1_iter.as_ref().map_or(0, |t| {
            t.len()
        });
        let med2_size = self.med2_iter.as_ref().map_or(0, |t| {
            t.len()
        });
        let med3_size = self.med3_iter.as_ref().map_or(0, |t| {
            t.len()
        });
        let t = SHADE_NUM
            + dbg!(med1_size) - self.med1_iter.as_ref().map_or(0, |_| 1)
            + dbg!(med2_size) - self.med2_iter.as_ref().map_or(0, |_| 1)
            + dbg!(med3_size) - self.med3_iter.as_ref().map_or(0, |_| 1)
            - dbg!(self.idx)
            - dbg!(self.back_idx);

        (t, Some(t))
    }
}

impl ShadeIter {
    fn nth_back(&mut self, back_n: usize) -> Option<Shade> {
        if self.back_idx + self.idx >= SHADE_NUM {
            return None;
        }

        let res = match ShadeIter::get(
            self,
            SHADE_NUM - dbg!(self.back_idx) - back_n - 1,
            false,
        ) {
            Res::Done(x) => {
                // move to requested, and past it
                self.back_idx += 1;
                Some(x)
            }
            Res::DoneStep(x) => {
                // move to requested, but not past it
                Some(x)
            }
            Res::EndStep => {
                // ok, this one failed, try the next one
                self.back_idx += 1;
                self.nth_back(0)
            }
            Res::End => None,
        };
        res
    }
}

impl DoubleEndedIterator for ShadeIter {
    fn next_back(&mut self) -> Option<Self::Item> {
        self.nth_back(0)
    }
}

impl ExactSizeIterator for ShadeIter {
    fn len(&self) -> usize {
        self.size_hint().0
    }
}

fn main() {
    println!("Hello, world!");
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn flatten() {
        let result = Shade::iter().collect::<Vec<_>>();
        let expected = vec![
            Shade::Light,
            Shade::Med1(Vibe::Weak),
            Shade::Med1(Vibe::Average),
            Shade::Med1(Vibe::Strong),
            Shade::Med2(Vibe::Weak),
            Shade::Med2(Vibe::Average),
            Shade::Med2(Vibe::Strong),
            Shade::Med3(Vibe::Weak),
            Shade::Med3(Vibe::Average),
            Shade::Med3(Vibe::Strong),
            Shade::Dark,
        ];
        assert_eq!(result, expected);
    }

    #[test]
    fn flatten_back() {
        let result = Shade::iter().rev().collect::<Vec<_>>();
        let expected = vec![
            Shade::Dark,
            Shade::Med3(Vibe::Strong),
            Shade::Med3(Vibe::Average),
            Shade::Med3(Vibe::Weak),
            Shade::Med2(Vibe::Strong),
            Shade::Med2(Vibe::Average),
            Shade::Med2(Vibe::Weak),
            Shade::Med1(Vibe::Strong),
            Shade::Med1(Vibe::Average),
            Shade::Med1(Vibe::Weak),
            Shade::Light,
        ];
        assert_eq!(result, expected);
    }

    #[test]
    fn iter_mixed_next_and_next_back() {
        let mut iter = Shade::iter();

        assert_eq!(iter.next(), Some(Shade::Light));
        assert_eq!(iter.next_back(), Some(Shade::Dark));

        assert_eq!(iter.next(), Some(Shade::Med1(Vibe::Weak)));
        assert_eq!(iter.next_back(), Some(Shade::Med3(Vibe::Strong)));

        assert_eq!(iter.next(), Some(Shade::Med1(Vibe::Average)));
        assert_eq!(iter.next_back(), Some(Shade::Med3(Vibe::Average)));

        assert_eq!(iter.next(), Some(Shade::Med1(Vibe::Strong)));
        assert_eq!(iter.next_back(), Some(Shade::Med3(Vibe::Weak)));

        assert_eq!(iter.next(), Some(Shade::Med2(Vibe::Weak)));
        assert_eq!(iter.next_back(), Some(Shade::Med2(Vibe::Strong)));

        assert_eq!(iter.next(), Some(Shade::Med2(Vibe::Average)));
        assert_eq!(iter.next_back(), None);
    }

    #[test]
    fn iter_quickheck() {
        use rand::Rng;

        let mut rng = rand::rng();
        for _ in 0..1000 {
            let mut iter = Shade::iter();
            let mut simple_iter = Shade::simple_iter();

            let mut results = vec![];
            let mut expected = vec![];
            for _ in 0..500 {
                if rng.random_bool(0.5) {
                    results.push(iter.next());
                    expected.push(simple_iter.next());
                } else {
                    results.push(iter.next_back());
                    expected.push(simple_iter.next_back());
                }
            }
            assert_eq!(results, expected);
        }
    }

    #[test]
    fn iter_quickheck_sizehint() {
        use rand::Rng;

        let mut rng = rand::rng();
        for _ in 0..1000 {
            let mut iter = Shade::iter();
            let mut simple_iter = Shade::simple_iter();

            assert_eq!(dbg!(iter.size_hint()), simple_iter.size_hint());
            for _ in 0..500 {
                if rng.random_bool(0.5) {
                    dbg!("next");
                    _ = iter.next();
                    _ = simple_iter.next();
                    assert_eq!(dbg!(iter.size_hint()), simple_iter.size_hint());
                } else {
                    dbg!("next_back");
                    _ = iter.next_back();
                    _ = simple_iter.next_back();
                    assert_eq!(dbg!(iter.size_hint()), simple_iter.size_hint());
                }
            }
        }
    }

    #[test]
    fn iter_quickheck_len() {
        use rand::Rng;

        let mut rng = rand::rng();
        for _ in 0..1000 {
            let mut iter = Shade::iter();
            const MAX: usize = 11;

            assert_eq!(dbg!(iter.len()), MAX);
            for i in 1..=MAX {
                if rng.random_bool(0.5) {
                    dbg!("next");
                    _ = iter.next();
                } else {
                    dbg!("next_back");
                    _ = iter.next_back();
                }
                assert_eq!(dbg!(iter.len()), MAX - i);
            }
        }
    }
}

Open to your comments 🙌

Copy link

@vic1707 vic1707 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, thx for this PR I also need this but didn't get the time to look into it much, you're saving me big time !
One concern I have is that the generated code isn't compatible with no_std anymore due to the vec![].
Quick testing shows that a simple array also does the trick

Updated example, nothing much changes, every vec![] is replaced by [] and
type IntoIter = std::vec::IntoIter<Vibe>; becomes type IntoIter = <[Self; 3] as core::iter::IntoIterator>::IntoIter; which could get tricky, maybe generate an associated constant containing the number of variants (<[Self; 4 + 3] as core::iter::IntoIterator>::IntoIter works) ?

#![no_std]

#[derive(Debug, Eq, PartialEq)]
enum Vibe {
    Weak,
    Average,
    Strong,
}

impl Vibe {
    fn iter() -> <Self as IntoIterator>::IntoIter {
        let vibe = Vibe::Weak;
        vibe.into_iter()
    }
}

impl IntoIterator for Vibe {
    type Item = Vibe;
    type IntoIter = <[Self; 3] as core::iter::IntoIterator>::IntoIter;
    fn into_iter(self) -> Self::IntoIter {
        [Vibe::Weak, Vibe::Average, Vibe::Strong].into_iter()
    }
}

const SHADE_NUM: usize = 5;
#[derive(Debug, Eq, PartialEq)]
enum Shade {
    Light,
    Med1(Vibe),
    Med2(Vibe),
    Med3(Vibe),
    Dark,
}

impl Shade {
    fn iter() -> ShadeIter {
        ShadeIter {
            idx: 0,
            med1_iter: Some(Vibe::iter()),
            med2_iter: Some(Vibe::iter()),
            med3_iter: Some(Vibe::iter()),
            back_idx: 0,
        }
    }
}

impl Shade {
    fn simple_iter() -> impl DoubleEndedIterator<Item = Shade> {
        [Shade::Light]
            .into_iter()
            .chain(Vibe::iter().map(Shade::Med1))
            .chain(Vibe::iter().map(Shade::Med2))
            .chain(Vibe::iter().map(Shade::Med3))
            .chain([Shade::Dark])
    }
}

struct ShadeIter {
    idx: usize,
    med1_iter: Option<<Vibe as IntoIterator>::IntoIter>,
    med2_iter: Option<<Vibe as IntoIterator>::IntoIter>,
    med3_iter: Option<<Vibe as IntoIterator>::IntoIter>,
    back_idx: usize,
}

#[derive(Debug)]
enum Res {
    Done(Shade),
    DoneStep(Shade),
    EndStep,
    End,
}

impl ShadeIter {
    fn nested_get(
        nested_iter: &mut Option<<Vibe as IntoIterator>::IntoIter>,
        wrap: fn(<Vibe as IntoIterator>::Item) -> Shade,
        forward: bool,
    ) -> Res {
        let next_inner = if forward {
            nested_iter.as_mut().and_then(|t| t.next())
        } else {
            nested_iter.as_mut().and_then(|t| t.next_back())
        };
        if let Some(it) = next_inner {
            Res::DoneStep(wrap(it))
        } else {
            nested_iter.take();
            Res::EndStep
        }
    }

    fn get(&mut self, idx: usize, forward: bool) -> Res {
        match idx {
            0 => Res::Done(Shade::Light),
            1 => Self::nested_get(&mut self.med1_iter, Shade::Med1, forward),
            2 => Self::nested_get(&mut self.med2_iter, Shade::Med2, forward),
            3 => Self::nested_get(&mut self.med3_iter, Shade::Med3, forward),
            4 => Res::Done(Shade::Dark),
            _ => Res::End,
        }
    }
}

impl Iterator for ShadeIter {
    type Item = Shade;

    fn next(&mut self) -> Option<Self::Item> {
        self.nth(0)
    }

    fn nth(&mut self, n: usize) -> Option<Self::Item> {
        if self.back_idx + self.idx >= SHADE_NUM {
            return None;
        }
        match ShadeIter::get(self, self.idx + n, true) {
            Res::Done(x) => {
                // move to requested, and past it
                self.idx += n + 1;
                Some(x)
            }
            Res::DoneStep(x) => {
                // move to requested, but not past it
                self.idx += n;
                Some(x)
            }
            Res::EndStep => {
                // ok, this one failed, move past it and request again
                self.idx += 1;
                let res = self.nth(0);
                res
            }
            Res::End => None,
        }
    }

    fn size_hint(&self) -> (usize, Option<usize>) {
        /*
        let min = if self.idx + self.back_idx >= SHADE_NUM {
            0
        } else {
            SHADE_NUM - self.idx - self.back_idx
        };
        */

        let med1_size = self.med1_iter.as_ref().map_or(0, |t| {
            t.len()
        });
        let med2_size = self.med2_iter.as_ref().map_or(0, |t| {
            t.len()
        });
        let med3_size = self.med3_iter.as_ref().map_or(0, |t| {
            t.len()
        });
        let t = SHADE_NUM
            + (med1_size) - self.med1_iter.as_ref().map_or(0, |_| 1)
            + (med2_size) - self.med2_iter.as_ref().map_or(0, |_| 1)
            + (med3_size) - self.med3_iter.as_ref().map_or(0, |_| 1)
            - (self.idx)
            - (self.back_idx);

        (t, Some(t))
    }
}

impl ShadeIter {
    fn nth_back(&mut self, back_n: usize) -> Option<Shade> {
        if self.back_idx + self.idx >= SHADE_NUM {
            return None;
        }

        let res = match ShadeIter::get(
            self,
            SHADE_NUM - self.back_idx - back_n - 1,
            false,
        ) {
            Res::Done(x) => {
                // move to requested, and past it
                self.back_idx += 1;
                Some(x)
            }
            Res::DoneStep(x) => {
                // move to requested, but not past it
                Some(x)
            }
            Res::EndStep => {
                // ok, this one failed, try the next one
                self.back_idx += 1;
                self.nth_back(0)
            }
            Res::End => None,
        };
        res
    }
}

impl DoubleEndedIterator for ShadeIter {
    fn next_back(&mut self) -> Option<Self::Item> {
        self.nth_back(0)
    }
}

impl ExactSizeIterator for ShadeIter {
    fn len(&self) -> usize {
        self.size_hint().0
    }
}

const fn main() {
    
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn flatten() {
        let result = Shade::iter().collect::<Vec<_>>();
        let expected = vec![
            Shade::Light,
            Shade::Med1(Vibe::Weak),
            Shade::Med1(Vibe::Average),
            Shade::Med1(Vibe::Strong),
            Shade::Med2(Vibe::Weak),
            Shade::Med2(Vibe::Average),
            Shade::Med2(Vibe::Strong),
            Shade::Med3(Vibe::Weak),
            Shade::Med3(Vibe::Average),
            Shade::Med3(Vibe::Strong),
            Shade::Dark,
        ];
        assert_eq!(result, expected);
    }

    #[test]
    fn flatten_back() {
        let result = Shade::iter().rev().collect::<Vec<_>>();
        let expected = vec![
            Shade::Dark,
            Shade::Med3(Vibe::Strong),
            Shade::Med3(Vibe::Average),
            Shade::Med3(Vibe::Weak),
            Shade::Med2(Vibe::Strong),
            Shade::Med2(Vibe::Average),
            Shade::Med2(Vibe::Weak),
            Shade::Med1(Vibe::Strong),
            Shade::Med1(Vibe::Average),
            Shade::Med1(Vibe::Weak),
            Shade::Light,
        ];
        assert_eq!(result, expected);
    }

    #[test]
    fn iter_mixed_next_and_next_back() {
        let mut iter = Shade::iter();

        assert_eq!(iter.next(), Some(Shade::Light));
        assert_eq!(iter.next_back(), Some(Shade::Dark));

        assert_eq!(iter.next(), Some(Shade::Med1(Vibe::Weak)));
        assert_eq!(iter.next_back(), Some(Shade::Med3(Vibe::Strong)));

        assert_eq!(iter.next(), Some(Shade::Med1(Vibe::Average)));
        assert_eq!(iter.next_back(), Some(Shade::Med3(Vibe::Average)));

        assert_eq!(iter.next(), Some(Shade::Med1(Vibe::Strong)));
        assert_eq!(iter.next_back(), Some(Shade::Med3(Vibe::Weak)));

        assert_eq!(iter.next(), Some(Shade::Med2(Vibe::Weak)));
        assert_eq!(iter.next_back(), Some(Shade::Med2(Vibe::Strong)));

        assert_eq!(iter.next(), Some(Shade::Med2(Vibe::Average)));
        assert_eq!(iter.next_back(), None);
    }

    #[test]
    fn iter_quickheck() {
        use rand::Rng;

        let mut rng = rand::rng();
        for _ in 0..1000 {
            let mut iter = Shade::iter();
            let mut simple_iter = Shade::simple_iter();

            let mut results = vec![];
            let mut expected = vec![];
            for _ in 0..500 {
                if rng.random_bool(0.5) {
                    results.push(iter.next());
                    expected.push(simple_iter.next());
                } else {
                    results.push(iter.next_back());
                    expected.push(simple_iter.next_back());
                }
            }
            assert_eq!(results, expected);
        }
    }

    #[test]
    fn iter_quickheck_sizehint() {
        use rand::Rng;

        let mut rng = rand::rng();
        for _ in 0..1000 {
            let mut iter = Shade::iter();
            let mut simple_iter = Shade::simple_iter();

            assert_eq!(dbg!(iter.size_hint()), simple_iter.size_hint());
            for _ in 0..500 {
                if rng.random_bool(0.5) {
                    dbg!("next");
                    _ = iter.next();
                    _ = simple_iter.next();
                    assert_eq!(dbg!(iter.size_hint()), simple_iter.size_hint());
                } else {
                    dbg!("next_back");
                    _ = iter.next_back();
                    _ = simple_iter.next_back();
                    assert_eq!(dbg!(iter.size_hint()), simple_iter.size_hint());
                }
            }
        }
    }

    #[test]
    fn iter_quickheck_len() {
        use rand::Rng;

        let mut rng = rand::rng();
        for _ in 0..1000 {
            let mut iter = Shade::iter();
            const MAX: usize = 11;

            assert_eq!(dbg!(iter.len()), MAX);
            for i in 1..=MAX {
                if rng.random_bool(0.5) {
                    dbg!("next");
                    _ = iter.next();
                } else {
                    dbg!("next_back");
                    _ = iter.next_back();
                }
                assert_eq!(dbg!(iter.len()), MAX - i);
            }
        }
    }
}

@juliancoffee
Copy link
Author

juliancoffee commented Mar 31, 2025

@vic1707
Oh, don't worry about that.
Both Vibe::iter() and Shade::simple_iter() are functions I added for prototyping.

Vibe::iter() was added because I needed a nested iterator, and yeah, I didn't care much about its implementation, because it wouldn't be present in "real" code.
In a real-world use case, Vibe would get the same implementation as Shade, if you place derive(EnumIter) on both.

Shade::simple_iter() is there so that I have something to compare results to without writing too many tests, so it wouldn't be present in generated code as well.

Thanks for noting that, though. I guess the drawback of simple_iter() approach is that it would be harder to get working with #[no_std], yet as you mentioned, it could work if you just replace vec with an array.

@vic1707
Copy link

vic1707 commented Mar 31, 2025

Sorry I for that misunderstanding on my part, good job, can't wait to see it land if the devs are ok 👍

custom_keyword!(default_with);
custom_keyword!(props);
custom_keyword!(ascii_case_insensitive);
custom_keyword!(flatten);
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

honestly, the biggest concern I have here is how should #[strum(flatten)] interact with other derives

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#369 (comment)

basically this thing

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add #[strum(flatten)] to EnumIter

2 participants