how to make subplots using user defined function [duplicate] - matplotlib

I haven't been able to find a solution to this.. Say I define some plotting function so that I don't have to copy-paste tons of code every time I make similar plots...
What I'd like to do is use this function to create a few different plots individually and then put them together as subplots into one figure. Is this even possible? I've tried the following but it just returns blanks:
import numpy as np
import matplotlib.pyplot as plt
# function to make boxplots
def make_boxplots(box_data):
fig, ax = plt.subplots()
box = ax.boxplot(box_data)
#plt.show()
return ax
# make some data:
data_1 = np.random.normal(0,1,500)
data_2 = np.random.normal(0,1.1,500)
# plot it
box1 = make_boxplots(box_data=data_1)
box2 = make_boxplots(box_data=data_2)
plt.close('all')
fig, ax = plt.subplots(2)
ax[0] = box1
ax[1] = box2
plt.show()

I tend to use the following template
def plot_something(data, ax=None, **kwargs):
ax = ax or plt.gca()
# Do some cool data transformations...
return ax.boxplot(data, **kwargs)
Then you can experiment with your plotting function by simply calling plot_something(my_data) and you can specify which axes to use like so.
fig, (ax1, ax2) = plt.subplots(2)
plot_something(data1, ax1, color='blue')
plot_something(data2, ax2, color='red')
plt.show() # This should NOT be called inside plot_something()
Adding the kwargs allows you to pass in arbitrary parameters to the plotting function such as labels, line styles, or colours.
The line ax = ax or plt.gca() uses the axes you have specified or gets the current axes from matplotlib (which may be new axes if you haven't created any yet).

Related

Changing subplots from 2x2 to 3x3? [duplicate]

I am a little confused about how this code works:
fig, axes = plt.subplots(nrows=2, ncols=2)
plt.show()
How does the fig, axes work in this case? What does it do?
Also why wouldn't this work to do the same thing:
fig = plt.figure()
axes = fig.subplots(nrows=2, ncols=2)
There are several ways to do it. The subplots method creates the figure along with the subplots that are then stored in the ax array. For example:
import matplotlib.pyplot as plt
x = range(10)
y = range(10)
fig, ax = plt.subplots(nrows=2, ncols=2)
for row in ax:
for col in row:
col.plot(x, y)
plt.show()
However, something like this will also work, it's not so "clean" though since you are creating a figure with subplots and then add on top of them:
fig = plt.figure()
plt.subplot(2, 2, 1)
plt.plot(x, y)
plt.subplot(2, 2, 2)
plt.plot(x, y)
plt.subplot(2, 2, 3)
plt.plot(x, y)
plt.subplot(2, 2, 4)
plt.plot(x, y)
plt.show()
import matplotlib.pyplot as plt
fig, ax = plt.subplots(2, 2)
ax[0, 0].plot(range(10), 'r') #row=0, col=0
ax[1, 0].plot(range(10), 'b') #row=1, col=0
ax[0, 1].plot(range(10), 'g') #row=0, col=1
ax[1, 1].plot(range(10), 'k') #row=1, col=1
plt.show()
You can also unpack the axes in the subplots call
And set whether you want to share the x and y axes between the subplots
Like this:
import matplotlib.pyplot as plt
# fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(nrows=2, ncols=2, sharex=True, sharey=True)
fig, axes = plt.subplots(nrows=2, ncols=2, sharex=True, sharey=True)
ax1, ax2, ax3, ax4 = axes.flatten()
ax1.plot(range(10), 'r')
ax2.plot(range(10), 'b')
ax3.plot(range(10), 'g')
ax4.plot(range(10), 'k')
plt.show()
You might be interested in the fact that as of matplotlib version 2.1 the second code from the question works fine as well.
From the change log:
Figure class now has subplots method
The Figure class now has a subplots() method which behaves the same as pyplot.subplots() but on an existing figure.
Example:
import matplotlib.pyplot as plt
fig = plt.figure()
axes = fig.subplots(nrows=2, ncols=2)
plt.show()
Read the documentation: matplotlib.pyplot.subplots
pyplot.subplots() returns a tuple fig, ax which is unpacked in two variables using the notation
fig, axes = plt.subplots(nrows=2, ncols=2)
The code:
fig = plt.figure()
axes = fig.subplots(nrows=2, ncols=2)
does not work because subplots() is a function in pyplot not a member of the object Figure.
Iterating through all subplots sequentially:
fig, axes = plt.subplots(nrows, ncols)
for ax in axes.flatten():
ax.plot(x,y)
Accessing a specific index:
for row in range(nrows):
for col in range(ncols):
axes[row,col].plot(x[row], y[col])
Subplots with pandas
This answer is for subplots with pandas, which uses matplotlib as the default plotting backend.
Here are four options to create subplots starting with a pandas.DataFrame
Implementation 1. and 2. are for the data in a wide format, creating subplots for each column.
Implementation 3. and 4. are for data in a long format, creating subplots for each unique value in a column.
Tested in python 3.8.11, pandas 1.3.2, matplotlib 3.4.3, seaborn 0.11.2
Imports and Data
import seaborn as sns # data only
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
# wide dataframe
df = sns.load_dataset('planets').iloc[:, 2:5]
orbital_period mass distance
0 269.300 7.10 77.40
1 874.774 2.21 56.95
2 763.000 2.60 19.84
3 326.030 19.40 110.62
4 516.220 10.50 119.47
# long dataframe
dfm = sns.load_dataset('planets').iloc[:, 2:5].melt()
variable value
0 orbital_period 269.300
1 orbital_period 874.774
2 orbital_period 763.000
3 orbital_period 326.030
4 orbital_period 516.220
1. subplots=True and layout, for each column
Use the parameters subplots=True and layout=(rows, cols) in pandas.DataFrame.plot
This example uses kind='density', but there are different options for kind, and this applies to them all. Without specifying kind, a line plot is the default.
ax is array of AxesSubplot returned by pandas.DataFrame.plot
See How to get a Figure object, if needed.
How to save pandas subplots
axes = df.plot(kind='density', subplots=True, layout=(2, 2), sharex=False, figsize=(10, 6))
# extract the figure object; only used for tight_layout in this example
fig = axes[0][0].get_figure()
# set the individual titles
for ax, title in zip(axes.ravel(), df.columns):
ax.set_title(title)
fig.tight_layout()
plt.show()
2. plt.subplots, for each column
Create an array of Axes with matplotlib.pyplot.subplots and then pass axes[i, j] or axes[n] to the ax parameter.
This option uses pandas.DataFrame.plot, but can use other axes level plot calls as a substitute (e.g. sns.kdeplot, plt.plot, etc.)
It's easiest to collapse the subplot array of Axes into one dimension with .ravel or .flatten. See .ravel vs .flatten.
Any variables applying to each axes, that need to be iterate through, are combined with .zip (e.g. cols, axes, colors, palette, etc.). Each object must be the same length.
fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(10, 6)) # define the figure and subplots
axes = axes.ravel() # array to 1D
cols = df.columns # create a list of dataframe columns to use
colors = ['tab:blue', 'tab:orange', 'tab:green'] # list of colors for each subplot, otherwise all subplots will be one color
for col, color, ax in zip(cols, colors, axes):
df[col].plot(kind='density', ax=ax, color=color, label=col, title=col)
ax.legend()
fig.delaxes(axes[3]) # delete the empty subplot
fig.tight_layout()
plt.show()
Result for 1. and 2.
3. plt.subplots, for each group in .groupby
This is similar to 2., except it zips color and axes to a .groupby object.
fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(10, 6)) # define the figure and subplots
axes = axes.ravel() # array to 1D
dfg = dfm.groupby('variable') # get data for each unique value in the first column
colors = ['tab:blue', 'tab:orange', 'tab:green'] # list of colors for each subplot, otherwise all subplots will be one color
for (group, data), color, ax in zip(dfg, colors, axes):
data.plot(kind='density', ax=ax, color=color, title=group, legend=False)
fig.delaxes(axes[3]) # delete the empty subplot
fig.tight_layout()
plt.show()
4. seaborn figure-level plot
Use a seaborn figure-level plot, and use the col or row parameter. seaborn is a high-level API for matplotlib. See seaborn: API reference
p = sns.displot(data=dfm, kind='kde', col='variable', col_wrap=2, x='value', hue='variable',
facet_kws={'sharey': False, 'sharex': False}, height=3.5, aspect=1.75)
sns.move_legend(p, "upper left", bbox_to_anchor=(.55, .45))
Convert the axes array to 1D
Generating subplots with plt.subplots(nrows, ncols), where both nrows and ncols is greater than 1, returns a nested array of <AxesSubplot:> objects.
It’s not necessary to flatten axes in cases where either nrows=1 or ncols=1, because axes will already be 1 dimensional, which is a result of the default parameter squeeze=True
The easiest way to access the objects, is to convert the array to 1 dimension with .ravel(), .flatten(), or .flat.
.ravel vs. .flatten
flatten always returns a copy.
ravel returns a view of the original array whenever possible.
Once the array of axes is converted to 1-d, there are a number of ways to plot.
This answer is relevant to seaborn axes-level plots, which have the ax= parameter (e.g. sns.barplot(…, ax=ax[0]).
seaborn is a high-level API for matplotlib. See Figure-level vs. axes-level functions and seaborn is not plotting within defined subplots
import matplotlib.pyplot as plt
import numpy as np # sample data only
# example of data
rads = np.arange(0, 2*np.pi, 0.01)
y_data = np.array([np.sin(t*rads) for t in range(1, 5)])
x_data = [rads, rads, rads, rads]
# Generate figure and its subplots
fig, axes = plt.subplots(nrows=2, ncols=2)
# axes before
array([[<AxesSubplot:>, <AxesSubplot:>],
[<AxesSubplot:>, <AxesSubplot:>]], dtype=object)
# convert the array to 1 dimension
axes = axes.ravel()
# axes after
array([<AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>],
dtype=object)
Iterate through the flattened array
If there are more subplots than data, this will result in IndexError: list index out of range
Try option 3. instead, or select a subset of the axes (e.g. axes[:-2])
for i, ax in enumerate(axes):
ax.plot(x_data[i], y_data[i])
Access each axes by index
axes[0].plot(x_data[0], y_data[0])
axes[1].plot(x_data[1], y_data[1])
axes[2].plot(x_data[2], y_data[2])
axes[3].plot(x_data[3], y_data[3])
Index the data and axes
for i in range(len(x_data)):
axes[i].plot(x_data[i], y_data[i])
zip the axes and data together and then iterate through the list of tuples.
for ax, x, y in zip(axes, x_data, y_data):
ax.plot(x, y)
Ouput
An option is to assign each axes to a variable, fig, (ax1, ax2, ax3) = plt.subplots(1, 3). However, as written, this only works in cases with either nrows=1 or ncols=1. This is based on the shape of the array returned by plt.subplots, and quickly becomes cumbersome.
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2) for a 2 x 2 array.
This option is most useful for two subplots (e.g.: fig, (ax1, ax2) = plt.subplots(1, 2) or fig, (ax1, ax2) = plt.subplots(2, 1)). For more subplots, it's more efficient to flatten and iterate through the array of axes.
You could use the following:
import numpy as np
import matplotlib.pyplot as plt
fig, _ = plt.subplots(nrows=2, ncols=2)
for i, ax in enumerate(fig.axes):
ax.plot(np.sin(np.linspace(0,2*np.pi,100) + np.pi/2*i))
Or alternatively, using the second variable that plt.subplot returns:
fig, ax_mat = plt.subplots(nrows=2, ncols=2)
for i, ax in enumerate(ax_mat.flatten()):
...
ax_mat is a matrix of the axes. It's shape is nrows x ncols.
here is a simple solution
fig, ax = plt.subplots(nrows=2, ncols=3, sharex=True, sharey=False)
for sp in fig.axes:
sp.plot(range(10))
Go with the following if you really want to use a loop:
def plot(data):
fig = plt.figure(figsize=(100, 100))
for idx, k in enumerate(data.keys(), 1):
x, y = data[k].keys(), data[k].values
plt.subplot(63, 10, idx)
plt.bar(x, y)
plt.show()
Another concise solution is:
// set up structure of plots
f, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(20,10))
// for plot 1
ax1.set_title('Title A')
ax1.plot(x, y)
// for plot 2
ax2.set_title('Title B')
ax2.plot(x, y)
// for plot 3
ax3.set_title('Title C')
ax3.plot(x,y)

Iterating over a folder and plotting multiple csv files [duplicate]

I am a little confused about how this code works:
fig, axes = plt.subplots(nrows=2, ncols=2)
plt.show()
How does the fig, axes work in this case? What does it do?
Also why wouldn't this work to do the same thing:
fig = plt.figure()
axes = fig.subplots(nrows=2, ncols=2)
There are several ways to do it. The subplots method creates the figure along with the subplots that are then stored in the ax array. For example:
import matplotlib.pyplot as plt
x = range(10)
y = range(10)
fig, ax = plt.subplots(nrows=2, ncols=2)
for row in ax:
for col in row:
col.plot(x, y)
plt.show()
However, something like this will also work, it's not so "clean" though since you are creating a figure with subplots and then add on top of them:
fig = plt.figure()
plt.subplot(2, 2, 1)
plt.plot(x, y)
plt.subplot(2, 2, 2)
plt.plot(x, y)
plt.subplot(2, 2, 3)
plt.plot(x, y)
plt.subplot(2, 2, 4)
plt.plot(x, y)
plt.show()
import matplotlib.pyplot as plt
fig, ax = plt.subplots(2, 2)
ax[0, 0].plot(range(10), 'r') #row=0, col=0
ax[1, 0].plot(range(10), 'b') #row=1, col=0
ax[0, 1].plot(range(10), 'g') #row=0, col=1
ax[1, 1].plot(range(10), 'k') #row=1, col=1
plt.show()
You can also unpack the axes in the subplots call
And set whether you want to share the x and y axes between the subplots
Like this:
import matplotlib.pyplot as plt
# fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(nrows=2, ncols=2, sharex=True, sharey=True)
fig, axes = plt.subplots(nrows=2, ncols=2, sharex=True, sharey=True)
ax1, ax2, ax3, ax4 = axes.flatten()
ax1.plot(range(10), 'r')
ax2.plot(range(10), 'b')
ax3.plot(range(10), 'g')
ax4.plot(range(10), 'k')
plt.show()
You might be interested in the fact that as of matplotlib version 2.1 the second code from the question works fine as well.
From the change log:
Figure class now has subplots method
The Figure class now has a subplots() method which behaves the same as pyplot.subplots() but on an existing figure.
Example:
import matplotlib.pyplot as plt
fig = plt.figure()
axes = fig.subplots(nrows=2, ncols=2)
plt.show()
Read the documentation: matplotlib.pyplot.subplots
pyplot.subplots() returns a tuple fig, ax which is unpacked in two variables using the notation
fig, axes = plt.subplots(nrows=2, ncols=2)
The code:
fig = plt.figure()
axes = fig.subplots(nrows=2, ncols=2)
does not work because subplots() is a function in pyplot not a member of the object Figure.
Iterating through all subplots sequentially:
fig, axes = plt.subplots(nrows, ncols)
for ax in axes.flatten():
ax.plot(x,y)
Accessing a specific index:
for row in range(nrows):
for col in range(ncols):
axes[row,col].plot(x[row], y[col])
Subplots with pandas
This answer is for subplots with pandas, which uses matplotlib as the default plotting backend.
Here are four options to create subplots starting with a pandas.DataFrame
Implementation 1. and 2. are for the data in a wide format, creating subplots for each column.
Implementation 3. and 4. are for data in a long format, creating subplots for each unique value in a column.
Tested in python 3.8.11, pandas 1.3.2, matplotlib 3.4.3, seaborn 0.11.2
Imports and Data
import seaborn as sns # data only
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
# wide dataframe
df = sns.load_dataset('planets').iloc[:, 2:5]
orbital_period mass distance
0 269.300 7.10 77.40
1 874.774 2.21 56.95
2 763.000 2.60 19.84
3 326.030 19.40 110.62
4 516.220 10.50 119.47
# long dataframe
dfm = sns.load_dataset('planets').iloc[:, 2:5].melt()
variable value
0 orbital_period 269.300
1 orbital_period 874.774
2 orbital_period 763.000
3 orbital_period 326.030
4 orbital_period 516.220
1. subplots=True and layout, for each column
Use the parameters subplots=True and layout=(rows, cols) in pandas.DataFrame.plot
This example uses kind='density', but there are different options for kind, and this applies to them all. Without specifying kind, a line plot is the default.
ax is array of AxesSubplot returned by pandas.DataFrame.plot
See How to get a Figure object, if needed.
How to save pandas subplots
axes = df.plot(kind='density', subplots=True, layout=(2, 2), sharex=False, figsize=(10, 6))
# extract the figure object; only used for tight_layout in this example
fig = axes[0][0].get_figure()
# set the individual titles
for ax, title in zip(axes.ravel(), df.columns):
ax.set_title(title)
fig.tight_layout()
plt.show()
2. plt.subplots, for each column
Create an array of Axes with matplotlib.pyplot.subplots and then pass axes[i, j] or axes[n] to the ax parameter.
This option uses pandas.DataFrame.plot, but can use other axes level plot calls as a substitute (e.g. sns.kdeplot, plt.plot, etc.)
It's easiest to collapse the subplot array of Axes into one dimension with .ravel or .flatten. See .ravel vs .flatten.
Any variables applying to each axes, that need to be iterate through, are combined with .zip (e.g. cols, axes, colors, palette, etc.). Each object must be the same length.
fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(10, 6)) # define the figure and subplots
axes = axes.ravel() # array to 1D
cols = df.columns # create a list of dataframe columns to use
colors = ['tab:blue', 'tab:orange', 'tab:green'] # list of colors for each subplot, otherwise all subplots will be one color
for col, color, ax in zip(cols, colors, axes):
df[col].plot(kind='density', ax=ax, color=color, label=col, title=col)
ax.legend()
fig.delaxes(axes[3]) # delete the empty subplot
fig.tight_layout()
plt.show()
Result for 1. and 2.
3. plt.subplots, for each group in .groupby
This is similar to 2., except it zips color and axes to a .groupby object.
fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(10, 6)) # define the figure and subplots
axes = axes.ravel() # array to 1D
dfg = dfm.groupby('variable') # get data for each unique value in the first column
colors = ['tab:blue', 'tab:orange', 'tab:green'] # list of colors for each subplot, otherwise all subplots will be one color
for (group, data), color, ax in zip(dfg, colors, axes):
data.plot(kind='density', ax=ax, color=color, title=group, legend=False)
fig.delaxes(axes[3]) # delete the empty subplot
fig.tight_layout()
plt.show()
4. seaborn figure-level plot
Use a seaborn figure-level plot, and use the col or row parameter. seaborn is a high-level API for matplotlib. See seaborn: API reference
p = sns.displot(data=dfm, kind='kde', col='variable', col_wrap=2, x='value', hue='variable',
facet_kws={'sharey': False, 'sharex': False}, height=3.5, aspect=1.75)
sns.move_legend(p, "upper left", bbox_to_anchor=(.55, .45))
Convert the axes array to 1D
Generating subplots with plt.subplots(nrows, ncols), where both nrows and ncols is greater than 1, returns a nested array of <AxesSubplot:> objects.
It’s not necessary to flatten axes in cases where either nrows=1 or ncols=1, because axes will already be 1 dimensional, which is a result of the default parameter squeeze=True
The easiest way to access the objects, is to convert the array to 1 dimension with .ravel(), .flatten(), or .flat.
.ravel vs. .flatten
flatten always returns a copy.
ravel returns a view of the original array whenever possible.
Once the array of axes is converted to 1-d, there are a number of ways to plot.
This answer is relevant to seaborn axes-level plots, which have the ax= parameter (e.g. sns.barplot(…, ax=ax[0]).
seaborn is a high-level API for matplotlib. See Figure-level vs. axes-level functions and seaborn is not plotting within defined subplots
import matplotlib.pyplot as plt
import numpy as np # sample data only
# example of data
rads = np.arange(0, 2*np.pi, 0.01)
y_data = np.array([np.sin(t*rads) for t in range(1, 5)])
x_data = [rads, rads, rads, rads]
# Generate figure and its subplots
fig, axes = plt.subplots(nrows=2, ncols=2)
# axes before
array([[<AxesSubplot:>, <AxesSubplot:>],
[<AxesSubplot:>, <AxesSubplot:>]], dtype=object)
# convert the array to 1 dimension
axes = axes.ravel()
# axes after
array([<AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>],
dtype=object)
Iterate through the flattened array
If there are more subplots than data, this will result in IndexError: list index out of range
Try option 3. instead, or select a subset of the axes (e.g. axes[:-2])
for i, ax in enumerate(axes):
ax.plot(x_data[i], y_data[i])
Access each axes by index
axes[0].plot(x_data[0], y_data[0])
axes[1].plot(x_data[1], y_data[1])
axes[2].plot(x_data[2], y_data[2])
axes[3].plot(x_data[3], y_data[3])
Index the data and axes
for i in range(len(x_data)):
axes[i].plot(x_data[i], y_data[i])
zip the axes and data together and then iterate through the list of tuples.
for ax, x, y in zip(axes, x_data, y_data):
ax.plot(x, y)
Ouput
An option is to assign each axes to a variable, fig, (ax1, ax2, ax3) = plt.subplots(1, 3). However, as written, this only works in cases with either nrows=1 or ncols=1. This is based on the shape of the array returned by plt.subplots, and quickly becomes cumbersome.
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2) for a 2 x 2 array.
This option is most useful for two subplots (e.g.: fig, (ax1, ax2) = plt.subplots(1, 2) or fig, (ax1, ax2) = plt.subplots(2, 1)). For more subplots, it's more efficient to flatten and iterate through the array of axes.
You could use the following:
import numpy as np
import matplotlib.pyplot as plt
fig, _ = plt.subplots(nrows=2, ncols=2)
for i, ax in enumerate(fig.axes):
ax.plot(np.sin(np.linspace(0,2*np.pi,100) + np.pi/2*i))
Or alternatively, using the second variable that plt.subplot returns:
fig, ax_mat = plt.subplots(nrows=2, ncols=2)
for i, ax in enumerate(ax_mat.flatten()):
...
ax_mat is a matrix of the axes. It's shape is nrows x ncols.
here is a simple solution
fig, ax = plt.subplots(nrows=2, ncols=3, sharex=True, sharey=False)
for sp in fig.axes:
sp.plot(range(10))
Go with the following if you really want to use a loop:
def plot(data):
fig = plt.figure(figsize=(100, 100))
for idx, k in enumerate(data.keys(), 1):
x, y = data[k].keys(), data[k].values
plt.subplot(63, 10, idx)
plt.bar(x, y)
plt.show()
Another concise solution is:
// set up structure of plots
f, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(20,10))
// for plot 1
ax1.set_title('Title A')
ax1.plot(x, y)
// for plot 2
ax2.set_title('Title B')
ax2.plot(x, y)
// for plot 3
ax3.set_title('Title C')
ax3.plot(x,y)

Draw various plots in one figure

The image below shows, what i want, 3 different plots in one execution but using a function
enter image description here
enter image description here
I used the following code:
def box_hist_plot(data):
sns.set()
ax, fig = plt.subplots(1,3, figsize=(20,5))
sns.boxplot(x=data, linewidth=2.5, ax=fig[0])
plt.hist(x=data, bins=50, density=True, ax = fig[1])
sns.violinplot(x = data, ax=fig[2])
and i got this error:
inner() got multiple values for argument 'ax'
Besides the fact that you should not call a Figure object ax and an array of Axes object fig, your problem comes from the line plt.hist(...,ax=...). plt.hist() should not take an ax= parameter, but is meant to act on the "current" axes. If you want to specify which Axes you want to plot, you should use Axes.hist().
def box_hist_plot(data):
sns.set()
fig, axs = plt.subplots(1,3, figsize=(20,5))
sns.boxplot(x=data, linewidth=2.5, ax=axs[0])
axs[1].hist(x=data, bins=50, density=True)
sns.violinplot(x = data, ax=axs[2])

Matplotlib - how to combine a list of AxesSubplot into one figure with multiple subplots? [duplicate]

Looking at the matplotlib documentation, it seems the standard way to add an AxesSubplot to a Figure is to use Figure.add_subplot:
from matplotlib import pyplot
fig = pyplot.figure()
ax = fig.add_subplot(1,1,1)
ax.hist( some params .... )
I would like to be able to create AxesSubPlot-like objects independently of the figure, so I can use them in different figures. Something like
fig = pyplot.figure()
histoA = some_axes_subplot_maker.hist( some params ..... )
histoA = some_axes_subplot_maker.hist( some other params ..... )
# make one figure with both plots
fig.add_subaxes(histo1, 211)
fig.add_subaxes(histo1, 212)
fig2 = pyplot.figure()
# make a figure with the first plot only
fig2.add_subaxes(histo1, 111)
Is this possible in matplotlib and if so, how can I do this?
Update: I have not managed to decouple creation of Axes and Figures, but following examples in the answers below, can easily re-use previously created axes in new or olf Figure instances. This can be illustrated with a simple function:
def plot_axes(ax, fig=None, geometry=(1,1,1)):
if fig is None:
fig = plt.figure()
if ax.get_geometry() != geometry :
ax.change_geometry(*geometry)
ax = fig.axes.append(ax)
return fig
Typically, you just pass the axes instance to a function.
For example:
import matplotlib.pyplot as plt
import numpy as np
def main():
x = np.linspace(0, 6 * np.pi, 100)
fig1, (ax1, ax2) = plt.subplots(nrows=2)
plot(x, np.sin(x), ax1)
plot(x, np.random.random(100), ax2)
fig2 = plt.figure()
plot(x, np.cos(x))
plt.show()
def plot(x, y, ax=None):
if ax is None:
ax = plt.gca()
line, = ax.plot(x, y, 'go')
ax.set_ylabel('Yabba dabba do!')
return line
if __name__ == '__main__':
main()
To respond to your question, you could always do something like this:
def subplot(data, fig=None, index=111):
if fig is None:
fig = plt.figure()
ax = fig.add_subplot(index)
ax.plot(data)
Also, you can simply add an axes instance to another figure:
import matplotlib.pyplot as plt
fig1, ax = plt.subplots()
ax.plot(range(10))
fig2 = plt.figure()
fig2.axes.append(ax)
plt.show()
Resizing it to match other subplot "shapes" is also possible, but it's going to quickly become more trouble than it's worth. The approach of just passing around a figure or axes instance (or list of instances) is much simpler for complex cases, in my experience...
The following shows how to "move" an axes from one figure to another. This is the intended functionality of #JoeKington's last example, which in newer matplotlib versions is not working anymore, because axes cannot live in several figures at once.
You would first need to remove the axes from the first figure, then append it to the next figure and give it some position to live in.
import matplotlib.pyplot as plt
fig1, ax = plt.subplots()
ax.plot(range(10))
ax.remove()
fig2 = plt.figure()
ax.figure=fig2
fig2.axes.append(ax)
fig2.add_axes(ax)
dummy = fig2.add_subplot(111)
ax.set_position(dummy.get_position())
dummy.remove()
plt.close(fig1)
plt.show()
For line plots, you can deal with the Line2D objects themselves:
fig1 = pylab.figure()
ax1 = fig1.add_subplot(111)
lines = ax1.plot(scipy.randn(10))
fig2 = pylab.figure()
ax2 = fig2.add_subplot(111)
ax2.add_line(lines[0])
TL;DR based partly on Joe nice answer.
Opt.1: fig.add_subplot()
def fcn_return_plot():
return plt.plot(np.random.random((10,)))
n = 4
fig = plt.figure(figsize=(n*3,2))
#fig, ax = plt.subplots(1, n, sharey=True, figsize=(n*3,2)) # also works
for index in list(range(n)):
fig.add_subplot(1, n, index + 1)
fcn_return_plot()
plt.title(f"plot: {index}", fontsize=20)
Opt.2: pass ax[index] to a function that returns ax[index].plot()
def fcn_return_plot_input_ax(ax=None):
if ax is None:
ax = plt.gca()
return ax.plot(np.random.random((10,)))
n = 4
fig, ax = plt.subplots(1, n, sharey=True, figsize=(n*3,2))
for index in list(range(n)):
fcn_return_plot_input_ax(ax[index])
ax[index].set_title(f"plot: {index}", fontsize=20)
Outputs respect.
Note: Opt.1 plt.title() changed in opt.2 to ax[index].set_title(). Find more Matplotlib Gotchas in Van der Plas book.
To go deeper in the rabbit hole. Extending my previous answer, one could return a whole ax, and not ax.plot() only. E.g.
If dataframe had 100 tests of 20 types (here id):
dfA = pd.DataFrame(np.random.random((100,3)), columns = ['y1', 'y2', 'y3'])
dfB = pd.DataFrame(np.repeat(list(range(20)),5), columns = ['id'])
dfC = dfA.join(dfB)
And the plot function (this is the key of this whole answer):
def plot_feature_each_id(df, feature, id_range=[], ax=None, legend_bool=False):
feature = df[feature]
if not len(id_range): id_range=set(df['id'])
legend_arr = []
for k in id_range:
pass
mask = (df['id'] == k)
ax.plot(feature[mask])
legend_arr.append(f"id: {k}")
if legend_bool: ax.legend(legend_arr)
return ax
We can achieve:
feature_arr = dfC.drop('id',1).columns
id_range= np.random.randint(len(set(dfC.id)), size=(10,))
n = len(feature_arr)
fig, ax = plt.subplots(1, n, figsize=(n*6,4));
for i,k in enumerate(feature_arr):
plot_feature_each_id(dfC, k, np.sort(id_range), ax[i], legend_bool=(i+1==n))
ax[i].set_title(k, fontsize=20)
ax[i].set_xlabel("test nr. (id)", fontsize=20)

pyplot - copy an axes content and show it in a new figure

let say I have this code:
num_rows = 10
num_cols = 1
fig, axs = plt.subplots(num_rows, num_cols, sharex=True)
for i in xrange(num_rows):
ax = axs[i]
ax.plot(np.arange(10), np.arange(10)**i)
plt.show()
the result figure has too much info and now I want to pick 1 of the axes and draw it alone in a new figure
I tried doing something like this
def on_click(event):
axes = event.inaxes.get_axes()
fig2 = plt.figure(15)
fig2.axes.append(axes)
fig2.show()
fig.canvas.mpl_connect('button_press_event', on_click)
but it didn't quite work. what would be the correct way to do it? searching through the docs and throw SE gave hardly any useful result
edit:
I don't mind redrawing the chosen axes, but I'm not sure how can I tell which of the axes was chosen so if that information is available somehow then it is a valid solution for me
edit #2:
so I've managed to do something like this:
def on_click(event):
fig2 = plt.figure(15)
fig2.clf()
for line in event.inaxes.axes.get_lines():
xydata = line.get_xydata()
plt.plot(xydata[:, 0], xydata[:, 1])
fig2.show()
which seems to be "working" (all the other information is lost - labels, lines colors, lines style, lines width, xlim, ylim, etc...)
but I feel like there must be a nicer way to do it
thanks
Copying the axes
The inital answer here does not work, we keep it for future reference and also to see why a more sophisticated approach is needed.
#There are some pitfalls on the way with the initial approach.
#Adding an `axes` to a figure can be done via `fig.add_axes(axes)`. However, at this point,
#the axes' figure needs to be the figure the axes should be added to.
#This may sound a bit like running in circles but we can actually set the axes'
#figure as `axes.figure = fig2` and hence break out of this.
#One might then also position the axes in the new figure to take the usual dimensions.
#For this a dummy axes can be added first, the axes can change its position to the position
#of the dummy axes and then the dummy axes is removed again. In total, this would look as follows.
import matplotlib.pyplot as plt
import numpy as np
num_rows = 10
num_cols = 1
fig, axs = plt.subplots(num_rows, num_cols, sharex=True)
for i in xrange(num_rows):
ax = axs[i]
ax.plot(np.arange(10), np.arange(10)**i)
def on_click(event):
axes = event.inaxes
if not axes: return
fig2 = plt.figure()
axes.figure=fig2
fig2.axes.append(axes)
fig2.add_axes(axes)
dummy = fig2.add_subplot(111)
axes.set_position(dummy.get_position())
dummy.remove()
fig2.show()
fig.canvas.mpl_connect('button_press_event', on_click)
plt.show()
#So far so good, however, be aware that now after a click the axes is somehow
#residing in both figures, which can cause all sorts of problems, e.g. if you
# want to resize or save the initial figure.
Instead, the following will work:
Pickling the figure
The problem is that axes cannot be copied (even deepcopy will fail). Hence to obtain a true copy of an axes, you may need to use pickle. The following will work. It pickles the complete figure and removes all but the one axes to show.
import matplotlib.pyplot as plt
import numpy as np
import pickle
import io
num_rows = 10
num_cols = 1
fig, axs = plt.subplots(num_rows, num_cols, sharex=True)
for i in range(num_rows):
ax = axs[i]
ax.plot(np.arange(10), np.arange(10)**i)
def on_click(event):
if not event.inaxes: return
inx = list(fig.axes).index(event.inaxes)
buf = io.BytesIO()
pickle.dump(fig, buf)
buf.seek(0)
fig2 = pickle.load(buf)
for i, ax in enumerate(fig2.axes):
if i != inx:
fig2.delaxes(ax)
else:
axes=ax
axes.change_geometry(1,1,1)
fig2.show()
fig.canvas.mpl_connect('button_press_event', on_click)
plt.show()
Recreate plots
The alternative to the above is of course to recreate the plot in a new figure each time the axes is clicked. To this end one may use a function that creates a plot on a specified axes and with a specified index as input. Using this function during figure creation as well as later for replicating the plot in another figure ensures to have the same plot in all cases.
import matplotlib.pyplot as plt
import numpy as np
num_rows = 10
num_cols = 1
colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
labels = ["Label {}".format(i+1) for i in range(num_rows)]
def myplot(i, ax):
ax.plot(np.arange(10), np.arange(10)**i, color=colors[i])
ax.set_ylabel(labels[i])
fig, axs = plt.subplots(num_rows, num_cols, sharex=True)
for i in xrange(num_rows):
myplot(i, axs[i])
def on_click(event):
axes = event.inaxes
if not axes: return
inx = list(fig.axes).index(axes)
fig2 = plt.figure()
ax = fig2.add_subplot(111)
myplot(inx, ax)
fig2.show()
fig.canvas.mpl_connect('button_press_event', on_click)
plt.show()
If you have, for example, a plot with three lines generated by the function plot_something, you can do something like this:
fig, axs = plot_something()
ax = axs[2]
l = list(ax.get_lines())[0]
l2 = list(ax.get_lines())[1]
l3 = list(ax.get_lines())[2]
plot(l.get_data()[0], l.get_data()[1])
plot(l2.get_data()[0], l2.get_data()[1])
plot(l3.get_data()[0], l3.get_data()[1])
ylim(0,1)