Skip to content

Commit 5bc9c34

Browse files
committed
Add groupMap and groupMapReduce extensions
1 parent 83c9cef commit 5bc9c34

File tree

4 files changed

+60
-1
lines changed

4 files changed

+60
-1
lines changed

compat/src/main/scala-2.11/scala/collection/compat/package.scala

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,11 @@
1212

1313
package scala.collection
1414

15-
package object compat extends compat.PackageShared
15+
import scala.collection.generic.IsTraversableLike
16+
17+
package object compat extends compat.PackageShared {
18+
implicit def toTraversableLikeExtensionMethods[Repr](self: Repr)(
19+
implicit traversable: IsTraversableLike[Repr])
20+
: TraversableLikeExtensionMethods[traversable.A, Repr] =
21+
new TraversableLikeExtensionMethods[traversable.A, Repr](traversable.conversion(self))
22+
}

compat/src/main/scala-2.11_2.12/scala/collection/compat/PackageShared.scala

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ private[compat] trait PackageShared {
3030
*/
3131
type Factory[-A, +C] = CanBuildFrom[Nothing, A, C]
3232

33+
type IsTraversableLikeAux[A1, Repr] = IsTraversableLike[Repr] { type A = A1 }
34+
3335
implicit class FactoryOps[-A, +C](private val factory: Factory[A, C]) {
3436

3537
/**
@@ -243,6 +245,36 @@ class TraversableExtensionMethods[A](private val self: c.Traversable[A]) extends
243245
def iterableFactory: GenericCompanion[Traversable] = self.companion
244246
}
245247

248+
class TraversableLikeExtensionMethods[A, Repr](private val self: c.GenTraversableLike[A, Repr])
249+
extends AnyVal {
250+
251+
def groupMap[K, B, That](key: A => K)(f: A => B)(
252+
implicit bf: CanBuildFrom[Repr, B, That]): Map[K, That] = {
253+
val map = m.Map.empty[K, m.Builder[B, That]]
254+
for (elem <- self) {
255+
val k = key(elem)
256+
val bldr = map.getOrElseUpdate(k, bf(self.repr))
257+
bldr += f(elem)
258+
}
259+
val res = Map.newBuilder[K, That]
260+
for ((k, bldr) <- map) res += ((k, bldr.result()))
261+
res.result()
262+
}
263+
264+
def groupMapReduce[K, B](key: A => K)(f: A => B)(reduce: (B, B) => B): Map[K, B] = {
265+
val map = m.Map.empty[K, B]
266+
for (elem <- self) {
267+
val k = key(elem)
268+
val v = map.get(k) match {
269+
case Some(b) => reduce(b, f(elem))
270+
case None => f(elem)
271+
}
272+
map.put(k, v)
273+
}
274+
map.toMap
275+
}
276+
}
277+
246278
class MapViewExtensionMethods[K, V, C <: scala.collection.Map[K, V]](
247279
private val self: IterableView[(K, V), C])
248280
extends AnyVal {

compat/src/main/scala-2.12/scala/collection/compat/package.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
package scala.collection
1414

15+
import scala.collection.generic.IsTraversableLike
1516
import scala.collection.{mutable => m}
1617

1718
package object compat extends compat.PackageShared {
@@ -24,4 +25,9 @@ package object compat extends compat.PackageShared {
2425
def from[K: Ordering, V](source: TraversableOnce[(K, V)]): m.SortedMap[K, V] =
2526
build(m.SortedMap.newBuilder[K, V], source)
2627
}
28+
29+
implicit def toTraversableLikeExtensionMethods[Repr](self: Repr)(
30+
implicit traversable: IsTraversableLike[Repr])
31+
: TraversableLikeExtensionMethods[traversable.A, Repr] =
32+
new TraversableLikeExtensionMethods[traversable.A, Repr](traversable.conversion(self))
2733
}

compat/src/test/scala/test/scala/collection/CollectionTest.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,4 +81,18 @@ class CollectionTest {
8181
assertFalse(it1.iterator.sameElements(it2))
8282
assertTrue(it2.iterator.sameElements(it3))
8383
}
84+
85+
@Test
86+
def groupMap(): Unit = {
87+
val res = Seq("foo", "test", "bar", "baz")
88+
.groupMap(_.length)(_.toUpperCase())
89+
assertEquals(Map(3 -> Seq("FOO", "BAR", "BAZ"), 4 -> Seq("TEST")), res)
90+
}
91+
92+
@Test
93+
def groupMapReduce(): Unit = {
94+
val res = Seq("foo", "test", "bar", "baz")
95+
.groupMapReduce(_.length)(_ => 1)(_ + _)
96+
assertEquals(Map(3 -> 3, 4 -> 1), res)
97+
}
8498
}

0 commit comments

Comments
 (0)