Python Plotting With Matplotlib (Guide)

by Brad Solomon Feb 28, 2018 basics data-science

A picture is worth a thousand words, and with Python’s matplotlib library, it fortunately takes far less than a thousand words of code to create a production-quality graphic.

However, matplotlib is also a massive library, and getting a plot to look just right is often achieved through trial and error. Using one-liners to generate basic plots in matplotlib is fairly simple, but skillfully commanding the remaining 98% of the library can be daunting.

This article is a beginner-to-intermediate-level walkthrough on matplotlib that mixes theory with examples. While learning by example can be tremendously insightful, it helps to have even just a surface-level understanding of the library’s inner workings and layout as well.

Here’s what we’ll cover:

  • Pylab and pyplot: which is which?
  • Key concepts of matplotlib’s design
  • Understanding plt.subplots()
  • Visualizing arrays with matplotlib
  • Plotting with the pandas + matplotlib combination

This article assumes the user knows a tiny bit of NumPy. We’ll mainly use the numpy.random module to generate “toy” data, drawing samples from different statistical distributions.

If you don’t already have matplotlib installed, see here for a walkthrough before proceeding.

Why Can Matplotlib Be Confusing?

Learning matplotlib can be a frustrating process at times. The problem is not that matplotlib’s documentation is lacking: the documentation is actually extensive. But the following issues can cause some challenges:

  • The library itself is huge, at something like 70,000 total lines of code.
  • Matplotlib is home to several different interfaces (ways of constructing a figure) and capable of interacting with a handful of different backends. (Backends deal with the process of how charts are actually rendered, not just structured internally.)
  • While it is comprehensive, some of matplotlib’s own public documentation is seriously out-of-date. The library is still evolving, and many older examples floating around online may take 70% fewer lines of code in their modern version.

So, before we get to any glitzy examples, it’s useful to grasp the core concepts of matplotlib’s design.

Pylab: What Is It, and Should I Use It?

Let’s start with a bit of history: John D. Hunter, a neurobiologist, began developing matplotlib around 2003, originally inspired to emulate commands from Mathworks’ MATLAB software. John passed away tragically young at age 44, in 2012, and matplotlib is now a full-fledged community effort, developed and maintained by a host of others. (John gave a talk about the evolution of matplotlib at the 2012 SciPy conference, which is worth a watch.)

One relevant feature of MATLAB is its global style. The Python concept of importing is not heavily used in MATLAB, and most of MATLAB’s functions are readily available to the user at the top level.

Knowing that matplotlib has its roots in MATLAB helps to explain why pylab exists. pylab is a module within the matplotlib library that was built to mimic MATLAB’s global style. It exists only to bring a number of functions and classes from both NumPy and matplotlib into the namespace, making for an easy transition for former MATLAB users who were not used to needing import statements.

Ex-MATLAB converts (who are all fine people, I promise!) liked this functionality, because with from pylab import *, they could simply call plot() or array() directly, as they would in MATLAB.

The issue here may be apparent to some Python users: using from pylab import * in a session or script is generally bad practice. Matplotlib now directly advises against this in its own tutorials:

“[pylab] still exists for historical reasons, but it is highly advised not to use. It pollutes namespaces with functions that will shadow Python built-ins and can lead to hard-to-track bugs. To get IPython integration without imports the use of the %matplotlib magic is preferred.” [Source]

Internally, there are a ton of potentially conflicting imports being masked within the short pylab source. In fact, using ipython --pylab (from the terminal/command line) or %pylab (from IPython/Jupyter tools) simply calls from pylab import * under the hood.

The bottom line is that matplotlib has abandoned this convenience module and now explicitly recommends against using pylab, bringing things more in line with one of Python’s key notions: explicit is better than implicit.

Without the need for pylab, we can usually get away with just one canonical import:

>>> import matplotlib.pyplot as plt

While we’re at it, let’s also import NumPy, which we’ll use for generating data later on, and call np.random.seed() to make examples with (pseudo)random data reproducible:

>>> import numpy as np
>>> np.random.seed(444)

The Matplotlib Object Hierarchy

One important big-picture matplotlib concept is its object hierarchy.

If you’ve worked through any introductory matplotlib tutorial, you’ve probably called something like plt.plot([1, 2, 3]). This one-liner hides the fact that a plot is really a hierarchy of nested Python objects. A “hierarchy” here means that there is a tree-like structure of matplotlib objects underlying each plot.

A Figure object is the outermost container for a matplotlib graphic, which can contain multiple Axes objects. One source of confusion is the name: an Axes actually translates into what we think of as an individual plot or graph (rather than the plural of “axis,” as we might expect).

You can think of the Figure object as a box-like container holding one or more Axes (actual plots). Below the Axes in the hierarchy are smaller objects such as tick marks, individual lines, legends, and text boxes. Almost every “element” of a chart is its own manipulable Python object, all the way down to the ticks and labels:

Chart

Here’s an illustration of this hierarchy in action. Don’t worry if you’re not completely familiar with this notation, which we’ll cover later on:

>>> fig, _ = plt.subplots()
>>> type(fig)
<class 'matplotlib.figure.Figure'>

Above, we created two variables with plt.subplots(). The first is a top-level Figure object. The second is a “throwaway” variable that we don’t need just yet, denoted with an underscore. Using attribute notation, it is easy to traverse down the figure hierarchy and see the first tick of the y axis of the first Axes object:

>>> one_tick = fig.axes[0].yaxis.get_major_ticks()[0]
>>> type(one_tick)
<class 'matplotlib.axis.YTick'>

Above, fig (a Figure class instance) has multiple Axes (a list, for which we take the first element). Each Axes has a yaxis and xaxis, each of which have a collection of “major ticks,” and we grab the first one.

Matplotlib presents this as a figure anatomy, rather than an explicit hierarchy:

Chart: anatomy of a figure

(In true matplotlib style, the figure above is created in the matplotlib docs here.)

Stateful Versus Stateless Approaches

Alright, we need one more chunk of theory before we can get around to the shiny visualizations: the difference between the stateful (state-based, state-machine) and stateless (object-oriented, OO) interfaces.

Above, we used import matplotlib.pyplot as plt to import the pyplot module from matplotlib and name it plt.

Almost all functions from pyplot, such as plt.plot(), are implicitly either referring to an existing current Figure and current Axes, or creating them anew if none exist. Hidden in the matplotlib docs is this helpful snippet:

“[With pyplot], simple functions are used to add plot elements (lines, images, text, etc.) to the current axes in the current figure.” [emphasis added]

Hardcore ex-MATLAB users may choose to word this by saying something like, “plt.plot() is a state-machine interface that implicitly tracks the current figure!” In English, this means that:

  • The stateful interface makes its calls with plt.plot() and other top-level pyplot functions. There is only ever one Figure or Axes that you’re manipulating at a given time, and you don’t need to explicitly refer to it.
  • Modifying the underlying objects directly is the object-oriented approach. We usually do this by calling methods of an Axes object, which is the object that represents a plot itself.

The flow of this process, at a high level, looks like this:

Flow

Tying these together, most of the functions from pyplot also exist as methods of the matplotlib.axes.Axes class.

This is easier to see by peeking under the hood. plt.plot() can be boiled down to five or so lines of code:

# matplotlib/pyplot.py
>>> def plot(*args, **kwargs):
...     """An abridged version of plt.plot()."""
...     ax = plt.gca()
...     return ax.plot(*args, **kwargs)

>>> def gca(**kwargs):
...     """Get the current Axes of the current Figure."""
...     return plt.gcf().gca(**kwargs)

Calling plt.plot() is just a convenient way to get the current Axes of the current Figure and then call its plot() method. This is what is meant by the assertion that the stateful interface always “implicitly tracks” the plot that it wants to reference.

pyplot is home to a batch of functions that are really just wrappers around matplotlib’s object-oriented interface. For example, with plt.title(), there are corresponding setter and getter methods within the OO approach, ax.set_title() and ax.get_title(). (Use of getters and setters tends to be more popular in languages such as Java but is a key feature of matplotlib’s OO approach.)

Calling plt.title() gets translated into this one line: gca().set_title(s, *args, **kwargs). Here’s what that is doing:

  • gca() grabs the current axis and returns it.
  • set_title() is a setter method that sets the title for that Axes object. The “convenience” here is that we didn’t need to specify any Axes object explicitly with plt.title().

Similarly, if you take a few moments to look at the source for top-level functions like plt.grid(), plt.legend(), and plt.ylabels(), you’ll notice that all of them follow the same structure of delegating to the current Axes with gca() and then calling some method of the current Axes. (This is the underlying object-oriented approach!)

Understanding plt.subplots() Notation

Alright, enough theory. Now, we’re ready to tie everything together and do some plotting. From here on out, we’ll mostly rely on the stateless (object-oriented) approach, which is more customizable and comes in handy as graphs become more complex.

The prescribed way to create a Figure with a single Axes under the OO approach is (not too intuitively) with plt.subplots(). This is really the only time that the OO approach uses pyplot, to create a Figure and Axes:

>>> fig, ax = plt.subplots()

Above, we took advantage of iterable unpacking to assign a separate variable to each of the two results of plt.subplots(). Notice that we didn’t pass arguments to subplots() here. The default call is subplots(nrows=1, ncols=1). Consequently, ax is a single AxesSubplot object:

>>> type(ax)
<class 'matplotlib.axes._subplots.AxesSubplot'>

We can call its instance methods to manipulate the plot similarly to how we call pyplots functions. Let’s illustrate with a stacked area graph of three time series:

>>> rng = np.arange(50)
>>> rnd = np.random.randint(0, 10, size=(3, rng.size))
>>> yrs = 1950 + rng

>>> fig, ax = plt.subplots(figsize=(5, 3))
>>> ax.stackplot(yrs, rng + rnd, labels=['Eastasia', 'Eurasia', 'Oceania'])
>>> ax.set_title('Combined debt growth over time')
>>> ax.legend(loc='upper left')
>>> ax.set_ylabel('Total debt')
>>> ax.set_xlim(xmin=yrs[0], xmax=yrs[-1])
>>> fig.tight_layout()

Here’s what’s going on above:

  • After creating three random time series, we defined one Figure (fig) containing one Axes (a plot, ax).

  • We call methods of ax directly to create a stacked area chart and to add a legend, title, and y-axis label. Under the object-oriented approach, it’s clear that all of these are attributes of ax.

  • tight_layout() applies to the Figure object as a whole to clean up whitespace padding.

Chart: debt growth over time

Let’s look at an example with multiple subplots (Axes) within one Figure, plotting two correlated arrays that are drawn from the discrete uniform distribution:

>>> x = np.random.randint(low=1, high=11, size=50)
>>> y = x + np.random.randint(1, 5, size=x.size)
>>> data = np.column_stack((x, y))

>>> fig, (ax1, ax2) = plt.subplots(