Skip to content

Implement get_disjoint_mut (previously get_many_mut) #238

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -269,3 +269,33 @@ impl core::fmt::Display for TryReserveError {
#[cfg(feature = "std")]
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
impl std::error::Error for TryReserveError {}

// NOTE: This is copied from the slice module in the std lib.
/// The error type returned by [`get_disjoint_indices_mut`][`IndexMap::get_disjoint_indices_mut`].
///
/// It indicates one of two possible errors:
/// - An index is out-of-bounds.
/// - The same index appeared multiple times in the array.
// (or different but overlapping indices when ranges are provided)
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum GetDisjointMutError {
/// An index provided was out-of-bounds for the slice.
IndexOutOfBounds,
/// Two indices provided were overlapping.
OverlappingIndices,
}

impl core::fmt::Display for GetDisjointMutError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
let msg = match self {
GetDisjointMutError::IndexOutOfBounds => "an index is out of bounds",
GetDisjointMutError::OverlappingIndices => "there were overlapping indices",
};

core::fmt::Display::fmt(msg, f)
}
}

#[cfg(feature = "std")]
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
impl std::error::Error for GetDisjointMutError {}
45 changes: 44 additions & 1 deletion src/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ use std::collections::hash_map::RandomState;

use self::core::IndexMapCore;
use crate::util::{third, try_simplify_range};
use crate::{Bucket, Entries, Equivalent, HashValue, TryReserveError};
use crate::{Bucket, Entries, Equivalent, GetDisjointMutError, HashValue, TryReserveError};

/// A hash table where the iteration order of the key-value pairs is independent
/// of the hash values of the keys.
Expand Down Expand Up @@ -790,6 +790,32 @@ where
}
}

/// Return the values for `N` keys. If any key is duplicated, this function will panic.
///
/// # Examples
///
/// ```
/// let mut map = indexmap::IndexMap::from([(1, 'a'), (3, 'b'), (2, 'c')]);
/// assert_eq!(map.get_disjoint_mut([&2, &1]), [Some(&mut 'c'), Some(&mut 'a')]);
/// ```
pub fn get_disjoint_mut<Q, const N: usize>(&mut self, keys: [&Q; N]) -> [Option<&mut V>; N]
where
Q: ?Sized + Hash + Equivalent<K>,
{
let indices = keys.map(|key| self.get_index_of(key));
match self.as_mut_slice().get_disjoint_opt_mut(indices) {
Err(GetDisjointMutError::IndexOutOfBounds) => {
unreachable!(
"Internal error: indices should never be OOB as we got them from get_index_of"
);
}
Err(GetDisjointMutError::OverlappingIndices) => {
panic!("duplicate keys found");
}
Ok(key_values) => key_values.map(|kv_opt| kv_opt.map(|kv| kv.1)),
}
}

/// Remove the key-value pair equivalent to `key` and return
/// its value.
///
Expand Down Expand Up @@ -1196,6 +1222,23 @@ impl<K, V, S> IndexMap<K, V, S> {
Some(IndexedEntry::new(&mut self.core, index))
}

/// Get an array of `N` key-value pairs by `N` indices
///
/// Valid indices are *0 <= index < self.len()* and each index needs to be unique.
///
/// # Examples
///
/// ```
/// let mut map = indexmap::IndexMap::from([(1, 'a'), (3, 'b'), (2, 'c')]);
/// assert_eq!(map.get_disjoint_indices_mut([2, 0]), Ok([(&2, &mut 'c'), (&1, &mut 'a')]));
/// ```
pub fn get_disjoint_indices_mut<const N: usize>(
&mut self,
indices: [usize; N],
) -> Result<[(&K, &mut V); N], GetDisjointMutError> {
self.as_mut_slice().get_disjoint_mut(indices)
}

/// Returns a slice of key-value pairs in the given range of indices.
///
/// Valid indices are `0 <= index < self.len()`.
Expand Down
46 changes: 46 additions & 0 deletions src/map/slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use super::{
ValuesMut,
};
use crate::util::{slice_eq, try_simplify_range};
use crate::GetDisjointMutError;

use alloc::boxed::Box;
use alloc::vec::Vec;
Expand Down Expand Up @@ -270,6 +271,51 @@ impl<K, V> Slice<K, V> {
self.entries
.partition_point(move |a| pred(&a.key, &a.value))
}

/// Get an array of `N` key-value pairs by `N` indices
///
/// Valid indices are *0 <= index < self.len()* and each index needs to be unique.
pub fn get_disjoint_mut<const N: usize>(
&mut self,
indices: [usize; N],
) -> Result<[(&K, &mut V); N], GetDisjointMutError> {
let indices = indices.map(Some);
let key_values = self.get_disjoint_opt_mut(indices)?;
Ok(key_values.map(Option::unwrap))
}

#[allow(unsafe_code)]
pub(crate) fn get_disjoint_opt_mut<const N: usize>(
&mut self,
indices: [Option<usize>; N],
) -> Result<[Option<(&K, &mut V)>; N], GetDisjointMutError> {
// SAFETY: Can't allow duplicate indices as we would return several mutable refs to the same data.
let len = self.len();
for i in 0..N {
if let Some(idx) = indices[i] {
if idx >= len {
return Err(GetDisjointMutError::IndexOutOfBounds);
} else if indices[..i].contains(&Some(idx)) {
return Err(GetDisjointMutError::OverlappingIndices);
}
}
}

let entries_ptr = self.entries.as_mut_ptr();
let out = indices.map(|idx_opt| {
match idx_opt {
Some(idx) => {
// SAFETY: The base pointer is valid as it comes from a slice and the reference is always
// in-bounds & unique as we've already checked the indices above.
let kv = unsafe { (*(entries_ptr.add(idx))).ref_mut() };
Some(kv)
}
None => None,
}
});

Ok(out)
}
}

impl<'a, K, V> IntoIterator for &'a Slice<K, V> {
Expand Down
178 changes: 178 additions & 0 deletions src/map/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -828,3 +828,181 @@ move_index_oob!(test_move_index_out_of_bounds_0_10, 0, 10);
move_index_oob!(test_move_index_out_of_bounds_0_max, 0, usize::MAX);
move_index_oob!(test_move_index_out_of_bounds_10_0, 10, 0);
move_index_oob!(test_move_index_out_of_bounds_max_0, usize::MAX, 0);

#[test]
fn disjoint_mut_empty_map() {
let mut map: IndexMap<u32, u32> = IndexMap::default();
assert_eq!(
map.get_disjoint_mut([&0, &1, &2, &3]),
[None, None, None, None]
);
}

#[test]
fn disjoint_mut_empty_param() {
let mut map: IndexMap<u32, u32> = IndexMap::default();
map.insert(1, 10);
assert_eq!(map.get_disjoint_mut([] as [&u32; 0]), []);
}

#[test]
fn disjoint_mut_single_fail() {
let mut map: IndexMap<u32, u32> = IndexMap::default();
map.insert(1, 10);
assert_eq!(map.get_disjoint_mut([&0]), [None]);
}

#[test]
fn disjoint_mut_single_success() {
let mut map: IndexMap<u32, u32> = IndexMap::default();
map.insert(1, 10);
assert_eq!(map.get_disjoint_mut([&1]), [Some(&mut 10)]);
}

#[test]
fn disjoint_mut_multi_success() {
let mut map: IndexMap<u32, u32> = IndexMap::default();
map.insert(1, 100);
map.insert(2, 200);
map.insert(3, 300);
map.insert(4, 400);
assert_eq!(
map.get_disjoint_mut([&1, &2]),
[Some(&mut 100), Some(&mut 200)]
);
assert_eq!(
map.get_disjoint_mut([&1, &3]),
[Some(&mut 100), Some(&mut 300)]
);
assert_eq!(
map.get_disjoint_mut([&3, &1, &4, &2]),
[
Some(&mut 300),
Some(&mut 100),
Some(&mut 400),
Some(&mut 200)
]
);
}

#[test]
fn disjoint_mut_multi_success_unsized_key() {
let mut map: IndexMap<&'static str, u32> = IndexMap::default();
map.insert("1", 100);
map.insert("2", 200);
map.insert("3", 300);
map.insert("4", 400);

assert_eq!(
map.get_disjoint_mut(["1", "2"]),
[Some(&mut 100), Some(&mut 200)]
);
assert_eq!(
map.get_disjoint_mut(["1", "3"]),
[Some(&mut 100), Some(&mut 300)]
);
assert_eq!(
map.get_disjoint_mut(["3", "1", "4", "2"]),
[
Some(&mut 300),
Some(&mut 100),
Some(&mut 400),
Some(&mut 200)
]
);
}

#[test]
fn disjoint_mut_multi_success_borrow_key() {
let mut map: IndexMap<String, u32> = IndexMap::default();
map.insert("1".into(), 100);
map.insert("2".into(), 200);
map.insert("3".into(), 300);
map.insert("4".into(), 400);

assert_eq!(
map.get_disjoint_mut(["1", "2"]),
[Some(&mut 100), Some(&mut 200)]
);
assert_eq!(
map.get_disjoint_mut(["1", "3"]),
[Some(&mut 100), Some(&mut 300)]
);
assert_eq!(
map.get_disjoint_mut(["3", "1", "4", "2"]),
[
Some(&mut 300),
Some(&mut 100),
Some(&mut 400),
Some(&mut 200)
]
);
}

#[test]
fn disjoint_mut_multi_fail_missing() {
let mut map: IndexMap<u32, u32> = IndexMap::default();
map.insert(1, 100);
map.insert(2, 200);
map.insert(3, 300);
map.insert(4, 400);

assert_eq!(map.get_disjoint_mut([&1, &5]), [Some(&mut 100), None]);
assert_eq!(map.get_disjoint_mut([&5, &6]), [None, None]);
assert_eq!(
map.get_disjoint_mut([&1, &5, &4]),
[Some(&mut 100), None, Some(&mut 400)]
);
}

#[test]
#[should_panic]
fn disjoint_mut_multi_fail_duplicate_panic() {
let mut map: IndexMap<u32, u32> = IndexMap::default();
map.insert(1, 100);
map.get_disjoint_mut([&1, &2, &1]);
}

#[test]
fn disjoint_indices_mut_fail_oob() {
let mut map: IndexMap<u32, u32> = IndexMap::default();
map.insert(1, 10);
map.insert(321, 20);
assert_eq!(
map.get_disjoint_indices_mut([1, 3]),
Err(crate::GetDisjointMutError::IndexOutOfBounds)
);
}

#[test]
fn disjoint_indices_mut_empty() {
let mut map: IndexMap<u32, u32> = IndexMap::default();
map.insert(1, 10);
map.insert(321, 20);
assert_eq!(map.get_disjoint_indices_mut([]), Ok([]));
}

#[test]
fn disjoint_indices_mut_success() {
let mut map: IndexMap<u32, u32> = IndexMap::default();
map.insert(1, 10);
map.insert(321, 20);
assert_eq!(map.get_disjoint_indices_mut([0]), Ok([(&1, &mut 10)]));

assert_eq!(map.get_disjoint_indices_mut([1]), Ok([(&321, &mut 20)]));
assert_eq!(
map.get_disjoint_indices_mut([0, 1]),
Ok([(&1, &mut 10), (&321, &mut 20)])
);
}

#[test]
fn disjoint_indices_mut_fail_duplicate() {
let mut map: IndexMap<u32, u32> = IndexMap::default();
map.insert(1, 10);
map.insert(321, 20);
assert_eq!(
map.get_disjoint_indices_mut([1, 0, 1]),
Err(crate::GetDisjointMutError::OverlappingIndices)
);
}