matplotlib change a Patch in PatchCollection - matplotlib

PatchCollection accepts a list of Patches and allows me to transform / add them to a canvas all at once. But changes to the one of the Patches after the construction of the PatchCollection object are not reflected
for example:
import matplotlib.pyplot as plt
import matplotlib as mpl
rect = mpl.patches.Rectangle((0,0),1,1)
rect.set_xy((1,1))
collection = mpl.collections.PatchCollection([rect])
rect.set_xy((2,2))
ax = plt.figure(None).gca()
ax.set_xlim(0,5)
ax.set_ylim(0,5)
ax.add_artist(collection)
plt.show() #shows a rectangle at (1,1), not (2,2)
I'm looking for a matplotlib collection that will group patches just so I can transform them together, but I want to be able to change the individual patches as well.

I don't know of a collection which will do what you want, but you could write one for yourself fairly easily:
import matplotlib.collections as mcollections
import matplotlib.pyplot as plt
import matplotlib as mpl
class UpdatablePatchCollection(mcollections.PatchCollection):
def __init__(self, patches, *args, **kwargs):
self.patches = patches
mcollections.PatchCollection.__init__(self, patches, *args, **kwargs)
def get_paths(self):
self.set_paths(self.patches)
return self._paths
rect = mpl.patches.Rectangle((0,0),1,1)
rect.set_xy((1,1))
collection = UpdatablePatchCollection([rect])
rect.set_xy((2,2))
ax = plt.figure(None).gca()
ax.set_xlim(0,5)
ax.set_ylim(0,5)
ax.add_artist(collection)
plt.show() # now shows a rectangle at (2,2)

Related

Legend handle to an xarray plot

I cannot modify the legend of plot of a dataset made with xarray plotting function.
The code below returns No handles with labels found to put in legend.
import xarray as xr
import matplotlib.pyplot as plt
air = xr.tutorial.open_dataset("air_temperature").air
air.isel(lon=10, lat=[19, 21, 22]).plot.line(x="time", add_legend=True)
plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
You can use seaborn's sns.move_legend(), followed by plt.tight_layout(). sns.move_legend() is new in seaborn 0.11.2.
import xarray as xr
import matplotlib.pyplot as plt
import seaborn as sns
air = xr.tutorial.open_dataset("air_temperature").air
air.isel(lon=10, lat=[19, 21, 22]).plot.line(x="time", add_legend=True)
sns.move_legend(plt.gca(), loc='center left', bbox_to_anchor=(1, 0.5))
plt.tight_layout()
plt.show()
PS: If you don't want to import seaborn, you could copy the function from its source. You'll need to remove a reference to sns.axisgrid.Grid and import matplotlib as mpl; import inspect:
import matplotlib.pyplot as plt
import matplotlib as mpl
import inspect
import xarray as xr
def move_legend(obj, loc, **kwargs):
"""
Recreate a plot's legend at a new location.
Extracted from seaborn/utils.py
"""
if isinstance(obj, mpl.axes.Axes):
old_legend = obj.legend_
legend_func = obj.legend
elif isinstance(obj, mpl.figure.Figure):
if obj.legends:
old_legend = obj.legends[-1]
else:
old_legend = None
legend_func = obj.legend
else:
err = "`obj` must be a matplotlib Axes or Figure instance."
raise TypeError(err)
if old_legend is None:
err = f"{obj} has no legend attached."
raise ValueError(err)
# Extract the components of the legend we need to reuse
handles = old_legend.legendHandles
labels = [t.get_text() for t in old_legend.get_texts()]
# Extract legend properties that can be passed to the recreation method
# (Vexingly, these don't all round-trip)
legend_kws = inspect.signature(mpl.legend.Legend).parameters
props = {k: v for k, v in old_legend.properties().items() if k in legend_kws}
# Delegate default bbox_to_anchor rules to matplotlib
props.pop("bbox_to_anchor")
# Try to propagate the existing title and font properties; respect new ones too
title = props.pop("title")
if "title" in kwargs:
title.set_text(kwargs.pop("title"))
title_kwargs = {k: v for k, v in kwargs.items() if k.startswith("title_")}
for key, val in title_kwargs.items():
title.set(**{key[6:]: val})
kwargs.pop(key)
# Try to respect the frame visibility
kwargs.setdefault("frameon", old_legend.legendPatch.get_visible())
# Remove the old legend and create the new one
props.update(kwargs)
old_legend.remove()
new_legend = legend_func(handles, labels, loc=loc, **props)
new_legend.set_title(title.get_text(), title.get_fontproperties())
air = xr.tutorial.open_dataset("air_temperature").air
air.isel(lon=10, lat=[19, 21, 22]).plot.line(x="time", add_legend=True)
move_legend(plt.gca(), loc='center left', bbox_to_anchor=(1, 0.5))
plt.tight_layout()
plt.show()

Making sure 0 gets white in a RdBu colorbar

I create a heatmap with the following snippet:
import numpy as np
import matplotlib.pyplot as plt
d = np.random.normal(.4,2,(10,10))
plt.imshow(d,cmap=plt.cm.RdBu)
plt.colorbar()
plt.show()
The result is plot below:
Now, since the middle point of the data is not 0, the cells in which the colormap has value 0 are not white, but rather a little reddish.
How do I force the colormap so that max=blue, min=red and 0=white?
Use a DivergingNorm.
Note: From matplotlib 3.2 onwards DivergingNorm is renamed to TwoSlopeNorm.
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
d = np.random.normal(.4,2,(10,10))
norm = mcolors.DivergingNorm(vmin=d.min(), vmax = d.max(), vcenter=0)
plt.imshow(d, cmap=plt.cm.RdBu, norm=norm)
plt.colorbar()
plt.show()
A previous SO post (Change colorbar gradient in matplotlib) wanted a solution for a more complicated situation, but one of the answers talked about the MidpointNormalize subclass in the matplotlib documentation. With that, the solution becomes:
import matplotlib as mpl
import numpy as np
import matplotlib.pyplot as plt
class MidpointNormalize(mpl.colors.Normalize):
## class from the mpl docs:
# https://matplotlib.org/users/colormapnorms.html
def __init__(self, vmin=None, vmax=None, midpoint=None, clip=False):
self.midpoint = midpoint
super().__init__(vmin, vmax, clip)
def __call__(self, value, clip=None):
# I'm ignoring masked values and all kinds of edge cases to make a
# simple example...
x, y = [self.vmin, self.midpoint, self.vmax], [0, 0.5, 1]
return np.ma.masked_array(np.interp(value, x, y))
d = np.random.normal(.4,2,(10,10))
plt.imshow(d,cmap=plt.cm.RdBu,norm=MidpointNormalize(midpoint=0))
plt.colorbar()
plt.show()
Kudos to Joe Kington for writing the subclass, and to Rutger Kassies for pointing out the answer.

matplotlib pyplot pcolor savefig colorbar transparency

I am trying to export a pcolor figure with a colorbar.
The cmap of the colorbar has a transparent color.
The exported figure has transparent colors in the axes but not in the colorbar. How can I fix this?
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
x = np.random.random((10, 10))
colors = [(0,0,0,0), (0,0,0,1)]
cm = LinearSegmentedColormap.from_list('custom', colors, N=256, gamma=0)
plt.pcolor(x,cmap=cm)
plt.colorbar()
plt.savefig('figure.pdf',transparent=True)
I put the image against a grey background to check. As can be seen, the cmap in the axes is transparent while the one in the colorbar is not.
While the colorbar resides inside an axes, it has an additional background patch associated with it. This is white by default and will not be taken into account when transparent=True is used inside of savefig.
A solution is hence to remove the facecolor of this patch manually,
cb.patch.set_facecolor("none")
A complete example, which shows this without actually saving the figure
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
x = np.random.random((10, 10))
colors = [(1,1,1,0), (0,0,0,1)]
cm = LinearSegmentedColormap.from_list('custom', colors, N=256, gamma=0)
fig, ax = plt.subplots(facecolor="grey")
im = ax.pcolor(x,cmap=cm)
cb = fig.colorbar(im, drawedges=False)
ax.set_facecolor("none")
cb.patch.set_facecolor("none")
plt.show()

Understanding plt.show() in Matplotlib

import numpy as np
import os.path
from skimage.io import imread
from skimage import data_dir
img = imread(os.path.join(data_dir, 'checker_bilevel.png'))
import matplotlib.pyplot as plt
#plt.imshow(img, cmap='Blues')
#plt.show()
imgT = img.T
plt.figure(1)
plt.imshow(imgT,cmap='Greys')
#plt.show()
imgR = img.reshape(20,5)
plt.figure(2)
plt.imshow(imgR,cmap='Blues')
plt.show(1)
I read that plt.figure() will create or assign the image a new ID if not explicitly given one. So here, I have given the two figures, ID 1 & 2 respectively. Now I wish to see only one one of the image.
I tried plt.show(1) epecting ONLY the first image will be displayed but both of them are.
What should I write to get only one?
plt.clf() will clear the figure
import matplotlib.pyplot as plt
plt.plot(range(10), 'r')
plt.clf()
plt.plot(range(12), 'g--')
plt.show()
plt.show will show all the figures created. The argument you forces the figure to be shown in a non-blocking way. If you only want to show a particular figure you can write a wrapper function.
import matplotlib.pyplot as plt
figures = [plt.subplots() for i in range(5)]
def show(figNum, figures):
if plt.fignum_exists(figNum):
fig = [f[0] for f in figures if f[0].number == figNum][0]
fig.show()
else:
print('figure not found')

Enumerate plots in matplotlib figure

In a matplotlib figure I would like to enumerate all (sub)plots with a), b), c) and so on. Is there a way to do this automatically?
So far I use the individual plots' titles, but that is far from ideal as I want the number to be left aligned, while an optional real title should be centered on the figure.
import string
from itertools import cycle
from six.moves import zip
def label_axes(fig, labels=None, loc=None, **kwargs):
"""
Walks through axes and labels each.
kwargs are collected and passed to `annotate`
Parameters
----------
fig : Figure
Figure object to work on
labels : iterable or None
iterable of strings to use to label the axes.
If None, lower case letters are used.
loc : len=2 tuple of floats
Where to put the label in axes-fraction units
"""
if labels is None:
labels = string.ascii_lowercase
# re-use labels rather than stop labeling
labels = cycle(labels)
if loc is None:
loc = (.9, .9)
for ax, lab in zip(fig.axes, labels):
ax.annotate(lab, xy=loc,
xycoords='axes fraction',
**kwargs)
example usage:
from matplotlib import pyplot as plt
fig, ax_lst = plt.subplots(3, 3)
label_axes(fig, ha='right')
plt.draw()
fig, ax_lst = plt.subplots(3, 3)
label_axes(fig, ha='left')
plt.draw()
This seems useful enough to me that I put this in a gist : https://gist.github.com/tacaswell/9643166
I wrote a function to do this automatically, where the label is introduced as a legend:
import numpy
import matplotlib.pyplot as plt
def setlabel(ax, label, loc=2, borderpad=0.6, **kwargs):
legend = ax.get_legend()
if legend:
ax.add_artist(legend)
line, = ax.plot(numpy.NaN,numpy.NaN,color='none',label=label)
label_legend = ax.legend(handles=[line],loc=loc,handlelength=0,handleheight=0,handletextpad=0,borderaxespad=0,borderpad=borderpad,frameon=False,**kwargs)
label_legend.remove()
ax.add_artist(label_legend)
line.remove()
fig,ax = plt.subplots()
ax.plot([1,2],[1,2])
setlabel(ax, '(a)')
plt.show()
The location of the label can be controlled with loc argument, the distance to the axis can be controlled with borderpad argument (negative value pushes the label to be outside the figure), and other options available to legend also can be used, such as fontsize. The above script gives such figure:
A super quick way to do this is to take advantage of the fact that chr() casts integers to characters. Since a-z fall in the range 97-122, one can do the following:
import matplotlib.pyplot as plt
fig,axs = plt.subplots(2,2)
for i,ax in enumerate(axs.flat, start=97):
ax.plot([0,1],[0,1])
ax.text(0.05,0.9,chr(i)+')', transform=ax.transAxes)
which produces: