How to Make Heatmap with Matplotlib in Python

How to Change Heatmap Color Palette in Matplotlib
Change Heatmap Color Palette in Matplotlib using cmap argument

In this post, we will learn how to make heatmap with Matplotlib in Python. In Matplotlib, we can make heatmap with the function imshow(). imshow() basically shows the input data as image.

We will start making a simple heatmap with a one-liner using imshow() first. And then show couple of simple customizations by adding axis labels to make the heatmap look better.

Let us get started by loading the modules needed.

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

We will use flights dataset available in Seaborn’s built-in datasets. First we have to pivot the data to make it in a matrix form with months as rows and years as columns.

flights = sns.load_dataset("flights")
# pivoting to make the data wide
flights = flights.pivot("month", "year", "passengers")
flights.head()


year	1949	1950	1951	1952	1953	1954	1955	1956	1957	1958	1959	1960
month												
Jan	112	115	145	171	196	204	242	284	315	340	360	417
Feb	118	126	150	180	196	188	233	277	301	318	342	391
Mar	132	141	178	193	236	235	267	317	356	362	406	419
Apr	129	135	163	181	235	227	269	313	348	348	396	461
May	121	125	172	183	229	234	270	318	355	363	420	472

Simple Heatmap with Matplotlib’s imshow()

When you have the data in a matrix form we use imshow() function to make a simple heatmap.

fig, ax = plt.subplots(figsize=(10,10))
im = ax.imshow(flights)

Let us add simple annotation first to make the heatmap slightly better. Here we have added a title to the heatmap using set_title() function.

fig, ax = plt.subplots(figsize=(10,10))
im = ax.imshow(flights)
ax.set_title("Matplotlib Heatmap with imshow", size=20)
fig.tight_layout()
plt.savefig("heatmap_in_matplotlib_using_imshow.png",
                    format='png',dpi=150)

How to make heatmap with Matplotlib

Note that x and y axis labels are missing. To add the axis labels let us get the months and years values from the row and column names of flights dataset.

months = flights.index.values
months
years = flights.columns.values

Add axis tick labels to Heatmap in Matplotlib

Now we can add the years on x -axis and months on y-axis tick labels. In Matplotlib, we can add the tick labels using set_xticks() and set_yticks() functions.

fig, ax = plt.subplots(figsize=(10,10))
im = ax.imshow(flights)
# Add axis tick labels
ax.set_xticks(np.arange(len(years)), 
              labels=years)
ax.set_yticks(np.arange(len(months)), 
              labels=months)
ax.set_title("Adding axis labels to Matplotlib Heatmap", 
             size=20)
fig.tight_layout()
plt.savefig("axis_labels_in_heatmap_in_matplotlib.png",
                    format='png',dpi=150)
Adding axis labels to heatmap with Matplotlib

Add color bar legend to Heatmap in Matplotlib

We can add color bar legend to help understand the range of numerical values and their association with the colors to the heatmap using figure.colorbar() function. Sometimes the color bar added can be slightly bigger than the heatmap. Here we have use shrink argument to reduce the size of the color bar.

fig, ax = plt.subplots( figsize=(10,10))
im = ax.imshow(flights)
cbar = ax.figure.colorbar(im, 
                          ax = ax,
                          shrink=0.5 )
# add tick labels
ax.set_xticks(np.arange(len(years)), 
              labels=years, 
              size=12)
ax.set_yticks(np.arange(len(months)),
              labels=months,size=12)
# Rotate the tick labels to be more legible
plt.setp(ax.get_xticklabels(),
         rotation = 45,
         ha = "right",
         rotation_mode = "anchor")
ax.set_title("Flights Data Seaborn", size=20)
fig.tight_layout()
plt.savefig("how_to_make_a_heatmap_with_matplotlib_Python.png",
                    format='png',dpi=150)
How to make a heatmap with imshow() in Matplotlib

Change color palette to Heatmap in Matplotlib

By default, Matplotlib’s imshow() use viridis color palette to make the heatmap. We can change the color palette of the heatmap by using cmap argument to the imshow() function. In this example, we are changing the default color palette to “YlGn”.

fig, ax = plt.subplots( figsize=(10,10))
im = ax.imshow(flights, cmap="YlGn")
cbar = ax.figure.colorbar(im,
                          ax=ax,
                          shrink=0.5 )
# add tick labels
ax.set_xticks(np.arange(len(years)), 
              labels=years, 
              size=12)
ax.set_yticks(np.arange(len(months)), 
              labels=months,
              size=12)
# Rotate the tick labels to be more legible
plt.setp(ax.get_xticklabels(), 
         rotation=45, 
         ha="right",
         rotation_mode="anchor")
ax.set_title("Flights Data Seaborn", size=20)
fig.tight_layout()
plt.savefig("change_heatmap_color_palette_matplotlib_Python.png",
                    format='png',dpi=150)
Change Heatmap Color Palette in Matplotlib using cmap argument
Exit mobile version