Take Full Control Over the Subplots in Matplotlib

Take Full Control Over the Subplots in Matplotlib

Using subplots and putting multiple plots in one figure can be very useful in summarizing a lot of information in a small space. They are helpful in making reports or presentations. This article will focus on how to use subplots efficiently and take fine control over the grids.

We will start with the basic subplot function to make equal size plots first. Let’s do the necessary imports:

%matplotlib inlineimport matplotlib.pyplot as plt 
import numpy as np
import pandas as pd

Here is the basic subplots function in Matplotlib that makes two rows and three columns of equal-sized rectangular space:

fig, ax = plt.subplots(2, 3, sharex = 'col', sharey = 'row', figsize = (9, 6))
fig.tight_layout(pad =3.0)

The ‘sharex’ parameter makes the plots in the same column have the same x-axis and setting the ‘sharey’ parameter to ‘row’ makes the plots in the same row share the same y-axis. That’s why there are x and y-axis values in the outer layer only. Sharing the axis can have its advantage and disadvantage. We will talk about it again later.

How to access a plot in this array of plots?

These rectangles are stored in a two-dimensional array. Let’s print out the ‘ax’:



array([[<AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>],
[<AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>]], dtype=object)

It is clear from the output above that it’s actually a two-dimensional array. So accessing one element in it should take a nested for loop. Start by putting some text in each of the rectangles:

fig, ax = plt.subplots(2, 3, sharex = 'col', sharey = 'row', figsize = (9, 6))
fig.tight_layout(pad = 2)
for a in ax:
for b in a:
text = 'I am a plot'
b.annotate(text, (0.3, 0.45), fontsize = 12)

The main purpose is to put some real graphs and plots in those rectangles. I will do exactly that next. But for that, we need a dataset. I will use this dataset.

Please feel free to download the dataset and follow along.

Here I am importing a dataset using pandas:

df = pd.read_csv('nhanes_2015_2016.csv')



I will make a 2×3 array of plots again and set plots in the ‘ax’ elements. We will access each ‘ax’ element by indexing simply like a two-dimensional array.

So, here is how to access the ‘ax’ elements and set plots in them. I kept two slots empty for you. Please feel free to fill them up. My idea was to demonstrate how to do it.

fig, ax = plt.subplots(2, 3, figsize = (15, 10))
fig.tight_layout(pad = 2)
ax[0, 0].scatter(df['BMXWT'], df['BMXHT'])
ax[1, 1].plot(df['BMXBMI'])
df['DMDHHSIZ'].hist(ax = ax[0, 2])
df.groupby('DMDEDUC2')['BPXSY1'].mean().plot(ax = ax[0, 1], kind='pie', colors = ['lightgreen', 'coral', 'pink', 'violet', 'skyblue'])

Notice, I didn’t set the ‘sharex’ to ‘col’ or ‘sharey’ or ‘row’. The plots won’t be helpful that way. Because the pie chart does not have that traditional x and y-axis. That histogram and scatter plots have very different ranges. Please try it yourself.

All the plots above are the same sizes. I mean the same height and the same width. It does not have to be that way. Plots in one subplot group can be of different heights and widths. That can be done using the ‘add_gridspec’ function.

fig = plt.figure(constrained_layout=True, figsize=(8, 8))s = fig.add_gridspec(3, 3, width_ratios = [2, 3, 4], height_ratios = [3, 3, 2])for row in range(3):
for col in range(3):
ax = fig.add_subplot(s[row, col])

Taking Even More Fine Control Over the Grids

That can be done using ‘subplots’ and ‘GridSpec’. Here is an example. I will explain it after the picture:

plt.figure(figsize = (15, 12))
grid = plt.GridSpec(3, 4, wspace =0.3, hspace = 0.3)
plt.subplot(grid[0, :3])
plt.subplot(grid[0, 3])
plt.subplot(grid[1:, :2])
plt.subplot(grid[1, 2:])
plt.subplot(grid[2, 2:])

Let me explain what happened here. First, we made a 3×4 grid. That means three rows and four columns. Next, we indexed through the grids to make custom sizes of plots.

plt.subplot(grid[0, :3])

Using this code we are indexing the grid and making a custom shape. ‘grid[0, 3]’ here is taking the first three plots of the first row and making a bigger plot. We put the row index first. Because it is the first row, row-index is 0, and column index is 0 to 3 as we are taking the first three columns. You can write it as 0:3. But when it starts with 0, it can be written as :3.

The next one is ‘grid[0, 3]’. This one is simpler. 0 means row-index is 0 and 3 means column-index is 3.

Using ‘grid[1:,:2]’ we are making that big square-shaped one. The row index starts at 1 and goes till the end. The column-index starts at 0 and takes 2 plots. So, it takes plots of indexes 0 and 1. So the column index becomes 0:2 which can be written as :2.

‘[1, 2:]’ refers to row index 1. The column index starts at 2 and goes till the end.

‘grid[2, 2:]’ means row index 2 and column-index 2 to end.

Now as you know how to index the grid and make custom-shaped plots, let’s make another one and put some real plots in them.

plt.figure(figsize = (12, 12))
grid = plt.GridSpec(4, 4, wspace =0.3, hspace = 0.8)
g1 = plt.subplot(grid[0:2, :3])
g2 = plt.subplot(grid[2:, 0:2])
g3 = plt.subplot(grid[:2, 3])
g4 = plt.subplot(grid[2:, 2:])
df.groupby('DMDMARTL')['BPXSY1'].mean().plot(kind = 'bar', ax = g1)
g1.set_title("Bar plot of Systolic blood pressure for different marital status")
df['BPXSY1'].hist(ax = g3, orientation = 'horizontal', color='gray')
g3.set_title("Distribution of systolic blood pressure")
df.plot('BPXSY1', 'BPXDI1', kind='scatter', ax = g2, alpha=0.3)
g2.set_title("Systolic vs Diastolic blood pressure")
df['BMXHT'].plot(ax = g4, color='gray')
g4.set_title('Line plot of the Weight of the population')

Here it is! The complete plots.

If you finished running all the code above and could understand them, using subplots should be easy and you may have full control over the subplots now. Here is the video version of this content:

Feel free to follow me on Twitter and like my Facebook page.

#python #matplotlib #DataVisualization #DataScience

Leave a Reply

Close Menu