How to make Stacked area plot with Matplotlib

Stacked area plot with Matplotlib
How to make Stacked area plot with Matplotlib

In this tutorial, we will learn how to make a stacked area plot using Python’s Matplotlib. We can make stacked area plot using matplotlib’s stackplot() function.

The basic syntax of using Matplotlib’s stackplot() function is

stackplot(x,y)

where x(N,) and y(M, N) array-like inputs.

Data for making Stacked area plot in Python

In this example, we will make stacked area plot starting from data stored in a Pandas dataframe instead of simple hand coded data. We will use gapminder data available from datavizpyr.com’s github page and load it directly using pandas.

Let us get started by loading the libraries needed.

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

We will use gapminder data to make the stacked area plot with year on x-axis and mean GDP per capita for each continent on y-axis.

p2data = "https://raw.githubusercontent.com/datavizpyr/data/master/gapminder-FiveYearData.csv"
gapminder = pd.read_csv(p2data)
gapminder.head()


country	year	pop	continent	lifeExp	gdpPercap
0	Afghanistan	1952	8425333.0	Asia	28.801	779.445314
1	Afghanistan	1957	9240934.0	Asia	30.332	820.853030
2	Afghanistan	1962	10267083.0	Asia	31.997	853.100710
3	Afghanistan	1967	11537966.0	Asia	34.020	836.197138
4	Afghanistan	1972	13079460.0	Asia	36.088	739.981106

First, we need to compute mean GDP per capita for each year and continent. Here we use Pandas groupby() function to compute mean GDP per year and continet.

df = (gapminder
      .groupby(['continent', 'year'])['gdpPercap']
      .mean()
      .reset_index()
     )

Our data with mean GDP looks like this.

df.head()


continent	year	gdpPercap
0	Africa	1952	1252.572466
1	Africa	1957	1385.236062
2	Africa	1962	1598.078825
3	Africa	1967	2050.363801
4	Africa	1972	2339.615674

We have the data we need to make the stacked area plot. Let us use the data to get the right input format for Matplotlib’s stackplot() function. For x-axis we need year, and we can unique values of year using Pandas values_count() function. Here we get the year values as a list using tolist() function.

year = df.year.value_counts().index.tolist()
year

[1952, 1957, 1962, 1967, 1972, 1977, 1982, 1987, 1992, 1997, 2002, 2007]

And for y-axis we need mean GDP per capita for each year and continent. Here we create a dictionary containing gdp values for each continent as key and gdp over the years as values. For each continent we will have a list of gdp per capita over the years.

gdp_cont = (df
            .groupby('continent')['gdpPercap']
            .apply(lambda x: x.values.tolist())
            .to_dict()
           )

we can see that the keys of the dictionary is a list with continents.

gdp_cont.keys()

dict_keys(['Africa', 'Americas', 'Asia', 'Europe', 'Oceania'])

And values of the dictionary is a list of list with gdp per capita for each continent and year.

gdp_cont.values()

dict_values([[1252.5724658211539, 1385.2360622557692, 1598.0788248461538, 2050.3638008576922, 2339.615674198077, 2585.9385084634614, 2481.592959753846, 2282.6689912596153, 2281.8103332442306, 2378.759555101923, 2599.385158890385, 3089.0326047365384], [4079.0625522, 4616.04373316, 4901.5418704, 5668.25349604, 6491.334139039999, 7352.007126280001, 7506.73708808, 7793.40026112, 8044.93440552, 8889.30086256, 9287.67710732, 11003.03162536], [5195.484004030303, 5787.732939942424, 5729.369624775758, 5971.1733736060605, 8187.468699448485, 7791.3140199303025, 7434.135157484849, 7608.226507730302, 8639.690247730303, 9834.093295248485, 10174.090396578787, 12473.026870133333], [5661.05743476, 6963.0128159333335, 8365.4868143, 10143.823756533333, 12479.575246466668, 14283.9791096, 15617.896551233334, 17214.310726633335, 17061.5680842, 19076.781801600002, 21711.73242243333, 25054.481635933334], [10298.08565, 11598.522455, 12696.452430000001, 14495.021789999999, 16417.33338, 17283.957605, 18554.70984, 20448.040159999997, 20894.045885, 24024.175170000002, 26938.77804, 29810.188275]])

Matplotlib’s stackplot(): Stacked area plot in Python

Now we can use Matplotlib’s stackplot() function to make the stacked area plot. Here we specify labels using labels argument and we also specify the location of the legend on the plot.

plt.stackplot(year,
            gdp_cont.values(), 
            labels=gdp_cont.keys())

plt.legend(loc='upper left')
plt.title('gapminder data: World population')
plt.xlabel('Year')
plt.ylabel('Mean GDP per capita')
plt.savefig("Stacked_area_plot_with_Matplotlib.png")
How to make Stacked area plot with Matplotlib
Exit mobile version