How to override mpl_toolkits.mplot3d.Axes3D.draw() method? - matplotlib

I'm doing a small project which requires to resolve a bug in matplotlib in order to fix zorders of some ax.patches and ax.collections. More exactly, ax.patches are symbols rotatable in space and ax.collections are sides of ax.voxels (so text must be placed on them). I know so far, that a bug is hidden in draw method of mpl_toolkits.mplot3d.Axes3D: zorder are recalculated each time I move my diagram in an undesired way. So I decided to change definition of draw method in these lines:
for i, col in enumerate(
sorted(self.collections,
key=lambda col: col.do_3d_projection(renderer),
reverse=True)):
#col.zorder = zorder_offset + i #comment this line
col.zorder = col.stable_zorder + i #add this extra line
for i, patch in enumerate(
sorted(self.patches,
key=lambda patch: patch.do_3d_projection(renderer),
reverse=True)):
#patch.zorder = zorder_offset + i #comment this line
patch.zorder = patch.stable_zorder + i #add this extra line
It's assumed that every object of ax.collection and ax.patch has a stable_attribute which is assigned manually in my project. So every time I run my project, I must be sure that mpl_toolkits.mplot3d.Axes3D.draw method is changed manually (outside my project). How to avoid this change and override this method in any way inside my project?
This is MWE of my project:
import matplotlib.pyplot as plt
import numpy as np
#from mpl_toolkits.mplot3d import Axes3D
import mpl_toolkits.mplot3d.art3d as art3d
from matplotlib.text import TextPath
from matplotlib.transforms import Affine2D
from matplotlib.patches import PathPatch
class VisualArray:
def __init__(self, arr, fig=None, ax=None):
if len(arr.shape) == 1:
arr = arr[None,None,:]
elif len(arr.shape) == 2:
arr = arr[None,:,:]
elif len(arr.shape) > 3:
raise NotImplementedError('More than 3 dimensions is not supported')
self.arr = arr
if fig is None:
self.fig = plt.figure()
else:
self.fig = fig
if ax is None:
self.ax = self.fig.gca(projection='3d')
else:
self.ax = ax
self.ax.azim, self.ax.elev = -120, 30
self.colors = None
def text3d(self, xyz, s, zdir="z", zorder=1, size=None, angle=0, usetex=False, **kwargs):
d = {'-x': np.array([[-1.0, 0.0, 0], [0.0, 1.0, 0.0], [0, 0.0, -1]]),
'-y': np.array([[0.0, 1.0, 0], [-1.0, 0.0, 0.0], [0, 0.0, 1]]),
'-z': np.array([[1.0, 0.0, 0], [0.0, -1.0, 0.0], [0, 0.0, -1]])}
x, y, z = xyz
if "y" in zdir:
x, y, z = x, z, y
elif "x" in zdir:
x, y, z = y, z, x
elif "z" in zdir:
x, y, z = x, y, z
text_path = TextPath((-0.5, -0.5), s, size=size, usetex=usetex)
aff = Affine2D()
trans = aff.rotate(angle)
# apply additional rotation of text_paths if side is dark
if '-' in zdir:
trans._mtx = np.dot(d[zdir], trans._mtx)
trans = trans.translate(x, y)
p = PathPatch(trans.transform_path(text_path), **kwargs)
self.ax.add_patch(p)
art3d.pathpatch_2d_to_3d(p, z=z, zdir=zdir)
p.stable_zorder = zorder
return p
def on_rotation(self, event):
vrot_idx = [self.ax.elev > 0, True].index(True)
v_zorders = 10000 * np.array([(1, -1), (-1, 1)])[vrot_idx]
for side, zorder in zip((self.side1, self.side4), v_zorders):
for patch in side:
patch.stable_zorder = zorder
hrot_idx = [self.ax.azim < -90, self.ax.azim < 0, self.ax.azim < 90, True].index(True)
h_zorders = 10000 * np.array([(1, 1, -1, -1), (-1, 1, 1, -1),
(-1, -1, 1, 1), (1, -1, -1, 1)])[hrot_idx]
sides = (self.side3, self.side2, self.side6, self.side5)
for side, zorder in zip(sides, h_zorders):
for patch in side:
patch.stable_zorder = zorder
def voxelize(self):
shape = self.arr.shape[::-1]
x, y, z = np.indices(shape)
arr = (x < shape[0]) & (y < shape[1]) & (z < shape[2])
self.ax.voxels(arr, facecolors=self.colors, edgecolor='k')
for col in self.ax.collections:
col.stable_zorder = col.zorder
def labelize(self):
self.fig.canvas.mpl_connect('motion_notify_event', self.on_rotation)
s = self.arr.shape
self.side1, self.side2, self.side3, self.side4, self.side5, self.side6 = [], [], [], [], [], []
# labelling surfaces of side1 and side4
surf = np.indices((s[2], s[1])).T[::-1].reshape(-1, 2) + 0.5
surf_pos1 = np.insert(surf, 2, self.arr.shape[0], axis=1)
surf_pos2 = np.insert(surf, 2, 0, axis=1)
labels1 = (self.arr[0]).flatten()
labels2 = (self.arr[-1]).flatten()
for xyz, label in zip(surf_pos1, [f'${n}$' for n in labels1]):
t = self.text3d(xyz, label, zdir="z", zorder=10000, size=1, usetex=True, ec="none", fc="k")
self.side1.append(t)
for xyz, label in zip(surf_pos2, [f'${n}$' for n in labels2]):
t = self.text3d(xyz, label, zdir="-z", zorder=-10000, size=1, usetex=True, ec="none", fc="k")
self.side4.append(t)
# labelling surfaces of side2 and side5
surf = np.indices((s[2], s[0])).T[::-1].reshape(-1, 2) + 0.5
surf_pos1 = np.insert(surf, 1, 0, axis=1)
surf = np.indices((s[0], s[2])).T[::-1].reshape(-1, 2) + 0.5
surf_pos2 = np.insert(surf, 1, self.arr.shape[1], axis=1)
labels1 = (self.arr[:, -1]).flatten()
labels2 = (self.arr[::-1, 0].T[::-1]).flatten()
for xyz, label in zip(surf_pos1, [f'${n}$' for n in labels1]):
t = self.text3d(xyz, label, zdir="y", zorder=10000, size=1, usetex=True, ec="none", fc="k")
self.side2.append(t)
for xyz, label in zip(surf_pos2, [f'${n}$' for n in labels2]):
t = self.text3d(xyz, label, zdir="-y", zorder=-10000, size=1, usetex=True, ec="none", fc="k")
self.side5.append(t)
# labelling surfaces of side3 and side6
surf = np.indices((s[1], s[0])).T[::-1].reshape(-1, 2) + 0.5
surf_pos1 = np.insert(surf, 0, self.arr.shape[2], axis=1)
surf_pos2 = np.insert(surf, 0, 0, axis=1)
labels1 = (self.arr[:, ::-1, -1]).flatten()
labels2 = (self.arr[:, ::-1, 0]).flatten()
for xyz, label in zip(surf_pos1, [f'${n}$' for n in labels1]):
t = self.text3d(xyz, label, zdir="x", zorder=-10000, size=1, usetex=True, ec="none", fc="k")
self.side6.append(t)
for xyz, label in zip(surf_pos2, [f'${n}$' for n in labels2]):
t = self.text3d(xyz, label, zdir="-x", zorder=10000, size=1, usetex=True, ec="none", fc="k")
self.side3.append(t)
def vizualize(self):
self.voxelize()
self.labelize()
plt.axis('off')
arr = np.arange(60).reshape((2,6,5))
va = VisualArray(arr)
va.vizualize()
plt.show()
This is an output I get after external change of ...\mpl_toolkits\mplot3d\axes3d.py file:
This is an output (an unwanted one) I get if no change is done:

What you want to achieve is called Monkey Patching.
It has its downsides and has to be used with some care (there is plenty of information available under this keyword). But one option could look something like this:
from matplotlib import artist
from mpl_toolkits.mplot3d import Axes3D
# Create a new draw function
#artist.allow_rasterization
def draw(self, renderer):
# Your version
# ...
# Add Axes3D explicitly to super() calls
super(Axes3D, self).draw(renderer)
# Overwrite the old draw function
Axes3D.draw = draw
# The rest of your code
# ...
Caveats here are to import artist for the decorator and the explicit call super(Axes3D, self).method() instead of just using super().method().
Depending on your use case and to stay compatible with the rest of your code you could also save the original draw function and use the custom only temporarily:
def draw_custom():
...
draw_org = Axes3D.draw
Axes3D.draw = draw_custom
# Do custom stuff
Axes3D.draw = draw_org
# Do normal stuff

Related

Real-time plotting of a custom turning marker

Is it somehow possible to plot a custom marker (like this) interactively, but have it turn in real-time? It seems that the scatter graph does not grant any access to the markers.
You can create a custom marker with a FancyArrowPatch. Many styles and options are possible. Such a patch is not easy to update, but you could just remove the patch and create it again to create an animation.
The easiest way to create an animation is via plt.pause(), but that doesn't work in all environments. Another way is via FuncAnimation, which involves a few more lines, but makes controlling the animation easier.
Here is some example code to show the concepts:
import matplotlib.pyplot as plt
from matplotlib import patches
from matplotlib.collections import PatchCollection
from matplotlib import animation
import numpy as np
fig, ax = plt.subplots()
N = 50
x = np.random.uniform(-20, 20, (N, 2))
dx = np.random.uniform(-1, 1, (N, 2))
dx /= np.linalg.norm(dx, axis=1, keepdims=True)
colors = plt.cm.magma(np.random.uniform(0, 1, N))
arrow_style = "Simple,head_length=2,head_width=3,tail_width=1"
ax.set_xlim(-40, 40)
ax.set_ylim(-30, 30)
ax.set_aspect('equal')
old_arrows = None
def animate(i):
global old_arrows, x, dx
if old_arrows is not None:
old_arrows.remove()
x += dx
dx += np.random.uniform(-.1, .1, (N, 2))
dx /= np.linalg.norm(dx, axis=1, keepdims=True)
arrows = [patches.FancyArrowPatch((xi, yi), (xi + dxi * 10, yi + dyi * 10), arrowstyle=arrow_style)
for (xi, yi), (dxi, dyi) in zip(x, dx)]
old_arrows = ax.add_collection(PatchCollection(arrows, facecolors=colors))
return old_arrows,
ani = animation.FuncAnimation(fig, animate, np.arange(1, 200),
interval=25, repeat=False, blit=True)
plt.show()
I solved it by remove() and static variables like this:
class pltMarker:
def __init__(self, angle=None, pathString=None):
self.angle = angle or []
self.pathString = pathString or """simply make and svg, open in a text editor and copy the path XML string in here"""
self.path = parse_path( self.pathString )
self.path.vertices -= self.path.vertices.mean( axis=0 )
self.marker = mpl.markers.MarkerStyle( marker=self.path )
self.marker._transform = self.marker.get_transform().rotate_deg(angle)
def rotate(self, angle=0):
self.marker._transform = self.marker.get_transform().rotate_deg(angle)
def animate(k):
angle = ... # new angle
myPltMarker.rotate(angle)
animate.Scatter.remove()
animate.Scatter = plt.scatter(1, 0, marker=myPltMarker.marker, s=100)
return animate.Scatter,
angle = ...
myPltMarker = pltMarker(angle=angle)
animatePlt.Scatter = plt.scatter(1, 0, marker=myPltMarker.marker, s=100)
anm = animation.FuncAnimation(fig, animate, blit=False, interval=1)
plt.show()

AttributeError: 'Axes' object has no attribute 'get_proj' in matplotlib

I am making the representation of a polynom function.
I have an error in a matplotlib code and cannot understand where it is coming from. any advice is welcome.
I tried already Gtk3agg but nothing changed.
Below is the failure code.
For any reason 'get_proj' dont work here for creating labels.
And: when I use ax.get_proj() instead,
a) all labels appear bottom left
b) not all labels appear at bottom left (all points are identified by the cursor bot the labels are not written at the bottom left).
The final project will be (few things still to be done):
- on button -> labelling with coordinate appear at each cursor movement (temporary)
- click right button, the labels will be persistent till button clear is clicked
- off button -> no labelling appear
My feeling: the 3x button creation is messing anything up.
# -*- coding: utf-8 -*-
import matplotlib as mpl
from mpl_toolkits.mplot3d.proj3d import proj_transform
import matplotlib.pyplot as plt
from matplotlib.widgets import Button
import numpy as np
mpl.use('tkagg')
def distance(point, event):
plt.sca(ax) # <------------------ introduce this one !!!!!!!!!!!!!!!!!!!!!!!!!!!
x2, y2, _ = proj_transform(point[0], point[1], point[2], plt.gca().get_proj())
x3, y3 = ax.transData.transform((x2, y2))
return np.sqrt ((x3 - event.x)**2 + (y3 - event.y)**2)
def calcClosestDatapoint(X, event):
distances = [distance(X[i, 0:3], event) for i in range(Sol)]
return np.argmin(distances)
#
def annotatePlot(X, index):
global last_mark, generated_labels
if activated_labelling:
x2, y2, _ = proj_transform(X[index, 0], X[index, 1], X[index, 2], ax.get_proj())
last_mark = plt.annotate(generated_labels[index],
xy = (x2, y2), xytext = (-20, 20), textcoords = 'offset points', ha = 'right', va = 'bottom',
bbox = dict(boxstyle = 'round,pad=0.5', fc = 'yellow', alpha = 0.5),
arrowprops = dict(arrowstyle = '->', connectionstyle = 'arc3,rad=0'))
fig.canvas.draw()
#
def onMouseMotion(event):
global Coord
if activated_labelling:
closestIndex = calcClosestDatapoint(Coord, event)
last_mark.remove()
annotatePlot(Coord, closestIndex)
def show_on(event):
global activated_labelling, last_mark,pid,mid
if activated_labelling == False:
activated_labelling = True
x2, y2, _ = proj_transform(Coord[0,0], Coord[0,1], Coord[0,2], ax.get_proj())
last_mark = plt.annotate("3D measurement on " + generated_labels[0],
xy = (x2, y2), xytext = (-20, 20), textcoords = 'offset points', ha = 'right', va = 'bottom',
bbox = dict(boxstyle = 'round,pad=0.5', fc = 'yellow', alpha = 0.5),
arrowprops = dict(arrowstyle = '->', connectionstyle = 'arc3,rad=0'))
mid = fig.canvas.mpl_connect('motion_notify_event', onMouseMotion)
#
def show_off(event):
global activated_labelling
'''
deactivate the persistent XYZ position labels at the grafic
'''
if activated_labelling:
activated_labelling = False
last_mark.remove()
fig.canvas.draw()
fig.canvas.mpl_disconnect(mid)
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
#ax = fig.gca(projection='3d')
activated_labelling = False
Wide = 100
Minimum = -50
ScanLimit = 3 # searching between o and 3; 4 and 5 are no solutions
Search = 45
Coord=[]
values=[]
generated_labels = []
#
XMin = 0
XMax = 0
YMin = 0
YMax = 0
ZMin = 0
ZMax = 0
# count the solutions found in the scan area defined above
Sol=0
for i in range(Wide+1):
for j in range(Wide+1):
for k in range(Wide+1):
########################################################################
########################################################################
####
#### THIS IS THE POLYNOM TO BE REPRESENTED
####
param_dens = ((i+Minimum)**3)+((j+Minimum)**3)+((k+Minimum)**3) -Search
if abs(param_dens) <= abs(ScanLimit):
Coord.append([i+Minimum,j+Minimum,k+Minimum])
if ScanLimit !=0:
values.append([abs(param_dens)])
labelling = "value {}\nin X:{} Y:{} Z:{}".format(Search+param_dens,i+Minimum,j+Minimum,k+Minimum)
generated_labels.append(labelling)
print(labelling+"\n")
# increase the number indicating the solutions found
Sol +=1
# for centering the window
if XMin > i+Minimum:
XMin = i+Minimum
if YMin > j+Minimum:
YMin = j+Minimum
if ZMin > k+Minimum:
ZMin = k+Minimum
if XMax < i+Minimum:
XMax = i+Minimum
if YMax < j+Minimum:
YMax = j+Minimum
if ZMax < k+Minimum:
ZMax = k+Minimum
print('######################################################')
print('## statistics / move this to a parallel search engine?')
print('## search ')
print("## total solution %d for searching center %d" % (Sol,Search))
print("## from %d to %d" % (Search-ScanLimit,Search+ScanLimit))
print("## from %d to %d" % (Minimum,Wide+Minimum))
print('##')
print('#######################################################')
#
values = np.array(values, dtype='int64')
Coord = np.array(Coord, dtype='int64')
#
if ScanLimit !=0:
cmap = plt.cm.jet # define the colormap
# extract all colors from the .jet map
cmaplist = [cmap(i) for i in range(cmap.N)]
# force the first color entry to be black
cmaplist[0] = (0, 0, 0, 1.0)
# create the new map
cmap = mpl.colors.LinearSegmentedColormap.from_list('Custom cmap', cmaplist, cmap.N)
# define the bins and normalize
bounds = np.linspace(0, ScanLimit, ScanLimit+1)
norm = mpl.colors.BoundaryNorm(bounds, cmap.N)
# create a second axes for the colorbar
ax2 = fig.add_axes([0.95, 0.1, 0.03, 0.8])
cb = mpl.colorbar.ColorbarBase(ax2, cmap=cmap, norm=norm,
spacing='proportional', ticks=bounds, boundaries=bounds, format='%1i')
#
ax.set_xlim3d(XMin-5, XMax+5)
ax.set_ylim3d(YMin-5, YMax+5)
ax.set_zlim3d(ZMin-5, ZMax+5)
#
ax.set_xlabel('X X')
ax.set_ylabel('Y Y')
ax.set_zlabel('Z Z')
ax.set_aspect(aspect=1)
# extract the scatterplot drawing in a separate function so we ca re-use the code
def draw_scatterplot():
if ScanLimit !=0:
ax.scatter3D(Coord[:,0], Coord[:,1], Coord[:,2], s=20, c=values[:,0], cmap=cmap, norm=norm)
else:
ax.scatter3D(Coord[:,0], Coord[:,1], Coord[:,2], s=20, c='green')
# draw the initial scatterplot
draw_scatterplot()
# create the "on" button, and place it somewhere on the screen
ax_on = plt.axes([0.0, 0.0, 0.1, 0.05])
button_on = Button(ax_on, 'on')
#
ax_off = plt.axes([0.12, 0.0, 0.1, 0.05])
button_off = Button(ax_off, 'off')
#
#ax_off = plt.axes([0.24, 0.0, 0.1, 0.05])
#button_off = Button(ax_off, 'off')
# link the event handler function to the click event on the button
button_on.on_clicked(show_on)
button_off.on_clicked(show_off)
#fig.colorbar(img)
plt.show()
Traceback (most recent call last):
File "C:\Program Files\Anaconda3\lib\site-packages\matplotlib\cbook\__init__.py", line 388, in process
proxy(*args, **kwargs)
File "C:\Program Files\Anaconda3\lib\site-packages\matplotlib\cbook\__init__.py", line 228, in __call__
return mtd(*args, **kwargs)
File "C:/Users/../Desktop/heat.py", line 137, in onClick
closestIndex,LowestDistance = calcClosestDatapoint(Coord, event)
File "C:/Users/../Desktop/heat.py", line 50, in calcClosestDatapoint
distances = [distance(X[i, 0:3], event) for i in range(Sol)]
File "C:/Users/../Desktop/heat.py", line 50, in <listcomp>
distances = [distance(X[i, 0:3], event) for i in range(Sol)]
File "C:/Users/../Desktop/heat.py", line 35, in distance
x2, y2, _ = proj_transform(point[0], point[1], point[2], plt.gca().get_proj())
AttributeError: 'Axes' object has no attribute 'get_proj'

Add colorbar to python 3D/2D quiver plot

I am very confused about how to add color bar to my 3D/2D vector fields.
My source code is below:
import copy
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib import cm
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
import fieldmap_utils
data = {}
dot = {
"ampX":0,
"ampY":0,
"ampZ":0,
"phaseX":0,
"phaseY":0,
"phaseZ":0,
}
location = {
"X":[],
"Y":[],
"Z":[]
}
def buildFrame(origin, orientation, length):
body = np.dot(orientation, length).tolist()
temp = origin.copy()
for x in body:
temp.append(x)
return temp
dic = fieldmap_utils.load_single_ampphase("EM averaged.csv")
lst = list(dic.index.get_level_values('freq (hz)').unique())
lst.sort()
print("List of available frequencies: ")
for each_freq in lst:
print(each_freq,end = " ")
print()
freq = input("Enter selected frequency: ")
dic = dic[dic.index.get_level_values('freq (hz)') == int(freq)]
for index, row in dic.iterrows():
# Looks like this:
# data[(x,y,z)] = {"ampX":0,,"ampY":0,"ampZ":0,"phaseX":0,"phaseY":0,"phaseZ":0}
data[(index[4][0], index[4][1], index[4][2])] = copy.deepcopy(dot)
for index, row in dic.iterrows():
if (index[7] == "TARGET-X"):
data[(index[4][0], index[4][1], index[4][2])]["ampX"] = row['ampl']
data[(index[4][0], index[4][1], index[4][2])]["phaseX"] = row['phase_deg']
if (index[7] == "TARGET-Y"):
data[(index[4][0], index[4][1], index[4][2])]["ampY"] = row['ampl']
data[(index[4][0], index[4][1], index[4][2])]["phaseY"] = row['phase_deg']
if (index[7] == "TARGET-Z"):
data[(index[4][0], index[4][1], index[4][2])]["ampZ"] = row['ampl']
data[(index[4][0], index[4][1], index[4][2])]["phaseZ"] = row['phase_deg']
colorMap = cm.get_cmap('Greys')
fig = plt.figure(figsize=(12,12))
ax3D = fig.add_subplot(2, 2, 1, projection = '3d')
axX = fig.add_subplot(2, 2, 2)
axY = fig.add_subplot(2, 2, 3)
axZ = fig.add_subplot(2, 2, 4)
positions = list(data.keys())
xloc = set()
yloc = set()
zloc = set()
for each_position in positions:
xloc.add(each_position[0])
yloc.add(each_position[1])
zloc.add(each_position[2])
xlst = list(xloc)
ylst = list(yloc)
zlst = list(zloc)
xlst.sort()
ylst.sort()
zlst.sort()
print("Unique X coordinates:")
for each_x in xlst:
print(each_x,end = " ")
print()
sliceX = int(input("Enter X slice position: "))
print("Unique Y coordinates:")
for each_y in xlst:
print(each_y,end = " ")
print()
sliceY = int(input("Enter Y slice position: "))
print("Unique Z coordinates:")
for each_z in zlst:
print(each_z,end = " ")
print()
sliceZ = int(input("Enter Z slice position: "))
scale = 500
for position in positions:
x, y, z = position
u, v, w = data[position]["ampX"]*scale, data[position]["ampY"]*scale, data[position]["ampZ"]*scale
# orientation = [[data[position]["ampX"], 0, 0],
# [0, data[position]["ampY"], 0],
# [0, 0, data[position]["ampZ"]]]
# x, y, z, u, v, w = buildFrame([position[0], position[1], position[2]], orientation, 5000)
ax3D.quiver(x, y, z, u, v, w)
if x == sliceX:
axX.quiver(y, z, v, w)
if y == sliceY:
axY.quiver(x, z, u, w)
if z == sliceZ:
axZ.quiver(x, y, u, v)
#, cmap = colorMap
ax3D.view_init(azim=50, elev=25)
ax3D.set_xlabel('X')
ax3D.set_ylabel('Y')
ax3D.set_zlabel('Z')
ax3D.set_xlim([-275, 300])
ax3D.set_ylim([-275, 450])
ax3D.set_zlim([0, 500])
axX.set_title('Looking from X-axis')
axX.set_xlabel('Y')
axX.set_ylabel('Z')
axX.set_xlim([-275, 425])
axX.set_ylim([0, 500])
axY.set_title('Looking from Y-axis')
axY.set_xlabel('X')
axY.set_ylabel('Z')
axY.set_xlim([-275, 300])
axY.set_ylim([0, 500])
axZ.set_title('Looking from Z-axis')
axZ.set_xlabel('X')
axZ.set_ylabel('Y')
axZ.set_xlim([-275, 300])
axZ.set_ylim([-275, 450])
# plt.savefig('demo', dpi = 1200)
plt.show()
The code isn't optimized or perfect, and it is only used as a demo.
However, I am really confused about how should I add 1 color bar to the 3 2D quiver plots.
I have read through some documentations of matlibplot, but I still don't have a decent idea of how to add color bar. I tried to use the same method that I did in scatter plot, but it didn't work out either.
Thank you guys for helping!

Matplotlib using image for points on plot

I have the following matplotlib script. I want to replace the points on the plot with images. Let's say 'red.png' for the red points and 'blue.png' for the blue points. How can I adjust the following to plot these images instead of the default points?
from scipy import linalg
import numpy as np
import pylab as pl
import matplotlib as mpl
import matplotlib.image as image
from sklearn.qda import QDA
###############################################################################
# load sample dataset
from sklearn.datasets import load_iris
iris = load_iris()
X = iris.data[:, 0:2] # Take only 2 dimensions
y = iris.target
X = X[y > 0]
y = y[y > 0]
y -= 1
target_names = iris.target_names[1:]
###############################################################################
# QDA
qda = QDA()
y_pred = qda.fit(X, y, store_covariances=True).predict(X)
###############################################################################
# Plot results
# constants
dpi = 72; imageSize = (32,32)
# read in our png file
im_red = image.imread('red.png')
im_blue = image.imread('blue.png')
def plot_ellipse(splot, mean, cov, color):
v, w = linalg.eigh(cov)
u = w[0] / linalg.norm(w[0])
angle = np.arctan(u[1] / u[0])
angle = 180 * angle / np.pi # convert to degrees
# filled gaussian at 2 standard deviation
ell = mpl.patches.Ellipse(mean, 2 * v[0] ** 0.5, 2 * v[1] ** 0.5,
180 + angle, color=color)
ell.set_clip_box(splot.bbox)
ell.set_alpha(0.5)
splot.add_artist(ell)
xx, yy = np.meshgrid(np.linspace(4, 8.5, 200), np.linspace(1.5, 4.5, 200))
X_grid = np.c_[xx.ravel(), yy.ravel()]
zz_qda = qda.predict_proba(X_grid)[:, 1].reshape(xx.shape)
pl.figure()
splot = pl.subplot(1, 1, 1)
pl.contourf(xx, yy, zz_qda > 0.5, alpha=0.5)
pl.scatter(X[y == 0, 0], X[y == 0, 1], c='b', label=target_names[0])
pl.scatter(X[y == 1, 0], X[y == 1, 1], c='r', label=target_names[1])
pl.contour(xx, yy, zz_qda, [0.5], linewidths=2., colors='k')
print(xx)
pl.axis('tight')
pl.show()
You can plot images instead of markers in a figure using BboxImage as in this tutorial.
from matplotlib import pyplot as plt
from matplotlib.image import BboxImage
from matplotlib.transforms import Bbox, TransformedBbox
# Load images.
redMarker = plt.imread('red.jpg')
blueMarker = plt.imread('blue.jpg')
# Data
blueX = [1, 2, 3, 4]
blueY = [1, 3, 5, 2]
redX = [1, 2, 3, 4]
redY = [3, 2, 3, 4]
# Create figure
fig = plt.figure()
ax = fig.add_subplot(111)
# Plots an image at each x and y location.
def plotImage(xData, yData, im):
for x, y in zip(xData, yData):
bb = Bbox.from_bounds(x,y,1,1)
bb2 = TransformedBbox(bb,ax.transData)
bbox_image = BboxImage(bb2,
norm = None,
origin=None,
clip_on=False)
bbox_image.set_data(im)
ax.add_artist(bbox_image)
plotImage(blueX, blueY, blueMarker)
plotImage(redX, redY, redMarker)
# Set the x and y limits
ax.set_ylim(0,6)
ax.set_xlim(0,6)
plt.show()

Quickly and Efficiently update Matplotlib Axes (Plot)

I've run into some performance issues with a GUI that I'm working on. Specificially I'm using wxPython as the backend for several matplotlib figures which I've embedded in the canvas. I've gotten the basic functionality working with a simple self.axes.clear() and self.axes.plot() command, but I can't seem to get any sort of reasonable frame rates using this method. After performing a search it seems that if I were using a plot object that I I could reset the xdata and ydata then redraw the figure to obtain a faster refresh rate. Unfortunately, the layout that I'm using precludes the use of the plot object, so I've implemented the code using the axes() object. As far as I can tell the axes() object does not have any equivalent methods for setting the xdata and ydata (see this post: How to update a plot in matplotlib?). Here's the code I'm using:
import sys,os,csv
import numpy as np
import wx
import matplotlib
import pylab
from matplotlib.figure import Figure
from matplotlib.backends.backend_wxagg import \
FigureCanvasWxAgg as FigCanvas, \
NavigationToolbar2WxAgg as NavigationToolbar
class UltrasoundDemoGUI(wx.Frame):
title = ' Ultrasound Demo '
def __init__(self):
wx.Frame.__init__(self, None, -1, self.title)
self.create_menu()
self.statusbar = self.CreateStatusBar()
self.create_main_panel()
self.dataMon = pylab.randn(100,1);
self.dataBi = pylab.randn(100,1);
self.dataLoc = pylab.randn(100,1);
self.redraw_timer = wx.Timer(self)
self.Bind(wx.EVT_TIMER, self.on_redraw_timer, self.redraw_timer)
self.redraw_timer.Start(300)
def create_menu(self):
self.menubar = wx.MenuBar()
menu_file = wx.Menu()
m_expt = menu_file.Append(-1, "&Save plot\tCtrl-S", "Save plot to file")
self.Bind(wx.EVT_MENU, self.on_save_plot, m_expt)
menu_file.AppendSeparator()
m_exit = menu_file.Append(-1, "E&xit\tCtrl-X", "Exit")
self.Bind(wx.EVT_MENU, self.on_exit, m_exit)
self.menubar.Append(menu_file, "&File")
self.SetMenuBar(self.menubar)
def create_main_panel(self):
self.panel = wx.Panel(self)
# self.init_plot()
self.dpi = 100
self.figSignals = Figure((12, 5.0), dpi=self.dpi)
self.canvas = FigCanvas(self.panel, -1, self.figSignals)
rectSubplotMono = .1, .55, .4, .4
self.axesMono = self.figSignals.add_axes(rectSubplotMono)
rectSubplotBi = .1, .05, .4, .4
self.axesBi = self.figSignals.add_axes(rectSubplotBi)
rectSubplotLoc = .55, .05, .4, .9
self.axesLoc = self.figSignals.add_axes(rectSubplotLoc)
self.axesMono.set_axis_bgcolor('white')
self.axesBi.set_axis_bgcolor('white')
self.axesLoc.set_axis_bgcolor('white')
self.axesMono.set_title('Raw Ultrasound Signal', size=12)
pylab.setp(self.axesMono.get_xticklabels(), fontsize=8)
pylab.setp(self.axesMono.get_yticklabels(), fontsize=8)
pylab.setp(self.axesBi.get_xticklabels(), fontsize=8)
pylab.setp(self.axesBi.get_yticklabels(), fontsize=8)
pylab.setp(self.axesLoc.get_xticklabels(), fontsize=8)
pylab.setp(self.axesLoc.get_yticklabels(), fontsize=8)
# plot the data as a line series, and save the reference
# to the plotted line series
#
self.dataMono = pylab.randn(100,1)
self.dataBi = pylab.randn(100,1)
self.dataLoc = pylab.randn(100,1)
self.plot_dataMono = self.axesMono.plot(
self.dataMono,
linewidth=1,
color=(1, 1, 0),
)[0]
self.plot_dataBi = self.axesBi.plot(
self.dataBi,
linewidth=1,
color=(1, 1, 0),
)[0]
self.plot_dataLoc = self.axesLoc.plot(
self.dataLoc,
linewidth=1,
color=(1, 1, 0),
)[0]
self.toolbar = NavigationToolbar(self.canvas)
self.vbox = wx.BoxSizer(wx.VERTICAL)
self.vbox.Add(self.canvas, 1, wx.LEFT | wx.TOP | wx.GROW)
self.vbox.AddSpacer(10)
self.hbox = wx.BoxSizer(wx.HORIZONTAL)
flags = wx.ALIGN_LEFT | wx.ALL | wx.ALIGN_CENTER_VERTICAL
self.panel.SetSizer(self.vbox)
self.vbox.Fit(self)
def init_plot(self):
rectSubplotMono = .1, .55, .4, .4
self.axesMono = self.figSignals.add_axes(rectSubplotMono)
rectSubplotBi = .1, .05, .4, .4
self.axesBi = self.figSignals.add_axes(rectSubplotBi)
rectSubplotLoc = .55, .05, .4, .9
self.axesLoc = self.figSignals.add_axes(rectSubplotLoc)
self.axesMono.set_axis_bgcolor('white')
self.axesBi.set_axis_bgcolor('white')
self.axesLoc.set_axis_bgcolor('white')
self.axesMono.set_title('Raw Ultrasound Signal', size=12)
pylab.setp(self.axesMono.get_xticklabels(), fontsize=8)
pylab.setp(self.axesMono.get_yticklabels(), fontsize=8)
pylab.setp(self.axesBi.get_xticklabels(), fontsize=8)
pylab.setp(self.axesBi.get_yticklabels(), fontsize=8)
pylab.setp(self.axesLoc.get_xticklabels(), fontsize=8)
pylab.setp(self.axesLoc.get_yticklabels(), fontsize=8)
def on_redraw_timer(self, event):
self.draw_plot()
def draw_plot(self):
self.axesMono.clear()
self.axesBi.clear()
self.axesLoc.clear()
i = np.arange(1,100)
w = i;
x = pylab.randn(100,1);
y = pylab.randn(100, 1);
z = pylab.randn(100, 1);
# self.axesMono.set_xdata(np.arange(len(x)))
# self.axesMono.set_ydata(np.array(x))
self.axesMono.plot(x, 'red')
self.axesBi.plot(x,'yellow')
self.axesLoc.plot(x, z, 'black')
self.canvas.draw()
def on_save_plot(self, event):
file_choices = "PNG (*.png)|*.png"
dlg = wx.FileDialog(
self,
message="Save plot as...",
defaultDir=os.getcwd(),
defaultFile="plot.png",
wildcard=file_choices,
style=wx.SAVE)
if dlg.ShowModal() == wx.ID_OK:
path = dlg.GetPath()
self.canvas.print_figure(path, dpi=self.dpi)
self.flash_status_message("Saved to %s" % path)
def on_exit(self, event):
self.Destroy()
if __name__ == '__main__':
app = wx.PySimpleApp()
app.frame = UltrasoundDemoGUI()
app.frame.Show()
app.frame.draw_plot()
app.MainLoop()
del app
I appear to be limited to a refresh rate of approximate 3Hz. Ideally I'd like to visualize the data at a frame rate of 10Hz or higher. Does anyone have any idea how I can efficiently (quickly) update the plots using the axes object?
Thanks for your help,
-B