Quickly and Efficiently update Matplotlib Axes (Plot) - matplotlib

I've run into some performance issues with a GUI that I'm working on. Specificially I'm using wxPython as the backend for several matplotlib figures which I've embedded in the canvas. I've gotten the basic functionality working with a simple self.axes.clear() and self.axes.plot() command, but I can't seem to get any sort of reasonable frame rates using this method. After performing a search it seems that if I were using a plot object that I I could reset the xdata and ydata then redraw the figure to obtain a faster refresh rate. Unfortunately, the layout that I'm using precludes the use of the plot object, so I've implemented the code using the axes() object. As far as I can tell the axes() object does not have any equivalent methods for setting the xdata and ydata (see this post: How to update a plot in matplotlib?). Here's the code I'm using:
import sys,os,csv
import numpy as np
import wx
import matplotlib
import pylab
from matplotlib.figure import Figure
from matplotlib.backends.backend_wxagg import \
FigureCanvasWxAgg as FigCanvas, \
NavigationToolbar2WxAgg as NavigationToolbar
class UltrasoundDemoGUI(wx.Frame):
title = ' Ultrasound Demo '
def __init__(self):
wx.Frame.__init__(self, None, -1, self.title)
self.create_menu()
self.statusbar = self.CreateStatusBar()
self.create_main_panel()
self.dataMon = pylab.randn(100,1);
self.dataBi = pylab.randn(100,1);
self.dataLoc = pylab.randn(100,1);
self.redraw_timer = wx.Timer(self)
self.Bind(wx.EVT_TIMER, self.on_redraw_timer, self.redraw_timer)
self.redraw_timer.Start(300)
def create_menu(self):
self.menubar = wx.MenuBar()
menu_file = wx.Menu()
m_expt = menu_file.Append(-1, "&Save plot\tCtrl-S", "Save plot to file")
self.Bind(wx.EVT_MENU, self.on_save_plot, m_expt)
menu_file.AppendSeparator()
m_exit = menu_file.Append(-1, "E&xit\tCtrl-X", "Exit")
self.Bind(wx.EVT_MENU, self.on_exit, m_exit)
self.menubar.Append(menu_file, "&File")
self.SetMenuBar(self.menubar)
def create_main_panel(self):
self.panel = wx.Panel(self)
# self.init_plot()
self.dpi = 100
self.figSignals = Figure((12, 5.0), dpi=self.dpi)
self.canvas = FigCanvas(self.panel, -1, self.figSignals)
rectSubplotMono = .1, .55, .4, .4
self.axesMono = self.figSignals.add_axes(rectSubplotMono)
rectSubplotBi = .1, .05, .4, .4
self.axesBi = self.figSignals.add_axes(rectSubplotBi)
rectSubplotLoc = .55, .05, .4, .9
self.axesLoc = self.figSignals.add_axes(rectSubplotLoc)
self.axesMono.set_axis_bgcolor('white')
self.axesBi.set_axis_bgcolor('white')
self.axesLoc.set_axis_bgcolor('white')
self.axesMono.set_title('Raw Ultrasound Signal', size=12)
pylab.setp(self.axesMono.get_xticklabels(), fontsize=8)
pylab.setp(self.axesMono.get_yticklabels(), fontsize=8)
pylab.setp(self.axesBi.get_xticklabels(), fontsize=8)
pylab.setp(self.axesBi.get_yticklabels(), fontsize=8)
pylab.setp(self.axesLoc.get_xticklabels(), fontsize=8)
pylab.setp(self.axesLoc.get_yticklabels(), fontsize=8)
# plot the data as a line series, and save the reference
# to the plotted line series
#
self.dataMono = pylab.randn(100,1)
self.dataBi = pylab.randn(100,1)
self.dataLoc = pylab.randn(100,1)
self.plot_dataMono = self.axesMono.plot(
self.dataMono,
linewidth=1,
color=(1, 1, 0),
)[0]
self.plot_dataBi = self.axesBi.plot(
self.dataBi,
linewidth=1,
color=(1, 1, 0),
)[0]
self.plot_dataLoc = self.axesLoc.plot(
self.dataLoc,
linewidth=1,
color=(1, 1, 0),
)[0]
self.toolbar = NavigationToolbar(self.canvas)
self.vbox = wx.BoxSizer(wx.VERTICAL)
self.vbox.Add(self.canvas, 1, wx.LEFT | wx.TOP | wx.GROW)
self.vbox.AddSpacer(10)
self.hbox = wx.BoxSizer(wx.HORIZONTAL)
flags = wx.ALIGN_LEFT | wx.ALL | wx.ALIGN_CENTER_VERTICAL
self.panel.SetSizer(self.vbox)
self.vbox.Fit(self)
def init_plot(self):
rectSubplotMono = .1, .55, .4, .4
self.axesMono = self.figSignals.add_axes(rectSubplotMono)
rectSubplotBi = .1, .05, .4, .4
self.axesBi = self.figSignals.add_axes(rectSubplotBi)
rectSubplotLoc = .55, .05, .4, .9
self.axesLoc = self.figSignals.add_axes(rectSubplotLoc)
self.axesMono.set_axis_bgcolor('white')
self.axesBi.set_axis_bgcolor('white')
self.axesLoc.set_axis_bgcolor('white')
self.axesMono.set_title('Raw Ultrasound Signal', size=12)
pylab.setp(self.axesMono.get_xticklabels(), fontsize=8)
pylab.setp(self.axesMono.get_yticklabels(), fontsize=8)
pylab.setp(self.axesBi.get_xticklabels(), fontsize=8)
pylab.setp(self.axesBi.get_yticklabels(), fontsize=8)
pylab.setp(self.axesLoc.get_xticklabels(), fontsize=8)
pylab.setp(self.axesLoc.get_yticklabels(), fontsize=8)
def on_redraw_timer(self, event):
self.draw_plot()
def draw_plot(self):
self.axesMono.clear()
self.axesBi.clear()
self.axesLoc.clear()
i = np.arange(1,100)
w = i;
x = pylab.randn(100,1);
y = pylab.randn(100, 1);
z = pylab.randn(100, 1);
# self.axesMono.set_xdata(np.arange(len(x)))
# self.axesMono.set_ydata(np.array(x))
self.axesMono.plot(x, 'red')
self.axesBi.plot(x,'yellow')
self.axesLoc.plot(x, z, 'black')
self.canvas.draw()
def on_save_plot(self, event):
file_choices = "PNG (*.png)|*.png"
dlg = wx.FileDialog(
self,
message="Save plot as...",
defaultDir=os.getcwd(),
defaultFile="plot.png",
wildcard=file_choices,
style=wx.SAVE)
if dlg.ShowModal() == wx.ID_OK:
path = dlg.GetPath()
self.canvas.print_figure(path, dpi=self.dpi)
self.flash_status_message("Saved to %s" % path)
def on_exit(self, event):
self.Destroy()
if __name__ == '__main__':
app = wx.PySimpleApp()
app.frame = UltrasoundDemoGUI()
app.frame.Show()
app.frame.draw_plot()
app.MainLoop()
del app
I appear to be limited to a refresh rate of approximate 3Hz. Ideally I'd like to visualize the data at a frame rate of 10Hz or higher. Does anyone have any idea how I can efficiently (quickly) update the plots using the axes object?
Thanks for your help,
-B

Related

Matplotlib cross hair cursor in PyQt5

I want to add a cross hair that snaps to data points and be updated on mouse move. I found this example that works well:
import numpy as np
import matplotlib.pyplot as plt
class SnappingCursor:
"""
A cross hair cursor that snaps to the data point of a line, which is
closest to the *x* position of the cursor.
For simplicity, this assumes that *x* values of the data are sorted.
"""
def __init__(self, ax, line):
self.ax = ax
self.horizontal_line = ax.axhline(color='k', lw=0.8, ls='--')
self.vertical_line = ax.axvline(color='k', lw=0.8, ls='--')
self.x, self.y = line.get_data()
self._last_index = None
# text location in axes coords
self.text = ax.text(0.72, 0.9, '', transform=ax.transAxes)
def set_cross_hair_visible(self, visible):
need_redraw = self.vertical_line.get_visible() != visible
self.vertical_line.set_visible(visible)
self.horizontal_line.set_visible(visible)
self.text.set_visible(visible)
return need_redraw
def on_mouse_move(self, event):
if not event.inaxes:
self._last_index = None
need_redraw = self.set_cross_hair_visible(False)
if need_redraw:
self.ax.figure.canvas.draw()
else:
self.set_cross_hair_visible(True)
x, y = event.xdata, event.ydata
index = min(np.searchsorted(self.y, y), len(self.y) - 1)
if index == self._last_index:
return # still on the same data point. Nothing to do.
self._last_index = index
x = self.x[index]
y = self.y[index]
# update the line positions
self.horizontal_line.set_ydata(y)
self.vertical_line.set_xdata(x)
self.text.set_text('x=%1.2f, y=%1.2f' % (x, y))
self.ax.figure.canvas.draw()
y = np.arange(0, 1, 0.01)
x = np.sin(2 * 2 * np.pi * y)
fig, ax = plt.subplots()
ax.set_title('Snapping cursor')
line, = ax.plot(x, y, 'o')
snap_cursor = SnappingCursor(ax, line)
fig.canvas.mpl_connect('motion_notify_event', snap_cursor.on_mouse_move)
plt.show()
But I get into trouble when I want to adapt the code with the PyQt5 and show the plot in a GUI. My code is:
from PyQt5.QtWidgets import QApplication, QMainWindow, QWidget, QVBoxLayout
import sys
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.figure import Figure
import numpy as np
class SnappingCursor:
"""
A cross hair cursor that snaps to the data point of a line, which is
closest to the *x* position of the cursor.
For simplicity, this assumes that *x* values of the data are sorted.
"""
def __init__(self, ax, line):
self.ax = ax
self.horizontal_line = ax.axhline(color='k', lw=0.8, ls='--')
self.vertical_line = ax.axvline(color='k', lw=0.8, ls='--')
self.x, self.y = line.get_data()
self._last_index = None
# text location in axes coords
self.text = ax.text(0.72, 0.9, '', transform=ax.transAxes)
def set_cross_hair_visible(self, visible):
need_redraw = self.vertical_line.get_visible() != visible
self.vertical_line.set_visible(visible)
self.horizontal_line.set_visible(visible)
self.text.set_visible(visible)
return need_redraw
def on_mouse_move(self, event):
if not event.inaxes:
self._last_index = None
need_redraw = self.set_cross_hair_visible(False)
if need_redraw:
self.ax.figure.canvas.draw()
else:
self.set_cross_hair_visible(True)
x, y = event.xdata, event.ydata
index = min(np.searchsorted(self.y, y), len(self.y) - 1)
if index == self._last_index:
return # still on the same data point. Nothing to do.
self._last_index = index
x = self.x[index]
y = self.y[index]
# update the line positions
self.horizontal_line.set_ydata(y)
self.vertical_line.set_xdata(x)
self.text.set_text('x=%1.2f, y=%1.2f' % (x, y))
self.ax.figure.canvas.draw()
class Window(QMainWindow):
def __init__(self):
super().__init__()
widget=QWidget()
vbox=QVBoxLayout()
plot1 = FigureCanvas(Figure(tight_layout=True, linewidth=3))
ax = plot1.figure.subplots()
x = np.arange(0, 1, 0.01)
y = np.sin(2 * 2 * np.pi * x)
line, = ax.plot(x, y, 'o')
snap_cursor = SnappingCursor(ax, line)
plot1.mpl_connect('motion_notify_event', snap_cursor.on_mouse_move)
vbox.addWidget(plot1)
widget.setLayout(vbox)
self.setCentralWidget(widget)
self.setWindowTitle('Example')
self.show()
App = QApplication(sys.argv)
window = Window()
sys.exit(App.exec())
By running the above code, the data is plotted properly, but the cross hair is only shown in its initial position and does not move by mouse movement. Data values are also not displayed.
I have found a similar question here too, but the question is not answered clearly.
There are 2 problems:
snap_cursor is a local variable that will be removed when __init__ finishes executing. You must make him a member of the class.
The initial code of the tutorial is designed so that the point that information is displayed is the horizontal line that passes through the cursor and intersects the curve. In your initial code it differs from the example and also does not work for your new curve so I restored the logic of the tutorial.
import sys
from PyQt5.QtWidgets import QApplication, QMainWindow, QVBoxLayout, QWidget
import numpy as np
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.figure import Figure
class SnappingCursor:
"""
A cross hair cursor that snaps to the data point of a line, which is
closest to the *x* position of the cursor.
For simplicity, this assumes that *x* values of the data are sorted.
"""
def __init__(self, ax, line):
self.ax = ax
self.horizontal_line = ax.axhline(color="k", lw=0.8, ls="--")
self.vertical_line = ax.axvline(color="k", lw=0.8, ls="--")
self.x, self.y = line.get_data()
self._last_index = None
# text location in axes coords
self.text = ax.text(0.72, 0.9, "", transform=ax.transAxes)
def set_cross_hair_visible(self, visible):
need_redraw = self.vertical_line.get_visible() != visible
self.vertical_line.set_visible(visible)
self.horizontal_line.set_visible(visible)
self.text.set_visible(visible)
return need_redraw
def on_mouse_move(self, event):
if not event.inaxes:
self._last_index = None
need_redraw = self.set_cross_hair_visible(False)
if need_redraw:
self.ax.figure.canvas.draw()
else:
self.set_cross_hair_visible(True)
x, y = event.xdata, event.ydata
index = min(np.searchsorted(self.x, x), len(self.x) - 1)
if index == self._last_index:
return # still on the same data point. Nothing to do.
self._last_index = index
x = self.x[index]
y = self.y[index]
# update the line positions
self.horizontal_line.set_ydata(y)
self.vertical_line.set_xdata(x)
self.text.set_text("x=%1.2f, y=%1.2f" % (x, y))
self.ax.figure.canvas.draw()
class Window(QMainWindow):
def __init__(self):
super().__init__()
widget = QWidget()
vbox = QVBoxLayout(widget)
x = np.arange(0, 1, 0.01)
y = np.sin(2 * 2 * np.pi * x)
canvas = FigureCanvas(Figure(tight_layout=True, linewidth=3))
ax = canvas.figure.subplots()
ax.set_title("Snapping cursor")
(line,) = ax.plot(x, y, "o")
self.snap_cursor = SnappingCursor(ax, line)
canvas.mpl_connect("motion_notify_event", self.snap_cursor.on_mouse_move)
vbox.addWidget(canvas)
self.setCentralWidget(widget)
self.setWindowTitle("Example")
app = QApplication(sys.argv)
w = Window()
w.show()
app.exec()

How to override mpl_toolkits.mplot3d.Axes3D.draw() method?

I'm doing a small project which requires to resolve a bug in matplotlib in order to fix zorders of some ax.patches and ax.collections. More exactly, ax.patches are symbols rotatable in space and ax.collections are sides of ax.voxels (so text must be placed on them). I know so far, that a bug is hidden in draw method of mpl_toolkits.mplot3d.Axes3D: zorder are recalculated each time I move my diagram in an undesired way. So I decided to change definition of draw method in these lines:
for i, col in enumerate(
sorted(self.collections,
key=lambda col: col.do_3d_projection(renderer),
reverse=True)):
#col.zorder = zorder_offset + i #comment this line
col.zorder = col.stable_zorder + i #add this extra line
for i, patch in enumerate(
sorted(self.patches,
key=lambda patch: patch.do_3d_projection(renderer),
reverse=True)):
#patch.zorder = zorder_offset + i #comment this line
patch.zorder = patch.stable_zorder + i #add this extra line
It's assumed that every object of ax.collection and ax.patch has a stable_attribute which is assigned manually in my project. So every time I run my project, I must be sure that mpl_toolkits.mplot3d.Axes3D.draw method is changed manually (outside my project). How to avoid this change and override this method in any way inside my project?
This is MWE of my project:
import matplotlib.pyplot as plt
import numpy as np
#from mpl_toolkits.mplot3d import Axes3D
import mpl_toolkits.mplot3d.art3d as art3d
from matplotlib.text import TextPath
from matplotlib.transforms import Affine2D
from matplotlib.patches import PathPatch
class VisualArray:
def __init__(self, arr, fig=None, ax=None):
if len(arr.shape) == 1:
arr = arr[None,None,:]
elif len(arr.shape) == 2:
arr = arr[None,:,:]
elif len(arr.shape) > 3:
raise NotImplementedError('More than 3 dimensions is not supported')
self.arr = arr
if fig is None:
self.fig = plt.figure()
else:
self.fig = fig
if ax is None:
self.ax = self.fig.gca(projection='3d')
else:
self.ax = ax
self.ax.azim, self.ax.elev = -120, 30
self.colors = None
def text3d(self, xyz, s, zdir="z", zorder=1, size=None, angle=0, usetex=False, **kwargs):
d = {'-x': np.array([[-1.0, 0.0, 0], [0.0, 1.0, 0.0], [0, 0.0, -1]]),
'-y': np.array([[0.0, 1.0, 0], [-1.0, 0.0, 0.0], [0, 0.0, 1]]),
'-z': np.array([[1.0, 0.0, 0], [0.0, -1.0, 0.0], [0, 0.0, -1]])}
x, y, z = xyz
if "y" in zdir:
x, y, z = x, z, y
elif "x" in zdir:
x, y, z = y, z, x
elif "z" in zdir:
x, y, z = x, y, z
text_path = TextPath((-0.5, -0.5), s, size=size, usetex=usetex)
aff = Affine2D()
trans = aff.rotate(angle)
# apply additional rotation of text_paths if side is dark
if '-' in zdir:
trans._mtx = np.dot(d[zdir], trans._mtx)
trans = trans.translate(x, y)
p = PathPatch(trans.transform_path(text_path), **kwargs)
self.ax.add_patch(p)
art3d.pathpatch_2d_to_3d(p, z=z, zdir=zdir)
p.stable_zorder = zorder
return p
def on_rotation(self, event):
vrot_idx = [self.ax.elev > 0, True].index(True)
v_zorders = 10000 * np.array([(1, -1), (-1, 1)])[vrot_idx]
for side, zorder in zip((self.side1, self.side4), v_zorders):
for patch in side:
patch.stable_zorder = zorder
hrot_idx = [self.ax.azim < -90, self.ax.azim < 0, self.ax.azim < 90, True].index(True)
h_zorders = 10000 * np.array([(1, 1, -1, -1), (-1, 1, 1, -1),
(-1, -1, 1, 1), (1, -1, -1, 1)])[hrot_idx]
sides = (self.side3, self.side2, self.side6, self.side5)
for side, zorder in zip(sides, h_zorders):
for patch in side:
patch.stable_zorder = zorder
def voxelize(self):
shape = self.arr.shape[::-1]
x, y, z = np.indices(shape)
arr = (x < shape[0]) & (y < shape[1]) & (z < shape[2])
self.ax.voxels(arr, facecolors=self.colors, edgecolor='k')
for col in self.ax.collections:
col.stable_zorder = col.zorder
def labelize(self):
self.fig.canvas.mpl_connect('motion_notify_event', self.on_rotation)
s = self.arr.shape
self.side1, self.side2, self.side3, self.side4, self.side5, self.side6 = [], [], [], [], [], []
# labelling surfaces of side1 and side4
surf = np.indices((s[2], s[1])).T[::-1].reshape(-1, 2) + 0.5
surf_pos1 = np.insert(surf, 2, self.arr.shape[0], axis=1)
surf_pos2 = np.insert(surf, 2, 0, axis=1)
labels1 = (self.arr[0]).flatten()
labels2 = (self.arr[-1]).flatten()
for xyz, label in zip(surf_pos1, [f'${n}$' for n in labels1]):
t = self.text3d(xyz, label, zdir="z", zorder=10000, size=1, usetex=True, ec="none", fc="k")
self.side1.append(t)
for xyz, label in zip(surf_pos2, [f'${n}$' for n in labels2]):
t = self.text3d(xyz, label, zdir="-z", zorder=-10000, size=1, usetex=True, ec="none", fc="k")
self.side4.append(t)
# labelling surfaces of side2 and side5
surf = np.indices((s[2], s[0])).T[::-1].reshape(-1, 2) + 0.5
surf_pos1 = np.insert(surf, 1, 0, axis=1)
surf = np.indices((s[0], s[2])).T[::-1].reshape(-1, 2) + 0.5
surf_pos2 = np.insert(surf, 1, self.arr.shape[1], axis=1)
labels1 = (self.arr[:, -1]).flatten()
labels2 = (self.arr[::-1, 0].T[::-1]).flatten()
for xyz, label in zip(surf_pos1, [f'${n}$' for n in labels1]):
t = self.text3d(xyz, label, zdir="y", zorder=10000, size=1, usetex=True, ec="none", fc="k")
self.side2.append(t)
for xyz, label in zip(surf_pos2, [f'${n}$' for n in labels2]):
t = self.text3d(xyz, label, zdir="-y", zorder=-10000, size=1, usetex=True, ec="none", fc="k")
self.side5.append(t)
# labelling surfaces of side3 and side6
surf = np.indices((s[1], s[0])).T[::-1].reshape(-1, 2) + 0.5
surf_pos1 = np.insert(surf, 0, self.arr.shape[2], axis=1)
surf_pos2 = np.insert(surf, 0, 0, axis=1)
labels1 = (self.arr[:, ::-1, -1]).flatten()
labels2 = (self.arr[:, ::-1, 0]).flatten()
for xyz, label in zip(surf_pos1, [f'${n}$' for n in labels1]):
t = self.text3d(xyz, label, zdir="x", zorder=-10000, size=1, usetex=True, ec="none", fc="k")
self.side6.append(t)
for xyz, label in zip(surf_pos2, [f'${n}$' for n in labels2]):
t = self.text3d(xyz, label, zdir="-x", zorder=10000, size=1, usetex=True, ec="none", fc="k")
self.side3.append(t)
def vizualize(self):
self.voxelize()
self.labelize()
plt.axis('off')
arr = np.arange(60).reshape((2,6,5))
va = VisualArray(arr)
va.vizualize()
plt.show()
This is an output I get after external change of ...\mpl_toolkits\mplot3d\axes3d.py file:
This is an output (an unwanted one) I get if no change is done:
What you want to achieve is called Monkey Patching.
It has its downsides and has to be used with some care (there is plenty of information available under this keyword). But one option could look something like this:
from matplotlib import artist
from mpl_toolkits.mplot3d import Axes3D
# Create a new draw function
#artist.allow_rasterization
def draw(self, renderer):
# Your version
# ...
# Add Axes3D explicitly to super() calls
super(Axes3D, self).draw(renderer)
# Overwrite the old draw function
Axes3D.draw = draw
# The rest of your code
# ...
Caveats here are to import artist for the decorator and the explicit call super(Axes3D, self).method() instead of just using super().method().
Depending on your use case and to stay compatible with the rest of your code you could also save the original draw function and use the custom only temporarily:
def draw_custom():
...
draw_org = Axes3D.draw
Axes3D.draw = draw_custom
# Do custom stuff
Axes3D.draw = draw_org
# Do normal stuff

move facial landmarks using matplotlib

tried to include your suggestions, not sure why it doesn't work:
# face alignment
import face_alignment
import numpy as np
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
from skimage import io
# Run the 3D face alignment on a test image, without CUDA.
fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, device='cpu', flip_input=True)
input = io.imread(r'C:/Users/Ihr Name/Pictures/Bewerbungsfotos/neuropic.jpg')
preds = fa.get_landmarks(input)[-1]
#landmarks == preds, input = pixels image
import matplotlib.pyplot as plt
import matplotlib.patches as patches
class DraggablePoints(object):
def __init__(self, artists, tolerance=5):
for artist in artists:
artist.set_picker(tolerance)
self.artists = artists
self.currently_dragging = False
self.current_artist = None
self.offset = (0, 0)
for canvas in set(artist.figure.canvas for artist in self.artists):
canvas.mpl_connect('button_press_event', self.on_press)
canvas.mpl_connect('button_release_event', self.on_release)
canvas.mpl_connect('pick_event', self.on_pick)
canvas.mpl_connect('motion_notify_event', self.on_motion)
def on_press(self, event):
self.currently_dragging = True
def on_release(self, event):
self.currently_dragging = False
self.current_artist = None
def on_pick(self, event):
if self.current_artist is None:
self.current_artist = event.artist
x0, y0 = event.artist.center
x1, y1 = event.mouseevent.xdata, event.mouseevent.ydata
self.offset = (x0 - x1), (y0 - y1)
def on_motion(self, event):
if not self.currently_dragging:
return
if self.current_artist is None:
return
dx, dy = self.offset
self.current_artist.center = event.xdata + dx, event.ydata + dy
self.current_artist.figure.canvas.draw()
if __name__ == '__main__':
fig = plt.figure(figsize=plt.figaspect(.5))
ax = fig.add_subplot(1, 2, 1)
ax.imshow(input)
ax.plot(preds[0:17,0],preds[0:17,1],marker='o',markersize=6,linestyle='-',color='w',lw=2)
ax.plot(preds[17:22,0],preds[17:22,1],marker='o',markersize=6,linestyle='-',color='w',lw=2)
ax.plot(preds[22:27,0],preds[22:27,1],marker='o',markersize=6,linestyle='-',color='w',lw=2)
ax.plot(preds[27:31,0],preds[27:31,1],marker='o',markersize=6,linestyle='-',color='w',lw=2)
ax.plot(preds[31:36,0],preds[31:36,1],marker='o',markersize=6,linestyle='-',color='w',lw=2)
ax.plot(preds[36:42,0],preds[36:42,1],marker='o',markersize=6,linestyle='-',color='w',lw=2)
ax.plot(preds[42:48,0],preds[42:48,1],marker='o',markersize=6,linestyle='-',color='w',lw=2)
ax.plot(preds[48:60,0],preds[48:60,1],marker='o',markersize=6,linestyle='-',color='w',lw=2)
ax.plot(preds[60:68,0],preds[60:68,1],marker='o',markersize=6,linestyle='-',color='w',lw=2)
ax.axis('off')
ax = fig.add_subplot(1, 2, 2, projection='3d')
surf = ax.scatter(preds[:,0]*1.2,preds[:,1],preds[:,2],c="cyan", alpha=0.5, edgecolor='b')
ax.plot3D(preds[:17,0]*1.2,preds[:17,1], preds[:17,2], color='blue' )
ax.plot3D(preds[17:22,0]*1.2,preds[17:22,1],preds[17:22,2], color='blue')
ax.plot3D(preds[22:27,0]*1.2,preds[22:27,1],preds[22:27,2], color='blue')
ax.plot3D(preds[27:31,0]*1.2,preds[27:31,1],preds[27:31,2], color='blue')
ax.plot3D(preds[31:36,0]*1.2,preds[31:36,1],preds[31:36,2], color='blue')
ax.plot3D(preds[36:42,0]*1.2,preds[36:42,1],preds[36:42,2], color='blue')
ax.plot3D(preds[42:48,0]*1.2,preds[42:48,1],preds[42:48,2], color='blue')
ax.plot3D(preds[48:,0]*1.2,preds[48:,1],preds[48:,2], color='blue' )
ax.view_init(elev=90., azim=90.)
ax.set_xlim(ax.get_xlim()[::-1])
#we want to move preds (landmarks)
for p in preds:
ax.add_patch(p)
dr = DraggablePoints(preds)
plt.show()

Include matplotlib in pyqt5 with hover labels

I have a plot from matplotlib for which I would like to display labels on the marker points when hover over with the mouse.
I found this very helpful working example on SO and I was trying to integrate the exact same plot into a pyqt5 application.
Unfortunately when having the plot in the application the hovering doesn't work anymore.
Here is a full working example based on the mentioned SO post:
import matplotlib.pyplot as plt
import scipy.spatial as spatial
import numpy as np
from PyQt5.QtCore import *
from PyQt5.QtGui import *
from PyQt5.QtWidgets import *
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
import sys
pi = np.pi
cos = np.cos
def fmt(x, y):
return 'x: {x:0.2f}\ny: {y:0.2f}'.format(x=x, y=y)
class FollowDotCursor(object):
"""Display the x,y location of the nearest data point.
https://stackoverflow.com/a/4674445/190597 (Joe Kington)
https://stackoverflow.com/a/13306887/190597 (unutbu)
https://stackoverflow.com/a/15454427/190597 (unutbu)
"""
def __init__(self, ax, x, y, tolerance=5, formatter=fmt, offsets=(-20, 20)):
try:
x = np.asarray(x, dtype='float')
except (TypeError, ValueError):
x = np.asarray(mdates.date2num(x), dtype='float')
y = np.asarray(y, dtype='float')
mask = ~(np.isnan(x) | np.isnan(y))
x = x[mask]
y = y[mask]
self._points = np.column_stack((x, y))
self.offsets = offsets
y = y[np.abs(y-y.mean()) <= 3*y.std()]
self.scale = x.ptp()
self.scale = y.ptp() / self.scale if self.scale else 1
self.tree = spatial.cKDTree(self.scaled(self._points))
self.formatter = formatter
self.tolerance = tolerance
self.ax = ax
self.fig = ax.figure
self.ax.xaxis.set_label_position('top')
self.dot = ax.scatter(
[x.min()], [y.min()], s=130, color='green', alpha=0.7)
self.annotation = self.setup_annotation()
plt.connect('motion_notify_event', self)
def scaled(self, points):
points = np.asarray(points)
return points * (self.scale, 1)
def __call__(self, event):
ax = self.ax
# event.inaxes is always the current axis. If you use twinx, ax could be
# a different axis.
if event.inaxes == ax:
x, y = event.xdata, event.ydata
elif event.inaxes is None:
return
else:
inv = ax.transData.inverted()
x, y = inv.transform([(event.x, event.y)]).ravel()
annotation = self.annotation
x, y = self.snap(x, y)
annotation.xy = x, y
annotation.set_text(self.formatter(x, y))
self.dot.set_offsets((x, y))
bbox = ax.viewLim
event.canvas.draw()
def setup_annotation(self):
"""Draw and hide the annotation box."""
annotation = self.ax.annotate(
'', xy=(0, 0), ha = 'right',
xytext = self.offsets, textcoords = 'offset points', va = 'bottom',
bbox = dict(
boxstyle='round,pad=0.5', fc='yellow', alpha=0.75),
arrowprops = dict(
arrowstyle='->', connectionstyle='arc3,rad=0'))
return annotation
def snap(self, x, y):
"""Return the value in self.tree closest to x, y."""
dist, idx = self.tree.query(self.scaled((x, y)), k=1, p=1)
try:
return self._points[idx]
except IndexError:
# IndexError: index out of bounds
return self._points[0]
class MainWindow(QMainWindow):
def __init__(self):
super().__init__()
self.width = 1000
self.height = 800
self.setGeometry(0, 0, self.width, self.height)
canvas = self.get_canvas()
w = QWidget()
w.layout = QHBoxLayout()
w.layout.addWidget(canvas)
w.setLayout(w.layout)
self.setCentralWidget(w)
self.show()
def get_canvas(self):
fig, ax = plt.subplots()
x = np.linspace(0.1, 2*pi, 10)
y = cos(x)
markerline, stemlines, baseline = ax.stem(x, y, '-.')
plt.setp(markerline, 'markerfacecolor', 'b')
plt.setp(baseline, 'color','r', 'linewidth', 2)
cursor = FollowDotCursor(ax, x, y, tolerance=20)
canvas = FigureCanvas(fig)
return canvas
app = QApplication(sys.argv)
win = MainWindow()
sys.exit(app.exec_())
What would I have to do to make the labels also show when hovering over in the pyqt application?
The first problem may be that you don't keep a reference to the FollowDotCursor.
So to make sure the FollowDotCursor stays alive, you can make it a class variable
self.cursor = FollowDotCursor(ax, x, y, tolerance=20)
instead of cursor = ....
Next make sure you instatiate the Cursor class after giving the figure a canvas.
canvas = FigureCanvas(fig)
self.cursor = FollowDotCursor(ax, x, y, tolerance=20)
Finally, keep a reference to the callback inside the FollowDotCursor and don't use plt.connect but the canvas itself:
self.cid = self.fig.canvas.mpl_connect('motion_notify_event', self)

Heatmap with text in each cell with matplotlib's pyplot

I use matplotlib.pyplot.pcolor() to plot a heatmap with matplotlib:
import numpy as np
import matplotlib.pyplot as plt
def heatmap(data, title, xlabel, ylabel):
plt.figure()
plt.title(title)
plt.xlabel(xlabel)
plt.ylabel(ylabel)
c = plt.pcolor(data, edgecolors='k', linewidths=4, cmap='RdBu', vmin=0.0, vmax=1.0)
plt.colorbar(c)
def main():
title = "ROC's AUC"
xlabel= "Timeshift"
ylabel="Scales"
data = np.random.rand(8,12)
heatmap(data, title, xlabel, ylabel)
plt.show()
if __name__ == "__main__":
main()
Is any way to add the corresponding value in each cell, e.g.:
(from Matlab's Customizable Heat Maps)
(I don't need the additional % for my current application, though I'd be curious to know for the future)
You need to add all the text by calling axes.text(), here is an example:
import numpy as np
import matplotlib.pyplot as plt
title = "ROC's AUC"
xlabel= "Timeshift"
ylabel="Scales"
data = np.random.rand(8,12)
plt.figure(figsize=(12, 6))
plt.title(title)
plt.xlabel(xlabel)
plt.ylabel(ylabel)
c = plt.pcolor(data, edgecolors='k', linewidths=4, cmap='RdBu', vmin=0.0, vmax=1.0)
def show_values(pc, fmt="%.2f", **kw):
from itertools import izip
pc.update_scalarmappable()
ax = pc.get_axes()
for p, color, value in izip(pc.get_paths(), pc.get_facecolors(), pc.get_array()):
x, y = p.vertices[:-2, :].mean(0)
if np.all(color[:3] > 0.5):
color = (0.0, 0.0, 0.0)
else:
color = (1.0, 1.0, 1.0)
ax.text(x, y, fmt % value, ha="center", va="center", color=color, **kw)
show_values(c)
plt.colorbar(c)
the output:
You could use Seaborn, which is a Python visualization library based on matplotlib that provides a high-level interface for drawing attractive statistical graphics.
Heatmap example:
import seaborn as sns
sns.set()
flights_long = sns.load_dataset("flights")
flights = flights_long.pivot("month", "year", "passengers")
sns.heatmap(flights, annot=True, fmt="d")
# To display the heatmap
import matplotlib.pyplot as plt
plt.show()
# To save the heatmap as a file:
fig = heatmap.get_figure()
fig.savefig('heatmap.pdf')
Documentation: https://seaborn.pydata.org/generated/seaborn.heatmap.html
If that's of interest to anyone, here is below the code I use to imitate the picture from Matlab's Customizable Heat Maps I had included in the question).
import numpy as np
import matplotlib.pyplot as plt
def show_values(pc, fmt="%.2f", **kw):
'''
Heatmap with text in each cell with matplotlib's pyplot
Source: http://stackoverflow.com/a/25074150/395857
By HYRY
'''
from itertools import izip
pc.update_scalarmappable()
ax = pc.get_axes()
for p, color, value in izip(pc.get_paths(), pc.get_facecolors(), pc.get_array()):
x, y = p.vertices[:-2, :].mean(0)
if np.all(color[:3] > 0.5):
color = (0.0, 0.0, 0.0)
else:
color = (1.0, 1.0, 1.0)
ax.text(x, y, fmt % value, ha="center", va="center", color=color, **kw)
def cm2inch(*tupl):
'''
Specify figure size in centimeter in matplotlib
Source: http://stackoverflow.com/a/22787457/395857
By gns-ank
'''
inch = 2.54
if type(tupl[0]) == tuple:
return tuple(i/inch for i in tupl[0])
else:
return tuple(i/inch for i in tupl)
def heatmap(AUC, title, xlabel, ylabel, xticklabels, yticklabels):
'''
Inspired by:
- http://stackoverflow.com/a/16124677/395857
- http://stackoverflow.com/a/25074150/395857
'''
# Plot it out
fig, ax = plt.subplots()
c = ax.pcolor(AUC, edgecolors='k', linestyle= 'dashed', linewidths=0.2, cmap='RdBu', vmin=0.0, vmax=1.0)
# put the major ticks at the middle of each cell
ax.set_yticks(np.arange(AUC.shape[0]) + 0.5, minor=False)
ax.set_xticks(np.arange(AUC.shape[1]) + 0.5, minor=False)
# set tick labels
#ax.set_xticklabels(np.arange(1,AUC.shape[1]+1), minor=False)
ax.set_xticklabels(xticklabels, minor=False)
ax.set_yticklabels(yticklabels, minor=False)
# set title and x/y labels
plt.title(title)
plt.xlabel(xlabel)
plt.ylabel(ylabel)
# Remove last blank column
plt.xlim( (0, AUC.shape[1]) )
# Turn off all the ticks
ax = plt.gca()
for t in ax.xaxis.get_major_ticks():
t.tick1On = False
t.tick2On = False
for t in ax.yaxis.get_major_ticks():
t.tick1On = False
t.tick2On = False
# Add color bar
plt.colorbar(c)
# Add text in each cell
show_values(c)
# resize
fig = plt.gcf()
fig.set_size_inches(cm2inch(40, 20))
def main():
x_axis_size = 19
y_axis_size = 10
title = "ROC's AUC"
xlabel= "Timeshift"
ylabel="Scales"
data = np.random.rand(y_axis_size,x_axis_size)
xticklabels = range(1, x_axis_size+1) # could be text
yticklabels = range(1, y_axis_size+1) # could be text
heatmap(data, title, xlabel, ylabel, xticklabels, yticklabels)
plt.savefig('image_output.png', dpi=300, format='png', bbox_inches='tight') # use format='svg' or 'pdf' for vectorial pictures
plt.show()
if __name__ == "__main__":
main()
#cProfile.run('main()') # if you want to do some profiling
Output:
It looks nicer when there are some patterns:
Same as #HYRY aswer, but python3 compatible version:
import numpy as np
import matplotlib.pyplot as plt
title = "ROC's AUC"
xlabel= "Timeshift"
ylabel="Scales"
data = np.random.rand(8,12)
plt.figure(figsize=(12, 6))
plt.title(title)
plt.xlabel(xlabel)
plt.ylabel(ylabel)
c = plt.pcolor(data, edgecolors='k', linewidths=4, cmap='RdBu', vmin=0.0, vmax=1.0)
def show_values(pc, fmt="%.2f", **kw):
pc.update_scalarmappable()
ax = pc.axes
for p, color, value in zip(pc.get_paths(), pc.get_facecolors(), pc.get_array()):
x, y = p.vertices[:-2, :].mean(0)
if np.all(color[:3] > 0.5):
color = (0.0, 0.0, 0.0)
else:
color = (1.0, 1.0, 1.0)
ax.text(x, y, fmt % value, ha="center", va="center", color=color, **kw)
show_values(c)
plt.colorbar(c)