Skip to content

Commit f8d2b78

Browse files
committed
Add HashSet::drain_filter method
Fixes #178.
1 parent fd03e12 commit f8d2b78

File tree

3 files changed

+124
-19
lines changed

3 files changed

+124
-19
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](http://semver.org/).
77

88
## [Unreleased]
99

10+
### Added
11+
- Added a `drain_filter` function to `HashSet`. (#179)
12+
1013
## [v0.8.0] - 2020-06-18
1114

1215
### Fixed

src/map.rs

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -609,8 +609,10 @@ impl<K, V, S> HashMap<K, V, S> {
609609
{
610610
DrainFilter {
611611
f,
612-
iter: unsafe { self.table.iter() },
613-
table: &mut self.table,
612+
inner: DrainFilterInner {
613+
iter: unsafe { self.table.iter() },
614+
table: &mut self.table,
615+
},
614616
}
615617
}
616618

@@ -1331,45 +1333,55 @@ where
13311333
F: FnMut(&K, &mut V) -> bool,
13321334
{
13331335
f: F,
1334-
iter: RawIter<(K, V)>,
1335-
table: &'a mut RawTable<(K, V)>,
1336+
inner: DrainFilterInner<'a, K, V>,
13361337
}
13371338

13381339
impl<'a, K, V, F> Drop for DrainFilter<'a, K, V, F>
13391340
where
13401341
F: FnMut(&K, &mut V) -> bool,
13411342
{
13421343
fn drop(&mut self) {
1343-
struct DropGuard<'r, 'a, K, V, F>(&'r mut DrainFilter<'a, K, V, F>)
1344-
where
1345-
F: FnMut(&K, &mut V) -> bool;
1346-
1347-
impl<'r, 'a, K, V, F> Drop for DropGuard<'r, 'a, K, V, F>
1348-
where
1349-
F: FnMut(&K, &mut V) -> bool,
1350-
{
1351-
fn drop(&mut self) {
1352-
while let Some(_) = self.0.next() {}
1353-
}
1354-
}
13551344
while let Some(item) = self.next() {
1356-
let guard = DropGuard(self);
1345+
let guard = ConsumeAllOnDrop(self);
13571346
drop(item);
13581347
mem::forget(guard);
13591348
}
13601349
}
13611350
}
13621351

1352+
pub(super) struct ConsumeAllOnDrop<'a, T: Iterator>(pub &'a mut T);
1353+
1354+
impl<T: Iterator> Drop for ConsumeAllOnDrop<'_, T> {
1355+
fn drop(&mut self) {
1356+
self.0.for_each(drop)
1357+
}
1358+
}
1359+
13631360
impl<K, V, F> Iterator for DrainFilter<'_, K, V, F>
13641361
where
13651362
F: FnMut(&K, &mut V) -> bool,
13661363
{
13671364
type Item = (K, V);
13681365
fn next(&mut self) -> Option<Self::Item> {
1366+
self.inner.next(&mut self.f)
1367+
}
1368+
}
1369+
1370+
/// Portions of `DrainFilter` shared with `set::DrainFilter`
1371+
pub(super) struct DrainFilterInner<'a, K, V> {
1372+
pub iter: RawIter<(K, V)>,
1373+
pub table: &'a mut RawTable<(K, V)>,
1374+
}
1375+
1376+
impl<K, V> DrainFilterInner<'_, K, V> {
1377+
pub(super) fn next<F>(&mut self, f: &mut F) -> Option<(K, V)>
1378+
where
1379+
F: FnMut(&K, &mut V) -> bool,
1380+
{
13691381
unsafe {
13701382
while let Some(item) = self.iter.next() {
13711383
let &mut (ref key, ref mut value) = item.as_mut();
1372-
if !(self.f)(key, value) {
1384+
if !f(key, value) {
13731385
return Some(self.table.remove(item));
13741386
}
13751387
}

src/set.rs

Lines changed: 91 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@ use core::borrow::Borrow;
44
use core::fmt;
55
use core::hash::{BuildHasher, Hash};
66
use core::iter::{Chain, FromIterator, FusedIterator};
7+
use core::mem;
78
use core::ops::{BitAnd, BitOr, BitXor, Sub};
89

9-
use super::map::{self, DefaultHashBuilder, HashMap, Keys};
10+
use super::map::{self, ConsumeAllOnDrop, DefaultHashBuilder, DrainFilterInner, HashMap, Keys};
1011

1112
// Future Optimization (FIXME!)
1213
// =============================
@@ -285,6 +286,38 @@ impl<T, S> HashSet<T, S> {
285286
self.map.retain(|k, _| f(k));
286287
}
287288

289+
/// Drains elements which are false under the given predicate,
290+
/// and returns an iterator over the removed items.
291+
///
292+
/// In other words, move all elements `e` such that `f(&e)` returns `false` out
293+
/// into another iterator.
294+
///
295+
/// When the returned DrainedFilter is dropped, the elements that don't satisfy
296+
/// the predicate are dropped from the set.
297+
///
298+
/// # Examples
299+
///
300+
/// ```
301+
/// use hashbrown::HashSet;
302+
///
303+
/// let mut set: HashSet<i32> = (0..8).collect();
304+
/// let drained = set.drain_filter(|&k| k % 2 == 0);
305+
/// assert_eq!(drained.count(), 4);
306+
/// assert_eq!(set.len(), 4);
307+
/// ```
308+
pub fn drain_filter<F>(&mut self, f: F) -> DrainFilter<'_, T, F>
309+
where
310+
F: FnMut(&T) -> bool,
311+
{
312+
DrainFilter {
313+
f,
314+
inner: DrainFilterInner {
315+
iter: unsafe { self.map.table.iter() },
316+
table: &mut self.map.table,
317+
},
318+
}
319+
}
320+
288321
/// Clears the set, removing all values.
289322
///
290323
/// # Examples
@@ -1185,6 +1218,21 @@ pub struct Drain<'a, K> {
11851218
iter: map::Drain<'a, K, ()>,
11861219
}
11871220

1221+
/// A draining iterator over entries of a `HashSet` which don't satisfy the predicate `f`.
1222+
///
1223+
/// This `struct` is created by the [`drain_filter`] method on [`HashSet`]. See its
1224+
/// documentation for more.
1225+
///
1226+
/// [`drain_filter`]: struct.HashSet.html#method.drain_filter
1227+
/// [`HashSet`]: struct.HashSet.html
1228+
pub struct DrainFilter<'a, K, F>
1229+
where
1230+
F: FnMut(&K) -> bool,
1231+
{
1232+
f: F,
1233+
inner: DrainFilterInner<'a, K, ()>,
1234+
}
1235+
11881236
/// A lazy iterator producing elements in the intersection of `HashSet`s.
11891237
///
11901238
/// This `struct` is created by the [`intersection`] method on [`HashSet`].
@@ -1365,6 +1413,31 @@ impl<K: fmt::Debug> fmt::Debug for Drain<'_, K> {
13651413
}
13661414
}
13671415

1416+
impl<'a, K, F> Drop for DrainFilter<'a, K, F>
1417+
where
1418+
F: FnMut(&K) -> bool,
1419+
{
1420+
fn drop(&mut self) {
1421+
while let Some(item) = self.next() {
1422+
let guard = ConsumeAllOnDrop(self);
1423+
drop(item);
1424+
mem::forget(guard);
1425+
}
1426+
}
1427+
}
1428+
1429+
impl<K, F> Iterator for DrainFilter<'_, K, F>
1430+
where
1431+
F: FnMut(&K) -> bool,
1432+
{
1433+
type Item = K;
1434+
fn next(&mut self) -> Option<Self::Item> {
1435+
let f = &mut self.f;
1436+
let (k, _) = self.inner.next(&mut |k, _| f(k))?;
1437+
Some(k)
1438+
}
1439+
}
1440+
13681441
impl<T, S> Clone for Intersection<'_, T, S> {
13691442
#[cfg_attr(feature = "inline-more", inline)]
13701443
fn clone(&self) -> Self {
@@ -1973,4 +2046,21 @@ mod test_set {
19732046
assert!(set.contains(&4));
19742047
assert!(set.contains(&6));
19752048
}
2049+
2050+
#[test]
2051+
fn test_drain_filter() {
2052+
{
2053+
let mut set: HashSet<i32> = (0..8).collect();
2054+
let drained = set.drain_filter(|&k| k % 2 == 0);
2055+
let mut out = drained.collect::<Vec<_>>();
2056+
out.sort_unstable();
2057+
assert_eq!(vec![1, 3, 5, 7], out);
2058+
assert_eq!(set.len(), 4);
2059+
}
2060+
{
2061+
let mut set: HashSet<i32> = (0..8).collect();
2062+
drop(set.drain_filter(|&k| k % 2 == 0));
2063+
assert_eq!(set.len(), 4, "Removes non-matching items on drop");
2064+
}
2065+
}
19762066
}

0 commit comments

Comments
 (0)