diff --git a/src/lib.rs b/src/lib.rs index 360636f5..61ec7e0e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -269,3 +269,33 @@ impl core::fmt::Display for TryReserveError { #[cfg(feature = "std")] #[cfg_attr(docsrs, doc(cfg(feature = "std")))] impl std::error::Error for TryReserveError {} + +// NOTE: This is copied from the slice module in the std lib. +/// The error type returned by [`get_disjoint_indices_mut`][`IndexMap::get_disjoint_indices_mut`]. +/// +/// It indicates one of two possible errors: +/// - An index is out-of-bounds. +/// - The same index appeared multiple times in the array. +// (or different but overlapping indices when ranges are provided) +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum GetDisjointMutError { + /// An index provided was out-of-bounds for the slice. + IndexOutOfBounds, + /// Two indices provided were overlapping. + OverlappingIndices, +} + +impl core::fmt::Display for GetDisjointMutError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + let msg = match self { + GetDisjointMutError::IndexOutOfBounds => "an index is out of bounds", + GetDisjointMutError::OverlappingIndices => "there were overlapping indices", + }; + + core::fmt::Display::fmt(msg, f) + } +} + +#[cfg(feature = "std")] +#[cfg_attr(docsrs, doc(cfg(feature = "std")))] +impl std::error::Error for GetDisjointMutError {} diff --git a/src/map.rs b/src/map.rs index 347649f8..79a45527 100644 --- a/src/map.rs +++ b/src/map.rs @@ -38,7 +38,7 @@ use std::collections::hash_map::RandomState; use self::core::IndexMapCore; use crate::util::{third, try_simplify_range}; -use crate::{Bucket, Entries, Equivalent, HashValue, TryReserveError}; +use crate::{Bucket, Entries, Equivalent, GetDisjointMutError, HashValue, TryReserveError}; /// A hash table where the iteration order of the key-value pairs is independent /// of the hash values of the keys. @@ -790,6 +790,32 @@ where } } + /// Return the values for `N` keys. If any key is duplicated, this function will panic. + /// + /// # Examples + /// + /// ``` + /// let mut map = indexmap::IndexMap::from([(1, 'a'), (3, 'b'), (2, 'c')]); + /// assert_eq!(map.get_disjoint_mut([&2, &1]), [Some(&mut 'c'), Some(&mut 'a')]); + /// ``` + pub fn get_disjoint_mut(&mut self, keys: [&Q; N]) -> [Option<&mut V>; N] + where + Q: ?Sized + Hash + Equivalent, + { + let indices = keys.map(|key| self.get_index_of(key)); + match self.as_mut_slice().get_disjoint_opt_mut(indices) { + Err(GetDisjointMutError::IndexOutOfBounds) => { + unreachable!( + "Internal error: indices should never be OOB as we got them from get_index_of" + ); + } + Err(GetDisjointMutError::OverlappingIndices) => { + panic!("duplicate keys found"); + } + Ok(key_values) => key_values.map(|kv_opt| kv_opt.map(|kv| kv.1)), + } + } + /// Remove the key-value pair equivalent to `key` and return /// its value. /// @@ -1196,6 +1222,23 @@ impl IndexMap { Some(IndexedEntry::new(&mut self.core, index)) } + /// Get an array of `N` key-value pairs by `N` indices + /// + /// Valid indices are *0 <= index < self.len()* and each index needs to be unique. + /// + /// # Examples + /// + /// ``` + /// let mut map = indexmap::IndexMap::from([(1, 'a'), (3, 'b'), (2, 'c')]); + /// assert_eq!(map.get_disjoint_indices_mut([2, 0]), Ok([(&2, &mut 'c'), (&1, &mut 'a')])); + /// ``` + pub fn get_disjoint_indices_mut( + &mut self, + indices: [usize; N], + ) -> Result<[(&K, &mut V); N], GetDisjointMutError> { + self.as_mut_slice().get_disjoint_mut(indices) + } + /// Returns a slice of key-value pairs in the given range of indices. /// /// Valid indices are `0 <= index < self.len()`. diff --git a/src/map/slice.rs b/src/map/slice.rs index 413aed79..035744ef 100644 --- a/src/map/slice.rs +++ b/src/map/slice.rs @@ -3,6 +3,7 @@ use super::{ ValuesMut, }; use crate::util::{slice_eq, try_simplify_range}; +use crate::GetDisjointMutError; use alloc::boxed::Box; use alloc::vec::Vec; @@ -270,6 +271,51 @@ impl Slice { self.entries .partition_point(move |a| pred(&a.key, &a.value)) } + + /// Get an array of `N` key-value pairs by `N` indices + /// + /// Valid indices are *0 <= index < self.len()* and each index needs to be unique. + pub fn get_disjoint_mut( + &mut self, + indices: [usize; N], + ) -> Result<[(&K, &mut V); N], GetDisjointMutError> { + let indices = indices.map(Some); + let key_values = self.get_disjoint_opt_mut(indices)?; + Ok(key_values.map(Option::unwrap)) + } + + #[allow(unsafe_code)] + pub(crate) fn get_disjoint_opt_mut( + &mut self, + indices: [Option; N], + ) -> Result<[Option<(&K, &mut V)>; N], GetDisjointMutError> { + // SAFETY: Can't allow duplicate indices as we would return several mutable refs to the same data. + let len = self.len(); + for i in 0..N { + if let Some(idx) = indices[i] { + if idx >= len { + return Err(GetDisjointMutError::IndexOutOfBounds); + } else if indices[..i].contains(&Some(idx)) { + return Err(GetDisjointMutError::OverlappingIndices); + } + } + } + + let entries_ptr = self.entries.as_mut_ptr(); + let out = indices.map(|idx_opt| { + match idx_opt { + Some(idx) => { + // SAFETY: The base pointer is valid as it comes from a slice and the reference is always + // in-bounds & unique as we've already checked the indices above. + let kv = unsafe { (*(entries_ptr.add(idx))).ref_mut() }; + Some(kv) + } + None => None, + } + }); + + Ok(out) + } } impl<'a, K, V> IntoIterator for &'a Slice { diff --git a/src/map/tests.rs b/src/map/tests.rs index 9de9db1b..f97f2f14 100644 --- a/src/map/tests.rs +++ b/src/map/tests.rs @@ -828,3 +828,181 @@ move_index_oob!(test_move_index_out_of_bounds_0_10, 0, 10); move_index_oob!(test_move_index_out_of_bounds_0_max, 0, usize::MAX); move_index_oob!(test_move_index_out_of_bounds_10_0, 10, 0); move_index_oob!(test_move_index_out_of_bounds_max_0, usize::MAX, 0); + +#[test] +fn disjoint_mut_empty_map() { + let mut map: IndexMap = IndexMap::default(); + assert_eq!( + map.get_disjoint_mut([&0, &1, &2, &3]), + [None, None, None, None] + ); +} + +#[test] +fn disjoint_mut_empty_param() { + let mut map: IndexMap = IndexMap::default(); + map.insert(1, 10); + assert_eq!(map.get_disjoint_mut([] as [&u32; 0]), []); +} + +#[test] +fn disjoint_mut_single_fail() { + let mut map: IndexMap = IndexMap::default(); + map.insert(1, 10); + assert_eq!(map.get_disjoint_mut([&0]), [None]); +} + +#[test] +fn disjoint_mut_single_success() { + let mut map: IndexMap = IndexMap::default(); + map.insert(1, 10); + assert_eq!(map.get_disjoint_mut([&1]), [Some(&mut 10)]); +} + +#[test] +fn disjoint_mut_multi_success() { + let mut map: IndexMap = IndexMap::default(); + map.insert(1, 100); + map.insert(2, 200); + map.insert(3, 300); + map.insert(4, 400); + assert_eq!( + map.get_disjoint_mut([&1, &2]), + [Some(&mut 100), Some(&mut 200)] + ); + assert_eq!( + map.get_disjoint_mut([&1, &3]), + [Some(&mut 100), Some(&mut 300)] + ); + assert_eq!( + map.get_disjoint_mut([&3, &1, &4, &2]), + [ + Some(&mut 300), + Some(&mut 100), + Some(&mut 400), + Some(&mut 200) + ] + ); +} + +#[test] +fn disjoint_mut_multi_success_unsized_key() { + let mut map: IndexMap<&'static str, u32> = IndexMap::default(); + map.insert("1", 100); + map.insert("2", 200); + map.insert("3", 300); + map.insert("4", 400); + + assert_eq!( + map.get_disjoint_mut(["1", "2"]), + [Some(&mut 100), Some(&mut 200)] + ); + assert_eq!( + map.get_disjoint_mut(["1", "3"]), + [Some(&mut 100), Some(&mut 300)] + ); + assert_eq!( + map.get_disjoint_mut(["3", "1", "4", "2"]), + [ + Some(&mut 300), + Some(&mut 100), + Some(&mut 400), + Some(&mut 200) + ] + ); +} + +#[test] +fn disjoint_mut_multi_success_borrow_key() { + let mut map: IndexMap = IndexMap::default(); + map.insert("1".into(), 100); + map.insert("2".into(), 200); + map.insert("3".into(), 300); + map.insert("4".into(), 400); + + assert_eq!( + map.get_disjoint_mut(["1", "2"]), + [Some(&mut 100), Some(&mut 200)] + ); + assert_eq!( + map.get_disjoint_mut(["1", "3"]), + [Some(&mut 100), Some(&mut 300)] + ); + assert_eq!( + map.get_disjoint_mut(["3", "1", "4", "2"]), + [ + Some(&mut 300), + Some(&mut 100), + Some(&mut 400), + Some(&mut 200) + ] + ); +} + +#[test] +fn disjoint_mut_multi_fail_missing() { + let mut map: IndexMap = IndexMap::default(); + map.insert(1, 100); + map.insert(2, 200); + map.insert(3, 300); + map.insert(4, 400); + + assert_eq!(map.get_disjoint_mut([&1, &5]), [Some(&mut 100), None]); + assert_eq!(map.get_disjoint_mut([&5, &6]), [None, None]); + assert_eq!( + map.get_disjoint_mut([&1, &5, &4]), + [Some(&mut 100), None, Some(&mut 400)] + ); +} + +#[test] +#[should_panic] +fn disjoint_mut_multi_fail_duplicate_panic() { + let mut map: IndexMap = IndexMap::default(); + map.insert(1, 100); + map.get_disjoint_mut([&1, &2, &1]); +} + +#[test] +fn disjoint_indices_mut_fail_oob() { + let mut map: IndexMap = IndexMap::default(); + map.insert(1, 10); + map.insert(321, 20); + assert_eq!( + map.get_disjoint_indices_mut([1, 3]), + Err(crate::GetDisjointMutError::IndexOutOfBounds) + ); +} + +#[test] +fn disjoint_indices_mut_empty() { + let mut map: IndexMap = IndexMap::default(); + map.insert(1, 10); + map.insert(321, 20); + assert_eq!(map.get_disjoint_indices_mut([]), Ok([])); +} + +#[test] +fn disjoint_indices_mut_success() { + let mut map: IndexMap = IndexMap::default(); + map.insert(1, 10); + map.insert(321, 20); + assert_eq!(map.get_disjoint_indices_mut([0]), Ok([(&1, &mut 10)])); + + assert_eq!(map.get_disjoint_indices_mut([1]), Ok([(&321, &mut 20)])); + assert_eq!( + map.get_disjoint_indices_mut([0, 1]), + Ok([(&1, &mut 10), (&321, &mut 20)]) + ); +} + +#[test] +fn disjoint_indices_mut_fail_duplicate() { + let mut map: IndexMap = IndexMap::default(); + map.insert(1, 10); + map.insert(321, 20); + assert_eq!( + map.get_disjoint_indices_mut([1, 0, 1]), + Err(crate::GetDisjointMutError::OverlappingIndices) + ); +}