Skip to content

Commit 1d5fac2

Browse files
committed
Updated with DrainFilter struct
The previous implementation did not match vec's drain filter's drop semantics. As per `amanieu`'s suggestion, added a DrainFilter which implements drop so that it removes the items which don't satisfy the predicate.
1 parent 4dd50ab commit 1d5fac2

File tree

1 file changed

+74
-19
lines changed

1 file changed

+74
-19
lines changed

src/map.rs

Lines changed: 74 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -962,6 +962,9 @@ where
962962
/// In other words, move all pairs `(k, v)` such that `f(&k,&mut v)` returns `false` out
963963
/// into another iterator.
964964
///
965+
/// When the returned DrainedFilter is dropped, the elements that don't satisfy
966+
/// the predicate are dropped from the table.
967+
///
965968
/// # Examples
966969
///
967970
/// ```
@@ -972,21 +975,14 @@ where
972975
/// assert_eq!(drained.count(), 4);
973976
/// assert_eq!(map.len(), 4);
974977
/// ```
975-
pub fn drain_filter<'a, F>(&'a mut self, mut f: F) -> impl Iterator<Item = (K, V)> + '_
978+
pub fn drain_filter<F>(&mut self, f: F) -> DrainFilter<'_, K, V, F>
976979
where
977-
F: 'a + FnMut(&K, &mut V) -> bool,
980+
F: FnMut(&K, &mut V) -> bool,
978981
{
979-
// Here we only use `iter` as a temporary, preventing use-after-free
980-
unsafe {
981-
self.table.iter().filter_map(move |item| {
982-
let &mut (ref key, ref mut value) = item.as_mut();
983-
if f(key, value) {
984-
None
985-
} else {
986-
self.table.erase_no_drop(&item);
987-
Some(item.read())
988-
}
989-
})
982+
DrainFilter {
983+
f,
984+
iter: unsafe { self.table.iter() },
985+
table: &mut self.table,
990986
}
991987
}
992988
}
@@ -1269,6 +1265,58 @@ impl<K, V> Drain<'_, K, V> {
12691265
}
12701266
}
12711267

1268+
/// A draining iterator over entries of a `HashMap` which don't satisfy the predicate `f`.
1269+
///
1270+
/// This `struct` is created by the [`drain_filter`] method on [`HashMap`]. See its
1271+
/// documentation for more.
1272+
///
1273+
/// [`drain_filter`]: struct.HashMap.html#method.drain_filter
1274+
/// [`HashMap`]: struct.HashMap.html
1275+
pub struct DrainFilter<'a, K, V, F>
1276+
where
1277+
F: FnMut(&K, &mut V) -> bool,
1278+
{
1279+
f: F,
1280+
iter: RawIter<(K, V)>,
1281+
table: &'a mut RawTable<(K, V)>,
1282+
}
1283+
1284+
impl<K, V, F> Drop for DrainFilter<'_, K, V, F>
1285+
where
1286+
F: FnMut(&K, &mut V) -> bool,
1287+
{
1288+
fn drop(&mut self) {
1289+
unsafe {
1290+
while let Some(item) = self.iter.next() {
1291+
let &mut (ref key, ref mut value) = item.as_mut();
1292+
if !(self.f)(key, value) {
1293+
self.table.erase_no_drop(&item);
1294+
item.drop();
1295+
}
1296+
}
1297+
}
1298+
}
1299+
}
1300+
1301+
impl<K, V, F> Iterator for DrainFilter<'_, K, V, F>
1302+
where
1303+
F: FnMut(&K, &mut V) -> bool,
1304+
{
1305+
type Item = (K, V);
1306+
fn next(&mut self) -> Option<Self::Item> {
1307+
unsafe {
1308+
while let Some(item) = self.iter.next() {
1309+
let &mut (ref key, ref mut value) = item.as_mut();
1310+
if !(self.f)(key, value) {
1311+
self.table.erase_no_drop(&item);
1312+
return Some(item.read());
1313+
}
1314+
}
1315+
}
1316+
None
1317+
}
1318+
}
1319+
12721320
/// A mutable iterator over the values of a `HashMap`.
12731321
///
12741322
/// This `struct` is created by the [`values_mut`] method on [`HashMap`]. See its
@@ -3523,12 +3571,19 @@ mod test_map {
35233571

35243572
#[test]
35253573
fn test_drain_filter() {
3526-
let mut map: HashMap<i32, i32> = (0..8).map(|x| (x, x * 10)).collect();
3527-
let drained = map.drain_filter(|&k, _| k % 2 == 0);
3528-
let mut out = drained.collect::<Vec<_>>();
3529-
out.sort_unstable();
3530-
assert_eq!(vec![(1, 10), (3, 30), (5, 50), (7, 70)], out);
3531-
assert_eq!(map.len(), 4);
3574+
{
3575+
let mut map: HashMap<i32, i32> = (0..8).map(|x| (x, x * 10)).collect();
3576+
let drained = map.drain_filter(|&k, _| k % 2 == 0);
3577+
let mut out = drained.collect::<Vec<_>>();
3578+
out.sort_unstable();
3579+
assert_eq!(vec![(1, 10), (3, 30), (5, 50), (7, 70)], out);
3580+
assert_eq!(map.len(), 4);
3581+
}
3582+
{
3583+
let mut map: HashMap<i32, i32> = (0..8).map(|x| (x, x * 10)).collect();
3584+
drop(map.drain_filter(|&k, _| k % 2 == 0));
3585+
assert_eq!(map.len(), 4);
3586+
}
35323587
}
35333588

35343589
#[test]

0 commit comments

Comments
 (0)