Skip to content

Commit fd53f2f

Browse files
committed
Merge pull request #510 from markhamstra/WithThing
mapWith, flatMapWith and filterWith
2 parents 4c5efcf + ab33e27 commit fd53f2f

File tree

2 files changed

+125
-1
lines changed

2 files changed

+125
-1
lines changed

core/src/main/scala/spark/RDD.scala

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,62 @@ abstract class RDD[T: ClassManifest](
364364
preservesPartitioning: Boolean = false): RDD[U] =
365365
new MapPartitionsWithIndexRDD(this, sc.clean(f), preservesPartitioning)
366366

367+
/**
368+
* Maps f over this RDD, where f takes an additional parameter of type A. This
369+
* additional parameter is produced by constructA, which is called in each
370+
* partition with the index of that partition.
371+
*/
372+
def mapWith[A: ClassManifest, U: ClassManifest](constructA: Int => A, preservesPartitioning: Boolean = false)
373+
(f:(T, A) => U): RDD[U] = {
374+
def iterF(index: Int, iter: Iterator[T]): Iterator[U] = {
375+
val a = constructA(index)
376+
iter.map(t => f(t, a))
377+
}
378+
new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), preservesPartitioning)
379+
}
380+
381+
/**
382+
* FlatMaps f over this RDD, where f takes an additional parameter of type A. This
383+
* additional parameter is produced by constructA, which is called in each
384+
* partition with the index of that partition.
385+
*/
386+
def flatMapWith[A: ClassManifest, U: ClassManifest](constructA: Int => A, preservesPartitioning: Boolean = false)
387+
(f:(T, A) => Seq[U]): RDD[U] = {
388+
def iterF(index: Int, iter: Iterator[T]): Iterator[U] = {
389+
val a = constructA(index)
390+
iter.flatMap(t => f(t, a))
391+
}
392+
new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), preservesPartitioning)
393+
}
394+
395+
/**
396+
* Applies f to each element of this RDD, where f takes an additional parameter of type A.
397+
* This additional parameter is produced by constructA, which is called in each
398+
* partition with the index of that partition.
399+
*/
400+
def foreachWith[A: ClassManifest](constructA: Int => A)
401+
(f:(T, A) => Unit) {
402+
def iterF(index: Int, iter: Iterator[T]): Iterator[T] = {
403+
val a = constructA(index)
404+
iter.map(t => {f(t, a); t})
405+
}
406+
(new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), true)).foreach(_ => {})
407+
}
408+
409+
/**
410+
* Filters this RDD with p, where p takes an additional parameter of type A. This
411+
* additional parameter is produced by constructA, which is called in each
412+
* partition with the index of that partition.
413+
*/
414+
def filterWith[A: ClassManifest](constructA: Int => A)
415+
(p:(T, A) => Boolean): RDD[T] = {
416+
def iterF(index: Int, iter: Iterator[T]): Iterator[T] = {
417+
val a = constructA(index)
418+
iter.filter(t => p(t, a))
419+
}
420+
new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), true)
421+
}
422+
367423
/**
368424
* Zips this RDD with another one, returning key-value pairs with the first element in each RDD,
369425
* second element in each RDD, etc. Assumes that the two RDDs have the *same number of
@@ -382,6 +438,14 @@ abstract class RDD[T: ClassManifest](
382438
sc.runJob(this, (iter: Iterator[T]) => iter.foreach(cleanF))
383439
}
384440

441+
/**
442+
* Applies a function f to each partition of this RDD.
443+
*/
444+
def foreachPartition(f: Iterator[T] => Unit) {
445+
val cleanF = sc.clean(f)
446+
sc.runJob(this, (iter: Iterator[T]) => f(iter))
447+
}
448+
385449
/**
386450
* Return an array that contains all of the elements in this RDD.
387451
*/
@@ -404,7 +468,7 @@ abstract class RDD[T: ClassManifest](
404468

405469
/**
406470
* Return an RDD with the elements from `this` that are not in `other`.
407-
*
471+
*
408472
* Uses `this` partitioner/partition size, because even if `other` is huge, the resulting
409473
* RDD will be <= us.
410474
*/

core/src/test/scala/spark/RDDSuite.scala

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,4 +208,64 @@ class RDDSuite extends FunSuite with LocalSparkContext {
208208
assert(prunedData.size === 1)
209209
assert(prunedData(0) === 10)
210210
}
211+
212+
test("mapWith") {
213+
import java.util.Random
214+
sc = new SparkContext("local", "test")
215+
val ones = sc.makeRDD(Array(1, 1, 1, 1, 1, 1), 2)
216+
val randoms = ones.mapWith(
217+
(index: Int) => new Random(index + 42))
218+
{(t: Int, prng: Random) => prng.nextDouble * t}.collect()
219+
val prn42_3 = {
220+
val prng42 = new Random(42)
221+
prng42.nextDouble(); prng42.nextDouble(); prng42.nextDouble()
222+
}
223+
val prn43_3 = {
224+
val prng43 = new Random(43)
225+
prng43.nextDouble(); prng43.nextDouble(); prng43.nextDouble()
226+
}
227+
assert(randoms(2) === prn42_3)
228+
assert(randoms(5) === prn43_3)
229+
}
230+
231+
test("flatMapWith") {
232+
import java.util.Random
233+
sc = new SparkContext("local", "test")
234+
val ones = sc.makeRDD(Array(1, 1, 1, 1, 1, 1), 2)
235+
val randoms = ones.flatMapWith(
236+
(index: Int) => new Random(index + 42))
237+
{(t: Int, prng: Random) =>
238+
val random = prng.nextDouble()
239+
Seq(random * t, random * t * 10)}.
240+
collect()
241+
val prn42_3 = {
242+
val prng42 = new Random(42)
243+
prng42.nextDouble(); prng42.nextDouble(); prng42.nextDouble()
244+
}
245+
val prn43_3 = {
246+
val prng43 = new Random(43)
247+
prng43.nextDouble(); prng43.nextDouble(); prng43.nextDouble()
248+
}
249+
assert(randoms(5) === prn42_3 * 10)
250+
assert(randoms(11) === prn43_3 * 10)
251+
}
252+
253+
test("filterWith") {
254+
import java.util.Random
255+
sc = new SparkContext("local", "test")
256+
val ints = sc.makeRDD(Array(1, 2, 3, 4, 5, 6), 2)
257+
val sample = ints.filterWith(
258+
(index: Int) => new Random(index + 42))
259+
{(t: Int, prng: Random) => prng.nextInt(3) == 0}.
260+
collect()
261+
val checkSample = {
262+
val prng42 = new Random(42)
263+
val prng43 = new Random(43)
264+
Array(1, 2, 3, 4, 5, 6).filter{i =>
265+
if (i < 4) 0 == prng42.nextInt(3)
266+
else 0 == prng43.nextInt(3)}
267+
}
268+
assert(sample.size === checkSample.size)
269+
for (i <- 0 until sample.size) assert(sample(i) === checkSample(i))
270+
}
211271
}

0 commit comments

Comments
 (0)