Notice: This page requires JavaScript to function properly.
Please enable JavaScript in your browser settings or update your browser.
Learn Adding Legend | Plots Customization
Ultimate Visualization with Python

book
Adding Legend

When multiple elements are present in a chart, it is often helpful to label them for clarity. The legend serves this purpose by providing a compact area that explains different components of the chart.

The following are three common ways to create a legend in matplotlib.

First Option

Consider the following example to clarify the concept:

import matplotlib.pyplot as plt
import numpy as np

# Define categories and data
questions = ['question_1', 'question_2', 'question_3']
yes_answers = np.array([500, 240, 726])
no_answers = np.array([432, 618, 101])
answers = np.array([yes_answers, no_answers])

# Set positions and bar width
positions = np.arange(len(questions))
width = 0.3

# Create the grouped bar chart
for i in range(len(answers)):
plt.bar(positions + width * i,
answers[i],
width)
# Adjust x-axis ticks to the center of groups
plt.xticks(positions + width * (len(answers) - 1) / 2,
questions)

# Setting the labels for the legend explicitly
plt.legend(['positive answers', 'negative answers'])
plt.show()
1234567891011121314151617181920212223242526
import matplotlib.pyplot as plt import numpy as np # Define categories and data questions = ['question_1', 'question_2', 'question_3'] yes_answers = np.array([500, 240, 726]) no_answers = np.array([432, 618, 101]) answers = np.array([yes_answers, no_answers]) # Set positions and bar width positions = np.arange(len(questions)) width = 0.3 # Create the grouped bar chart for i in range(len(answers)): plt.bar(positions + width * i, answers[i], width) # Adjust x-axis ticks to the center of groups plt.xticks(positions + width * (len(answers) - 1) / 2, questions) # Setting the labels for the legend explicitly plt.legend(['positive answers', 'negative answers']) plt.show()
copy

In the upper left corner, a legend explains the different bars on the chart. This legend is created using the plt.legend() function, with a list of labels passed as the first argument—commonly referred to as labels.

Second Option

Another option involves specifying the label parameter in each call of the plotting function, such as bar in our example:

import matplotlib.pyplot as plt
import numpy as np

# Define x-axis categories and their positions
questions = ['question_1', 'question_2', 'question_3']
positions = np.arange(len(questions))

# Define answers for each category
yes_answers = np.array([500, 240, 726])
no_answers = np.array([432, 618, 101])
answers = np.array([yes_answers, no_answers])
labels = ['positive answers', 'negative answers']

# Set the width for each bar
width = 0.3

# Plot each category with a label
for i in range(len(answers)):
plt.bar(positions + width * i,
answers[i],
width,
label=labels[i])

# Set x-axis ticks and labels at the center of each group
plt.xticks(positions + width * (len(answers) - 1) / 2, questions)

# Automatically create legend from label parameters
plt.legend()
plt.show()
1234567891011121314151617181920212223242526272829
import matplotlib.pyplot as plt import numpy as np # Define x-axis categories and their positions questions = ['question_1', 'question_2', 'question_3'] positions = np.arange(len(questions)) # Define answers for each category yes_answers = np.array([500, 240, 726]) no_answers = np.array([432, 618, 101]) answers = np.array([yes_answers, no_answers]) labels = ['positive answers', 'negative answers'] # Set the width for each bar width = 0.3 # Plot each category with a label for i in range(len(answers)): plt.bar(positions + width * i, answers[i], width, label=labels[i]) # Set x-axis ticks and labels at the center of each group plt.xticks(positions + width * (len(answers) - 1) / 2, questions) # Automatically create legend from label parameters plt.legend() plt.show()
copy

Here, plt.legend() automatically determines the elements to be added to the legend and their labels; all the elements with the label parameter specified are included.

Third Option

In fact, there is even one more option using set_label() method on the artist (bar in our example):

import matplotlib.pyplot as plt
import numpy as np

questions = ['question_1', 'question_2', 'question_3']
positions = np.arange(len(questions))

yes_answers = np.array([500, 240, 726])
no_answers = np.array([432, 618, 101])
answers = np.array([yes_answers, no_answers])

width = 0.3
labels = ['positive answers', 'negative answers']

# Plot bars for each category with labels
for i in range(len(answers)):
bar = plt.bar(positions + width * i, answers[i], width)
bar.set_label(labels[i])

# Set x-axis ticks and labels at the center of the grouped bars
center_positions = positions + width * (len(answers) - 1) / 2
plt.xticks(center_positions, questions)

# Display legend above the plot, centered horizontally
plt.legend(loc='upper center')

plt.show()
1234567891011121314151617181920212223242526
import matplotlib.pyplot as plt import numpy as np questions = ['question_1', 'question_2', 'question_3'] positions = np.arange(len(questions)) yes_answers = np.array([500, 240, 726]) no_answers = np.array([432, 618, 101]) answers = np.array([yes_answers, no_answers]) width = 0.3 labels = ['positive answers', 'negative answers'] # Plot bars for each category with labels for i in range(len(answers)): bar = plt.bar(positions + width * i, answers[i], width) bar.set_label(labels[i]) # Set x-axis ticks and labels at the center of the grouped bars center_positions = positions + width * (len(answers) - 1) / 2 plt.xticks(center_positions, questions) # Display legend above the plot, centered horizontally plt.legend(loc='upper center') plt.show()
copy

Legend Location

There is another important keyword argument of the legend() function, loc, which specifies the location of the legend. Its default value is best which "tells" the matplotlib to automatically choose the best location for the legend to avoid overlapping with data.

import matplotlib.pyplot as plt
import numpy as np

questions = ['question_1', 'question_2', 'question_3']
positions = np.arange(len(questions))

yes_answers = np.array([500, 240, 726])
no_answers = np.array([432, 618, 101])
answers = np.array([yes_answers, no_answers])

width = 0.3
labels = ['positive answers', 'negative answers']

# Plot bars for each category with labels
for i, label in enumerate(labels):
bars = plt.bar(positions + width * i, answers[i], width)
bars.set_label(label)

# Set x-axis ticks and labels at the center of the grouped bars
center_positions = positions + width * (len(answers) - 1) / 2
plt.xticks(center_positions, questions)

# Display legend above the plot, centered horizontally
plt.legend(loc='upper center')

plt.show()
1234567891011121314151617181920212223242526
import matplotlib.pyplot as plt import numpy as np questions = ['question_1', 'question_2', 'question_3'] positions = np.arange(len(questions)) yes_answers = np.array([500, 240, 726]) no_answers = np.array([432, 618, 101]) answers = np.array([yes_answers, no_answers]) width = 0.3 labels = ['positive answers', 'negative answers'] # Plot bars for each category with labels for i, label in enumerate(labels): bars = plt.bar(positions + width * i, answers[i], width) bars.set_label(label) # Set x-axis ticks and labels at the center of the grouped bars center_positions = positions + width * (len(answers) - 1) / 2 plt.xticks(center_positions, questions) # Display legend above the plot, centered horizontally plt.legend(loc='upper center') plt.show()
copy

In this example, the legend is positioned in the upper center of the plot. Other valid values for the loc parameter include:

  • 'upper right', 'upper left', 'lower left';

  • 'lower right', 'right';

  • 'center left', 'center right', 'lower center', 'center'.

Note
Study More

You can explore more in legend() documentation

Task

Swipe to start coding

  1. Label the lowest bars as 'primary sector' specifying the appropriate keyword argument.
  2. Label the bars in the middle as 'secondary sector' specifying the appropriate keyword argument.
  3. Label the top bars as 'tertiary sector' specifying the appropriate keyword argument.
  4. Place the legend on the right side, centered vertically.

Solution

import matplotlib.pyplot as plt
import numpy as np

countries = ['USA', 'China', 'Japan']
primary_sector = np.array([1.4, 4.8, 0.4])
secondary_sector = np.array([11.3, 6.2, 0.8])
tertiary_sector = np.array([14.2, 8.4, 3.2])

# Set the label for the lowest bars
plt.bar(countries, primary_sector, label='primary sector')

# Set the label for the middle bars
plt.bar(countries, secondary_sector, bottom=primary_sector, label='secondary sector')

# Set the label for the top bars
plt.bar(countries, tertiary_sector, bottom=primary_sector + secondary_sector, label='tertiary sector')

# Create the legend with specifying its location
plt.legend(loc='center right')
plt.show()
Everything was clear?

How can we improve it?

Thanks for your feedback!

Section 3. Chapter 2
import matplotlib.pyplot as plt
import numpy as np

countries = ['USA', 'China', 'Japan']
primary_sector = np.array([1.4, 4.8, 0.4])
secondary_sector = np.array([11.3, 6.2, 0.8])
tertiary_sector = np.array([14.2, 8.4, 3.2])

# Set the label for the lowest bars
plt.bar(countries, primary_sector, ___=___)

# Set the label for the middle bars
plt.bar(countries, secondary_sector, bottom=primary_sector, ___=___)

# Set the label for the top bars
plt.bar(countries, tertiary_sector, bottom=primary_sector + secondary_sector, ___=___)

# Create the legend with specifying its location
___.___(___=___)
plt.show()

Ask AI

expand
ChatGPT

Ask anything or try one of the suggested questions to begin our chat

some-alt