Instrumenting PySpark Applications using Spark Accumulators
Apache Spark provides a very convenient abstraction layer for building distributed applications that process massive amounts of data. It is one of the most useful and robust tools in the tool belt of any data engineering team, facilitating both the exploratory analysis of huge datasets, scheduled production batch processing applications as well as streaming and machine learning applications.
Instrumentation, the practice of providing summary statistics of quantitative aspects of a system, is critical to any serious software engineering endeavor. In this post, I want to share how we used Spark’s Accumulator
feature to provide a familiar, intuitive way to instrument Spark applications in production.
What are Spark Accumulators
Accumulators in Spark are entities that are used to gather information across different executors. The distributed nature of Spark applications prohibit updating from some global metric registry so Spark provides Accumulators as the golden way to basically share counters across process boundaries.
Consider this example:
from pyspark.sql import SparkSession
def filter_non_42(item):
return '42' in str(item)
if __name__ == '__main__':
spark = SparkSession \
.builder \
.appName("accumulator-example") \
.getOrCreate()
sc = spark.sparkContext
print(sc.range(0, 10000).filter(filter_non_42).sum())
What this application does is:
- Instantiate a Spark context
- Creates an RDD with the range of integers from 1 to 10000
- Filters the RDD to only contain numbers that have the digit sequence
42
in them - Sums the RDD
Let’s see it in action:
➜ spark python accum.py
20/02/25 14:01:33 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
1412358
Now suppose we wanted to instrument this application, and for this silly example, let’s say that we think that the number of even numbers in the input dataset is a meaningful signal for determining our system’s health, we can use an accumulator to count them:
from pyspark.sql import SparkSession
import pyspark
from functools import partial
def filter_non_42(item, accumulator):
if item % 2 == 0:
accumulator += 1
return '42' in str(item)
if __name__ == '__main__':
spark = SparkSession \
.builder \
.appName("accumulator-example") \
.getOrCreate()
sc = spark.sparkContext
accumulator = sc.accumulator(0)
counting_filter = partial(filter_non_42, accumulator=accumulator)
print(sc.range(0, 10000).filter(counting_filter).sum())
print('accum', accumulator)
In this program we:
- Add an
accumulator
parameter to the filter function, and increment it whenever we see an even number. - Initialize an Accumulator using the sparkContext and set it to
0
in the driver. - Use
functools.partial
to create thecounting_filter
, which remembers ouraccumulator
variable - Run our Spark application with the new
counting_filter
- Print the sum and the final value of the accumulator
Let’s see it in action:
➜ spark python accum.py
20/02/27 08:19:42 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
1412358
accum 5000
Indeed, as expected, our code encountered an even number 5000 times.
When we invoke sparkContext.accumulator(..)
we register a new Accumulator
with the driver of our application, Spark treats these entities in a special way which facilitate statistics gathering. This is how the Spark documentation describes them:
A shared variable that can be accumulated, i.e., has a commutative and associative “add” operation. Worker tasks on a Spark cluster can add values to an Accumulator with the += operator, but only the driver program is allowed to access its value, using value. Updates from the workers get propagated automatically to the driver program.
The documentation further states, “While SparkContext supports accumulators for primitive data types like int and float, users can also define accumulators for custom types by providing a custom AccumulatorParam
object”. In the following parts of this post, we will see how we can use this capability to create an easy to use metric registry for instrumenting Spark applications.
Requirements from a Metric Registry
At the end of our pipeline, we would like to record our measurements in a time-series database such as InfluxDB. These data stores usually revolve around three major concepts:
- Measurements - akin to SQL tables, measurements are named sets of data, in a time-series store, for example,
user_service_latency
. - Tags - each record in the measurement table can have a few key-value pairs describing it, which allows for performing quick aggregate functions on a subset of the measurement. For example:
region=us-west-1
,env=prod
orendpoint=/reset_password
. The cardinality (total number of options for each key) for tags should be relatively low, so you don’t store here things that vary much such as a user identifier. - Values - finally, each record in the measurement tables has one or more values (Influx allows for multiple value columns per measurements, Prometheus does not), these are the numbers we later want to aggregate, monitor and alert on in production. Values are usually one of three: Counters (monotonically increasing numeric values), Histograms (statistical representation of the distribution of values, e.g. latency histograms) and Gauges (a value which is increasing and decreasing over time, e.g. current CPU utilization). In this post, I will focus on implementing the simplest of the three: Counters.
Our goal in this project is to create a MetricRegistry
class which will be the container for these measurements, broken down by tags, when our applications finish running we will retrieve the data from the MetricRegistry
and report it to our storage of choice. We are aiming for something like this:
class MetricRegistry:
def __init__(self, ...):
# ... initialization
def inc(self, measurement: str, tags: dict, value):
# somehow use Spark Accumulators to keep track
def get(self) -> Accumulator:
# returns the accumulator
Spark Custom AccumulatorParams
This is the documentation on AccumulatorParam
:
class pyspark.AccumulatorParam
Helper object that defines how to accumulate values of a given type.
addInPlace(value1, value2)
Add two values of the accumulator’s data type, returning a new value;
for efficiency, can also update value1 in place and return it.
zero(value)
Provide a “zero value” for the type, compatible in dimensions with
the provided value (e.g., a zero vector)
The main thing to do is implementing the addInPlace
function, which must guarantee the Accumulator’s commutative (a+b = b+a
) and associative (a+(b+c) = (a+b)+c
) properties.
Luckily, the python standard library has a neat collections.Counter
which is basically a dict
which supports add operations like:
>>> from collections import Counter
>>>
>>> c1 = Counter({'a': 1, 'b': 1})
>>> c2 = Counter({'b': 1, 'c': 1})
>>> c1 + c2
Counter({'b': 2, 'a': 1, 'c': 1})
Naturally, this add operation is both commutative:
>>> c1 + c2
Counter({'b': 2, 'a': 1, 'c': 1})
>>> c2 + c1
Counter({'b': 2, 'a': 1, 'c': 1})
And associative:
>>> c3 = Counter({'d': 1})
>>> c1 + (c2 + c3)
Counter({'b': 2, 'a': 1, 'c': 1, 'd': 1})
>>> (c1 + c2) + c3
Counter({'b': 2, 'a': 1, 'c': 1, 'd': 1})
As you can see, it is a perfect fit for containing data as an AccumulatorParam
like this:
from collections import Counter
import pyspark
class CounterAccumulator(pyspark.accumulators.AccumulatorParam):
def zero(self, value: Counter) -> Counter:
return Counter({})
def addInPlace(self, value1: Counter, value2: Counter) -> Counter:
return value1 + value2
Let’s stitch this back into our original example:
from pyspark.sql import SparkSession
import pyspark
from functools import partial
from collections import Counter
class CounterAccumulator(pyspark.accumulators.AccumulatorParam):
def zero(self, value: Counter) -> Counter:
return Counter({})
def addInPlace(self, value1: Counter, value2: Counter) -> Counter:
return value1 + value2
def filter_non_42(item, accumulator: CounterAccumulator):
if item % 2 == 0:
accumulator += Counter({'even_number': 1})
return '42' in str(item)
if __name__ == '__main__':
spark = SparkSession \
.builder \
.appName("accumulator-example") \
.getOrCreate()
sc = spark.sparkContext
accumulator = sc.accumulator(Counter(), CounterAccumulator())
counting_filter = partial(filter_non_42, accumulator=accumulator)
print(sc.range(0, 10000).filter(counting_filter).sum())
print('accum', accumulator)
Which prints:
20/02/28 18:45:08 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
1412358
accum Counter({'even_number': 5000})
Putting it all together
Now that we have a Spark Accumulator
that can contain counters we are all set to implement our basic MetricRegistry
:
class MetricRegistry:
def __init__(self, accumulator: pyspark.Accumulator, base_tags: dict):
self.accumulator = accumulator
self.base_tags = base_tags
def inc(self, measurement: str, amount: int = 1, tags: dict = None):
k = MetricRegistry.serialize_key(measurement, tags=tags)
self.accumulator += Counter({k: amount})
def get(self) -> List[Dict]:
val = self.accumulator.value
output = []
for serialized, counter in val.items():
measurement_name, tags = MetricRegistry.deserialize_key(serialized)
output.append(dict(measurement_name=measurement_name, tags={**self.base_tags, **tags}, value=counter))
return output
@classmethod
def serialize_key(cls, name: str, tags: Dict[str, str]) -> str:
if tags is None:
tags = dict()
v = dict(n=name, t=tags)
return dumps(v, sort_keys=True)
@classmethod
def deserialize_key(cls, serialized: str) -> Tuple[str, Dict[str, str]]:
v = loads(serialized)
return v['n'], v['t']
What’s going on here:
- The constructor accepts an
Accumulator
which will contain our shiny newCounterAccumulator
and adict
of base tags, this is a common pattern in instrumentation libraries, as you usually want to enrich all measurements coming from a specific instance with certain metadata (environment name, deployment region, customer tier, etc.) - The
inc
method accepts a measurement name, an optional amount (sometimes we want to increment by more than 1), and atags
dictionary which can be used to provide additional context about what we are measuring (for example,errorname=CopyPasteException
). This method takes the measurement name and tags and serialized them into string format (using a deterministic dict to JSON function) so they can be stored in ourCounter
basedAccumulatorParam
. - The
get
method returns a list of dictionaries containingmeasurement_name
, atags
dictionary containing both base tags (of the registry) and the specific tags for this row, and a counter value.
Our final application looks like this:
# ...
# imports and MetricRegistry removed for brevity
def filter_non_42(item, metric_registry: MetricRegistry):
if item % 2 == 0:
metric_registry.inc('even_numbers', tags=dict(more_context='demo'))
return '42' in str(item)
if __name__ == '__main__':
spark = SparkSession \
.builder \
.appName("accumulator-example") \
.getOrCreate()
sc = spark.sparkContext
accumulator = sc.accumulator(Counter(), CounterAccumulator())
metric_registry = MetricRegistry(accumulator, base_tags=dict(server_name='prod.server.1'))
counting_filter = partial(filter_non_42, metric_registry=metric_registry)
print(sc.range(0, 10000).filter(counting_filter).sum())
print('accum', metric_registry.get())
Which prints:
20/02/28 19:18:02 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
1412358
accum [{'measurement_name': 'even_numbers', 'tags': {'server_name': 'prod.server.1', 'more_context': 'demo'}, 'value': 5000}]
Concluding
Instrumenting Spark applications is critical if you want to provide good service your data consumers. However, application-level data, which we discussed in this post, is only one layer of metrics we consider when evaluating the performance of Spark applications. A comprehensive instrumentation stack should probably consist of:
- Resource utilization - how well are we utilizing the hardware we are provisioning?
- Spark internal data - we can find bottlenecks and errors which aren’t exactly applicative by examining the data Spark exposes about jobs, stages, and tasks: are we doing excessive, expensive data shuffles of data between stages? Do executors die throughout the execution of our DAG?
- Scheduling and total errors - your jobs (unless you are using Spark Streaming) are probably invoked by an external scheduler, as with any scheduled job, you should instrument: failed invocations, missed executions and duration.