correctness

Using Spark's cache for correctness, not just performance

RDDs are immutable. Right? This is one of the first things we learn when we read about Apache Spark™. Here’s a little program which appears to contradict this. This Scala program creates a small RDD, performs a few simple transformations on it, and then calls RDD.count() on the same RDD twice. The values of the two calls to count are compared with an assert, and at first glance, we would think that this should always pass. There are no calls in between the two calls to count(), and even if there were, RDDs are immutable, so we must get the same value for count(), right? Here’s the program:

/*
 * This file is licensed to You under the Eclipse Public License (EPL);
 *  http://www.opensource.org/licenses/eclipse-1.0.php
 * (C) Copyright IBM Corporation 2015
 */

import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import scala.util.Random

object MutableRDD {
 def main(args: Array[String]) {
   val conf = new SparkConf().setAppName("Immutable RDD test")
   val sc = new SparkContext(conf)

   // start with a sequence of 10,000 zeros
   val zeros = Seq.fill(10000)(0)

   // create a RDD from the sequence, and replace all zeros with random values
   val randomRDD = sc.parallelize(zeros).map(x=>Random.nextInt())

   // filter out all non-positive values, roughly half the set
   val filteredRDD = randomRDD.filter(x=>x>0)

   // count the number of elements that remain, twice
   val count1 = filteredRDD.count()
   val count2 = filteredRDD.count()

   // Since filteredRDD is immutable, this should always pass, right? 
   assert(count1 == count2, "\nMismatch!  count1="+count1+" count2=+count2)

   System.out.println("Program completed successfully")
 }
}

Since we’re using a random number generator, it’s possible that this program will indeed complete successfully if the numbers line up properly, but in a large number of test runs, I always get output which looks like this:

Exception in thread "main" java.lang.AssertionError: assertion failed: 
Mismatch!  count1=4984 count2=4973
    at scala.Predef$.assert(Predef.scala:179)
    at MutableRDD$.main(MutableRDD.scala:30)
    at MutableRDD.main(MutableRDD.scala)
    at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
    at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:95)
    at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:55)
    at java.lang.reflect.Method.invoke(Method.java:507)
    at org.apache.spark.deploy.SparkSubmit$.org$apache$spark$deploy$SparkSubmit$$runMain(SparkSubmit.scala:664)
    at org.apache.spark.deploy.SparkSubmit$.doRunMain$1(SparkSubmit.scala:169)
    at org.apache.spark.deploy.SparkSubmit$.submit(SparkSubmit.scala:192)
    at org.apache.spark.deploy.SparkSubmit$.main(SparkSubmit.scala:111)
    at org.apache.spark.deploy.SparkSubmit.main(SparkSubmit.scala)

So, what’s going on here? This is a case where the distinction between a Spark transformation and a Spark action is critical. As you know, transformations are lazily built up, but Spark does not perform any actual processing until an action is needed. Looking at the program above, here are the steps, and their types:

map() -> transformation
filter() -> transformation
count() -> action
count() -> action

Before the call to count(), all of the previous steps are transformations, so it’s the first call to count() which actually causes the RDD to be computed, starting from our initial set of 10,000 zeros. But this still doesn’t fully explain why the second call to count produces a different value.

Recall that Spark will keep RDDs in memory for reuse, but only if the programmer explicitly makes a call to cache() or persist() the RDD. Our program above does not cache the RDD, and so after the first call to filteredRDD.count() is completed, the filteredRDD contents is discarded! Our second call to filteredRDD.count() then is creating a new RDD, again starting from the initial set of 10,000 zeros. Although the steps to create the RDD for the second call to filteredRDD.count() are identical, we’re using a random number generator in our map function and filtering based on the values it produces, so this second filteredRDD is a completely different set of values, and produces a different value for count(). It doesn’t matter that RDDs are immutable, or that filteredRDD is a val, or that there are no other calls in between successive calls to count().

If we want to fix our program to always pass the assertion, we need to cache our filteredRDD, by replacing randomRDD.filter(x=>x>0) with randomRDD.filter(x=>x>0)<em>.cache()</em>. If the resulting RDD is too large to fit in memory, even with the call to cache(), Spark may drop and recompute portions of the RDD. Since our program is small there’s no danger of running out of memory, but in a larger program, it is better to use randomRDD.filter(x=>x>0)<em>.persist(StorageLevel.MEMORY_AND_DISK)</em> instead, as this will guarantee the RDD is consistent, spilling to disk if necessary. With this small change, our program completes correctly:

Program completed successfully

Using Spark’s cache is not just a performance tool, which can be left out of simple programs. As seen above, it can also be important for reproducibility in any program which contains some level of sampling, random values, or other forms of variability.

Spark Technology Center

Newsletter

Subscribe to the Spark Technology Center newsletter for the latest thought leadership in Apache Spark™, machine learning and open source.

Subscribe

Newsletter

You Might Also Enjoy