placeholder
thoughts and learnings in software engineering by Rotem Tamir

Instrumenting PySpark Applications using Spark Accumulators

Apache Spark provides a very convenient abstraction layer for building distributed applications that process massive amounts of data. It is one of the most useful and robust tools in the tool belt of any data engineering team, facilitating both the exploratory analysis of huge datasets, scheduled production batch processing applications as well as streaming and machine learning applications.

Instrumentation, the practice of providing summary statistics of quantitative aspects of a system, is critical to any serious software engineering endeavor. In this post, I want to share how we used Spark’s Accumulator feature to provide a familiar, intuitive way to instrument Spark applications in production.

What are Spark Accumulators

Accumulators in Spark are entities that are used to gather information across different executors. The distributed nature of Spark applications prohibit updating from some global metric registry so Spark provides Accumulators as the golden way to basically share counters across process boundaries.

Consider this example:

from pyspark.sql import SparkSession

def filter_non_42(item):
    return '42' in str(item)

if __name__ == '__main__':
    spark = SparkSession \
            .builder \
            .appName("accumulator-example") \
            .getOrCreate()

    sc = spark.sparkContext

    print(sc.range(0, 10000).filter(filter_non_42).sum())

What this application does is:

  • Instantiate a Spark context
  • Creates an RDD with the range of integers from 1 to 10000
  • Filters the RDD to only contain numbers that have the digit sequence 42 in them
  • Sums the RDD

Let’s see it in action:

➜  spark python accum.py
20/02/25 14:01:33 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
1412358

Now suppose we wanted to instrument this application, and for this silly example, let’s say that we think that the number of even numbers in the input dataset is a meaningful signal for determining our system’s health, we can use an accumulator to count them:

from pyspark.sql import SparkSession
import pyspark
from functools import partial

def filter_non_42(item, accumulator):
    if item % 2 == 0:
        accumulator += 1
    return '42' in str(item)

if __name__ == '__main__':
    spark = SparkSession \
            .builder \
            .appName("accumulator-example") \
            .getOrCreate()

    sc = spark.sparkContext

    accumulator = sc.accumulator(0)
    counting_filter = partial(filter_non_42, accumulator=accumulator)

    print(sc.range(0, 10000).filter(counting_filter).sum())

    print('accum', accumulator)

In this program we:

  • Add an accumulator parameter to the filter function, and increment it whenever we see an even number.
  • Initialize an Accumulator using the sparkContext and set it to 0 in the driver.
  • Use functools.partial to create the counting_filter, which remembers our accumulator variable
  • Run our Spark application with the new counting_filter
  • Print the sum and the final value of the accumulator

Let’s see it in action:

➜  spark python accum.py

20/02/27 08:19:42 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
1412358
accum 5000

Indeed, as expected, our code encountered an even number 5000 times.

When we invoke sparkContext.accumulator(..) we register a new Accumulator with the driver of our application, Spark treats these entities in a special way which facilitate statistics gathering. This is how the Spark documentation describes them:

A shared variable that can be accumulated, i.e., has a commutative and associative “add” operation. Worker tasks on a Spark cluster can add values to an Accumulator with the += operator, but only the driver program is allowed to access its value, using value. Updates from the workers get propagated automatically to the driver program.

The documentation further states, “While SparkContext supports accumulators for primitive data types like int and float, users can also define accumulators for custom types by providing a custom AccumulatorParam object”. In the following parts of this post, we will see how we can use this capability to create an easy to use metric registry for instrumenting Spark applications.

Requirements from a Metric Registry

At the end of our pipeline, we would like to record our measurements in a time-series database such as InfluxDB. These data stores usually revolve around three major concepts:

  1. Measurements - akin to SQL tables, measurements are named sets of data, in a time-series store, for example, user_service_latency.
  2. Tags - each record in the measurement table can have a few key-value pairs describing it, which allows for performing quick aggregate functions on a subset of the measurement. For example: region=us-west-1, env=prod or endpoint=/reset_password. The cardinality (total number of options for each key) for tags should be relatively low, so you don’t store here things that vary much such as a user identifier.
  3. Values - finally, each record in the measurement tables has one or more values (Influx allows for multiple value columns per measurements, Prometheus does not), these are the numbers we later want to aggregate, monitor and alert on in production. Values are usually one of three: Counters (monotonically increasing numeric values), Histograms (statistical representation of the distribution of values, e.g. latency histograms) and Gauges (a value which is increasing and decreasing over time, e.g. current CPU utilization). In this post, I will focus on implementing the simplest of the three: Counters.

Our goal in this project is to create a MetricRegistry class which will be the container for these measurements, broken down by tags, when our applications finish running we will retrieve the data from the MetricRegistry and report it to our storage of choice. We are aiming for something like this:

class MetricRegistry:
    def __init__(self, ...):
        # ... initialization
    def inc(self, measurement: str, tags: dict, value):
        # somehow use Spark Accumulators to keep track 
    def get(self) -> Accumulator: 
        # returns the accumulator

Spark Custom AccumulatorParams

This is the documentation on AccumulatorParam:

class pyspark.AccumulatorParam
    Helper object that defines how to accumulate values of a given type.

    addInPlace(value1, value2)
        Add two values of the accumulator’s data type, returning a new value; 
        for efficiency, can also update value1 in place and return it.

    zero(value)
        Provide a “zero value” for the type, compatible in dimensions with
            the provided value (e.g., a zero vector)

The main thing to do is implementing the addInPlace function, which must guarantee the Accumulator’s commutative (a+b = b+a) and associative (a+(b+c) = (a+b)+c) properties.

Luckily, the python standard library has a neat collections.Counter which is basically a dict which supports add operations like:

>>> from collections import Counter
>>>
>>> c1 = Counter({'a': 1, 'b': 1})
>>> c2 = Counter({'b': 1, 'c': 1})
>>> c1 + c2
Counter({'b': 2, 'a': 1, 'c': 1})

Naturally, this add operation is both commutative:

>>> c1 + c2
Counter({'b': 2, 'a': 1, 'c': 1})
>>> c2 + c1
Counter({'b': 2, 'a': 1, 'c': 1})

And associative:

>>> c3 = Counter({'d': 1})
>>> c1 + (c2 + c3)
Counter({'b': 2, 'a': 1, 'c': 1, 'd': 1})
>>> (c1 + c2) + c3
Counter({'b': 2, 'a': 1, 'c': 1, 'd': 1})

As you can see, it is a perfect fit for containing data as an AccumulatorParamlike this:

from collections import Counter

import pyspark

class CounterAccumulator(pyspark.accumulators.AccumulatorParam):
    def zero(self, value: Counter) -> Counter:
        return Counter({})

    def addInPlace(self, value1: Counter, value2: Counter) -> Counter:
        return value1 + value2

Let’s stitch this back into our original example:

from pyspark.sql import SparkSession
import pyspark
from functools import partial
from collections import Counter

class CounterAccumulator(pyspark.accumulators.AccumulatorParam):
    def zero(self, value: Counter) -> Counter:
        return Counter({})

    def addInPlace(self, value1: Counter, value2: Counter) -> Counter:
        return value1 + value2


def filter_non_42(item, accumulator: CounterAccumulator):
    if item % 2 == 0:
        accumulator += Counter({'even_number': 1})
    return '42' in str(item)


if __name__ == '__main__':
    spark = SparkSession \
            .builder \
            .appName("accumulator-example") \
            .getOrCreate()

    sc = spark.sparkContext

    accumulator = sc.accumulator(Counter(), CounterAccumulator())
    counting_filter = partial(filter_non_42, accumulator=accumulator)

    print(sc.range(0, 10000).filter(counting_filter).sum())

    print('accum', accumulator)

Which prints:

20/02/28 18:45:08 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
1412358
accum Counter({'even_number': 5000})

Putting it all together

Now that we have a Spark Accumulator that can contain counters we are all set to implement our basic MetricRegistry:

class MetricRegistry:
    def __init__(self, accumulator: pyspark.Accumulator, base_tags: dict):
        self.accumulator = accumulator
        self.base_tags = base_tags

    def inc(self, measurement: str, amount: int = 1, tags: dict = None):
        k = MetricRegistry.serialize_key(measurement, tags=tags)
        self.accumulator += Counter({k: amount})

    def get(self) -> List[Dict]:
        val = self.accumulator.value
        output =  []
        for serialized, counter in val.items():
            measurement_name, tags = MetricRegistry.deserialize_key(serialized)
            output.append(dict(measurement_name=measurement_name, tags={**self.base_tags, **tags}, value=counter))
        return output

    @classmethod
    def serialize_key(cls, name: str, tags: Dict[str, str]) -> str:
        if tags is None:
            tags = dict()
        v = dict(n=name, t=tags)
        return dumps(v, sort_keys=True)

    @classmethod
    def deserialize_key(cls, serialized: str) -> Tuple[str, Dict[str, str]]:
        v = loads(serialized)
        return v['n'], v['t']

What’s going on here:

  • The constructor accepts an Accumulator which will contain our shiny new CounterAccumulator and a dict of base tags, this is a common pattern in instrumentation libraries, as you usually want to enrich all measurements coming from a specific instance with certain metadata (environment name, deployment region, customer tier, etc.)
  • The inc method accepts a measurement name, an optional amount (sometimes we want to increment by more than 1), and a tags dictionary which can be used to provide additional context about what we are measuring (for example, errorname=CopyPasteException). This method takes the measurement name and tags and serialized them into string format (using a deterministic dict to JSON function) so they can be stored in our Counter based AccumulatorParam.
  • The get method returns a list of dictionaries containing measurement_name, a tags dictionary containing both base tags (of the registry) and the specific tags for this row, and a counter value.

Our final application looks like this:

# ... 
# imports and MetricRegistry removed for brevity
def filter_non_42(item, metric_registry: MetricRegistry):
    if item % 2 == 0:
        metric_registry.inc('even_numbers', tags=dict(more_context='demo'))
    return '42' in str(item)


if __name__ == '__main__':
    spark = SparkSession \
            .builder \
            .appName("accumulator-example") \
            .getOrCreate()

    sc = spark.sparkContext

    accumulator = sc.accumulator(Counter(), CounterAccumulator())
    metric_registry = MetricRegistry(accumulator, base_tags=dict(server_name='prod.server.1'))
    counting_filter = partial(filter_non_42, metric_registry=metric_registry)

    print(sc.range(0, 10000).filter(counting_filter).sum())

    print('accum', metric_registry.get())

Which prints:

20/02/28 19:18:02 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
1412358
accum [{'measurement_name': 'even_numbers', 'tags': {'server_name': 'prod.server.1', 'more_context': 'demo'}, 'value': 5000}]

Concluding

Instrumenting Spark applications is critical if you want to provide good service your data consumers. However, application-level data, which we discussed in this post, is only one layer of metrics we consider when evaluating the performance of Spark applications. A comprehensive instrumentation stack should probably consist of:

  • Resource utilization - how well are we utilizing the hardware we are provisioning?
  • Spark internal data - we can find bottlenecks and errors which aren’t exactly applicative by examining the data Spark exposes about jobs, stages, and tasks: are we doing excessive, expensive data shuffles of data between stages? Do executors die throughout the execution of our DAG?
  • Scheduling and total errors - your jobs (unless you are using Spark Streaming) are probably invoked by an external scheduler, as with any scheduled job, you should instrument: failed invocations, missed executions and duration.