Discrete Color Bar with Tick labels in between colors - matplotlib

I am trying to plot some data with a discrete color bar. I was following the example given (https://gist.github.com/jakevdp/91077b0cae40f8f8244a) but the issue is this example does not work 1-1 with different spacing. For example, the spacing in the example in the link is for only increasing by 1 but my data is increasing by 0.5. You can see the output from the code I have.. Any help with this would be appreciated. I know I am missing something key here but cant figure it out.
import matplotlib.pylab as plt
import numpy as np
def discrete_cmap(N, base_cmap=None):
"""Create an N-bin discrete colormap from the specified input map"""
# Note that if base_cmap is a string or None, you can simply do
# return plt.cm.get_cmap(base_cmap, N)
# The following works for string, None, or a colormap instance:
base = plt.cm.get_cmap(base_cmap)
color_list = base(np.linspace(0, 1, N))
cmap_name = base.name + str(N)
return base.from_list(cmap_name, color_list, N)
x = np.random.randn(40)
y = np.random.randn(40)
c = np.random.randint(num, size=40)
plt.scatter(x, y, c=c, s=50, cmap=discrete_cmap(num, 'jet'))
plt.clim(-0.5, num - 0.5)

Not sure what version of matplotlib/pyplot introduced this, but plt.get_cmap now supports an int argument specifying the number of colors you want to get, for discrete colormaps.
This automatically results in the colorbar being discrete.
By the way, pandas has an even better handling of the colorbar.
import numpy as np
from matplotlib import pyplot as plt
# remove if not using Jupyter/IPython
%matplotlib inline
# choose number of clusters and number of points in each cluster
n_clusters = 5
n_samples = 20
# there are fancier ways to do this
clusters = np.array([k for k in range(n_clusters) for i in range(n_samples)])
# generate the coordinates of the center
# of each cluster by shuffling a range of values
clusters_x = np.arange(n_clusters)
clusters_y = np.arange(n_clusters)
# get dicts like cluster -> center coordinate
x_dict = dict(enumerate(clusters_x))
y_dict = dict(enumerate(clusters_y))
# get coordinates of cluster center for each point
x = np.array(list(x_dict[k] for k in clusters)).astype(float)
y = np.array(list(y_dict[k] for k in clusters)).astype(float)
# add noise
x += np.random.normal(scale=0.5, size=n_clusters*n_samples)
y += np.random.normal(scale=0.5, size=n_clusters*n_samples)
### Finally, plot
fig, ax = plt.subplots(figsize=(12,8))
# get discrete colormap
cmap = plt.get_cmap('viridis', n_clusters)
# scatter points
scatter = ax.scatter(x, y, c=clusters, cmap=cmap)
# scatter cluster centers
ax.scatter(clusters_x, clusters_y, c='red')
# add colorbar
cbar = plt.colorbar(scatter)
# set ticks locations (not very elegant, but it works):
# - shift by 0.5
# - scale so that the last value is at the center of the last color
tick_locs = (np.arange(n_clusters) + 0.5)*(n_clusters-1)/n_clusters
# set tick labels (as before)

Ok so this is the hack I found for my own question. I am sure there is a better way to do this but this works for what I am doing. Feel free to suggest a better way to do this.
import numpy as np
import matplotlib.pylab as plt
def discrete_cmap(N, base_cmap=None):
"""Create an N-bin discrete colormap from the specified input map"""
# Note that if base_cmap is a string or None, you can simply do
# return plt.cm.get_cmap(base_cmap, N)
# The following works for string, None, or a colormap instance:
base = plt.cm.get_cmap(base_cmap)
color_list = base(np.linspace(0, 1, N))
cmap_name = base.name + str(N)
return base.from_list(cmap_name, color_list, N)
x = np.random.randn(40)
y = np.random.randn(40)
c = np.random.randint(num, size=40)
plt.scatter(x, y, c=c, s=50, cmap=discrete_cmap(num, 'jet'))
plt.clim(-0.5, num - 0.5)
For some reason I cannot upload the image associated with the code above. I get an error when uploading so not sure how to show the final example. But simply I set the color bar axes for tick labels for a vertical color bar and passed in the labels I want and it produced the correct output.


How to plot a 3D function with colors given spacing 2D input

Let's assume I have 3 arrays defined as:
Then I have a function that takes those 3 values and gives me the desired output, let's assume it is like:
f = (v1 + v2*10)/v3
I want to plot that function on a 3D plot with axis v1,v2,v3 and color it's surface depending on its value.
More than the best way to plot it, I was also interested in how to scroll all the values in the in vectors and build the function point by point.
I have been trying with for loops inside other for loops but I am always getting one error.
I tried this but i'm always getting a line instead of a surface
import mpl_toolkits.mplot3d.axes3d as axes3d
import sympy
from sympy import symbols, Function
# Parameters I use in the function
L = 132
alpha = 45*math.pi/180
beta = 0
s,t = symbols('s,t')
z = Function('z')(s,t)
figure = plt.figure(figsize=(8,8))
ax = figure.add_subplot(1, 1, 1, projection='3d')
# experiment with various range of data in x and y
x1 = np.linspace(-40,-40,100)
y1 = np.linspace(-40,40,100)
x,y = np.meshgrid(x1,y1)
# My function Z
den = math.sqrt((c1*c2)+s1)
In this example I'm going to show you how to achieve your goal. Specifically, I use Numpy because it supports vectorized operations, hence I avoid for loops.
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize
import matplotlib.cm as cm
# Parameters I use in the function
L = 132
alpha = 45*np.pi/180
beta = 0
figure = plt.figure()
ax = figure.add_subplot(1, 1, 1, projection='3d')
# experiment with various range of data in x and y
x1 = np.linspace(-40,40,100)
y1 = np.linspace(-40,40,100)
x,y = np.meshgrid(x1,y1)
# My function Z
den = np.sqrt((c1*c2)+s1)
# compute the color values according to some other function
color_values = np.sqrt(x**2 + y**2 + z**2)
# normalize color values between 0 and 1
norm = Normalize(vmin=color_values.min(), vmax=color_values.max())
norm_color_values = norm(color_values)
# chose a colormap and create colors starting from the normalized values
cmap = cm.rainbow
colors = cmap(norm_color_values)
surf = ax.plot_surface(x,y,z,facecolors=colors)
# add a colorbar
figure.colorbar(cm.ScalarMappable(norm=norm, cmap=cmap), label="radius")

how to add one more plot in matplotlib script

My matplotlib script plots a file "band.hdf5", which is in hdf5 format, with
f = h5py.File('band.hdf5', 'r')
I want to add one more hdf5 file "band-new.hdf5" here in such a way that the output plot will have one more plot on right side for new file. Y-axis label should be avoided for "band-new.hdf5" and X-axis label should be common for both file.
The header of the script is this
import h5py
import matplotlib.pyplot as plt
import warnings
import matplotlib
This script is taken from the accepted answer
Is this the solution you needed?
I take the code from
and adapted it to draw two plots side-to-side from the data you shared.
import h5py
import matplotlib.pyplot as plt
import warnings
import matplotlib
warnings.filterwarnings("ignore") # Ignore all warnings
cmap = matplotlib.cm.get_cmap('jet', 4)
params = {
'mathtext.default': 'regular',
'axes.linewidth': 1.2,
'axes.edgecolor': 'Black',
'font.family' : 'serif'
#get the viridis cmap with a resolution of 3
#apply a scale to the y axis. I'm just picking an arbritrary number here
scale = 10
offset = 0 #set this to a non-zero value if you want to have your lines offset in a waterfall style effect
f_left = h5py.File('band-222.hdf5', 'r')
f_right = h5py.File('band-332.hdf5', 'r')
print ('datasets from left are:')
print ('datasets from right are:')
fig = plt.figure(figsize=(16,8))
ax1 = fig.add_subplot(121)
lbl = {0:'AB', 1:'BC', 2:'CD', 3:'fourth'}
for i, section in enumerate(dist):
for nbnd, _ in enumerate(freq[i][0]):
x = section # to_list() you may need to convert sample to list.
y = (freq[i, :, nbnd] + offset*nbnd) * scale
if (nbnd<3):
ax1.plot(x, y, c=color, lw=2.0, alpha=0.8, label = lbl[nbnd] if nbnd < 3 and i == 0 else None)
# Labels and axis limit and ticks
ax1.set_ylabel(r'Frequency (THz)', fontsize=12)
ax1.set_xlabel(r'Wave Vector (q)', fontsize=12)
xticks=[dist[i][0] for i in range(len(dist))]
# Plot grid
ax1.grid(which='major', axis='x', c='green', lw=2.5, linestyle='--', alpha=0.8)
ax2 = fig.add_subplot(122)
lbl = {0:'AB', 1:'BC', 2:'CD', 3:'fourth'}
for i, section in enumerate(dist):
for nbnd, _ in enumerate(freq[i][0]):
x = section # to_list() you may need to convert sample to list.
y = (freq[i, :, nbnd] + offset*nbnd) * scale
if (nbnd<3):
ax2.plot(x, y, c=color, lw=2.0, alpha=0.8, label = lbl[nbnd] if nbnd < 3 and i == 0 else None)
# remove y axis
ax2.set_xlabel(r'Wave Vector (q)', fontsize=12)
xticks=[dist[i][0] for i in range(len(dist))]
# Plot grid
ax2.grid(which='major', axis='x', c='green', lw=2.5, linestyle='--', alpha=0.8)
fig.tight_layout() # Or equivalently, "plt.tight_layout()"
# Save to pdf
plt.savefig('plots.pdf', bbox_inches='tight')
The final figure is like this.

Scatterplot with marginal KDE plots and multiple categories in Matplotlib

I'd like a function in Matplotlib similar to the Matlab 'scatterhist' function which takes continuous values for 'x' and 'y' axes, plus a categorical variable as input; and produces a scatter plot with marginal KDE plots and two or more categorical variables in different colours as output:
I've found examples of scatter plots with marginal histograms in Matplotlib, marginal histograms in Seaborn jointplot, overlapping histograms in Matplotlib and marginal KDE plots in Matplotib ; but I haven't found any examples which combine scatter plots with marginal KDE plots and are colour coded to indicate different categories.
If possible, I'd like a solution which uses 'vanilla' Matplotlib without Seaborn, as this will avoid dependencies and allow complete control and customisation of the plot appearance using standard Matplotlib commands.
I was going to try to write something based on the above examples; but before doing so wanted to check whether a similar function was already available, and if not then would be grateful for any guidance on the best approach to use.
#ImportanceOfBeingEarnest: Many thanks for your help.
Here's my first attempt at a solution.
It's a bit hacky but achieves my objectives, and is fully customisable using standard matplotlib commands. I'm posting the code here with annotations in case anyone else wishes to use it or develop it further. If there are any improvements or neater ways of writing the code I'm always keen to learn and would be grateful for guidance.
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import gridspec
from scipy import stats
label = ['Setosa','Versicolor','Virginica'] # List of labels for categories
cl = ['b','r','y'] # List of colours for categories
categories = len(label)
sample_size = 20 # Number of samples in each category
# Create numpy arrays for dummy x and y data:
x = np.zeros(shape=(categories, sample_size))
y = np.zeros(shape=(categories, sample_size))
# Generate random data for each categorical variable:
for n in range (0, categories):
x[n,:] = np.array(np.random.randn(sample_size)) + 4 + n
y[n,:] = np.array(np.random.randn(sample_size)) + 6 - n
# Set up 4 subplots as axis objects using GridSpec:
gs = gridspec.GridSpec(2, 2, width_ratios=[1,3], height_ratios=[3,1])
# Add space between scatter plot and KDE plots to accommodate axis labels:
gs.update(hspace=0.3, wspace=0.3)
# Set background canvas colour to White instead of grey default
fig = plt.figure()
ax = plt.subplot(gs[0,1]) # Instantiate scatter plot area and axis range
ax.set_xlim(x.min(), x.max())
ax.set_ylim(y.min(), y.max())
axl = plt.subplot(gs[0,0], sharey=ax) # Instantiate left KDE plot area
axl.get_xaxis().set_visible(False) # Hide tick marks and spines
axb = plt.subplot(gs[1,1], sharex=ax) # Instantiate bottom KDE plot area
axb.get_xaxis().set_visible(False) # Hide tick marks and spines
axc = plt.subplot(gs[1,0]) # Instantiate legend plot area
axc.axis('off') # Hide tick marks and spines
# Plot data for each categorical variable as scatter and marginal KDE plots:
for n in range (0, categories):
ax.scatter(x[n],y[n], color='none', label=label[n], s=100, edgecolor= cl[n])
kde = stats.gaussian_kde(x[n,:])
xx = np.linspace(x.min(), x.max(), 1000)
axb.plot(xx, kde(xx), color=cl[n])
kde = stats.gaussian_kde(y[n,:])
yy = np.linspace(y.min(), y.max(), 1000)
axl.plot(kde(yy), yy, color=cl[n])
# Copy legend object from scatter plot to lower left subplot and display:
# NB 'scatterpoints = 1' customises legend box to show only 1 handle (icon) per label
handles, labels = ax.get_legend_handles_labels()
axc.legend(handles, labels, scatterpoints = 1, loc = 'center', fontsize = 12)
Version 2, using Pandas to import 'real' data from a csv file, with a different number of entries in each category. (csv file format: row 0 = headers; col 0 = x values, col 1 = y values, col 2 = category labels). Scatterplot axis and legend labels are generated from column headers.
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import gridspec
from scipy import stats
import pandas as pd
Create scatter plot with marginal KDE plots
from csv file with 3 cols of data
formatted as following example (first row of
data are headers):
'x_label', 'y_label', 'category_label'
4,6, 'virginica' etc...
df = pd.read_csv('iris_2.csv') # enter filename for csv file to be imported (within current working directory)
cl = ['b','r','y', 'g', 'm', 'k'] # Custom list of colours for each categories - increase as needed...
headers = list(df.columns) # Extract list of column headers
# Find min and max values for all x (= col [0]) and y (= col [1]) in dataframe:
xmin, xmax = df.min(axis=0)[0], df.max(axis=0)[0]
ymin, ymax = df.min(axis=0)[1], df.max(axis=0)[1]
# Create a list of all unique categories which occur in the right hand column (ie index '2'):
category_list = df.ix[:,2].unique()
# Set up 4 subplots and aspect ratios as axis objects using GridSpec:
gs = gridspec.GridSpec(2, 2, width_ratios=[1,3], height_ratios=[3,1])
# Add space between scatter plot and KDE plots to accommodate axis labels:
gs.update(hspace=0.3, wspace=0.3)
fig = plt.figure() # Set background canvas colour to White instead of grey default
ax = plt.subplot(gs[0,1]) # Instantiate scatter plot area and axis range
ax.set_xlim(xmin, xmax)
ax.set_ylim(ymin, ymax)
ax.set_xlabel(headers[0], fontsize = 14)
ax.set_ylabel(headers[1], fontsize = 14)
ax.yaxis.labelpad = 10 # adjust space between x and y axes and their labels if needed
axl = plt.subplot(gs[0,0], sharey=ax) # Instantiate left KDE plot area
axl.get_xaxis().set_visible(False) # Hide tick marks and spines
axb = plt.subplot(gs[1,1], sharex=ax) # Instantiate bottom KDE plot area
axb.get_xaxis().set_visible(False) # Hide tick marks and spines
axc = plt.subplot(gs[1,0]) # Instantiate legend plot area
axc.axis('off') # Hide tick marks and spines
# For each category in the list...
for n in range(0, len(category_list)):
# Create a sub-table containing only entries matching current category:
st = df.loc[df[headers[2]] == category_list[n]]
# Select first two columns of sub-table as x and y values to be plotted:
x = st[headers[0]]
y = st[headers[1]]
# Plot data for each categorical variable as scatter and marginal KDE plots:
ax.scatter(x,y, color='none', s=100, edgecolor= cl[n], label = category_list[n])
kde = stats.gaussian_kde(x)
xx = np.linspace(xmin, xmax, 1000)
axb.plot(xx, kde(xx), color=cl[n])
kde = stats.gaussian_kde(y)
yy = np.linspace(ymin, ymax, 1000)
axl.plot(kde(yy), yy, color=cl[n])
# Copy legend object from scatter plot to lower left subplot and display:
# NB 'scatterpoints = 1' customises legend box to show only 1 handle (icon) per label
handles, labels = ax.get_legend_handles_labels()
axc.legend(handles, labels, title = headers[2], scatterpoints = 1, loc = 'center', fontsize = 12)

grouped bar chart with broken axis in matplotlib [duplicate]

I'm trying to create a plot using pyplot that has a discontinuous x-axis. The usual way this is drawn is that the axis will have something like this:
(values)----//----(later values)
where the // indicates that you're skipping everything between (values) and (later values).
I haven't been able to find any examples of this, so I'm wondering if it's even possible. I know you can join data over a discontinuity for, eg, financial data, but I'd like to make the jump in the axis more explicit. At the moment I'm just using subplots but I'd really like to have everything end up on the same graph in the end.
Paul's answer is a perfectly fine method of doing this.
However, if you don't want to make a custom transform, you can just use two subplots to create the same effect.
Rather than put together an example from scratch, there's an excellent example of this written by Paul Ivanov in the matplotlib examples (It's only in the current git tip, as it was only committed a few months ago. It's not on the webpage yet.).
This is just a simple modification of this example to have a discontinuous x-axis instead of the y-axis. (Which is why I'm making this post a CW)
Basically, you just do something like this:
import matplotlib.pylab as plt
import numpy as np
# If you're not familiar with np.r_, don't worry too much about this. It's just
# a series with points from 0 to 1 spaced at 0.1, and 9 to 10 with the same spacing.
x = np.r_[0:1:0.1, 9:10:0.1]
y = np.sin(x)
fig,(ax,ax2) = plt.subplots(1, 2, sharey=True)
# plot the same data on both axes
ax.plot(x, y, 'bo')
ax2.plot(x, y, 'bo')
# zoom-in / limit the view to different portions of the data
ax.set_xlim(0,1) # most of the data
ax2.set_xlim(9,10) # outliers only
# hide the spines between ax and ax2
ax.tick_params(labeltop='off') # don't put tick labels at the top
# Make the spacing between the two axes a bit smaller
To add the broken axis lines // effect, we can do this (again, modified from Paul Ivanov's example):
import matplotlib.pylab as plt
import numpy as np
# If you're not familiar with np.r_, don't worry too much about this. It's just
# a series with points from 0 to 1 spaced at 0.1, and 9 to 10 with the same spacing.
x = np.r_[0:1:0.1, 9:10:0.1]
y = np.sin(x)
fig,(ax,ax2) = plt.subplots(1, 2, sharey=True)
# plot the same data on both axes
ax.plot(x, y, 'bo')
ax2.plot(x, y, 'bo')
# zoom-in / limit the view to different portions of the data
ax.set_xlim(0,1) # most of the data
ax2.set_xlim(9,10) # outliers only
# hide the spines between ax and ax2
ax.tick_params(labeltop='off') # don't put tick labels at the top
# Make the spacing between the two axes a bit smaller
# This looks pretty good, and was fairly painless, but you can get that
# cut-out diagonal lines look with just a bit more work. The important
# thing to know here is that in axes coordinates, which are always
# between 0-1, spine endpoints are at these locations (0,0), (0,1),
# (1,0), and (1,1). Thus, we just need to put the diagonals in the
# appropriate corners of each of our axes, and so long as we use the
# right transform and disable clipping.
d = .015 # how big to make the diagonal lines in axes coordinates
# arguments to pass plot, just so we don't keep repeating them
kwargs = dict(transform=ax.transAxes, color='k', clip_on=False)
ax.plot((1-d,1+d),(-d,+d), **kwargs) # top-left diagonal
ax.plot((1-d,1+d),(1-d,1+d), **kwargs) # bottom-left diagonal
kwargs.update(transform=ax2.transAxes) # switch to the bottom axes
ax2.plot((-d,d),(-d,+d), **kwargs) # top-right diagonal
ax2.plot((-d,d),(1-d,1+d), **kwargs) # bottom-right diagonal
# What's cool about this is that now if we vary the distance between
# ax and ax2 via f.subplots_adjust(hspace=...) or plt.subplot_tool(),
# the diagonal lines will move accordingly, and stay right at the tips
# of the spines they are 'breaking'
I see many suggestions for this feature but no indication that it's been implemented. Here is a workable solution for the time-being. It applies a step-function transform to the x-axis. It's a lot of code, but it's fairly simple since most of it is boilerplate custom scale stuff. I have not added any graphics to indicate the location of the break, since that is a matter of style. Good luck finishing the job.
from matplotlib import pyplot as plt
from matplotlib import scale as mscale
from matplotlib import transforms as mtransforms
import numpy as np
def CustomScaleFactory(l, u):
class CustomScale(mscale.ScaleBase):
name = 'custom'
def __init__(self, axis, **kwargs):
self.thresh = None #thresh
def get_transform(self):
return self.CustomTransform(self.thresh)
def set_default_locators_and_formatters(self, axis):
class CustomTransform(mtransforms.Transform):
input_dims = 1
output_dims = 1
is_separable = True
lower = l
upper = u
def __init__(self, thresh):
self.thresh = thresh
def transform(self, a):
aa = a.copy()
aa[a>self.lower] = a[a>self.lower]-(self.upper-self.lower)
aa[(a>self.lower)&(a<self.upper)] = self.lower
return aa
def inverted(self):
return CustomScale.InvertedCustomTransform(self.thresh)
class InvertedCustomTransform(mtransforms.Transform):
input_dims = 1
output_dims = 1
is_separable = True
lower = l
upper = u
def __init__(self, thresh):
self.thresh = thresh
def transform(self, a):
aa = a.copy()
aa[a>self.lower] = a[a>self.lower]+(self.upper-self.lower)
return aa
def inverted(self):
return CustomScale.CustomTransform(self.thresh)
return CustomScale
mscale.register_scale(CustomScaleFactory(1.12, 8.88))
x = np.concatenate((np.linspace(0,1,10), np.linspace(9,10,10)))
xticks = np.concatenate((np.linspace(0,1,6), np.linspace(9,10,6)))
y = np.sin(x)
plt.plot(x, y, '.')
ax = plt.gca()
Check the brokenaxes package:
import matplotlib.pyplot as plt
from brokenaxes import brokenaxes
import numpy as np
fig = plt.figure(figsize=(5,2))
bax = brokenaxes(
xlims=((0, .1), (.4, .7)),
ylims=((-1, .7), (.79, 1)),
x = np.linspace(0, 1, 100)
bax.plot(x, np.sin(10 * x), label='sin')
bax.plot(x, np.cos(10 * x), label='cos')
A very simple hack is to
scatter plot rectangles over the axes' spines and
draw the "//" as text at that position.
Worked like a charm for me:
# plot a white rectangle on the x-axis-spine to "break" it
xpos = 10 # x position of the "break"
ypos = plt.gca().get_ylim()[0] # y position of the "break"
plt.scatter(xpos, ypos, color='white', marker='s', s=80, clip_on=False, zorder=100)
# draw "//" on the same place as text
plt.text(xpos, ymin-0.125, r'//', fontsize=label_size, zorder=101, horizontalalignment='center', verticalalignment='center')
Example Plot:
For those interested, I've expanded upon #Paul's answer and added it to the matplotlib wrapper proplot. It can do axis "jumps", "speedups", and "slowdowns".
There is no way currently to add "crosses" that indicate the discrete jump like in Joe's answer, but I plan to add this in the future. I also plan to add a default "tick locator" that sets sensible default tick locations depending on the CutoffScale arguments.
Adressing Frederick Nord's question how to enable parallel orientation of the diagonal "breaking" lines when using a gridspec with ratios unequal 1:1, the following changes based on the proposals of Paul Ivanov and Joe Kingtons may be helpful. Width ratio can be varied using variables n and m.
import matplotlib.pylab as plt
import numpy as np
import matplotlib.gridspec as gridspec
x = np.r_[0:1:0.1, 9:10:0.1]
y = np.sin(x)
n = 5; m = 1;
gs = gridspec.GridSpec(1,2, width_ratios = [n,m])
ax = plt.subplot(gs[0,0])
ax2 = plt.subplot(gs[0,1], sharey = ax)
plt.setp(ax2.get_yticklabels(), visible=False)
plt.subplots_adjust(wspace = 0.1)
ax.plot(x, y, 'bo')
ax2.plot(x, y, 'bo')
# hide the spines between ax and ax2
ax.tick_params(labeltop='off') # don't put tick labels at the top
d = .015 # how big to make the diagonal lines in axes coordinates
# arguments to pass plot, just so we don't keep repeating them
kwargs = dict(transform=ax.transAxes, color='k', clip_on=False)
on = (n+m)/n; om = (n+m)/m;
ax.plot((1-d*on,1+d*on),(-d,d), **kwargs) # bottom-left diagonal
ax.plot((1-d*on,1+d*on),(1-d,1+d), **kwargs) # top-left diagonal
kwargs.update(transform=ax2.transAxes) # switch to the bottom axes
ax2.plot((-d*om,d*om),(-d,d), **kwargs) # bottom-right diagonal
ax2.plot((-d*om,d*om),(1-d,1+d), **kwargs) # top-right diagonal
This is a hacky but pretty solution for x-axis breaks.
The solution is based on https://matplotlib.org/stable/gallery/subplots_axes_and_figures/broken_axis.html, which gets rid of the problem with positioning the break above the spine, solved by How can I plot points so they appear over top of the spines with matplotlib?
from matplotlib.patches import Rectangle
import matplotlib.pyplot as plt
def axis_break(axis, xpos=[0.1, 0.125], slant=1.5):
d = slant # proportion of vertical to horizontal extent of the slanted line
anchor = (xpos[0], -1)
w = xpos[1] - xpos[0]
h = 1
kwargs = dict(marker=[(-1, -d), (1, d)], markersize=12, zorder=3,
linestyle="none", color='k', mec='k', mew=1, clip_on=False)
anchor, w, h, fill=True, color="white",
transform=axis.transAxes, clip_on=False, zorder=3)
axis.plot(xpos, [0, 0], transform=axis.transAxes, **kwargs)
fig, ax = plt.subplots(1,1)
axis_break(ax, xpos=[0.1, 0.12], slant=1.5)
axis_break(ax, xpos=[0.3, 0.31], slant=-10)
if you want to replace an axis label, this would do the trick:
from matplotlib import ticker
def replace_pos_with_label(fig, pos, label, axis):
fig.canvas.draw() # this is needed to set up the x-ticks
labs = axis.get_xticklabels()
labels = []
locs = []
for text in labs:
x = text._x
lab = text._text
if x == pos:
lab = label
fig, ax = plt.subplots(1,1)
replace_pos_with_label(fig, 0, "-10", axis=ax)
replace_pos_with_label(fig, 6, "$10^{4}$", axis=ax)
axis_break(ax, xpos=[0.1, 0.12], slant=2)

How can I plot the label on the line of a lineplot?

I would like to plot labels on a line of a lineplot in matplotlib.
Minimal example
#!/usr/bin/env python
import numpy as np
import seaborn as sns
sns.set_palette(sns.color_palette("Greens", 8))
from scipy.ndimage.filters import gaussian_filter1d
for i in range(8):
# Create data
y = np.roll(np.cumsum(np.random.randn(1000, 1)),
np.random.randint(0, 1000))
y = gaussian_filter1d(y, 10)
sns.plt.plot(y, label=str(i))
instead, I would prefer something like
Maybe a bit hacky, but does this solve your problem?
#!/usr/bin/env python
import numpy as np
import seaborn as sns
sns.set_palette(sns.color_palette("Greens", 8))
from scipy.ndimage.filters import gaussian_filter1d
for i in range(8):
# Create data
y = np.roll(np.cumsum(np.random.randn(1000, 1)),
np.random.randint(0, 1000))
y = gaussian_filter1d(y, 10)
p = sns.plt.plot(y, label=str(i))
color = p[0].get_color()
for x in [250, 500, 750]:
y2 = y[x]
sns.plt.plot(x, y2, 'o', color='white', markersize=9)
sns.plt.plot(x, y2, 'k', marker="$%s$" % str(i), color=color,
Here's the result I get:
Edit: I gave it a little more thought and came up with a solution that automatically tries to find the best possible position for the labels in order to avoid the labels being positioned at x-values where two lines are very close to each other (which could e.g. lead to overlap between the labels):
#!/usr/bin/env python
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
sns.set_palette(sns.color_palette("Greens", 8))
from scipy.ndimage.filters import gaussian_filter1d
# -----------------------------------------------------------------------------
def inline_legend(lines, n_markers=1):
Take a list containing the lines of a plot (typically the result of
calling plt.gca().get_lines()), and add the labels for those lines on the
lines themselves; more precisely, put each label n_marker times on the
[Source of problem: https://stackoverflow.com/q/43573623/4100721]
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d
from math import fabs
def chunkify(a, n):
Split list a into n approximately equally sized chunks and return the
indices (start/end) of those chunks.
[Idea: Props to http://stackoverflow.com/a/2135920/4100721 :)]
k, m = divmod(len(a), n)
return list([(i * k + min(i, m), (i + 1) * k + min(i + 1, m))
for i in range(n)])
# Calculate linear interpolations of every line. This is necessary to
# compare the values of the lines if they use different x-values
interpolations = [interp1d(_.get_xdata(), _.get_ydata())
for _ in lines]
# Loop over all lines
for idx, line in enumerate(lines):
# Get basic properties of the current line
label = line.get_label()
color = line.get_color()
x_values = line.get_xdata()
y_values = line.get_ydata()
# Get all lines that are not the current line, as well as the
# functions that are linear interpolations of them
other_lines = lines[0:idx] + lines[idx+1:]
other_functions = interpolations[0:idx] + interpolations[idx+1:]
# Split the x-values in chunks to get regions in which to put
# labels. Creating 3 times as many chunks as requested and using only
# every third ensures that no two labels for the same line are too
# close to each other.
chunks = list(chunkify(line.get_xdata(), 3*n_markers))[::3]
# For each chunk, find the optimal position of the label
for chunk_nr in range(n_markers):
# Start and end index of the current chunk
chunk_start = chunks[chunk_nr][0]
chunk_end = chunks[chunk_nr][1]
# For the given chunk, loop over all x-values of the current line,
# evaluate the value of every other line at every such x-value,
# and store the result.
other_values = [[fabs(y_values[int(x)] - f(x)) for x in
for f in other_functions]
# Now loop over these values and find the minimum, i.e. for every
# x-value in the current chunk, find the distance to the closest
# other line ("closest" meaning abs_value(value(current line at x)
# - value(other lines at x)) being at its minimum)
distances = [min([_ for _ in [row[i] for row in other_values]])
for i in range(len(other_values[0]))]
# Now find the value of x in the current chunk where the distance
# is maximal, i.e. the best position for the label and add the
# necessary offset to take into account that the index obtained
# from "distances" is relative to the current chunk
best_pos = distances.index(max(distances)) + chunks[chunk_nr][0]
# Short notation for the position of the label
x = best_pos
y = y_values[x]
# Actually plot the label onto the line at the calculated position
plt.plot(x, y, 'o', color='white', markersize=9)
plt.plot(x, y, 'k', marker="$%s$" % label, color=color,
# -----------------------------------------------------------------------------
for i in range(8):
# Create data
y = np.roll(np.cumsum(np.random.randn(1000, 1)),
np.random.randint(0, 1000))
y = gaussian_filter1d(y, 10)
sns.plt.plot(y, label=str(i))
inline_legend(plt.gca().get_lines(), n_markers=3)
Example output of this solution (note how the x-positions of the labels are no longer all the same):
If one wants to avoid the use of scipy.interpolate.interp1d, one might consider a solution where for a given x-value of line A, one finds the x-value of line B that is closest to that. I think this might be problematic though if the lines use very different and/or sparse grids?