How to create a scalebar using cartopy and matplotlib? - matplotlib

in respect to the previous examples in stackoverflow, I searched for other alternatives in order to create a scalebar.
In my research, I verified that the Basemap class from mpl_toolkits.basemap see here. It has the "drawmapscale" method. This method has the option barstyle = 'fancy' for a more interesting scalebar drawing.
Therefore, I attempted to convert the "drawmapscale" from the Basemap into a cartopy version.
Nevertheless, the results were not positive, and I got error messages from the figure. I believe that the error is in the Transform of the data.
Here is the script:
import numpy as np
import matplotlib.pyplot as plt
import pyproj
import cartopy.crs as ccrs
from matplotlib import is_interactive
from cartopy.crs import (WGS84_SEMIMAJOR_AXIS, WGS84_SEMIMINOR_AXIS)
from pyproj import Transformer
class Scalebar():
def __init__(self,
ax,
suppress_ticks=True,
geographical_crs = 4326,
planar_crs = 5880,
fix_aspect=True,
anchor='C',
celestial=False,
round=False,
noticks=False,
metric_ccrs=ccrs.TransverseMercator()):
self.ax = ax
self.fix_aspect = fix_aspect
# setting metric ccrs for reference in the plotting
self.metric_ccrs = metric_ccrs
self.anchor = anchor
# geographic or celestial coords?
self.celestial = celestial
# map projection.
self.projection = ax.projection
self.geographical_crs = geographical_crs
self.planar_crs = planar_crs
self._initialized_axes = set()
self.round = round
# map boundary not yet drawn.
self._mapboundarydrawn = False
self.rmajor = np.float(ax.projection.globe.semimajor_axis or WGS84_SEMIMAJOR_AXIS)
self.rminor = np.float(ax.projection.globe.semiminor_axis or WGS84_SEMIMINOR_AXIS)
# set instance variables defining map region.
self.xmin = self.projection.boundary.bounds[0]
self.xmax = self.projection.boundary.bounds[2]
self.ymin = self.projection.boundary.bounds[1]
self.ymax = self.projection.boundary.bounds[3]
self._width = self.xmax - self.xmin
self._height = self.ymax - self.ymin
self.noticks = noticks
def __call__(self,x,y,
inverse=False
):
"""
Calling the class instance with the arguments lon, lat will
convert lon/lat (in degrees) to x/y map projection
coordinates (in meters).
If optional keyword ``inverse`` is True (default is False),
the inverse transformation from x/y to lon/lat is performed.
Input arguments:
lon, lat can be either scalar floats, sequences, or numpy arrays.
"""
if not inverse:
transformer = Transformer.from_crs("epsg:{0}".format(self.geographical_crs),
"epsg:{0}".format(self.planar_crs))
else:
transformer = Transformer.from_crs("epsg:{0}".format(self.planar_crs),
"epsg:{0}".format(self.geographical_crs))
return transformer.transform(x, y)
def drawmapscale(self,
lon,
lat,
length,
lon0=None,
lat0=None,
barstyle='simple',\
units='km',
fontsize=9,
yoffset=None,
labelstyle='simple',\
fontcolor='k',
fillcolor1='w',
fillcolor2='k',\
format='%d',
zorder=None,
linecolor=None,
linewidth=None):
"""
Draw a map scale at ``lon,lat`` of length ``length``
representing distance in the map
projection coordinates at ``lon0,lat0``.
.. tabularcolumns:: |l|L|
============== ====================================================
Keywords Description
============== ====================================================
units the units of the length argument (Default km).
barstyle ``simple`` or ``fancy`` (roughly corresponding
to the styles provided by Generic Mapping Tools).
Default ``simple``.
fontsize for map scale annotations, default 9.
fontcolor for map scale annotations, default black.
labelstyle ``simple`` (default) or ``fancy``. For
``fancy`` the map scale factor (ratio betwee
the actual distance and map projection distance
at lon0,lat0) and the value of lon0,lat0 are also
displayed on the top of the scale bar. For
``simple``, just the units are display on top
and the distance below the scale bar.
If equal to False, plot an empty label.
format a string formatter to format numeric values
yoffset yoffset controls how tall the scale bar is,
and how far the annotations are offset from the
scale bar. Default is 0.02 times the height of
the map (0.02*(self.ymax-self.ymin)).
fillcolor1(2) colors of the alternating filled regions
(default white and black). Only relevant for
'fancy' barstyle.
zorder sets the zorder for the map scale.
linecolor sets the color of the scale, by default, fontcolor
is used
linewidth linewidth for scale and ticks
============== ====================================================
Extra keyword ``ax`` can be used to override the default axis instance.
"""
# get current axes instance (if none specified).
ax = self.ax
# convert length to meters
lenlab = length
if units == 'km':
length = length*1000
elif units == 'mi':
length = length*1609.344
elif units == 'nmi':
length = length*1852
elif units == 'ft':
length = length*0.3048
elif units != 'm':
msg = "units must be 'm' (meters), 'km' (kilometers), "\
"'mi' (miles), 'nmi' (nautical miles), or 'ft' (feet)"
raise KeyError(msg)
# Setting the center coordinates of the axes:
xmin, xmax, ymin, ymax = self.ax.get_extent()
if lon0 == None:
lon0 = np.mean([xmin, xmax])
if lat0 == None:
lat0 = np.mean([ymin, ymax])
# reference point and center of scale.
x0,y0 = self(lon0,lat0)
print('\n\n Central coords prior to transform')
print('lon0,lat0: ', [lon0,lat0])
print('\n\n central coordinates after transform')
print('x0,y0: ', [x0,y0])
xc,yc = self(lon,lat)
print('\n\n positional coordinates prior to transform')
print('lon, lat: ', [lon,lat])
print('\n\n central coordinates after transform')
print('xc,yc: ', [xc,yc])
print('-'*20, '\n')
# make sure lon_0 between -180 and 180
lon_0 = ((lon0+360) % 360) - 360
if lat0>0:
if lon>0:
lonlatstr = u'%g\N{DEGREE SIGN}N, %g\N{DEGREE SIGN}E' % (lat0,lon_0)
elif lon<0:
lonlatstr = u'%g\N{DEGREE SIGN}N, %g\N{DEGREE SIGN}W' % (lat0,lon_0)
else:
lonlatstr = u'%g\N{DEGREE SIGN}, %g\N{DEGREE SIGN}W' % (lat0,lon_0)
else:
if lon>0:
lonlatstr = u'%g\N{DEGREE SIGN}S, %g\N{DEGREE SIGN}E' % (lat0,lon_0)
elif lon<0:
lonlatstr = u'%g\N{DEGREE SIGN}S, %g\N{DEGREE SIGN}W' % (lat0,lon_0)
else:
lonlatstr = u'%g\N{DEGREE SIGN}S, %g\N{DEGREE SIGN}' % (lat0,lon_0)
# left edge of scale
lon1,lat1 = self(x0-length/2,y0, inverse=True)
x1,y1 = self(lon1,lat1)
# right edge of scale
lon4,lat4 = self(x0+length/2,y0, inverse=True)
x4,y4 = self(lon4,lat4)
x1 = x1-x0+xc
y1 = y1-y0+yc
print('\n\n positional coordinates prior to transform')
print('lon1,lat1: ', [lon1,lat1])
print('\n\n positional coordinates prior to transform')
print('x1, y1: ', [x1,y1])
print()
print('\n\n central coordinates after transform')
print('lon4,lat4: ', [lon4,lat4])
print('-'*20, '\n')
print('\n\n central coordinates after transform')
print('x4,y4: ', [x4,y4])
print('-'*20, '\n')
x4 = x4-x0+xc
y4 = y4-y0+yc
if x1 > 1.e20 or x4 > 1.e20 or y1 > 1.e20 or y4 > 1.e20:
raise ValueError("scale bar positioned outside projection limb")
# scale factor for true distance
gc = pyproj.Geod(a=self.rmajor,b=self.rminor)
az12,az21,dist = gc.inv(lon1,lat1,lon4,lat4)
scalefact = dist/length
# label to put on top of scale bar.
if labelstyle=='simple':
labelstr = units
elif labelstyle == 'fancy':
labelstr = units+" (scale factor %4.2f at %s)"%(scalefact,lonlatstr)
elif labelstyle == False:
labelstr = ''
else:
raise KeyError("labelstyle must be 'simple' or 'fancy'")
# default y offset is 2 percent of map height.
if yoffset is None:
yoffset = 0.02*(self.ymax-self.ymin)
rets = [] # will hold all plot objects generated.
# set linecolor
if linecolor is None:
linecolor = fontcolor
# 'fancy' style
if barstyle == 'fancy':
#we need 5 sets of x coordinates (in map units)
#quarter scale
lon2,lat2 = self(x0-length/4,y0,inverse=True)
x2,y2 = self(lon2,lat2)
x2 = x2-x0+xc; y2 = y2-y0+yc
#three quarter scale
lon3,lat3 = self(x0+length/4,y0,inverse=True)
x3,y3 = self(lon3,lat3)
x3 = x3-x0+xc; y3 = y3-y0+yc
#plot top line
ytop = yc+yoffset/2
ybottom = yc-yoffset/2
ytick = ybottom - yoffset/2
ytext = ytick - yoffset/2
lontext , lattext = self(lon0,ytext, inverse=True)
#lon_top, lat_top = self(lon4,ytop,inverse=True)
#lon_top, lat_bottom = self(lon4,ybottom,inverse=True)
transform = self.metric_ccrs # this crs projection is meant to be for metric data
rets.append(self.plot([x1,x4],
[ytop,ytop],
transform=transform,
color=linecolor,
linewidth=linewidth)[0])
#plot bottom line
rets.append(self.plot([x1,x4],
[ybottom,ybottom],
transform=transform,
color=linecolor,
linewidth=linewidth)[0])
#plot left edge
rets.append(self.plot([x1,x1],
[ybottom,ytop],
transform=transform,
color=linecolor,
linewidth=linewidth)[0])
#plot right edge
rets.append(self.plot([x4,x4],
[ybottom,ytop],
transform=transform,
color=linecolor,
linewidth=linewidth)[0])
#make a filled black box from left edge to 1/4 way across
rets.append(ax.fill([x1,x2,x2,x1,x1],
[ytop,ytop,ybottom,ybottom,ytop],
transform=transform,
ec=fontcolor,
fc=fillcolor1)[0])
#make a filled white box from 1/4 way across to 1/2 way across
rets.append(ax.fill([x2,x0,x0,x2,x2],
[ytop,ytop,ybottom,ybottom,ytop],
transform=transform,
ec=fontcolor,
fc=fillcolor2)[0])
#make a filled white box from 1/2 way across to 3/4 way across
rets.append(ax.fill([x0,x3,x3,x0,x0],
[ytop,ytop,ybottom,ybottom,ytop],
transform=transform,
ec=fontcolor,
fc=fillcolor1)[0])
#make a filled white box from 3/4 way across to end
rets.append(ax.fill([x3,x4,x4,x3,x3],
[ytop,ytop,ybottom,ybottom,ytop],
transform=transform,
ec=fontcolor,
fc=fillcolor2)[0])
#plot 3 tick marks at left edge, center, and right edge
rets.append(self.plot([x1,x1],
[ytick,ybottom],
color=linecolor,
transform=transform,
linewidth=linewidth)[0])
rets.append(self.plot([x0,x0],
[ytick,ybottom],
transform=transform,
color=linecolor,
linewidth=linewidth)[0])
rets.append(self.plot([x4,x4],
[ytick,ybottom],
transform=transform,
color=linecolor,
linewidth=linewidth)[0])
#label 3 tick marks
rets.append(ax.text(x1,lattext,format % (0),\
horizontalalignment='center',\
verticalalignment='top',\
fontsize=fontsize,color=fontcolor))
rets.append(ax.text(x0,lattext,format % (0.5*lenlab),\
horizontalalignment='center',\
verticalalignment='top',\
fontsize=fontsize,color=fontcolor))
rets.append(ax.text(x4,lattext,format % (lenlab),\
horizontalalignment='center',\
verticalalignment='top',\
fontsize=fontsize,color=fontcolor))
#put units, scale factor on top
rets.append(ax.text(x0,ytop+yoffset/2,labelstr,\
horizontalalignment='center',\
verticalalignment='bottom',\
fontsize=fontsize,color=fontcolor))
# 'simple' style
elif barstyle == 'simple':
rets.append(self.plot([x1,x4],[yc,yc],color=linecolor, linewidth=linewidth)[0])
rets.append(self.plot([x1,x1],[yc-yoffset,yc+yoffset],color=linecolor, linewidth=linewidth)[0])
rets.append(self.plot([x4,x4],[yc-yoffset,yc+yoffset],color=linecolor, linewidth=linewidth)[0])
rets.append(ax.text(xc,yc-yoffset,format % lenlab,\
verticalalignment='top',horizontalalignment='center',\
fontsize=fontsize,color=fontcolor))
#put units, scale factor on top
rets.append(ax.text(xc,yc+yoffset,labelstr,\
horizontalalignment='center',\
verticalalignment='bottom',\
fontsize=fontsize,color=fontcolor))
else:
raise KeyError("barstyle must be 'simple' or 'fancy'")
if zorder is not None:
for ret in rets:
try:
ret.set_zorder(zorder)
except:
pass
return rets
def plot(self, *args, **kwargs):
"""
Draw lines and/or markers on the map
(see matplotlib.pyplot.plot documentation).
If ``latlon`` keyword is set to True, x,y are intrepreted as
longitude and latitude in degrees. Data and longitudes are
automatically shifted to match map projection region for cylindrical
and pseudocylindrical projections, and x,y are transformed to map
projection coordinates. If ``latlon`` is False (default), x and y
are assumed to be map projection coordinates.
Extra keyword ``ax`` can be used to override the default axis instance.
Other \**kwargs passed on to matplotlib.pyplot.plot.
"""
ax = self.ax
self._save_use_hold(ax, kwargs)
try:
ret = ax.plot(*args,
**kwargs)
finally:
self._restore_hold(ax)
# set axes limits to fit map region.
self.set_axes_limits(ax=ax)
# clip to map limbs
ret,c = self._cliplimb(ax,ret)
return ret
def _save_use_hold(self, ax, kwargs):
h = kwargs.pop('hold', None)
if hasattr(ax, '_hold'):
self._tmp_hold = ax._hold
if h is not None:
ax._hold = h
def _restore_hold(self, ax):
if hasattr(ax, '_hold'):
ax._hold = self._tmp_hold
def set_axes_limits(self,ax=None):
"""
Final step in Basemap method wrappers of Axes plotting methods:
Set axis limits, fix aspect ratio for map domain using current
or specified axes instance. This is done only once per axes
instance.
In interactive mode, this method always calls draw_if_interactive
before returning.
"""
# get current axes instance (if none specified).
ax = ax or self._check_ax()
# If we have already set the axes limits, and if the user
# has not defeated this by turning autoscaling back on,
# then all we need to do is plot if interactive.
if (hash(ax) in self._initialized_axes
and not ax.get_autoscalex_on()
and not ax.get_autoscaley_on()):
if is_interactive():
import matplotlib.pyplot as plt
plt.draw_if_interactive()
return
self._initialized_axes.add(hash(ax))
# Take control of axis scaling:
ax.set_autoscale_on(False)
# update data limits for map domain.
corners = ((self.xmin, self.ymin), (self.xmax, self.ymax))
ax.update_datalim(corners)
ax.set_xlim((self.xmin, self.xmax))
ax.set_ylim((self.ymin, self.ymax))
# if map boundary not yet drawn for elliptical maps, draw it with default values.
# make sure aspect ratio of map preserved.
# plot is re-centered in bounding rectangle.
# (anchor instance var determines where plot is placed)
if self.fix_aspect:
ax.set_aspect('equal',anchor=self.anchor)
else:
ax.set_aspect('auto',anchor=self.anchor)
# make sure axis ticks are turned off.
if self.noticks:
ax.set_xticks([])
ax.set_yticks([])
# force draw if in interactive mode.
if is_interactive():
import matplotlib.pyplot as plt
plt.draw_if_interactive()
def _cliplimb(self,ax,coll):
if not self._mapboundarydrawn:
return coll, None
c = self._mapboundarydrawn
if c not in ax.patches:
p = ax.add_patch(c)
#p.set_clip_on(False)
try:
coll.set_clip_path(c)
except:
for item in coll:
item.set_clip_path(c)
return coll,c
# now the test
from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER
import geopandas as gpd
def get_standard_gdf():
""" basic function for getting some geographical data in geopandas GeoDataFrame python's instance:
An example data can be downloaded from Brazilian IBGE:
ref: ftp://geoftp.ibge.gov.br/organizacao_do_territorio/malhas_territoriais/malhas_municipais/municipio_2017/Brasil/BR/br_municipios.zip
"""
gdf_path = r'C:\my_file_path\Shapefile.shp'
return gpd.read_file(gdf_path)
def format_ax(ax, projection, xlim, ylim):
ax.set_xlim(xlim)
ax.set_ylim(ylim)
ax.set_global()
ax.coastlines()
def main():
fig = plt.figure(figsize=(8, 10))
# Label axes of a Plate Carree projection with a central longitude of 180:
#for enum, proj in enumerate(['Mercator, PlateCarree']):
gdf = get_standard_gdf()
xmin, ymin, xmax, ymax = gdf.total_bounds
xlim = [xmin, xmax]
ylim = [ymin, ymax]
lon_c = np.mean(xlim)
lat_c = np.mean(ylim)
projection = ccrs.PlateCarree(central_longitude=0)
ax1 = fig.add_subplot(3, 1, 1,
projection=projection,
xlim=[xmin, xmax],
ylim=[ymin, ymax])
gdf.plot(ax=ax1, transform=projection)
format_ax(ax1, projection, xlim, ylim)
Grider = ax1.gridlines(draw_labels=True)
Grider.xformatter = LONGITUDE_FORMATTER
Grider.yformatter = LATITUDE_FORMATTER
Grider.xlabels_top = False
Grider.ylabels_right = False
# Label axes of a Mercator projection without degree symbols in the labels
# and formatting labels to include 1 decimal place:
ax2 = fig.add_subplot(3, 1, 2,
projection=ccrs.Mercator(),
xlim=[xmin, xmax],
ylim=[ymin, ymax])
gdf.plot(ax=ax2, transform=projection)
format_ax(ax2, projection, xlim, ylim)
Grider = ax2.gridlines(draw_labels=True)
Grider.xformatter = LONGITUDE_FORMATTER
Grider.yformatter = LATITUDE_FORMATTER
Grider.xlabels_top = False
Grider.ylabels_right = False
ax3 = fig.add_subplot(3, 1, 3,
projection=ccrs.Robinson(central_longitude=lon_c,
#central_latitude=lat_c
),
xlim=[xmin, xmax],
ylim=[ymin, ymax])
gdf.plot(ax=ax3, transform=projection)
format_ax(ax3, projection, xlim, ylim)
ax3.set_xticks([-180, -120, -60, 0, 60, 120, 180])
ax3.set_yticks([-78.5, -60, -25.5, 25.5, 60, 80])
ax3.xaxis.set_major_formatter(LONGITUDE_FORMATTER)
ax3.yaxis.set_major_formatter(LATITUDE_FORMATTER)
plt.draw()
return fig, fig.get_axes()
if __name__ == '__main__':
length = 1000
fig, axes = main()
gdf = get_standard_gdf()
xmin, ymin, xmax, ymax = gdf.total_bounds
xoff = 0.3 * (xmax - xmin)
yoff = 0.2 * (ymax - ymin)
for ax in axes:
if hasattr(ax, 'projection'):
x0, x1, y0, y1 = np.ravel(ax.get_extent())
Scaler = Scalebar(ax=ax,
metric_ccrs=ccrs.Geodetic())
Scaler.drawmapscale(lon = xmin+xoff,
lat = ymin + yoff,
length=length,
units = 'km',
barstyle='fancy',
yoffset=0.2 * (ymax - ymin)
)
fig.suptitle('Using Cartopy')
fig.show()
When the above code is run, the scalebar is misplaced in the geoaxes. The scalebar xticks are misplaced, and its yaxis height proportion is also wrong.
Here is an example: the geopandas is plotted in blue. Note that the scalebar is only visible in the second and third geoaxes.

I found a solution for the current problem.
For sake of brevety, the code is presented in here.
Feel free to check it out. The algorithm still requires some adjustment in order to support other cartopy projections.
Meanwhile, it can be applied to PlateCarree projection.

Related

How can I get rid of this dummy mappable object and still draw my colorbar in Matplotlib?

I have the code below to plot circles add them to an ax.
I color the circles with respect to a colorbar.
However, to add the colorbar to my plot, I'm using sc=plot.scatter(...) and putting the colorbar using this dummy sc. Because plt.colorbar(sc,...) requires a mappable argument. How can I get rid of this dummy sc and still draw my colorbar?
import matplotlib
import numpy as np
import os
import matplotlib as mpl
from matplotlib.colors import Normalize
import matplotlib.cm as matplotlib_cm
from matplotlib import pyplot as plt
print(matplotlib.__version__)
row_list=['row1', 'row2', 'row3']
column_list=[2]
maxProcessiveGroupLength=2
index = column_list.index(maxProcessiveGroupLength)
plot1,panel1 = plt.subplots(figsize=(20+1.5*len(column_list), 10+1.5*len(row_list)))
plt.rc('axes', edgecolor='lightgray')
#make aspect ratio square
panel1.set_aspect(1.0)
panel1.text(0.1, 1.2, 'DEBUG', horizontalalignment='center', verticalalignment='top', fontsize=60, fontweight='bold', fontname='Arial',transform=panel1.transAxes)
if (len(column_list) > 1):
panel1.set_xlim([1, index + 1])
panel1.set_xticks(np.arange(0, index + 2, 1))
else:
panel1.set_xlim([0, len(column_list)])
panel1.set_xticks(np.arange(0, len(column_list)+1, 1))
if (len(row_list) > 1):
panel1.set_ylim([1, len(row_list)])
else:
panel1.set_ylim([0, len(row_list)])
panel1.set_yticks(np.arange(0, len(row_list) + 1, 1))
panel1.set_facecolor('white')
panel1.grid(color='black')
for edge, spine in panel1.spines.items():
spine.set_visible(True)
spine.set_color('black')
xlabels = None
if (index is not None):
xlabels = column_list[0:index + 1]
ylabels = row_list
cmap = matplotlib_cm.get_cmap('Blues') # Looks better
v_min = 2
v_max = 20
norm = Normalize(v_min, v_max)
bounds = np.arange(v_min, v_max+1, 2)
# Plot the circles with color
for row_index, row in enumerate(row_list):
for column_index, processive_group_length in enumerate(column_list):
radius=0.35
color=10+column_index*3+row_index*3
circle = plt.Circle((column_index + 0.5, row_index + 0.5), radius,color=cmap(norm(color)), fill=True)
panel1.add_patch(circle)
# Used for scatter plot
x = []
y = []
c = []
for row_index, processiveGroupLength in enumerate(row_list):
x.append(row_index)
y.append(row_index)
c.append(0.5)
# This code defines the ticks on the color bar
# plot the scatter plot
sc = plt.scatter(x, y, s=0, c=c, cmap=cmap, vmin=v_min, vmax=v_max, edgecolors='black')
# colorbar to the bottom
cb = plt.colorbar(sc ,orientation='horizontal') # this works because of the scatter
cb.ax.set_xlabel("colorbar label", fontsize=50, labelpad=25)
# common for horizontal colorbar and vertical colorbar
cbax = cb.ax
cbax.tick_params(labelsize=40)
text_x = cbax.xaxis.label
text_y = cbax.yaxis.label
font = mpl.font_manager.FontProperties(size=40)
text_x.set_font_properties(font)
text_y.set_font_properties(font)
# CODE GOES HERE TO CENTER X-AXIS LABELS...
panel1.set_xticklabels([])
mticks = panel1.get_xticks()
panel1.set_xticks((mticks[:-1] + mticks[1:]) / 2, minor=True)
panel1.tick_params(axis='x', which='minor', length=0, labelsize=50)
if xlabels is not None:
panel1.set_xticklabels(xlabels,minor=True)
panel1.xaxis.set_ticks_position('top')
plt.tick_params(
axis='x', # changes apply to the x-axis
which='major', # both major and minor ticks are affected
bottom=False, # ticks along the bottom edge are off
top=False) # labels along the bottom edge are off
# CODE GOES HERE TO CENTER Y-AXIS LABELS...
panel1.set_yticklabels([])
mticks = panel1.get_yticks()
panel1.set_yticks((mticks[:-1] + mticks[1:]) / 2, minor=True)
panel1.tick_params(axis='y', which='minor', length=0, labelsize=50)
panel1.set_yticklabels(ylabels, minor=True) # fontsize
plt.tick_params(
axis='y', # changes apply to the x-axis
which='major', # both major and minor ticks are affected
left=False) # labels along the bottom edge are off
plt.show()
From the documentation of colorbar:
Note that one can create a ScalarMappable "on-the-fly" to generate
colorbars not attached to a previously drawn artist
In your example, the following allows for creating the same colorbar without the scatter plot:
cb = plt.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap), orientation='horizontal')

Shrink matplotlib parasite axis horizontally to take up approximately 25% of the image length

I have an image like the one below:
The issue is I need the curves to only take up about 25% - 30% of the image. In other words I need to shrink the size of the two parasite axes horizontally. Is this even possible?
Here is what I have so far:
"""
Plotting _____________________________________________________________________________________________________________
"""
fig = plt.figure(figsize=(20,15))
host1 = host_subplot(211, axes_class=AA.Axes)
plt.subplots_adjust(right=0.75)
#Create custom axes
cax1 = plt.axes(frameon=False)
# Now create parasite axis
par11 = host1.twiny()
par12 = host1.twiny()
top_offset = 50
new_fixed_axis1 = par12.get_grid_helper().new_fixed_axis
par12.axis["top"] = new_fixed_axis1(loc="top",
axes=par11,
offset=(0, top_offset))
par11.axis["top"].toggle(all=True)
par12.axis["top"].toggle(all=True)
# Bottom Axis
bottom_offset1 = -50
bottom_offset2 = -100
par21 = host1.twiny()
par22 = host1.twiny()
new_fixed_axis2 = par21.get_grid_helper().new_fixed_axis
par21.axis["bottom"] = new_fixed_axis2(loc="bottom",
axes=par12,
offset=(0, bottom_offset1))
# Set Host Axis Labels
host1.set_xlabel("UTC Time")
host1.set_ylabel("Elevation (km")
# Set Top Axis Labels
par11.set_xlabel("Sonde Potential Temperature (K)")
par12.set_xlabel("Sonde Relative Humidity %")
vmin, vmax = np.min(chan_1064), np.max(chan_1064)
im = host1.imshow(chan_1064, aspect="auto", cmap=get_a_color_map(), vmin=-2e-4, vmax=0.6e-2,
extent=(min(xs), max(xs), min(bin_alt_array), max(bin_alt_array)))
scatter = host1.scatter(xs, ys, s=100, color='gold')
host1.set_xlim(min(xs), max(xs))
fig.colorbar(im)
plt.draw()
leg = plt.legend( loc = 'lower right')
# Adjust Fonts
font = {'family' : 'normal',
'weight' : 'bold',
'size' : 12}
mpl.rc('font', **font)
plt.tight_layout()
plt.show()
Sorry if it's a simple solution but, I have not been able to figure it out for the life of me.

Matplotlib: different scale on negative side of the axis

Background
I am trying to show three variables on a single plot. I have connected the three points using lines of different colours based on some other variables. This is shown here
Problem
What I want to do is to have a different scale on the negative x-axis. This would help me in providing positive x_ticks, different axis label and also clear and uncluttered representation of the lines on left side of the image
Question
How to have a different positive x-axis starting from 0 towards negative direction?
Have xticks based on data plotted in that direction
Have a separate xlabel for this new axis
Additional information
I have checked other questions regarding inclusion of multiple axes e.g. this and this. However, these questions did not serve the purpose.
Code Used
font_size = 20
plt.rcParams.update({'font.size': font_size})
fig = plt.figure()
ax = fig.add_subplot(111)
#read my_data from file or create it
for case in my_data:
#Iterating over my_data
if condition1 == True:
local_linestyle = '-'
local_color = 'r'
local_line_alpha = 0.6
elif condition2 == 1:
local_linestyle = '-'
local_color = 'b'
local_line_alpha = 0.6
else:
local_linestyle = '--'
local_color = 'g'
local_line_alpha = 0.6
datapoint = [case[0], case[1], case[2]]
plt.plot(datapoint[0], 0, color=local_color)
plt.plot(-datapoint[2], 0, color=local_color)
plt.plot(0, datapoint[1], color=local_color)
plt.plot([datapoint[0], 0], [0, datapoint[1]], linestyle=local_linestyle, color=local_color)
plt.plot([-datapoint[2], 0], [0, datapoint[1]], linestyle=local_linestyle, color=local_color)
plt.show()
exit()
You can define a custom scale, where values below zero are scaled differently than those above zero.
import numpy as np
from matplotlib import scale as mscale
from matplotlib import transforms as mtransforms
from matplotlib.ticker import FuncFormatter
class AsymScale(mscale.ScaleBase):
name = 'asym'
def __init__(self, axis, **kwargs):
mscale.ScaleBase.__init__(self)
self.a = kwargs.get("a", 1)
def get_transform(self):
return self.AsymTrans(self.a)
def set_default_locators_and_formatters(self, axis):
# possibly, set a different locator and formatter here.
fmt = lambda x,pos: "{}".format(np.abs(x))
axis.set_major_formatter(FuncFormatter(fmt))
class AsymTrans(mtransforms.Transform):
input_dims = 1
output_dims = 1
is_separable = True
def __init__(self, a):
mtransforms.Transform.__init__(self)
self.a = a
def transform_non_affine(self, x):
return (x >= 0)*x + (x < 0)*x*self.a
def inverted(self):
return AsymScale.InvertedAsymTrans(self.a)
class InvertedAsymTrans(AsymTrans):
def transform_non_affine(self, x):
return (x >= 0)*x + (x < 0)*x/self.a
def inverted(self):
return AsymScale.AsymTrans(self.a)
Using this you would provide a scale parameter a that scales the negative part of the axes.
# Now that the Scale class has been defined, it must be registered so
# that ``matplotlib`` can find it.
mscale.register_scale(AsymScale)
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
ax.plot([-2, 0, 5], [0,1,0])
ax.set_xscale("asym", a=2)
ax.annotate("negative axis", xy=(.25,0), xytext=(0,-30),
xycoords = "axes fraction", textcoords="offset points", ha="center")
ax.annotate("positive axis", xy=(.75,0), xytext=(0,-30),
xycoords = "axes fraction", textcoords="offset points", ha="center")
plt.show()
The question is not very clear about what xticks and labels are desired, so I left that out for now.
Here's how to get what you want. This solution uses two twined axes object to get different scaling to the left and right of the origin, and then hides all the evidence:
import matplotlib.pyplot as plt
import matplotlib as mpl
from numbers import Number
tickkwargs = {m+k:False for k in ('bottom','top','left','right') for m in ('','label')}
p = np.zeros((10, 3, 2))
p[:,0,0] -= np.arange(10)*.1 + .5
p[:,1,1] += np.repeat(np.arange(5), 2)*.1 + .3
p[:,2,0] += np.arange(10)*.5 + 2
fig = plt.figure(figsize=(8,6))
host = fig.add_subplot(111)
par = host.twiny()
host.set_xlim(-6, 6)
par.set_xlim(-1, 1)
for ps in p:
# mask the points with negative x values
ppos = ps[ps[:,0] >= 0].T
host.plot(*ppos)
# mask the points with positive x values
pneg = ps[ps[:,0] <= 0].T
par.plot(*pneg)
# hide all possible ticks/notation text that could be set by the second x axis
par.tick_params(axis="both", **tickkwargs)
par.xaxis.get_offset_text().set_visible(False)
# fix the x tick labels so they're all positive
host.set_xticklabels(np.abs(host.get_xticks()))
fig.show()
Output:
Here's what the set of points p I used in the code above look like when plotted normally:
fig = plt.figure(figsize=(8,6))
ax = fig.gca()
for ps in p:
ax.plot(*ps.T)
fig.show()
Output:
The method of deriving a class of mscale.ScaleBase as shown in other answers may be too complicated for your purpose.
You can pass two scale transform functions to set_xscale or set_yscale, something like the following.
def get_scale(a=1): # a is the scale of your negative axis
def forward(x):
x = (x >= 0) * x + (x < 0) * x * a
return x
def inverse(x):
x = (x >= 0) * x + (x < 0) * x / a
return x
return forward, inverse
fig, ax = plt.subplots()
forward, inverse = get_scale(a=3)
ax.set_xscale('function', functions=(forward, inverse)) # this is for setting x axis
# do plotting
More examples can be found in this doc.

grouped bar chart with broken axis in matplotlib [duplicate]

I'm trying to create a plot using pyplot that has a discontinuous x-axis. The usual way this is drawn is that the axis will have something like this:
(values)----//----(later values)
where the // indicates that you're skipping everything between (values) and (later values).
I haven't been able to find any examples of this, so I'm wondering if it's even possible. I know you can join data over a discontinuity for, eg, financial data, but I'd like to make the jump in the axis more explicit. At the moment I'm just using subplots but I'd really like to have everything end up on the same graph in the end.
Paul's answer is a perfectly fine method of doing this.
However, if you don't want to make a custom transform, you can just use two subplots to create the same effect.
Rather than put together an example from scratch, there's an excellent example of this written by Paul Ivanov in the matplotlib examples (It's only in the current git tip, as it was only committed a few months ago. It's not on the webpage yet.).
This is just a simple modification of this example to have a discontinuous x-axis instead of the y-axis. (Which is why I'm making this post a CW)
Basically, you just do something like this:
import matplotlib.pylab as plt
import numpy as np
# If you're not familiar with np.r_, don't worry too much about this. It's just
# a series with points from 0 to 1 spaced at 0.1, and 9 to 10 with the same spacing.
x = np.r_[0:1:0.1, 9:10:0.1]
y = np.sin(x)
fig,(ax,ax2) = plt.subplots(1, 2, sharey=True)
# plot the same data on both axes
ax.plot(x, y, 'bo')
ax2.plot(x, y, 'bo')
# zoom-in / limit the view to different portions of the data
ax.set_xlim(0,1) # most of the data
ax2.set_xlim(9,10) # outliers only
# hide the spines between ax and ax2
ax.spines['right'].set_visible(False)
ax2.spines['left'].set_visible(False)
ax.yaxis.tick_left()
ax.tick_params(labeltop='off') # don't put tick labels at the top
ax2.yaxis.tick_right()
# Make the spacing between the two axes a bit smaller
plt.subplots_adjust(wspace=0.15)
plt.show()
To add the broken axis lines // effect, we can do this (again, modified from Paul Ivanov's example):
import matplotlib.pylab as plt
import numpy as np
# If you're not familiar with np.r_, don't worry too much about this. It's just
# a series with points from 0 to 1 spaced at 0.1, and 9 to 10 with the same spacing.
x = np.r_[0:1:0.1, 9:10:0.1]
y = np.sin(x)
fig,(ax,ax2) = plt.subplots(1, 2, sharey=True)
# plot the same data on both axes
ax.plot(x, y, 'bo')
ax2.plot(x, y, 'bo')
# zoom-in / limit the view to different portions of the data
ax.set_xlim(0,1) # most of the data
ax2.set_xlim(9,10) # outliers only
# hide the spines between ax and ax2
ax.spines['right'].set_visible(False)
ax2.spines['left'].set_visible(False)
ax.yaxis.tick_left()
ax.tick_params(labeltop='off') # don't put tick labels at the top
ax2.yaxis.tick_right()
# Make the spacing between the two axes a bit smaller
plt.subplots_adjust(wspace=0.15)
# This looks pretty good, and was fairly painless, but you can get that
# cut-out diagonal lines look with just a bit more work. The important
# thing to know here is that in axes coordinates, which are always
# between 0-1, spine endpoints are at these locations (0,0), (0,1),
# (1,0), and (1,1). Thus, we just need to put the diagonals in the
# appropriate corners of each of our axes, and so long as we use the
# right transform and disable clipping.
d = .015 # how big to make the diagonal lines in axes coordinates
# arguments to pass plot, just so we don't keep repeating them
kwargs = dict(transform=ax.transAxes, color='k', clip_on=False)
ax.plot((1-d,1+d),(-d,+d), **kwargs) # top-left diagonal
ax.plot((1-d,1+d),(1-d,1+d), **kwargs) # bottom-left diagonal
kwargs.update(transform=ax2.transAxes) # switch to the bottom axes
ax2.plot((-d,d),(-d,+d), **kwargs) # top-right diagonal
ax2.plot((-d,d),(1-d,1+d), **kwargs) # bottom-right diagonal
# What's cool about this is that now if we vary the distance between
# ax and ax2 via f.subplots_adjust(hspace=...) or plt.subplot_tool(),
# the diagonal lines will move accordingly, and stay right at the tips
# of the spines they are 'breaking'
plt.show()
I see many suggestions for this feature but no indication that it's been implemented. Here is a workable solution for the time-being. It applies a step-function transform to the x-axis. It's a lot of code, but it's fairly simple since most of it is boilerplate custom scale stuff. I have not added any graphics to indicate the location of the break, since that is a matter of style. Good luck finishing the job.
from matplotlib import pyplot as plt
from matplotlib import scale as mscale
from matplotlib import transforms as mtransforms
import numpy as np
def CustomScaleFactory(l, u):
class CustomScale(mscale.ScaleBase):
name = 'custom'
def __init__(self, axis, **kwargs):
mscale.ScaleBase.__init__(self)
self.thresh = None #thresh
def get_transform(self):
return self.CustomTransform(self.thresh)
def set_default_locators_and_formatters(self, axis):
pass
class CustomTransform(mtransforms.Transform):
input_dims = 1
output_dims = 1
is_separable = True
lower = l
upper = u
def __init__(self, thresh):
mtransforms.Transform.__init__(self)
self.thresh = thresh
def transform(self, a):
aa = a.copy()
aa[a>self.lower] = a[a>self.lower]-(self.upper-self.lower)
aa[(a>self.lower)&(a<self.upper)] = self.lower
return aa
def inverted(self):
return CustomScale.InvertedCustomTransform(self.thresh)
class InvertedCustomTransform(mtransforms.Transform):
input_dims = 1
output_dims = 1
is_separable = True
lower = l
upper = u
def __init__(self, thresh):
mtransforms.Transform.__init__(self)
self.thresh = thresh
def transform(self, a):
aa = a.copy()
aa[a>self.lower] = a[a>self.lower]+(self.upper-self.lower)
return aa
def inverted(self):
return CustomScale.CustomTransform(self.thresh)
return CustomScale
mscale.register_scale(CustomScaleFactory(1.12, 8.88))
x = np.concatenate((np.linspace(0,1,10), np.linspace(9,10,10)))
xticks = np.concatenate((np.linspace(0,1,6), np.linspace(9,10,6)))
y = np.sin(x)
plt.plot(x, y, '.')
ax = plt.gca()
ax.set_xscale('custom')
ax.set_xticks(xticks)
plt.show()
Check the brokenaxes package:
import matplotlib.pyplot as plt
from brokenaxes import brokenaxes
import numpy as np
fig = plt.figure(figsize=(5,2))
bax = brokenaxes(
xlims=((0, .1), (.4, .7)),
ylims=((-1, .7), (.79, 1)),
hspace=.05
)
x = np.linspace(0, 1, 100)
bax.plot(x, np.sin(10 * x), label='sin')
bax.plot(x, np.cos(10 * x), label='cos')
bax.legend(loc=3)
bax.set_xlabel('time')
bax.set_ylabel('value')
A very simple hack is to
scatter plot rectangles over the axes' spines and
draw the "//" as text at that position.
Worked like a charm for me:
# FAKE BROKEN AXES
# plot a white rectangle on the x-axis-spine to "break" it
xpos = 10 # x position of the "break"
ypos = plt.gca().get_ylim()[0] # y position of the "break"
plt.scatter(xpos, ypos, color='white', marker='s', s=80, clip_on=False, zorder=100)
# draw "//" on the same place as text
plt.text(xpos, ymin-0.125, r'//', fontsize=label_size, zorder=101, horizontalalignment='center', verticalalignment='center')
Example Plot:
For those interested, I've expanded upon #Paul's answer and added it to the matplotlib wrapper proplot. It can do axis "jumps", "speedups", and "slowdowns".
There is no way currently to add "crosses" that indicate the discrete jump like in Joe's answer, but I plan to add this in the future. I also plan to add a default "tick locator" that sets sensible default tick locations depending on the CutoffScale arguments.
Adressing Frederick Nord's question how to enable parallel orientation of the diagonal "breaking" lines when using a gridspec with ratios unequal 1:1, the following changes based on the proposals of Paul Ivanov and Joe Kingtons may be helpful. Width ratio can be varied using variables n and m.
import matplotlib.pylab as plt
import numpy as np
import matplotlib.gridspec as gridspec
x = np.r_[0:1:0.1, 9:10:0.1]
y = np.sin(x)
n = 5; m = 1;
gs = gridspec.GridSpec(1,2, width_ratios = [n,m])
plt.figure(figsize=(10,8))
ax = plt.subplot(gs[0,0])
ax2 = plt.subplot(gs[0,1], sharey = ax)
plt.setp(ax2.get_yticklabels(), visible=False)
plt.subplots_adjust(wspace = 0.1)
ax.plot(x, y, 'bo')
ax2.plot(x, y, 'bo')
ax.set_xlim(0,1)
ax2.set_xlim(10,8)
# hide the spines between ax and ax2
ax.spines['right'].set_visible(False)
ax2.spines['left'].set_visible(False)
ax.yaxis.tick_left()
ax.tick_params(labeltop='off') # don't put tick labels at the top
ax2.yaxis.tick_right()
d = .015 # how big to make the diagonal lines in axes coordinates
# arguments to pass plot, just so we don't keep repeating them
kwargs = dict(transform=ax.transAxes, color='k', clip_on=False)
on = (n+m)/n; om = (n+m)/m;
ax.plot((1-d*on,1+d*on),(-d,d), **kwargs) # bottom-left diagonal
ax.plot((1-d*on,1+d*on),(1-d,1+d), **kwargs) # top-left diagonal
kwargs.update(transform=ax2.transAxes) # switch to the bottom axes
ax2.plot((-d*om,d*om),(-d,d), **kwargs) # bottom-right diagonal
ax2.plot((-d*om,d*om),(1-d,1+d), **kwargs) # top-right diagonal
plt.show()
This is a hacky but pretty solution for x-axis breaks.
The solution is based on https://matplotlib.org/stable/gallery/subplots_axes_and_figures/broken_axis.html, which gets rid of the problem with positioning the break above the spine, solved by How can I plot points so they appear over top of the spines with matplotlib?
from matplotlib.patches import Rectangle
import matplotlib.pyplot as plt
def axis_break(axis, xpos=[0.1, 0.125], slant=1.5):
d = slant # proportion of vertical to horizontal extent of the slanted line
anchor = (xpos[0], -1)
w = xpos[1] - xpos[0]
h = 1
kwargs = dict(marker=[(-1, -d), (1, d)], markersize=12, zorder=3,
linestyle="none", color='k', mec='k', mew=1, clip_on=False)
axis.add_patch(Rectangle(
anchor, w, h, fill=True, color="white",
transform=axis.transAxes, clip_on=False, zorder=3)
)
axis.plot(xpos, [0, 0], transform=axis.transAxes, **kwargs)
fig, ax = plt.subplots(1,1)
plt.plot(np.arange(10))
axis_break(ax, xpos=[0.1, 0.12], slant=1.5)
axis_break(ax, xpos=[0.3, 0.31], slant=-10)
if you want to replace an axis label, this would do the trick:
from matplotlib import ticker
def replace_pos_with_label(fig, pos, label, axis):
fig.canvas.draw() # this is needed to set up the x-ticks
labs = axis.get_xticklabels()
labels = []
locs = []
for text in labs:
x = text._x
lab = text._text
if x == pos:
lab = label
labels.append(lab)
locs.append(x)
axis.xaxis.set_major_locator(ticker.FixedLocator(locs))
axis.set_xticklabels(labels)
fig, ax = plt.subplots(1,1)
plt.plot(np.arange(10))
replace_pos_with_label(fig, 0, "-10", axis=ax)
replace_pos_with_label(fig, 6, "$10^{4}$", axis=ax)
axis_break(ax, xpos=[0.1, 0.12], slant=2)

Annotating ranges of data in matplotlib

How can I annotate a range of my data? E.g., say the data from x = 5 to x = 10 is larger than some cut-off, how could I indicate that on the graph. If I was annotating by hand, I would just draw a large bracket above the range and write my annotation above the bracket.
The closest I've seen is using arrowstyle='<->' and connectionstyle='bar', to make two arrows pointing to the edges of your data with a line connecting their tails. But that doesn't quite do the right thing; the text that you enter for the annotation will end up under one of the arrows, rather than above the bar.
Here is my attempt, along with it's results:
annotate(' ', xy=(1,.5), xycoords='data',
xytext=(190, .5), textcoords='data',
arrowprops=dict(arrowstyle="<->",
connectionstyle="bar",
ec="k",
shrinkA=5, shrinkB=5,
)
)
Another problem with my attempted solution is that the squared shape of the annotating bracket does not really make it clear that I am highlighting a range (unlike, e.g., a curly brace). But I suppose that's just being nitpicky at this point.
As mentioned in this answer, you can construct curly brackets with sigmoidal functions. Below is a function that adds curly brackets just above the x-axis. The curly brackets it produces should look the same regardless of the axes limits, as long as the figure width and height don't vary.
import numpy as np
import matplotlib.pyplot as plt
def draw_brace(ax, xspan, text):
"""Draws an annotated brace on the axes."""
xmin, xmax = xspan
xspan = xmax - xmin
ax_xmin, ax_xmax = ax.get_xlim()
xax_span = ax_xmax - ax_xmin
ymin, ymax = ax.get_ylim()
yspan = ymax - ymin
resolution = int(xspan/xax_span*100)*2+1 # guaranteed uneven
beta = 300./xax_span # the higher this is, the smaller the radius
x = np.linspace(xmin, xmax, resolution)
x_half = x[:resolution//2+1]
y_half_brace = (1/(1.+np.exp(-beta*(x_half-x_half[0])))
+ 1/(1.+np.exp(-beta*(x_half-x_half[-1]))))
y = np.concatenate((y_half_brace, y_half_brace[-2::-1]))
y = ymin + (.05*y - .01)*yspan # adjust vertical position
ax.autoscale(False)
ax.plot(x, y, color='black', lw=1)
ax.text((xmax+xmin)/2., ymin+.07*yspan, text, ha='center', va='bottom')
ax = plt.gca()
ax.plot(range(10))
draw_brace(ax, (0, 8), 'large brace')
draw_brace(ax, (8, 9), 'small brace')
Output:
I modified Joooeey's answer to allow to change the vertical position of braces:
def draw_brace(ax, xspan, yy, text):
"""Draws an annotated brace on the axes."""
xmin, xmax = xspan
xspan = xmax - xmin
ax_xmin, ax_xmax = ax.get_xlim()
xax_span = ax_xmax - ax_xmin
ymin, ymax = ax.get_ylim()
yspan = ymax - ymin
resolution = int(xspan/xax_span*100)*2+1 # guaranteed uneven
beta = 300./xax_span # the higher this is, the smaller the radius
x = np.linspace(xmin, xmax, resolution)
x_half = x[:int(resolution/2)+1]
y_half_brace = (1/(1.+np.exp(-beta*(x_half-x_half[0])))
+ 1/(1.+np.exp(-beta*(x_half-x_half[-1]))))
y = np.concatenate((y_half_brace, y_half_brace[-2::-1]))
y = yy + (.05*y - .01)*yspan # adjust vertical position
ax.autoscale(False)
ax.plot(x, y, color='black', lw=1)
ax.text((xmax+xmin)/2., yy+.07*yspan, text, ha='center', va='bottom')
ax = plt.gca()
ax.plot(range(10))
draw_brace(ax, (0, 8), -0.5, 'large brace')
draw_brace(ax, (8, 9), 3, 'small brace')
Output:
Also note that in Joooeey's answer, line
x_half = x[:resolution/2+1]
should be
x_half = x[:int(resolution/2)+1]
Otherwise, the number that the script tries to use as index here is a float.
Finally, note that right now the brace will not show up if you move it out of bounds. You need to add parameter clip_on=False, like this:
ax.plot(x, y, color='black', lw=1, clip_on=False)
You can just wrap it all up in a function:
def add_range_annotation(ax, start, end, txt_str, y_height=.5, txt_kwargs=None, arrow_kwargs=None):
"""
Adds horizontal arrow annotation with text in the middle
Parameters
----------
ax : matplotlib.Axes
The axes to draw to
start : float
start of line
end : float
end of line
txt_str : string
The text to add
y_height : float
The height of the line
txt_kwargs : dict or None
Extra kwargs to pass to the text
arrow_kwargs : dict or None
Extra kwargs to pass to the annotate
Returns
-------
tuple
(annotation, text)
"""
if txt_kwargs is None:
txt_kwargs = {}
if arrow_kwargs is None:
# default to your arrowprops
arrow_kwargs = {'arrowprops':dict(arrowstyle="<->",
connectionstyle="bar",
ec="k",
shrinkA=5, shrinkB=5,
)}
trans = ax.get_xaxis_transform()
ann = ax.annotate('', xy=(start, y_height),
xytext=(end, y_height),
transform=trans,
**arrow_kwargs)
txt = ax.text((start + end) / 2,
y_height + .05,
txt_str,
**txt_kwargs)
if plt.isinteractive():
plt.draw()
return ann, txt
Alternately,
start, end = .6, .8
ax.axvspan(start, end, alpha=.2, color='r')
trans = ax.get_xaxis_transform()
ax.text((start + end) / 2, .5, 'test', transform=trans)
Here is a minor modification to guzey and jooeey's answer to plot the flower braces outside the axes.
def draw_brace(ax, xspan, yy, text):
"""Draws an annotated brace outside the axes."""
xmin, xmax = xspan
xspan = xmax - xmin
ax_xmin, ax_xmax = ax.get_xlim()
xax_span = ax_xmax - ax_xmin
ymin, ymax = ax.get_ylim()
yspan = ymax - ymin
resolution = int(xspan/xax_span*100)*2+1 # guaranteed uneven
beta = 300./xax_span # the higher this is, the smaller the radius
x = np.linspace(xmin, xmax, resolution)
x_half = x[:int(resolution/2)+1]
y_half_brace = (1/(1.+np.exp(-beta*(x_half-x_half[0])))
+ 1/(1.+np.exp(-beta*(x_half-x_half[-1]))))
y = np.concatenate((y_half_brace, y_half_brace[-2::-1]))
y = yy + (.05*y - .01)*yspan # adjust vertical position
ax.autoscale(False)
ax.plot(x, -y, color='black', lw=1, clip_on=False)
ax.text((xmax+xmin)/2., -yy-.17*yspan, text, ha='center', va='bottom')
# Sample code
fmax = 1
fstart = -100
fend = 0
frise = 50
ffall = 20
def S(x):
if x<=0:
return 0
elif x>=1:
return 1
else:
return 1/(1+np.exp((1/(x-1))+(1/x)))
x = np.linspace(700,1000,500)
lam = [fmax*(S((i-880)/60)-S(((i-1000)/25)+1)) for i in x]
fig = plt.figure(1)
ax = fig.add_subplot(111)
plt.plot(x,lam)
plt.xlim([850,1000])
ax.set_aspect(50,adjustable='box')
plt.ylabel('$\lambda$')
plt.xlabel('$x$')
ax.xaxis.set_label_coords(0.5, -0.35)
draw_brace(ax, (900,950),0.2, 'rise')
draw_brace(ax, (980,1000),0.2, 'fall')
plt.text(822,0.95,'$(\lambda_{\mathrm{max}})$')
Sample output
a minor modification of the draw_brace of #Joooeey and #guezy to have also the brace upside down
+argument upsidedown
def draw_brace(ax, xspan, yy, text, upsidedown=False):
"""Draws an annotated brace on the axes."""
# shamelessly copied from https://stackoverflow.com/questions/18386210/annotating-ranges-of-data-in-matplotlib
xmin, xmax = xspan
xspan = xmax - xmin
ax_xmin, ax_xmax = ax.get_xlim()
xax_span = ax_xmax - ax_xmin
ymin, ymax = ax.get_ylim()
yspan = ymax - ymin
resolution = int(xspan/xax_span*100)*2+1 # guaranteed uneven
beta = 300./xax_span # the higher this is, the smaller the radius
x = np.linspace(xmin, xmax, resolution)
x_half = x[:int(resolution/2)+1]
y_half_brace = (1/(1.+np.exp(-beta*(x_half-x_half[0])))
+ 1/(1.+np.exp(-beta*(x_half-x_half[-1]))))
if upsidedown:
y = np.concatenate((y_half_brace[-2::-1], y_half_brace))
else:
y = np.concatenate((y_half_brace, y_half_brace[-2::-1]))
y = yy + (.05*y - .01)*yspan # adjust vertical position
ax.autoscale(False)
line = ax.plot(x, y, color='black', lw=1)
if upsidedown:
text = ax.text((xmax+xmin)/2., yy+-.07*yspan, text, ha='center', va='bottom',fontsize=7)
else:
text = ax.text((xmax+xmin)/2., yy+.07*yspan, text, ha='center', va='bottom',fontsize=7)
return line, text
I updated the previous answers to have some of the features I wanted, like an option for a vertical brace, that I wanted to place in multi-plot figures. One still has to futz with the beta_scale parameter sometimes depending on the scale of the data that one is applying this to.
def rotate_point(x, y, angle_rad):
cos,sin = np.cos(angle_rad),np.sin(angle_rad)
return cos*x-sin*y,sin*x+cos*y
def draw_brace(ax, span, position, text, text_pos, brace_scale=1.0, beta_scale=300., rotate=False, rotate_text=False):
'''
all positions and sizes are in axes units
span: size of the curl
position: placement of the tip of the curl
text: label to place somewhere
text_pos: position for the label
beta_scale: scaling for the curl, higher makes a smaller radius
rotate: true rotates to place the curl vertically
rotate_text: true rotates the text vertically
'''
# get the total width to help scale the figure
ax_xmin, ax_xmax = ax.get_xlim()
xax_span = ax_xmax - ax_xmin
resolution = int(span/xax_span*100)*2+1 # guaranteed uneven
beta = beta_scale/xax_span # the higher this is, the smaller the radius
# center the shape at (0, 0)
x = np.linspace(-span/2., span/2., resolution)
# calculate the shape
x_half = x[:int(resolution/2)+1]
y_half_brace = (1/(1.+np.exp(-beta*(x_half-x_half[0])))
+ 1/(1.+np.exp(-beta*(x_half-x_half[-1]))))
y = np.concatenate((y_half_brace, y_half_brace[-2::-1]))
# put the tip of the curl at (0, 0)
max_y = np.max(y)
min_y = np.min(y)
y /= (max_y-min_y)
y *= brace_scale
y -= max_y
# rotate the trace before shifting
if rotate:
x,y = rotate_point(x, y, np.pi/2)
# shift to the user's spot
x += position[0]
y += position[1]
ax.autoscale(False)
ax.plot(x, y, color='black', lw=1, clip_on=False)
# put the text
ax.text(text_pos[0], text_pos[1], text, ha='center', va='bottom', rotation=90 if rotate_text else 0)