Grid of histograms according to filtered data - matplotlib

Consider this kind of data file:
data-file.txt
75,15,1,57.5,9.9,5
75,15,1,58.1,10.0,5
75,15,2,37.9,8.3,5
75,15,2,18.2,7.3,5
150,15,1,26.4,8.3,10
150,15,1,31.6,7.9,10
150,15,2,30.6,7.5,10
150,15,2,25.1,7.1,10
Observe that 3rd column values are only 1,2.
I would like to produce 3x2-grid of histograms. The subplots below looks right, but each row should contain 2 histograms from different data set, I mean, I filter the data according to last column.
The important code is ax.hist(X[ (y==grp) & (X[:,2]==1), cols], where the filter occurs.
I want 2 histograms on each row:
the 1st row with (X[:,2]== * ) where * being any value from 3rd column (1 or 2),
the 2nd row with (X[:,2]==1) and
the 3rd row with (X[:,2]==2).
In resume, I expect to get on 2nd, 3rd rows histograms for the filtered data:
3rd column value = 1
75,15,1,57.5,9.9,5
75,15,1,58.1,10.0,5
150,15,1,26.4,8.3,10
150,15,1,31.6,7.9,10
3rd column value = 2
75,15,2,37.9,8.3,5
75,15,2,18.2,7.3,5
150,15,2,30.6,7.5,10
150,15,2,25.1,7.1,10
Code:
#!/usr/bin/python
# -*- coding: utf-8 -*-
import pandas as pd
import numpy as np
import math
from matplotlib import pyplot as plt
from itertools import combinations
data_file='data-file.txt'
df = pd.io.parsers.read_csv(
filepath_or_buffer=data_file,
delim_whitespace=False,
)
M, N = df.shape[0], df.shape[1]
feature_dict = {i+1:label for i,label in zip(
range(N),
('L',
'A',
'G',
'P',
'T',
'PP',
))}
df.columns = [l for i,l in sorted(feature_dict.items())]
X = df[range(N-1)].values
y = df['PP'].values
label_dict = dict(enumerate(sorted(list(set(y)))))
label_dict = {x+1:y for x,y in label_dict.iteritems()}
num_grupos = len(label_dict.keys())
grps_to_hist_list = [[j for j in i] for i in combinations(label_dict.keys(), 2)]
grps_to_hist_list_values = [[j for j in i] for i in combinations(label_dict.values(), 2)]
cols_to_hist = [3, 4]
for grps_to_hist in grps_to_hist_list:
grps_str = [ label_dict[grps_to_hist[0]], label_dict[grps_to_hist[1]] ]
print 'creating histogram for groups %s from data file %s' % (grps_str , data_file)
fig, axes = plt.subplots(nrows=3, ncols=2, figsize=(18,8))
for ax,cols in zip(axes.ravel(), cols_to_hist):
# set bin sizes
min_b = math.floor(np.min(X[:,cols]))
max_b = math.ceil(np.max(X[:,cols]))
bins = np.linspace(min_b, max_b, 40)
# ploting the histograms
#"""
for grp,color in zip( grps_str, ('blue', 'red')):
ax.hist(X[ (y==grp) & (X[:,2]==1), cols],
color=color,
label='%s' % grp,
bins=bins,
alpha=0.3,)
ylims = ax.get_ylim()
# plot annotation
leg = ax.legend(loc='upper right', fancybox=True, fontsize=8)
leg.get_frame().set_alpha(0.5)
ax.set_ylim([0, max(ylims)+2])
ax.set_xlabel(feature_dict[cols+1])
ax.set_title('%s' % str(data_file))
# hide axis ticks
ax.tick_params(axis="both", which="both", bottom="off", top="off", labelbottom="on", left="off", right="off", labelleft="on")
# remove axis spines
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.spines["bottom"].set_visible(False)
ax.spines["left"].set_visible(False)
#"""
fig.tight_layout()
plt.show()
Here is a screen-shot from the code above with the filter (y==grp) & (X[:,2]==1) (which should be on 2nd row).

My logic is to iterate over rows with corresponding masks of your choice, [(X[:,2]==1) | (X[:,2]==2), X[:,2]==1, X[:,2]==2]. Hopefully this is what you want:
#!/usr/bin/python
# -*- coding: utf-8 -*-
import pandas as pd
import numpy as np
import math
from matplotlib import pyplot as plt
from itertools import combinations
data_file='data-file.txt'
df = pd.io.parsers.read_csv(
filepath_or_buffer=data_file,
delim_whitespace=False,
)
M, N = df.shape[0], df.shape[1]
feature_dict = {i+1:label for i,label in zip(
range(N),
('L',
'A',
'G',
'P',
'T',
'PP',
))}
df.columns = [l for i,l in sorted(feature_dict.items())]
X = df[range(N-1)].values
y = df['PP'].values
label_dict = dict(enumerate(sorted(list(set(y)))))
label_dict = {x+1:y for x,y in label_dict.iteritems()}
num_grupos = len(label_dict.keys())
grps_to_hist_list = [[j for j in i] for i in combinations(label_dict.keys(), 2)]
grps_to_hist_list_values = [[j for j in i] for i in combinations(label_dict.values(), 2)]
cols_to_hist = [3, 4]
for grps_to_hist in grps_to_hist_list:
grps_str = [ label_dict[grps_to_hist[0]], label_dict[grps_to_hist[1]] ]
print 'creating histogram for groups %s from data file %s' % (grps_str , data_file)
fig, axes = plt.subplots(nrows=3, ncols=2, figsize=(18,8))
for row_ax, row_mask in zip(axes, [(X[:,2]==1) | (X[:,2]==2), X[:,2]==1, X[:,2]==2]):
for ax,cols in zip(row_ax, cols_to_hist):
# set bin sizes
min_b = math.floor(np.min(X[:,cols]))
max_b = math.ceil(np.max(X[:,cols]))
bins = np.linspace(min_b, max_b, 40)
# ploting the histograms
#"""
for grp,color in zip( grps_str, ('blue', 'red')):
ax.hist(X[ (y==grp) & row_mask, cols],
color=color,
label='%s' % grp,
bins=bins,
alpha=0.3,)
ylims = ax.get_ylim()
# plot annotation
leg = ax.legend(loc='upper right', fancybox=True, fontsize=8)
leg.get_frame().set_alpha(0.5)
ax.set_ylim([0, max(ylims)+2])
ax.set_xlabel(feature_dict[cols+1])
ax.set_title('%s' % str(data_file))
# hide axis ticks
ax.tick_params(axis="both", which="both", bottom="off", top="off", labelbottom="on", left="off", right="off", labelleft="on")
# remove axis spines
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.spines["bottom"].set_visible(False)
ax.spines["left"].set_visible(False)
#"""
fig.tight_layout()
plt.show()

Related

Showing Matplotlib pie chart only top 3 item's percentage [duplicate]

I have the following code:
import matplotlib.pyplot as plt
import numpy as np
np.random.seed(123456)
import pandas as pd
df = pd.DataFrame(3 * np.random.rand(4, 4), index=['a', 'b', 'c', 'd'],
columns=['x', 'y','z','w'])
plt.style.use('ggplot')
colors = plt.rcParams['axes.color_cycle']
fig, axes = plt.subplots(nrows=2, ncols=3)
for ax in axes.flat:
ax.axis('off')
for ax, col in zip(axes.flat, df.columns):
ax.pie(df[col], labels=df.index, autopct='%.2f', colors=colors)
ax.set(ylabel='', title=col, aspect='equal')
axes[0, 0].legend(bbox_to_anchor=(0, 0.5))
fig.savefig('your_file.png') # Or whichever format you'd like
plt.show()
Which produce the following:
My question is, how can I remove the label based on a condition. For example I'd only want to display labels with percent > 20%. Such that the labels and value of a,c,d won't be displayed in X, etc.
The autopct argument from pie can be a callable, which will receive the current percentage. So you only would need to provide a function that returns an empty string for the values you want to omit the percentage.
Function
def my_autopct(pct):
return ('%.2f' % pct) if pct > 20 else ''
Plot with matplotlib.axes.Axes.pie
fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(8, 6))
for ax, col in zip(axes.flat, df.columns):
ax.pie(df[col], labels=df.index, autopct=my_autopct)
ax.set(ylabel='', title=col, aspect='equal')
fig.tight_layout()
Plot directly with the dataframe
axes = df.plot(kind='pie', autopct=my_autopct, figsize=(8, 6), subplots=True, layout=(2, 2), legend=False)
for ax in axes.flat:
yl = ax.get_ylabel()
ax.set(ylabel='', title=yl)
fig = axes[0, 0].get_figure()
fig.tight_layout()
If you need to parametrize the value on the autopct argument, you'll need a function that returns a function, like:
def autopct_generator(limit):
def inner_autopct(pct):
return ('%.2f' % pct) if pct > limit else ''
return inner_autopct
ax.pie(df[col], labels=df.index, autopct=autopct_generator(20), colors=colors)
For the labels, the best thing I can come up with is using list comprehension:
for ax, col in zip(axes.flat, df.columns):
data = df[col]
labels = [n if v > data.sum() * 0.2 else ''
for n, v in zip(df.index, data)]
ax.pie(data, autopct=my_autopct, colors=colors, labels=labels)
Note, however, that the legend by default is being generated from the first passed labels, so you'll need to pass all values explicitly to keep it intact.
axes[0, 0].legend(df.index, bbox_to_anchor=(0, 0.5))
For labels I have used:
def my_level_list(data):
list = []
for i in range(len(data)):
if (data[i]*100/np.sum(data)) > 2 : #2%
list.append('Label '+str(i+1))
else:
list.append('')
return list
patches, texts, autotexts = plt.pie(data, radius = 1, labels=my_level_list(data), autopct=my_autopct, shadow=True)
You can make the labels function a little shorter using list comprehension:
def my_autopct(pct):
return ('%1.1f' % pct) if pct > 1 else ''
def get_new_labels(sizes, labels):
new_labels = [label if size > 1 else '' for size, label in zip(sizes, labels)]
return new_labels
fig, ax = plt.subplots()
_,_,_ = ax.pie(sizes, labels=get_new_labels(sizes, labels), colors=colors, autopct=my_autopct, startangle=90, rotatelabels=False)

How to set values of a vertical stem plot as xticks labels?

I would like to reverse a grouped data and use group name as xtick label to draw it side by side. below demo mostly good but the label position not as expected.
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
def main():
data = [['AAAAAA',8],['AAAAAA',9],['AAAAAA',10],['BBBBBB',5],['BBBBBB',6],['BBBBBB',7],['CCCCCC',1],['CCCCCC',2],['CCCCCC',3],['CCCCCC',4]]
df = pd.DataFrame(data,columns=['name','value'])
dfg = df.groupby('name')
fig, ax = plt.subplots(figsize=(8, 4))
i = 0
ymin = df['value'].min()
c1='#ececec'
c2='#bcbcbc'
color=c1
for ix, row in reversed(tuple(dfg)):
print(ix,row)
n = len(row['name'])
x = np.linspace(i,i + n,n)
ax.stem(x,row['value'])
font_dict = {'family':'serif','color':'darkred', 'size':8}
ax.text(i + n/2,ymin,ix,ha='right',va='top',rotation=90, fontdict=font_dict)
if color == c1:
color = c2
else:
color = c1
plt.axvspan(i, i+n, facecolor=color, alpha=0.5)
i += len(row)
ax.xaxis.set_ticks_position('none')
plt.setp( ax.get_xticklabels(), visible=False)
ax.grid(axis='y',color='gray', linestyle='dashed', alpha=1)
ax.spines[["top", "right"]].set_visible(False)
fig.tight_layout()
plt.show()
return
main()
Output:
Welcome to comment any other proper way to do this, or how to improve the xticks down, use ymin properly not good way to do it.
If my understanding of what you are trying to achieve is correct, here is one way to do it:
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
df = pd.DataFrame(
[
["AAAAAA", 8],
["AAAAAA", 9],
["AAAAAA", 10],
["BBBBBB", 5],
["BBBBBB", 6],
["BBBBBB", 7],
["CCCCCC", 1],
["CCCCCC", 2],
["CCCCCC", 3],
["CCCCCC", 4],
],
columns=["name", "value"],
)
fig, ax = plt.subplots(figsize=(8, 4))
i = 0
c1 = "#ececec"
c2 = "#bcbcbc"
color = c1
ticks = {}
for ix, row in reversed(tuple(df.groupby("name"))):
# Create stem plot
n = len(row["name"])
x = np.linspace(i, i + n, n)
ax.stem(x, row["value"])
# Create axvspan plot
if color == c1:
color = c2
else:
color = c1
ax.axvspan(i, i + n, facecolor=color, alpha=0.5)
# Save positions and names in a dict
for key, name in zip(x, row["name"]):
if key not in ticks.keys():
ticks[key] = name
else:
# Deal with multiple names for same tick
ticks[key] += f"\n{name}"
i += len(row)
# Add ticks and ticks labels
ax.set_xticks(ticks=list(ticks.keys()))
ax.set_xticklabels(list(ticks.values()), fontsize=10, rotation="vertical")
# In Jupyter notebook
fig
Output:
And to avoid repeating the labels, you can, for instance, do:
ax.set_xticklabels(
[
"",
"CCCCCC",
"",
"CCCCCC\nBBBBBB",
"BBBBBB",
"BBBBBB\nAAAAAA",
" " * 20 + "AAAAAA",
"",
],
fontsize=10,
)
# In Jupyter notebook
fig
Output:

Matplotlib - How to show coordinates in scatterplot? [duplicate]

I am using matplotlib to make scatter plots. Each point on the scatter plot is associated with a named object. I would like to be able to see the name of an object when I hover my cursor over the point on the scatter plot associated with that object. In particular, it would be nice to be able to quickly see the names of the points that are outliers. The closest thing I have been able to find while searching here is the annotate command, but that appears to create a fixed label on the plot. Unfortunately, with the number of points that I have, the scatter plot would be unreadable if I labeled each point. Does anyone know of a way to create labels that only appear when the cursor hovers in the vicinity of that point?
It seems none of the other answers here actually answer the question. So here is a code that uses a scatter and shows an annotation upon hovering over the scatter points.
import matplotlib.pyplot as plt
import numpy as np; np.random.seed(1)
x = np.random.rand(15)
y = np.random.rand(15)
names = np.array(list("ABCDEFGHIJKLMNO"))
c = np.random.randint(1,5,size=15)
norm = plt.Normalize(1,4)
cmap = plt.cm.RdYlGn
fig,ax = plt.subplots()
sc = plt.scatter(x,y,c=c, s=100, cmap=cmap, norm=norm)
annot = ax.annotate("", xy=(0,0), xytext=(20,20),textcoords="offset points",
bbox=dict(boxstyle="round", fc="w"),
arrowprops=dict(arrowstyle="->"))
annot.set_visible(False)
def update_annot(ind):
pos = sc.get_offsets()[ind["ind"][0]]
annot.xy = pos
text = "{}, {}".format(" ".join(list(map(str,ind["ind"]))),
" ".join([names[n] for n in ind["ind"]]))
annot.set_text(text)
annot.get_bbox_patch().set_facecolor(cmap(norm(c[ind["ind"][0]])))
annot.get_bbox_patch().set_alpha(0.4)
def hover(event):
vis = annot.get_visible()
if event.inaxes == ax:
cont, ind = sc.contains(event)
if cont:
update_annot(ind)
annot.set_visible(True)
fig.canvas.draw_idle()
else:
if vis:
annot.set_visible(False)
fig.canvas.draw_idle()
fig.canvas.mpl_connect("motion_notify_event", hover)
plt.show()
Because people also want to use this solution for a line plot instead of a scatter, the following would be the same solution for plot (which works slightly differently).
import matplotlib.pyplot as plt
import numpy as np; np.random.seed(1)
x = np.sort(np.random.rand(15))
y = np.sort(np.random.rand(15))
names = np.array(list("ABCDEFGHIJKLMNO"))
norm = plt.Normalize(1,4)
cmap = plt.cm.RdYlGn
fig,ax = plt.subplots()
line, = plt.plot(x,y, marker="o")
annot = ax.annotate("", xy=(0,0), xytext=(-20,20),textcoords="offset points",
bbox=dict(boxstyle="round", fc="w"),
arrowprops=dict(arrowstyle="->"))
annot.set_visible(False)
def update_annot(ind):
x,y = line.get_data()
annot.xy = (x[ind["ind"][0]], y[ind["ind"][0]])
text = "{}, {}".format(" ".join(list(map(str,ind["ind"]))),
" ".join([names[n] for n in ind["ind"]]))
annot.set_text(text)
annot.get_bbox_patch().set_alpha(0.4)
def hover(event):
vis = annot.get_visible()
if event.inaxes == ax:
cont, ind = line.contains(event)
if cont:
update_annot(ind)
annot.set_visible(True)
fig.canvas.draw_idle()
else:
if vis:
annot.set_visible(False)
fig.canvas.draw_idle()
fig.canvas.mpl_connect("motion_notify_event", hover)
plt.show()
In case someone is looking for a solution for lines in twin axes, refer to How to make labels appear when hovering over a point in multiple axis?
In case someone is looking for a solution for bar plots, please refer to e.g. this answer.
This solution works when hovering a line without the need to click it:
import matplotlib.pyplot as plt
# Need to create as global variable so our callback(on_plot_hover) can access
fig = plt.figure()
plot = fig.add_subplot(111)
# create some curves
for i in range(4):
# Giving unique ids to each data member
plot.plot(
[i*1,i*2,i*3,i*4],
gid=i)
def on_plot_hover(event):
# Iterating over each data member plotted
for curve in plot.get_lines():
# Searching which data member corresponds to current mouse position
if curve.contains(event)[0]:
print("over %s" % curve.get_gid())
fig.canvas.mpl_connect('motion_notify_event', on_plot_hover)
plt.show()
From http://matplotlib.sourceforge.net/examples/event_handling/pick_event_demo.html :
from matplotlib.pyplot import figure, show
import numpy as npy
from numpy.random import rand
if 1: # picking on a scatter plot (matplotlib.collections.RegularPolyCollection)
x, y, c, s = rand(4, 100)
def onpick3(event):
ind = event.ind
print('onpick3 scatter:', ind, npy.take(x, ind), npy.take(y, ind))
fig = figure()
ax1 = fig.add_subplot(111)
col = ax1.scatter(x, y, 100*s, c, picker=True)
#fig.savefig('pscoll.eps')
fig.canvas.mpl_connect('pick_event', onpick3)
show()
This recipe draws an annotation on picking a data point: http://scipy-cookbook.readthedocs.io/items/Matplotlib_Interactive_Plotting.html .
This recipe draws a tooltip, but it requires wxPython:
Point and line tooltips in matplotlib?
The easiest option is to use the mplcursors package.
mplcursors: read the docs
mplcursors: github
If using Anaconda, install with these instructions, otherwise use these instructions for pip.
This must be plotted in an interactive window, not inline.
For jupyter, executing something like %matplotlib qt in a cell will turn on interactive plotting. See How can I open the interactive matplotlib window in IPython notebook?
Tested in python 3.10, pandas 1.4.2, matplotlib 3.5.1, seaborn 0.11.2
import matplotlib.pyplot as plt
import pandas_datareader as web # only for test data; must be installed with conda or pip
from mplcursors import cursor # separate package must be installed
# reproducible sample data as a pandas dataframe
df = web.DataReader('aapl', data_source='yahoo', start='2021-03-09', end='2022-06-13')
plt.figure(figsize=(12, 7))
plt.plot(df.index, df.Close)
cursor(hover=True)
plt.show()
Pandas
ax = df.plot(y='Close', figsize=(10, 7))
cursor(hover=True)
plt.show()
Seaborn
Works with axes-level plots like sns.lineplot, and figure-level plots like sns.relplot.
import seaborn as sns
# load sample data
tips = sns.load_dataset('tips')
sns.relplot(data=tips, x="total_bill", y="tip", hue="day", col="time")
cursor(hover=True)
plt.show()
The other answers did not address my need for properly showing tooltips in a recent version of Jupyter inline matplotlib figure. This one works though:
import matplotlib.pyplot as plt
import numpy as np
import mplcursors
np.random.seed(42)
fig, ax = plt.subplots()
ax.scatter(*np.random.random((2, 26)))
ax.set_title("Mouse over a point")
crs = mplcursors.cursor(ax,hover=True)
crs.connect("add", lambda sel: sel.annotation.set_text(
'Point {},{}'.format(sel.target[0], sel.target[1])))
plt.show()
Leading to something like the following picture when going over a point with mouse:
A slight edit on an example provided in http://matplotlib.org/users/shell.html:
import numpy as np
import matplotlib.pyplot as plt
fig = plt.figure()
ax = fig.add_subplot(111)
ax.set_title('click on points')
line, = ax.plot(np.random.rand(100), '-', picker=5) # 5 points tolerance
def onpick(event):
thisline = event.artist
xdata = thisline.get_xdata()
ydata = thisline.get_ydata()
ind = event.ind
print('onpick points:', *zip(xdata[ind], ydata[ind]))
fig.canvas.mpl_connect('pick_event', onpick)
plt.show()
This plots a straight line plot, as Sohaib was asking
mpld3 solve it for me.
EDIT (CODE ADDED):
import matplotlib.pyplot as plt
import numpy as np
import mpld3
fig, ax = plt.subplots(subplot_kw=dict(axisbg='#EEEEEE'))
N = 100
scatter = ax.scatter(np.random.normal(size=N),
np.random.normal(size=N),
c=np.random.random(size=N),
s=1000 * np.random.random(size=N),
alpha=0.3,
cmap=plt.cm.jet)
ax.grid(color='white', linestyle='solid')
ax.set_title("Scatter Plot (with tooltips!)", size=20)
labels = ['point {0}'.format(i + 1) for i in range(N)]
tooltip = mpld3.plugins.PointLabelTooltip(scatter, labels=labels)
mpld3.plugins.connect(fig, tooltip)
mpld3.show()
You can check this example
mplcursors worked for me. mplcursors provides clickable annotation for matplotlib. It is heavily inspired from mpldatacursor (https://github.com/joferkington/mpldatacursor), with a much simplified API
import matplotlib.pyplot as plt
import numpy as np
import mplcursors
data = np.outer(range(10), range(1, 5))
fig, ax = plt.subplots()
lines = ax.plot(data)
ax.set_title("Click somewhere on a line.\nRight-click to deselect.\n"
"Annotations can be dragged.")
mplcursors.cursor(lines) # or just mplcursors.cursor()
plt.show()
showing object information in matplotlib statusbar
Features
no extra libraries needed
clean plot
no overlap of labels and artists
supports multi artist labeling
can handle artists from different plotting calls (like scatter, plot, add_patch)
code in library style
Code
### imports
import matplotlib as mpl
import matplotlib.pylab as plt
import numpy as np
# https://stackoverflow.com/a/47166787/7128154
# https://matplotlib.org/3.3.3/api/collections_api.html#matplotlib.collections.PathCollection
# https://matplotlib.org/3.3.3/api/path_api.html#matplotlib.path.Path
# https://stackoverflow.com/questions/15876011/add-information-to-matplotlib-navigation-toolbar-status-bar
# https://stackoverflow.com/questions/36730261/matplotlib-path-contains-point
# https://stackoverflow.com/a/36335048/7128154
class StatusbarHoverManager:
"""
Manage hover information for mpl.axes.Axes object based on appearing
artists.
Attributes
----------
ax : mpl.axes.Axes
subplot to show status information
artists : list of mpl.artist.Artist
elements on the subplot, which react to mouse over
labels : list (list of strings) or strings
each element on the top level corresponds to an artist.
if the artist has items
(i.e. second return value of contains() has key 'ind'),
the element has to be of type list.
otherwise the element if of type string
cid : to reconnect motion_notify_event
"""
def __init__(self, ax):
assert isinstance(ax, mpl.axes.Axes)
def hover(event):
if event.inaxes != ax:
return
info = 'x={:.2f}, y={:.2f}'.format(event.xdata, event.ydata)
ax.format_coord = lambda x, y: info
cid = ax.figure.canvas.mpl_connect("motion_notify_event", hover)
self.ax = ax
self.cid = cid
self.artists = []
self.labels = []
def add_artist_labels(self, artist, label):
if isinstance(artist, list):
assert len(artist) == 1
artist = artist[0]
self.artists += [artist]
self.labels += [label]
def hover(event):
if event.inaxes != self.ax:
return
info = 'x={:.2f}, y={:.2f}'.format(event.xdata, event.ydata)
for aa, artist in enumerate(self.artists):
cont, dct = artist.contains(event)
if not cont:
continue
inds = dct.get('ind')
if inds is not None: # artist contains items
for ii in inds:
lbl = self.labels[aa][ii]
info += '; artist [{:d}, {:d}]: {:}'.format(
aa, ii, lbl)
else:
lbl = self.labels[aa]
info += '; artist [{:d}]: {:}'.format(aa, lbl)
self.ax.format_coord = lambda x, y: info
self.ax.figure.canvas.mpl_disconnect(self.cid)
self.cid = self.ax.figure.canvas.mpl_connect(
"motion_notify_event", hover)
def demo_StatusbarHoverManager():
fig, ax = plt.subplots()
shm = StatusbarHoverManager(ax)
poly = mpl.patches.Polygon(
[[0,0], [3, 5], [5, 4], [6,1]], closed=True, color='green', zorder=0)
artist = ax.add_patch(poly)
shm.add_artist_labels(artist, 'polygon')
artist = ax.scatter([2.5, 1, 2, 3], [6, 1, 1, 7], c='blue', s=10**2)
lbls = ['point ' + str(ii) for ii in range(4)]
shm.add_artist_labels(artist, lbls)
artist = ax.plot(
[0, 0, 1, 5, 3], [0, 1, 1, 0, 2], marker='o', color='red')
lbls = ['segment ' + str(ii) for ii in range(5)]
shm.add_artist_labels(artist, lbls)
plt.show()
# --- main
if __name__== "__main__":
demo_StatusbarHoverManager()
I have made a multi-line annotation system to add to: https://stackoverflow.com/a/47166787/10302020.
for the most up to date version:
https://github.com/AidenBurgess/MultiAnnotationLineGraph
Simply change the data in the bottom section.
import matplotlib.pyplot as plt
def update_annot(ind, line, annot, ydata):
x, y = line.get_data()
annot.xy = (x[ind["ind"][0]], y[ind["ind"][0]])
# Get x and y values, then format them to be displayed
x_values = " ".join(list(map(str, ind["ind"])))
y_values = " ".join(str(ydata[n]) for n in ind["ind"])
text = "{}, {}".format(x_values, y_values)
annot.set_text(text)
annot.get_bbox_patch().set_alpha(0.4)
def hover(event, line_info):
line, annot, ydata = line_info
vis = annot.get_visible()
if event.inaxes == ax:
# Draw annotations if cursor in right position
cont, ind = line.contains(event)
if cont:
update_annot(ind, line, annot, ydata)
annot.set_visible(True)
fig.canvas.draw_idle()
else:
# Don't draw annotations
if vis:
annot.set_visible(False)
fig.canvas.draw_idle()
def plot_line(x, y):
line, = plt.plot(x, y, marker="o")
# Annotation style may be changed here
annot = ax.annotate("", xy=(0, 0), xytext=(-20, 20), textcoords="offset points",
bbox=dict(boxstyle="round", fc="w"),
arrowprops=dict(arrowstyle="->"))
annot.set_visible(False)
line_info = [line, annot, y]
fig.canvas.mpl_connect("motion_notify_event",
lambda event: hover(event, line_info))
# Your data values to plot
x1 = range(21)
y1 = range(0, 21)
x2 = range(21)
y2 = range(0, 42, 2)
# Plot line graphs
fig, ax = plt.subplots()
plot_line(x1, y1)
plot_line(x2, y2)
plt.show()
Based off Markus Dutschke" and "ImportanceOfBeingErnest", I (imo) simplified the code and made it more modular.
Also this doesn't require additional packages to be installed.
import matplotlib.pylab as plt
import numpy as np
plt.close('all')
fh, ax = plt.subplots()
#Generate some data
y,x = np.histogram(np.random.randn(10000), bins=500)
x = x[:-1]
colors = ['#0000ff', '#00ff00','#ff0000']
x2, y2 = x,y/10
x3, y3 = x, np.random.randn(500)*10+40
#Plot
h1 = ax.plot(x, y, color=colors[0])
h2 = ax.plot(x2, y2, color=colors[1])
h3 = ax.scatter(x3, y3, color=colors[2], s=1)
artists = h1 + h2 + [h3] #concatenating lists
labels = [list('ABCDE'*100),list('FGHIJ'*100),list('klmno'*100)] #define labels shown
#___ Initialize annotation arrow
annot = ax.annotate("", xy=(0,0), xytext=(20,20),textcoords="offset points",
bbox=dict(boxstyle="round", fc="w"),
arrowprops=dict(arrowstyle="->"))
annot.set_visible(False)
def on_plot_hover(event):
if event.inaxes != ax: #exit if mouse is not on figure
return
is_vis = annot.get_visible() #check if an annotation is visible
# x,y = event.xdata,event.ydata #coordinates of mouse in graph
for ii, artist in enumerate(artists):
is_contained, dct = artist.contains(event)
if(is_contained):
if('get_data' in dir(artist)): #for plot
data = list(zip(*artist.get_data()))
elif('get_offsets' in dir(artist)): #for scatter
data = artist.get_offsets().data
inds = dct['ind'] #get which data-index is under the mouse
#___ Set Annotation settings
xy = data[inds[0]] #get 1st position only
annot.xy = xy
annot.set_text(f'pos={xy},text={labels[ii][inds[0]]}')
annot.get_bbox_patch().set_edgecolor(colors[ii])
annot.get_bbox_patch().set_alpha(0.7)
annot.set_visible(True)
fh.canvas.draw_idle()
else:
if is_vis:
annot.set_visible(False) #disable when not hovering
fh.canvas.draw_idle()
fh.canvas.mpl_connect('motion_notify_event', on_plot_hover)
Giving the following result:
Maybe this helps anybody, but I have adapted the #ImportanceOfBeingErnest's answer to work with patches and classes. Features:
The entire framework is contained inside of a single class, so all of the used variables are only available within their relevant scopes.
Can create multiple distinct sets of patches
Hovering over a patch prints patch collection name and patch subname
Hovering over a patch highlights all patches of that collection by changing their edge color to black
Note: For my applications, the overlap is not relevant, thus only one object's name is displayed at a time. Feel free to extend to multiple objects if you wish, it is not too hard.
Usage
fig, ax = plt.subplots(tight_layout=True)
ap = annotated_patches(fig, ax)
ap.add_patches('Azure', 'circle', 'blue', np.random.uniform(0, 1, (4,2)), 'ABCD', 0.1)
ap.add_patches('Lava', 'rect', 'red', np.random.uniform(0, 1, (3,2)), 'EFG', 0.1, 0.05)
ap.add_patches('Emerald', 'rect', 'green', np.random.uniform(0, 1, (3,2)), 'HIJ', 0.05, 0.1)
plt.axis('equal')
plt.axis('off')
plt.show()
Implementation
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.collections import PatchCollection
np.random.seed(1)
class annotated_patches:
def __init__(self, fig, ax):
self.fig = fig
self.ax = ax
self.annot = self.ax.annotate("", xy=(0,0),
xytext=(20,20),
textcoords="offset points",
bbox=dict(boxstyle="round", fc="w"),
arrowprops=dict(arrowstyle="->"))
self.annot.set_visible(False)
self.collectionsDict = {}
self.coordsDict = {}
self.namesDict = {}
self.isActiveDict = {}
self.motionCallbackID = self.fig.canvas.mpl_connect("motion_notify_event", self.hover)
def add_patches(self, groupName, kind, color, xyCoords, names, *params):
if kind=='circle':
circles = [mpatches.Circle(xy, *params, ec="none") for xy in xyCoords]
thisCollection = PatchCollection(circles, facecolor=color, alpha=0.5, edgecolor=None)
ax.add_collection(thisCollection)
elif kind == 'rect':
rectangles = [mpatches.Rectangle(xy, *params, ec="none") for xy in xyCoords]
thisCollection = PatchCollection(rectangles, facecolor=color, alpha=0.5, edgecolor=None)
ax.add_collection(thisCollection)
else:
raise ValueError('Unexpected kind', kind)
self.collectionsDict[groupName] = thisCollection
self.coordsDict[groupName] = xyCoords
self.namesDict[groupName] = names
self.isActiveDict[groupName] = False
def update_annot(self, groupName, patchIdxs):
self.annot.xy = self.coordsDict[groupName][patchIdxs[0]]
self.annot.set_text(groupName + ': ' + self.namesDict[groupName][patchIdxs[0]])
# Set edge color
self.collectionsDict[groupName].set_edgecolor('black')
self.isActiveDict[groupName] = True
def hover(self, event):
vis = self.annot.get_visible()
updatedAny = False
if event.inaxes == self.ax:
for groupName, collection in self.collectionsDict.items():
cont, ind = collection.contains(event)
if cont:
self.update_annot(groupName, ind["ind"])
self.annot.set_visible(True)
self.fig.canvas.draw_idle()
updatedAny = True
else:
if self.isActiveDict[groupName]:
collection.set_edgecolor(None)
self.isActiveDict[groupName] = True
if (not updatedAny) and vis:
self.annot.set_visible(False)
self.fig.canvas.draw_idle()

Colormap is not categorizing the data properly

Here is my script to plot data from a Geogtiff file using basemap. The data is categorical and there are 13 categories within this domain. The problem is that some categories get bunched up into one colour and thus some resolution is lost.
Unfortunately, I do not know how to fix this. I read that plt.cm.get_cmp is better for discrete datasets but I have not gotten it to work unfortunately.
gtif = 'some_dir'
ds = gdal.Open(gtif)
data = ds.ReadAsArray()
gt = ds.GetGeoTransform()
proj = ds.GetProjection()
xres = gt[1]
yres = gt[5]
xmin = gt[0] + xres
xmax = gt[0] + (xres * ds.RasterXSize) - xres
ymin = gt[3] + (yres * ds.RasterYSize) + yres
ymax = gt[3] - yres
xy_source = np.mgrid[xmin:xmax+xres:xres, ymax+yres:ymin:yres]
ds = None
fig2 = plt.figure(figsize=[12, 11])
ax2 = fig2.add_subplot(111)
ax2.set_title("Land use plot")
bm2 = Basemap(ax=ax2,projection='cyl',llcrnrlat=ymin,urcrnrlat=ymax,llcrnrlon=xmin,urcrnrlon=xmax,resolution='l')
bm2.drawcoastlines(linewidth=0.2)
bm2.drawcountries(linewidth=0.2)
data_new=np.copy(data)
data_new[data_new==255] = 0
nbins = np.unique(data_new).size
cb =plt.cm.get_cmap('jet', nbins+1)
img2 =bm2.imshow(np.flipud(data_new), cmap=cb)
ax2.set_xlim(3, 6)
ax2.set_ylim(50,53)
plt.show()
labels = [str(i) for i in np.unique(data_new)]
cb2=bm2.colorbar(img2, "right", size="5%", pad='3%', label='NOAH Land Use Category')
cb2.set_ticklabels(labels)
cb2.set_ticks(np.unique(data_new))
Here are the categories that are found within the domain (numbered classes):
np.unique(data_new)
array([ 0, 1, 4, 5, 7, 10, 11, 12, 13, 14, 15, 16, 17], dtype=uint8)
Thanks so much for any help here. I have also attached the output image that shows the mismatch. (not working)
First, this colormap problem is independent of the use of basemap. The following is therefore applicable to any matplotlib plot.
The problem here is that creating a colormap from n values distributes those values equally over the colormap range. Some values from the image therefore fall into the same colorrange within the colormap.
To prevent this, one can generate a colormap with the initial number of categories as shown below.
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.colors
# generate some data
data = np.array( [ 0, 1, 4, 5, 7, 10]*8 )
np.random.shuffle(data)
data = data.reshape((8,6))
# generate colormap and norm
unique = np.unique(data)
vals = np.arange(int(unique.max()+1))/float(unique.max())
cols = plt.cm.jet(vals)
cmap = matplotlib.colors.ListedColormap(cols, int(unique.max())+1)
norm=matplotlib.colors.Normalize(vmin=-0.5, vmax=unique.max()+0.5)
fig, ax = plt.subplots(figsize=(5,5))
im = ax.imshow(data, cmap=cmap, norm=norm)
for i in range(data.shape[0]):
for j in range(data.shape[1]):
ax.text(j,i,data[i,j], color="w", ha="center", va="center")
cb = fig.colorbar(im, ax=ax, norm=norm)
cb.set_ticks(unique)
plt.show()
This can be extended to exclude the colors not present in the image as follows:
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.colors
# generate some data
data = np.array( [ 0, 1, 4, 5, 7, 10]*8 )
np.random.shuffle(data)
data = data.reshape((8,6))
unique, newdata = np.unique(data, return_inverse=1)
newdata = newdata.reshape(data.shape)
# generate colormap and norm
new_unique = np.unique(newdata)
vals = np.arange(int(new_unique.max()+1))/float(new_unique.max())
cols = plt.cm.jet(vals)
cmap = matplotlib.colors.ListedColormap(cols, int(new_unique.max())+1)
norm=matplotlib.colors.Normalize(vmin=-0.5, vmax=new_unique.max()+0.5)
fig, ax = plt.subplots(figsize=(5,5))
im = ax.imshow(newdata, cmap=cmap, norm=norm)
for i in range(newdata.shape[0]):
for j in range(newdata.shape[1]):
ax.text(j,i,data[i,j], color="w", ha="center", va="center")
cb = fig.colorbar(im, ax=ax, norm=norm)
cb.ax.set_yticklabels(unique)
plt.show()

Pretty confusion matrix visualisation with matplotlib

I'm wondering if there are some templates for viewing confusion matrices in matplotlib with a similar rendering, of which I ignore the specific nomenclature.
I have tried doing something similar with your fig 2. Here is my code using hand written digits data.
import numpy as np
from scipy import ndimage
from matplotlib import pyplot as plt
from sklearn import manifold, datasets
from scipy.spatial.distance import pdist, squareform
from scipy.cluster.hierarchy import leaves_list, linkage
def get_small_Xy(X, y, n=8):
X = np.vstack([X[y==e][0:n] for e in np.unique(y)])
y = np.hstack([[e]*n for e in np.unique(y)])
return X, y
# Load digit data
X_, y_ = datasets.load_digits(return_X_y=True)
# get a small set of data
X, y = get_small_Xy(X_, y_)
# Get similarity matrix
D = 1-squareform(pdist(X, metric='cosine'))
Z = linkage(D, method='ward')
ind = leaves_list(Z)
D = D[ind, :]
D = D[:, ind]
# labels and colors related
lbs = np.array([i if i==j else 10 for i in y for j in y])
colors = np.array(['C{}'.format(i) for i in range(10)]+['gray'])
colors[7] = '#413c39'
c = colors[lbs]
font1 = {'family': 'Arial',
'weight': 'normal',
'size': 8,
}
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
n = np.product(X.shape[0])
xx, yy = np.meshgrid(range(n), range(n))
xy = np.stack([xx.ravel(), yy.ravel()]).T
ax.scatter(xy[:, 0], xy[:, 1], s=D**4*30, fc=c, ec=None, alpha=0.8)
ax.set_xlim(-1, n)
ax.set_ylim(n, -1)
ax.tick_params(top=False, bottom=False, left=False, right=False, labelleft=False, labelbottom=False)
# place text
for i, e in enumerate(y):
ax.text(-1.2, i, e, ha='right', va='center', fontdict=font1, c=colors[e])
for i, e in enumerate(y):
ax.text(i, -1, e, ha='center', va='bottom', fontdict=font1, c=colors[e])
# draw lines
for e in np.where(np.diff(y))[0]:
ax.axhline(e+0.5, color='gray', lw=0.5, alpha=0.8)
ax.axvline(e+0.5, color='gray', lw=0.5, alpha=0.8)
One issue is the alpha of all points, which seems not to possible to set with different values with plot scatters in one run.