In this post, we will learn how to make heatmap with Matplotlib in Python. In Matplotlib, we can make heatmap with the function imshow(). imshow() basically shows the input data as image.
We will start making a simple heatmap with a one-liner using imshow() first. And then show couple of simple customizations by adding axis labels to make the heatmap look better.
Let us get started by loading the modules needed.
import numpy as np import matplotlib.pyplot as plt import pandas as pd import seaborn as sns
We will use flights dataset available in Seaborn’s built-in datasets. First we have to pivot the data to make it in a matrix form with months as rows and years as columns.
flights = sns.load_dataset("flights") # pivoting to make the data wide flights = flights.pivot("month", "year", "passengers") flights.head() year 1949 1950 1951 1952 1953 1954 1955 1956 1957 1958 1959 1960 month Jan 112 115 145 171 196 204 242 284 315 340 360 417 Feb 118 126 150 180 196 188 233 277 301 318 342 391 Mar 132 141 178 193 236 235 267 317 356 362 406 419 Apr 129 135 163 181 235 227 269 313 348 348 396 461 May 121 125 172 183 229 234 270 318 355 363 420 472
Simple Heatmap with Matplotlib’s imshow()
When you have the data in a matrix form we use imshow() function to make a simple heatmap.
fig, ax = plt.subplots(figsize=(10,10)) im = ax.imshow(flights)
Let us add simple annotation first to make the heatmap slightly better. Here we have added a title to the heatmap using set_title() function.
fig, ax = plt.subplots(figsize=(10,10)) im = ax.imshow(flights) ax.set_title("Matplotlib Heatmap with imshow", size=20) fig.tight_layout() plt.savefig("heatmap_in_matplotlib_using_imshow.png", format='png',dpi=150)
Note that x and y axis labels are missing. To add the axis labels let us get the months and years values from the row and column names of flights dataset.
months = flights.index.values months
years = flights.columns.values
Add axis tick labels to Heatmap in Matplotlib
Now we can add the years on x -axis and months on y-axis tick labels. In Matplotlib, we can add the tick labels using set_xticks() and set_yticks() functions.
fig, ax = plt.subplots(figsize=(10,10)) im = ax.imshow(flights) # Add axis tick labels ax.set_xticks(np.arange(len(years)), labels=years) ax.set_yticks(np.arange(len(months)), labels=months) ax.set_title("Adding axis labels to Matplotlib Heatmap", size=20) fig.tight_layout() plt.savefig("axis_labels_in_heatmap_in_matplotlib.png", format='png',dpi=150)
Add color bar legend to Heatmap in Matplotlib
We can add color bar legend to help understand the range of numerical values and their association with the colors to the heatmap using figure.colorbar() function. Sometimes the color bar added can be slightly bigger than the heatmap. Here we have use shrink argument to reduce the size of the color bar.
fig, ax = plt.subplots( figsize=(10,10)) im = ax.imshow(flights) cbar = ax.figure.colorbar(im, ax = ax, shrink=0.5 ) # add tick labels ax.set_xticks(np.arange(len(years)), labels=years, size=12) ax.set_yticks(np.arange(len(months)), labels=months,size=12) # Rotate the tick labels to be more legible plt.setp(ax.get_xticklabels(), rotation = 45, ha = "right", rotation_mode = "anchor") ax.set_title("Flights Data Seaborn", size=20) fig.tight_layout() plt.savefig("how_to_make_a_heatmap_with_matplotlib_Python.png", format='png',dpi=150)
Change color palette to Heatmap in Matplotlib
By default, Matplotlib’s imshow() use viridis color palette to make the heatmap. We can change the color palette of the heatmap by using cmap argument to the imshow() function. In this example, we are changing the default color palette to “YlGn”.
fig, ax = plt.subplots( figsize=(10,10)) im = ax.imshow(flights, cmap="YlGn") cbar = ax.figure.colorbar(im, ax=ax, shrink=0.5 ) # add tick labels ax.set_xticks(np.arange(len(years)), labels=years, size=12) ax.set_yticks(np.arange(len(months)), labels=months, size=12) # Rotate the tick labels to be more legible plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") ax.set_title("Flights Data Seaborn", size=20) fig.tight_layout() plt.savefig("change_heatmap_color_palette_matplotlib_Python.png", format='png',dpi=150)