In this tutorial, we will learn how to make a stacked area plot using Python’s Matplotlib. We can make stacked area plot using matplotlib’s stackplot() function.
The basic syntax of using Matplotlib’s stackplot() function is
stackplot(x,y)
where x(N,) and y(M, N) array-like inputs.
Data for making Stacked area plot in Python
In this example, we will make stacked area plot starting from data stored in a Pandas dataframe instead of simple hand coded data. We will use gapminder data available from datavizpyr.com’s github page and load it directly using pandas.
Let us get started by loading the libraries needed.
import matplotlib.pyplot as plt import pandas as pd import numpy as np
We will use gapminder data to make the stacked area plot with year on x-axis and mean GDP per capita for each continent on y-axis.
p2data = "https://raw.githubusercontent.com/datavizpyr/data/master/gapminder-FiveYearData.csv" gapminder = pd.read_csv(p2data) gapminder.head() country year pop continent lifeExp gdpPercap 0 Afghanistan 1952 8425333.0 Asia 28.801 779.445314 1 Afghanistan 1957 9240934.0 Asia 30.332 820.853030 2 Afghanistan 1962 10267083.0 Asia 31.997 853.100710 3 Afghanistan 1967 11537966.0 Asia 34.020 836.197138 4 Afghanistan 1972 13079460.0 Asia 36.088 739.981106
First, we need to compute mean GDP per capita for each year and continent. Here we use Pandas groupby() function to compute mean GDP per year and continet.
df = (gapminder .groupby(['continent', 'year'])['gdpPercap'] .mean() .reset_index() )
Our data with mean GDP looks like this.
df.head() continent year gdpPercap 0 Africa 1952 1252.572466 1 Africa 1957 1385.236062 2 Africa 1962 1598.078825 3 Africa 1967 2050.363801 4 Africa 1972 2339.615674
We have the data we need to make the stacked area plot. Let us use the data to get the right input format for Matplotlib’s stackplot() function. For x-axis we need year, and we can unique values of year using Pandas values_count() function. Here we get the year values as a list using tolist() function.
year = df.year.value_counts().index.tolist() year [1952, 1957, 1962, 1967, 1972, 1977, 1982, 1987, 1992, 1997, 2002, 2007]
And for y-axis we need mean GDP per capita for each year and continent. Here we create a dictionary containing gdp values for each continent as key and gdp over the years as values. For each continent we will have a list of gdp per capita over the years.
gdp_cont = (df .groupby('continent')['gdpPercap'] .apply(lambda x: x.values.tolist()) .to_dict() )
we can see that the keys of the dictionary is a list with continents.
gdp_cont.keys() dict_keys(['Africa', 'Americas', 'Asia', 'Europe', 'Oceania'])
And values of the dictionary is a list of list with gdp per capita for each continent and year.
gdp_cont.values() dict_values([[1252.5724658211539, 1385.2360622557692, 1598.0788248461538, 2050.3638008576922, 2339.615674198077, 2585.9385084634614, 2481.592959753846, 2282.6689912596153, 2281.8103332442306, 2378.759555101923, 2599.385158890385, 3089.0326047365384], [4079.0625522, 4616.04373316, 4901.5418704, 5668.25349604, 6491.334139039999, 7352.007126280001, 7506.73708808, 7793.40026112, 8044.93440552, 8889.30086256, 9287.67710732, 11003.03162536], [5195.484004030303, 5787.732939942424, 5729.369624775758, 5971.1733736060605, 8187.468699448485, 7791.3140199303025, 7434.135157484849, 7608.226507730302, 8639.690247730303, 9834.093295248485, 10174.090396578787, 12473.026870133333], [5661.05743476, 6963.0128159333335, 8365.4868143, 10143.823756533333, 12479.575246466668, 14283.9791096, 15617.896551233334, 17214.310726633335, 17061.5680842, 19076.781801600002, 21711.73242243333, 25054.481635933334], [10298.08565, 11598.522455, 12696.452430000001, 14495.021789999999, 16417.33338, 17283.957605, 18554.70984, 20448.040159999997, 20894.045885, 24024.175170000002, 26938.77804, 29810.188275]])
Matplotlib’s stackplot(): Stacked area plot in Python
Now we can use Matplotlib’s stackplot() function to make the stacked area plot. Here we specify labels using labels argument and we also specify the location of the legend on the plot.
plt.stackplot(year, gdp_cont.values(), labels=gdp_cont.keys()) plt.legend(loc='upper left') plt.title('gapminder data: World population') plt.xlabel('Year') plt.ylabel('Mean GDP per capita') plt.savefig("Stacked_area_plot_with_Matplotlib.png")