Stateful Spark Streaming Using Transform

In an ideal world, Streaming applications are stateless. Some data comes in, you do some simple filtering, maybe change date strings from one format to another, and happily pass the results on to the next step. Sometimes though, we need state. Maybe it’s something simple, such as the largest value seen so far, or a running average. Sometimes it’s more complex, like a machine learning model which is used to score each element of a stream, and is also continually updated. In Apache Spark’s Streaming API, there are a couple of ways to hold and update state within a step of the streaming API. In this blog post, I’ll talk about how to use the transform operation, which is the less well-known method.

If you read the Apache Spark Streaming documentation looking for a way to carry state in a data stream, you’ll quickly find the method updateStateByKey(func). While this is a useful function, it has the limitation that the state information it carries is also the data passed onto the next stage of the stream. For example, let’s say that you have a stream of integers, and want to produce as output the difference between the current value, and the previous one. So given the sequence “1, 3, 7, 2, …” and an initial starting value of 0, you want to produce the output “1, 2, 4, -5, …” for the following stages of your Dstream. To calculate this, you need to remember the previous value, but output the difference. In each micro batch, we may get zero or many values at once, and for the state, we only need to keep the last seen value from one batch to calculate the difference with the first value in the next batch.

One possible implementation using updateStateByKey(func) uses multiple steps. We don’t really have data in key+value form, so our state will consist of an imaginary key such as “machineA”, with a value containing the last value from the previous batch, followed by the incoming integers in the new batch: (“machineA”, 0, [1,3,7,2]). After computing differences within the array, we get (“machineA”, 2, [1,2,4,-5]), where “machineA” is our key, 2 is the last value of the sequence, and [1,2,4,-5] are the differences between values. We then need to use a simple map operation as the next streaming step to throw away everything other than the set of differences. I’m not providing the full source for this updateStateByKey(func) example due to space constraints, but you can look at the StatefulNetworkWordCount program for a similar fully coded example.

The main drawback with our updateStateByKey example is that the only state information that we really need to carry from one micro batch to the next is the last seen value: 2. But because the state is inseparable from the output of the streaming stage, we need a data structure with both in it: (“machine_A”, 2, [1,2,4,-5]), so that we can both hold our state, and also provide the output to later stages.

Another approach is to separate the state from the output, using the transform(func) operation. The transform method allows us to perform arbitrary RDD functions on the microbatch; we can drop down from the Streaming APIs to the standard Spark RDD APIs, and use methods such as RDD.join() to merge running state with incoming data. An example of this approach is in the program below.

import org.apache.spark.SparkConf
import org.apache.spark.HashPartitioner
import org.apache.spark.streaming._
import org.apache.spark.rdd._

object StatefulStreaming {
  val machineA = "machine_A" // we use a RDD of key+value entries, where the machine name is the key
  def main(args: Array[String]) {
    if (args.length < 2) {
      System.err.println("Usage: StatefulStreaming  ")

    val sparkConf = new SparkConf().setAppName("StatefulStreaming")
    val ssc = new StreamingContext(sparkConf, Seconds(1))

    // Initial RDD, which has initial data.  This gets updated with each iteration.
    var deltaState = ssc.sparkContext.parallelize(List((machineA,0))) // name, value
    // we need to checkpoint regularly to limit the RDD dependency chain
    var checkpointCounter = 0

    def calculateDeltas(previous:Int, newvals:Option[Array[Int]]) = newvals match {
      case Some(newvals) => {
        var p = previous
        val deltas = new Array[Int](newvals.length)
        for(i<-0 until newvals.length){
          deltas(i) = newvals(i) - p
          p = newvals(i)
        (p, Some(deltas))
      case None => {(previous, newvals)}

    val calculateDeltaRDD = (newbatch: org.apache.spark.rdd.RDD[(String, Array[Int])]) => {
      val newDeltas = deltaState.leftOuterJoin(newbatch).mapValues(x=>calculateDeltas(x._1, x._2))
      deltaState = newDeltas.mapValues(x=>x._1).persist(StorageLevel.MEMORY_AND_DISK)
      if (checkpointCounter > 10) {
      else checkpointCounter=checkpointCounter+1
      val printme = deltaState.take(2)
      System.err.println("New deltaState: ("+printme(0)._1+", prev="+printme(0)._2+"))
      newDeltas.mapValues(x=>x._2) // output (machine, difference)

    // Create one Dstream, associated with one machine
    // new value in input stream of \n delimited text (eg. generated by 'nc')
    val machineAlines = ssc.socketTextStream(args(0), args(1).toInt)
    val machineADstream =" ").map(_.toInt)).map(x=>(machineA, x))

    // do some stuff with these streams.  Chain them together, output deltas, etc
    val deltasDstream = machineADstream.transform[(String, Option[Array[Int]])](calculateDeltaRDD)

    // print out the deltas with each iteration


The state information that we want to keep across microbatches is in the deltaState variable, which holds the last value seen for each machine (this program only tracks a single machine, but multiple machine streams could be joined via the DStream.union(DStream) method). The deltaState is calculated for the current microbatch, cached, and joined with the incoming data in the next microbatch repeatedly.

There are a few points to note in this program. First, we need to regularly checkpoint our deltaState, to ensure that the RDD dependency chain doesn’t become too long, causing a stack overflow. We do this by explicitly checkpointing deltaState every 10 microbatches, which is infrequently enough to not be a drag on performance. Second, we explicitly persist our new deltaState and unpersist the previous deltaState with every microbatch in MEMORYANDDISK, which uses memory if possible but overflows to disk if the RDD doesn’t fit. This ensures we don’t attempt to re-calculate previous deltas all the way back to the checkpoint with each microbatch. Finally, we force the creation of the RDD via the RDD.take() action, to ensure that the RDD is actually cached and checkpointed. Without an action such as RDD.count() or RDD.take() we only have transformations, which will be computed lazily. This combination ensures that we have a program which can run for a very long time without performance degradation.