Contenido del Curso
Ultimate Visualization with Python
Ultimate Visualization with Python
Subplots
Up until now we only created multiple plots on one Axes
object using multiple calls of the plotting functions (we can combine different plot types).
Now it’s time to learn how to create multiple Axes
objects and thus multiple plots on different Axes
objects.
pyplot
has a subplots()
function exactly for this purpose. We have already used this function when we created a canvas in the first section, now we'll have a more detailed look at it.
Rows and Columns
The two most important arguments of this function are nrows
and ncolumns
which specify the number of rows and columns of the subplot grid respectively (their default values are 1
and 1
resulting in just one Axes
object).
subplots()
returns a Figure
object and either an Axes
object or an array of Axes
objects.
Let’s have a look at an example:
import matplotlib.pyplot as plt fig, axs = plt.subplots(2, 2) plt.show()
We have just created a 2 by 2 subplot grid.
Note
The
subplots
function here returns an array ofAxes
objects, since there is more than one subplot. When there is an array ofAxes
returned the variable for storing it is often calledaxs
(ax
is mostly used for a singleAxes
object).
axs
in our case is a two-dimensional array, hence why we should use both a row index and a column index to access a particular Axes
object.
Let’s create a few plots:
import matplotlib.pyplot as plt import numpy as np data_linear = np.arange(1, 11) data_squared = data_linear ** 2 # Creating a 2x2 subplot grid fig, axs = plt.subplots(2, 2) # Creating a different plot for each Axes object axs[0, 0].plot(data_linear) axs[0, 1].plot(data_squared) axs[1, 0].scatter(data_linear, data_linear) axs[1, 1].scatter(data_linear, data_squared) plt.show()
The first row of the subplot grid (row 0
) has two line plots and the second row (row 1
) has two scatter plots.
Remember that we cannot use plt.plot()
or plt.scatter()
here, since we want to place each plot on a separate Axes
object (subplot).
Converting to 1D Array
It is also possible to use the .ravel()
method to convert 2D Axes
array to 1D contiguous flattened array:
import matplotlib.pyplot as plt import numpy as np data_linear = np.arange(1, 11) data_squared = data_linear ** 2 # Creating a 2x2 subplot grid fig, axs = plt.subplots(2, 2) # Converting axs to 1D array of 4 elements axs = axs.ravel() # Creating a different plot for each Axes object axs[0].plot(data_linear) axs[1].plot(data_squared) axs[2].scatter(data_linear, data_linear) axs[3].scatter(data_linear, data_squared) plt.show()
Since we have a 2x2 array, axs.ravel()
returns a 1D array with four elements.
Sharing an Axis
Another two important parameters of the subplots()
function are sharex
and sharey
which specify whether the properties should be shared across the x or y axes respectively. They both have a default value of False
. Let’s change one of them in our example:
import matplotlib.pyplot as plt import numpy as np data_linear = np.arange(1, 11) data_squared = data_linear ** 2 # Share x-axis among all the subplots fig, axs = plt.subplots(2, 2, sharex=True) # Converting axs to 1D array of 4 elements axs = axs.ravel() # Creating a different plot for each Axes object axs[0].plot(data_linear) axs[1].plot(data_squared) axs[2].scatter(data_linear, data_linear) axs[3].scatter(data_linear, data_squared) plt.show()
With the True
value for sharex
x-axis will be shared among all subplots, which makes sense here, since we have the same x-axis coordinates for all of the subplots.
Moreover, we can set sharex
or sharey
parameters to row
( each subplot row will share a respective axis) or col
(each subplot column will share a respective axis).
As usual feel free to explore more in the documentation in case you want to.
Swipe to show code editor
- Use the correct function to create a subplot grid.
- The grid should have 3 rows and 1 column (specify the first two parameters).
- Specify the rightmost keyword argument, so that x-axis will be shared among all the subplots.
- Store the result of the function for creating subplots in the
fig
andaxs
variables (from left to right). - Place the first line plot for
data_linear
on the first row (row0
) of the subplot grid. - Place the second line plot for
data_squared
on the second row (row1
) of the subplot grid. - Place the third line plot for
data_exp
on the third row (row2
) of the subplot grid.
Solución
¡Gracias por tus comentarios!
Subplots
Up until now we only created multiple plots on one Axes
object using multiple calls of the plotting functions (we can combine different plot types).
Now it’s time to learn how to create multiple Axes
objects and thus multiple plots on different Axes
objects.
pyplot
has a subplots()
function exactly for this purpose. We have already used this function when we created a canvas in the first section, now we'll have a more detailed look at it.
Rows and Columns
The two most important arguments of this function are nrows
and ncolumns
which specify the number of rows and columns of the subplot grid respectively (their default values are 1
and 1
resulting in just one Axes
object).
subplots()
returns a Figure
object and either an Axes
object or an array of Axes
objects.
Let’s have a look at an example:
import matplotlib.pyplot as plt fig, axs = plt.subplots(2, 2) plt.show()
We have just created a 2 by 2 subplot grid.
Note
The
subplots
function here returns an array ofAxes
objects, since there is more than one subplot. When there is an array ofAxes
returned the variable for storing it is often calledaxs
(ax
is mostly used for a singleAxes
object).
axs
in our case is a two-dimensional array, hence why we should use both a row index and a column index to access a particular Axes
object.
Let’s create a few plots:
import matplotlib.pyplot as plt import numpy as np data_linear = np.arange(1, 11) data_squared = data_linear ** 2 # Creating a 2x2 subplot grid fig, axs = plt.subplots(2, 2) # Creating a different plot for each Axes object axs[0, 0].plot(data_linear) axs[0, 1].plot(data_squared) axs[1, 0].scatter(data_linear, data_linear) axs[1, 1].scatter(data_linear, data_squared) plt.show()
The first row of the subplot grid (row 0
) has two line plots and the second row (row 1
) has two scatter plots.
Remember that we cannot use plt.plot()
or plt.scatter()
here, since we want to place each plot on a separate Axes
object (subplot).
Converting to 1D Array
It is also possible to use the .ravel()
method to convert 2D Axes
array to 1D contiguous flattened array:
import matplotlib.pyplot as plt import numpy as np data_linear = np.arange(1, 11) data_squared = data_linear ** 2 # Creating a 2x2 subplot grid fig, axs = plt.subplots(2, 2) # Converting axs to 1D array of 4 elements axs = axs.ravel() # Creating a different plot for each Axes object axs[0].plot(data_linear) axs[1].plot(data_squared) axs[2].scatter(data_linear, data_linear) axs[3].scatter(data_linear, data_squared) plt.show()
Since we have a 2x2 array, axs.ravel()
returns a 1D array with four elements.
Sharing an Axis
Another two important parameters of the subplots()
function are sharex
and sharey
which specify whether the properties should be shared across the x or y axes respectively. They both have a default value of False
. Let’s change one of them in our example:
import matplotlib.pyplot as plt import numpy as np data_linear = np.arange(1, 11) data_squared = data_linear ** 2 # Share x-axis among all the subplots fig, axs = plt.subplots(2, 2, sharex=True) # Converting axs to 1D array of 4 elements axs = axs.ravel() # Creating a different plot for each Axes object axs[0].plot(data_linear) axs[1].plot(data_squared) axs[2].scatter(data_linear, data_linear) axs[3].scatter(data_linear, data_squared) plt.show()
With the True
value for sharex
x-axis will be shared among all subplots, which makes sense here, since we have the same x-axis coordinates for all of the subplots.
Moreover, we can set sharex
or sharey
parameters to row
( each subplot row will share a respective axis) or col
(each subplot column will share a respective axis).
As usual feel free to explore more in the documentation in case you want to.
Swipe to show code editor
- Use the correct function to create a subplot grid.
- The grid should have 3 rows and 1 column (specify the first two parameters).
- Specify the rightmost keyword argument, so that x-axis will be shared among all the subplots.
- Store the result of the function for creating subplots in the
fig
andaxs
variables (from left to right). - Place the first line plot for
data_linear
on the first row (row0
) of the subplot grid. - Place the second line plot for
data_squared
on the second row (row1
) of the subplot grid. - Place the third line plot for
data_exp
on the third row (row2
) of the subplot grid.
Solución
¡Gracias por tus comentarios!