I am using a package called shap which has a integrated plot function. However i want to adjust some things like the labels, legend, coloring, size etc.
apparently due to the developer thats possible via using plt.gcf().
I call the plot like this, this will give a figure object but i am not sure how to use it:
fig = shap.summary_plot(shap_values_DT, data_train,color=plt.get_cmap("tab10"), show=False)
ax = plt.subplot()
Finally i got everything adjusted as i wanted it by doing the following:
shap.summary_plot(shap_values_DT, data_train, color=plt.get_cmap("tab10"), show=False)
fig = plt.gcf()
ax = plt.gca()
ax.set_xlabel(r'durchschnittliche SHAP Werte $\vert\sigma_{ij}\vert$', fontsize=16)
ax.set_ylabel('Inputparameter', fontsize=16)
ylabels = string_latexer([tick.get_text() for tick in ax.get_yticklabels()])
leg = ax.legend()
for l in leg.get_texts(): l.set_text(l.get_text().replace('Class', 'Klasse'))

I have not used shap yet, but maybe you can modify in the following way:
shap.summary_plot(shap_values_DT, data_train,color=plt.get_cmap("tab10"), show=False)
plt.title('my custom title')
From the official documentation, I read
import xgboost
import shap
# load JS visualization code to notebook
# train XGBoost model
X,y =
model = xgboost.train({"learning_rate": 0.01}, xgboost.DMatrix(X, label=y), 100)
# explain the model's predictions using SHAP values
# (same syntax works for LightGBM, CatBoost, and scikit-learn models)
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X)
# visualize the first prediction's explanation (use matplotlib=True to avoid Javascript)
shap.force_plot(explainer.expected_value, shap_values[0,:], X.iloc[0,:])
I quickly tried the example and it seems to work, if you add the matplotlib=True option. Nevertheless, not all functions seem to support it...


Matplotlib sliders on multiple figures

I am writing a Python tool that needs several figures open at the same time, each one with its own widgets (sliders, for the most part). I don't need any interactions across the figures here. Each figure is independent of the other ones, with its own plot and its own sliders affecting only itself.
I can get Matplotlib sliders working fine on a single figure, but I can't get them to work on multiple figures concurrently. Only the sliders of the LAST figure to open are working. The other ones are unresponsive.
I recreated my problem with the simple code below, starting from the example in the Matplotlib.Slider doc. If I run it as-is, only the sliders for the second figure (amplitude) works. The other doesn't. If I invert the two function calls at the bottom, it's the other way around.
I've had no luck googling solutions or pointers. Any help would be much appreciated.
I'm on Python 3.9.12, btw. I can upload a requirements file if someone tries and cannot reproduce the issue. Thank you!
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider
# The parametrized function to be plotted
def f(time, amplitude, frequency):
return amplitude * np.sin(2 * np.pi * frequency * time)
# Define initial parameters
init_amplitude = 5
init_frequency = 3
t = np.linspace(0, 1, 1000)
def create_first_fig():
# Create the figure and the line that we will manipulate
fig1, ax1 = plt.subplots()
line1, = ax1.plot(t, f(t, init_amplitude, init_frequency), lw=2, color='b')
ax1.title.set_text('First plot - interactive frequency')
ax1.set_xlabel('Time [s]')
# adjust the main plot to make room for the sliders
fig1.subplots_adjust(left=0.25, bottom=0.25)
# Make a horizontal slider to control the frequency.
axfreq = fig1.add_axes([0.25, 0.1, 0.65, 0.03])
freq_slider = Slider(
label='Frequency [Hz]',
# register the update function with each slider
freq_slider.on_changed(lambda val: update_first_fig(val, fig1, line1))
return fig1
# The function to be called anytime a slider's value changes
def update_first_fig(val, fig, line):
line.set_ydata(f(t, init_amplitude, val))
def create_second_fig():
# Create the figure and the line that we will manipulate
fig2, ax2 = plt.subplots()
line2, = ax2.plot(t, f(t, init_amplitude, init_frequency), lw=2, color='r')
ax2.title.set_text('Second plot - interactive amplitude')
ax2.set_xlabel('Time [s]')
# adjust the main plot to make room for the sliders
fig2.subplots_adjust(left=0.25, bottom=0.25)
# Make a vertically oriented slider to control the amplitude
axamp = fig2.add_axes([0.1, 0.25, 0.0225, 0.63])
amp_slider = Slider(
# register the update function with each slider
amp_slider.on_changed(lambda val: update_second_fig(val, fig2, line2))
return fig2
# The function to be called anytime a slider's value changes
def update_second_fig(val, fig, line):
line.set_ydata(f(t, val, init_frequency))
figure1 = create_first_fig()
figure2 = create_second_fig()
I would expect the slider in both figures to work the way it does when I only open the corresponding figure. So far it's only the slider in the figure that's created last that works.
Edit in case someone else looks at this: see Yulia V's answer below. It works perfectly, including in my initial application. The site doesn't let me upvote it because I am too new on here, but it's a perfect solution to my problem. Thanks Yulia V!
You need to save the references to sliders as variables to make it work. No idea why, but this is how matplotlib works.
Specifically, in your functions, you need to have
return freq_slider, fig1
return amp_slider, fig2
instead of
return fig1
return fig2
and in the main script,
freq_slider, figure1 = create_first_fig()
amp_slider, figure2 = create_second_fig()
instead of
figure1 = create_first_fig()
figure2 = create_second_fig()
Just to illustrate my comment below #Yulia V's answer, it works too if we store the sliders as an attribute of the figure instead of returning them:
def create_first_fig():
fig1._slider = freq_slider
return fig1
def create_first_fig():
fig2._slider = amp_slider
return fig2
figure1 = create_first_fig()
figure2 = create_second_fig()

How to make an animation using PyPlot.jl with multiple axes?

I would like to create an animation of two axes. In the simple example illustrated below, I would like to plot two matrices using imshow as a function of time.
Coming from python, I would create an animation using matplotlib.animation similar to this:
using PyCall
#pyimport matplotlib.animation as anim
using PyPlot
import IJulia
A = randn(20,20,20,2)
fig, axes = PyPlot.subplots(nrows=1, ncols=2, figsize=(7, 2.5))
ax1, ax2 = axes
function make_frame(i)
ax1.imshow(A[:,:,i+1, 1])
ax2.imshow(A[:,:,i+1, 2])
withfig(fig) do
myanim = anim.FuncAnimation(fig, make_frame, frames=size(A,3), interval=20, blit=false)
myanim[:save]("test.mp4", bitrate=-1, extra_args=["-vcodec", "libx264", "-pix_fmt", "yuv420p"])
This however just creates a blank animation.
Do I need to use an init_func in the FuncAnimation? Do I need to enable blitting? Or can I update the artist using a set_data attribute?
Don't use IJulia for this routine. If you plan to include the routine in a notebook, just run the code and then view the file you create, without using withfig. withfig is causing your animation creation to abort for some reason, probably because it expects something within an IJulia environment to be set differently.
This works:
using PyCall
#pyimport matplotlib.animation as anim
using PyPlot
A = randn(20,20,20,2)
fig, axes = PyPlot.subplots(nrows=1, ncols=2, figsize=(7, 2.5))
ax1, ax2 = axes
function make_frame(i)
ax1.imshow(A[:,:,i+1, 1])
ax2.imshow(A[:,:,i+1, 2])
myanim = anim.FuncAnimation(fig, make_frame, frames=size(A,3), interval=20, blit=false)
myanim[:save]("test.mp4", bitrate=-1, extra_args=["-vcodec", "libx264", "-pix_fmt", "yuv420p"])
# now you can call your video viewer on "test.mp4"

Is there a way to make a plot clickable so it will tell me what EEG channel I am looking at?

Note: This is a question relating to mouse EEG data plotting.
I made a plot showing the averaged trial signals for filtered EEG electrode channels. While plotting this I saw a few signals that I want to exclude from my plot, but I don't have a way to tell what channels were plotted. Is there a way to add something that would allow me to click on or hover over one of the plotted lines/channels and have my jupyter notebook tell me what channel I clicked/am hovering over?
This is the plot I am hoping to make clickable:
Here is the code I used to make the plots if that's helpful:
pick_stim = 'opto'
pick_param = '500ms'
pick_sweep = 0
prex = .1 # .2ms before stim to plot
postx = .1 # .6ms after stim to plot
auc_window = [-.04, .1]
fig, axs = plt.subplots(1,2, figsize=(9,5), sharex=True, sharey=True, constrained_layout=True)
run_timex = trial_running[pick_stim][pick_param][pick_sweep][0]
run_trials = trial_running[pick_stim][pick_param][pick_sweep][1] #running speed
ztimex = zscore_traces[pick_stim][pick_param][pick_sweep][0] #need for AUC
zscore_trials_all = zscore_not_mean_traces[pick_stim][pick_param][pick_sweep][1]
# Run trials #
mean_run_zscore = np.mean(zscore_trials_all[:,:,run_trial], axis=2)
run_zscore_inds = np.nonzero((ztimex >= auc_window[0]) & (ztimex <= auc_window[1]))[0]
run_zscore_trace = mean_run_zscore[run_zscore_inds,:]
axs[0].plot(ztimex[run_zscore_inds],run_zscore_trace, color='black', linewidth=0.6, alpha=0.8)
#axs[0].plot(run_timex, run_trials, color='k', linewidth=0.6)
axs[0].axvspan(-.001, .001, color='r', alpha=0.5)
#axs[0].set_xlim([-prex, postx])
axs[0].set_title('Run trials')
# No Run #
mean_no_run_zscore = np.mean(zscore_trials_all[:,:,no_run], axis=2)
no_run_zscore_inds = np.nonzero((ztimex >= auc_window[0]) & (ztimex <= auc_window[1]))[0]
no_run_zscore_trace = mean_no_run_zscore[no_run_zscore_inds,:]
axs[1].plot(ztimex[no_run_zscore_inds],no_run_zscore_trace, color='black', linewidth=0.6, alpha=0.8)
axs[1].axvspan(-.001, .001, color='r', alpha=0.5)
axs[1].set_title('No Run trials')
You can add a label to each of the curves and then use mplcursors to show an annotation while hovering (or when clicking with hover=False).
Note that to have an interactive plot in a Jupyter notebook, %matplotlib notebook (this might depend on how Jupyter is installed) is needed instead of just %matplotlib inline (which generates static images). See Docs.
Here is an example showing the general idea with some test data:
%matplotlib notebook
import matplotlib.pyplot as plt
import numpy as np
import mplcursors
x = np.arange(100)
y = np.random.randn(100, 20).cumsum(axis=0)
fig, ax = plt.subplots()
curves = plt.plot(x, y, color='black', alpha=0.3)
for ind, curv in enumerate(curves):
curv.set_label(f'curve nÂș{ind}')
cursor = mplcursors.cursor(curves, hover=True)
cursor.connect('add', lambda sel: sel.annotation.set_text(sel.artist.get_label()))

How to add a legend for a GeoAxes that adds a Cartopy shapely feature?

I copied the code for adding legend via proxy artists from matplotlib's documentation but it doesn't work. I also tried the rest in matplotlib's legends guide but nothing works. I guess it's because the element is a shapely feature which ax.legend() somehow doesn't recognize.
bounds = [116.9283371, 126.90534668, 4.58693981, 21.07014084]
stamen_terrain = cimgt.Stamen('terrain-background')
fault_line = ShapelyFeature(Reader('faultLines.shp').geometries(), ccrs.epsg(32651),
linewidth=1, edgecolor='black', facecolor='none') # geometry is multilinestring
fig = plt.figure(figsize=(15,10))
ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree())
ax.add_image(stamen_terrain, 8)
a = ax.add_feature(fault_line, zorder=1, label='test')
ax.legend([a], loc='lower left', fancybox=True) #plt.legend() has the same result
When copying the matplotlib example, you omitted the actual "proxy" artist line!
red_patch = mpatches.Patch(color='red', label='The red data')
That red_patch is the proxy artist. You have to create a dummy artist to pass to legend(). Your code as written is still passing the unrecognized Shapely feature.
It's tedious, but the relevant code would be something like:
fault_line = ShapelyFeature(Reader('faultLines.shp').geometries(), ccrs.epsg(32651), linewidth=1, edgecolor='black', facecolor='none')
ax.add_feature(fault_line, zorder=1)
# Now make a dummy object that looks as similar as possible
import matplotlib.patches as mpatches
proxy_artist = mpatches.Rectangle((0, 0), 1, 0.1, linewidth=1, edgecolor='black', facecolor='none')
# And manually add the labels here
ax.legend([proxy_artist], ['test'], loc='lower left', fancybox=True)
Here I just used a Rectangle, but depending on the feature, you can use various supported matplotlib "artists".

matplotlib update figure in loop [duplicate]

I'm having issues with redrawing the figure here. I allow the user to specify the units in the time scale (x-axis) and then I recalculate and call this function plots(). I want the plot to simply update, not append another plot to the figure.
def plots():
global vlgaBuffSorted
result = collections.defaultdict(list)
for d in vlgaBuffSorted:
result_list = result.values()
f = Figure()
graph1 = f.add_subplot(211)
graph2 = f.add_subplot(212,sharex=graph1)
for item in result_list:
tL = []
vgsL = []
vdsL = []
isubL = []
for dict in item:
plotCanvas = FigureCanvasTkAgg(f, pltFrame)
toolbar = NavigationToolbar2TkAgg(plotCanvas, pltFrame)
You essentially have two options:
Do exactly what you're currently doing, but call graph1.clear() and graph2.clear() before replotting the data. This is the slowest, but most simplest and most robust option.
Instead of replotting, you can just update the data of the plot objects. You'll need to make some changes in your code, but this should be much, much faster than replotting things every time. However, the shape of the data that you're plotting can't change, and if the range of your data is changing, you'll need to manually reset the x and y axis limits.
To give an example of the second option:
import matplotlib.pyplot as plt
import numpy as np
x = np.linspace(0, 6*np.pi, 100)
y = np.sin(x)
# You probably won't need this if you're embedding things in a tkinter plot...
fig = plt.figure()
ax = fig.add_subplot(111)
line1, = ax.plot(x, y, 'r-') # Returns a tuple of line objects, thus the comma
for phase in np.linspace(0, 10*np.pi, 500):
line1.set_ydata(np.sin(x + phase))
You can also do like the following:
This will draw a 10x1 random matrix data on the plot for 50 cycles of the for loop.
import matplotlib.pyplot as plt
import numpy as np
for i in range(50):
y = np.random.random([10,1])
This worked for me. Repeatedly calls a function updating the graph every time.
import matplotlib.pyplot as plt
import matplotlib.animation as anim
def plot_cont(fun, xmax):
y = []
fig = plt.figure()
ax = fig.add_subplot(1,1,1)
def update(i):
yi = fun()
x = range(len(y))
ax.plot(x, y)
print i, ': ', yi
a = anim.FuncAnimation(fig, update, frames=xmax, repeat=False)
"fun" is a function that returns an integer.
FuncAnimation will repeatedly call "update", it will do that "xmax" times.
This worked for me:
from matplotlib import pyplot as plt
from IPython.display import clear_output
import numpy as np
for i in range(50):
y = np.random.random([10,1])
I have released a package called python-drawnow that provides functionality to let a figure update, typically called within a for loop, similar to Matlab's drawnow.
An example usage:
from pylab import figure, plot, ion, linspace, arange, sin, pi
def draw_fig():
# can be arbitrarily complex; just to draw a figure
#figure() # don't call!
plot(t, x)
#show() # don't call!
N = 1e3
figure() # call here instead!
ion() # enable interactivity
t = linspace(0, 2*pi, num=N)
for i in arange(100):
x = sin(2 * pi * i**2 * t / 100.0)
This package works with any matplotlib figure and provides options to wait after each figure update or drop into the debugger.
In case anyone comes across this article looking for what I was looking for, I found examples at
How to visualize scalar 2D data with Matplotlib?
then modified them to use imshow with an input stack of frames, instead of generating and using contours on the fly.
Starting with a 3D array of images of shape (nBins, nBins, nBins), called frames.
def animate_frames(frames):
nBins = frames.shape[0]
frame = frames[0]
tempCS1 = plt.imshow(frame,
for k in range(nBins):
frame = frames[k]
tempCS1 = plt.imshow(frame,
del tempCS1
#time.sleep(1e-2) #unnecessary, but useful
fig = plt.figure()
ax = fig.add_subplot(111)
win = fig.canvas.manager.window
fig.canvas.manager.window.after(100, animate_frames, frames)
I also found a much simpler way to go about this whole process, albeit less robust:
fig = plt.figure()
for k in range(nBins):
time.sleep(1e-6) #unnecessary, but useful
Note that both of these only seem to work with ipython --pylab=tk, a.k.a.backend = TkAgg
Thank you for the help with everything.
All of the above might be true, however for me "online-updating" of figures only works with some backends, specifically wx. You just might try to change to this, e.g. by starting ipython/pylab by ipython --pylab=wx! Good luck!
Based on the other answers, I wrapped the figure's update in a python decorator to separate the plot's update mechanism from the actual plot. This way, it is much easier to update any plot.
def plotlive(func):
def new_func(*args, **kwargs):
# Clear all axes in the current figure.
axes = plt.gcf().get_axes()
for axis in axes:
# Call func to plot something
result = func(*args, **kwargs)
# Draw the plot
return result
return new_func
Usage example
And then you can use it like any other decorator.
def plot_something_live(ax, x, y):
ax.plot(x, y)
ax.set_ylim([0, 100])
The only constraint is that you have to create the figure before the loop:
fig, ax = plt.subplots()
for i in range(100):
x = np.arange(100)
y = np.full([100], fill_value=i)
plot_something_live(ax, x, y)