I am trying to plot multiple subplots using SHAP but for some reasons all the plot end up in the bottom right corner when I run plt.show(). Here is a snippet of my code below to prepare the shap and train the model:
X = df.loc[:, features]
X = X.fillna(X.median())
y_test = df.loc[:, 'y_test'].values
model = lgbm.LGBMRegressor()
model.fit(X, y_test)
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X)
Here is the snippet to get the graphs:
fig, axes = plt.subplots(nrows=3, ncols=3, figsize=(16, 16))
axes_dic = {
0: axes[0, 0],
1: axes[0, 1],
2: axes[0, 2],
3: axes[1, 0],
4: axes[1, 1],
5: axes[1, 2],
6: axes[2, 0],
7: axes[2, 1],
8: axes[2, 2]
}
fig, axes = plt.subplots(nrows=3, ncols=3, figsize=(16, 16))
fig.suptitle('Partial Dependence SHAP plots', fontsize=20)
for i in range(len(axes_dic)):
shap.partial_dependence_plot(
features[i], model.predict, X, model_expected_value=True,
feature_expected_value=True, ice=False, show=False, ax=axes_dic[i]
)
plt.tight_layout()
plt.show()
While this doesn't work, this works for the Dependence plot
fig, axes = plt.subplots(nrows=3, ncols=3, figsize=(16, 16))
fig.suptitle('Dependence SHAP plots', fontsize=20)
for i in range(len(axes_dic)):
shap.dependence_plot(
features[i], shap_values, X, ax=axes_dic[i], show=False,
interaction_index="uprc"
)
plt.tight_layout()
plt.show()
Anyone can help me figure out what is going on? Why does dependence_plot works well but not the partial_dependence_plot.
print(matplotlib.__version__)
3.6.1
print(shap.__version__)
0.41.0
I tried multiple alternatives to get it to plot on the right axe but it doesn't seem to be doing what I would expect.
Related
I want to replicate seaborn pairplot function but I need increased flexibility to plot extra things.
I have the following minimal example, but I need that the uppermost labels show the range [10, 20] instead of [0, 1] without changing the displayed range of the first plot, and that the grid for the right plots shows horizontal bars.
How can I keep x and y shared and displayed for all plots but exclude the y axis of diagonal plots as in seaborn?
import matplotlib.pyplot as plt
fig, axes = plt.subplots(nrows=2, ncols=2, sharex='col',
sharey='row')
axes[0, 0]._shared_axes['y'].remove(axes[0, 0])
axes[1, 1]._shared_axes['y'].remove(axes[1, 1])
axes[0, 0].plot([0, 1])
axes[0, 1].scatter([0, 1], [10, 20])
axes[1, 0].scatter([0, 1], [100, 200])
axes[1, 1].plot([0, 1])
axes[0, 0].grid(color='lightgray', linestyle='--')
axes[0, 1].grid(color='lightgray', linestyle='--')
axes[1, 0].grid(color='lightgray', linestyle='--')
axes[1, 1].grid(color='lightgray', linestyle='--')
plt.show()
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
df = pd.DataFrame([[10, 100], [20, 200]])
sns.pairplot(df, diag_kind='kde')
plt.show()
I have the following code:
import matplotlib.pyplot as plt
fig, axs = plt.subplots(1, 4)
for n, ax in enumerate(axs):
ax.plot([1, 2], [1, 2])
ax.set_xticks([])
ax.set_yticks([])
ax.set_xlabel(n)
plt.show()
...which displays this:
What I want is to hide the black boxes but keep the labels. I've tried adding ax.set_axis_off() but that removes the labels as well:
How can I do this?
Just change the color of spines to None:
fig, axs = plt.subplots(1, 4)
for n, ax in enumerate(axs):
ax.plot([1, 2], [1, 2])
ax.set_xticks([])
ax.set_yticks([])
ax.set_xlabel(n)
plt.setp(ax.spines.values(), color=None)
plt.show()
After subplot.plot([1, 2, 3], [8, 9, 10]), subplot.axis() returns an x-axis that goes roughly from 1 to 3.
But after subplot.plot([1, 2, 3], [None, None, None]), x goes from 0 to 1.
After subplot.plot([1, 2, 3], [8, 9, None]), it goes from 1 to 2.
Is this normal? What should I do to get a correct chart in such a case? (for example, an empty chart spanning 1 to 3 if all None).
#!/usr/bin/env python3
import matplotlib
matplotlib.use("AGG")
import matplotlib.pyplot as plt
def create_plot(ydata):
fig = plt.figure()
subplot = fig.add_subplot(1, 1, 1)
subplot.scatter([1, 2, 3], ydata)
x1, x2, y1, y2 = subplot.axis()
print(f"X axis goes from {x1} to {x2} when ydata={ydata}")
create_plot([8, 9, 10])
create_plot([None, None, None])
create_plot([8, 9, None])
Yes, it's normal.
In case ax.plot([1, 2, 3], [None, None, None]) there is no data to plot, hence the axes does not autoscale and stays at the default [0,1] view limits.
In case ax.plot([1, 2, 3], [8, 9, None]), the only data to show are (1,8) and (2,9). Hence the x axis range is 1 to 2.
If you want to get the same x-axis limits with or without data, you can update the data limits
ax.update_datalim(np.c_[xdata,[0]*len(xdata)], updatey=False)
ax.autoscale()
Complete example:
import numpy as np
import matplotlib.pyplot as plt
def create_plot(ydata):
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
xdata = [1, 2, 3]
ax.scatter(xdata, ydata)
ax.update_datalim(np.c_[xdata,[0]*len(xdata)], updatey=False)
ax.autoscale()
x1, x2, y1, y2 = ax.axis()
print(f"X axis goes from {x1} to {x2} when ydata={ydata}")
create_plot([8, 9, 10])
create_plot([None, None, None])
create_plot([8, 9, None])
plt.show()
IIUC, this is normal as plt will ignore None data. To change the limit, you could set the limit of xaxis:
def create_plot(ydata):
xvals = [1,2,3]
fig = plt.figure()
subplot = fig.add_subplot(1, 1, 1)
subplot.scatter(xvals, ydata)
# set the limit
subplot.set_xlim(min(xvals), max(xvals))
x1, x2, y1, y2 = subplot.axis()
print(f"X axis goes from {x1} to {x2} when ydata={ydata}")
I want to create plots from Python of 3D coordinate transformations. For example, I want to create the following image (generated statically by TikZ):
A bit of searching led me to the following program:
import numpy as np
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.patches import FancyArrowPatch
from mpl_toolkits.mplot3d import proj3d
class Arrow3D(FancyArrowPatch):
def __init__(self, xs, ys, zs, *args, **kwargs):
FancyArrowPatch.__init__(self, (0, 0), (0, 0), *args, **kwargs)
self._verts3d = xs, ys, zs
def draw(self, renderer):
xs3d, ys3d, zs3d = self._verts3d
xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, renderer.M)
self.set_positions((xs[0], ys[0]), (xs[1], ys[1]))
FancyArrowPatch.draw(self, renderer)
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
arrow_prop_dict = dict(mutation_scale=20, arrowstyle='->', shrinkA=0, shrinkB=0)
a = Arrow3D([0, 1], [0, 0], [0, 0], **arrow_prop_dict, color='r')
ax.add_artist(a)
a = Arrow3D([0, 0], [0, 1], [0, 0], **arrow_prop_dict, color='b')
ax.add_artist(a)
a = Arrow3D([0, 0], [0, 0], [0, 1], **arrow_prop_dict, color='g')
ax.add_artist(a)
ax.text(0.0, 0.0, -0.1, r'$o$')
ax.text(1.1, 0, 0, r'$x$')
ax.text(0, 1.1, 0, r'$y$')
ax.text(0, 0, 1.1, r'$z$')
ax.view_init(azim=-90, elev=90)
ax.set_axis_off()
plt.show()
The result doesn't look like what one normally sees in books:
Furthermore, when I include the axes, the origin is not at the intersection of the three planes, which is where I expect it to be.
You may change the view of your 3D axes by changing the following line :
ax.view_init(azim=-90, elev=90)
to
ax.view_init(azim=20, elev=10)
to reproduce your TikZ plot, or change the view angles to any other values as you prefer.
Given the below code I would expect the x-axis to be between 0 and 3 with some margins added.
Instead it is much larger. I would expect the call to scatter to automatically update x-axis limits.
I could set the xlim and ylim my self but would like them to be set automatically. What am I doing wrong?
import matplotlib.pyplot as plt
if __name__ == '__main__':
fig = plt.figure()
ax = fig.add_subplot(111)
x = [0, 4000, 8000, 100000]
y = [0, 10, 100, 150]
ax.scatter(x, y)
x = [0, 1, 2, 3]
y = x
ax.clear()
ax.scatter(x, y)
plt.show()
You can clear the figure, and open a new subplot, than the axes will be adjusted as you wanted.
fig = plt.figure()
ax = fig.add_subplot(111)
x = [0, 4000, 8000, 100000]
y = [0, 10, 100, 150]
ax.scatter(x, y)
plt.clf()
ax = fig.add_subplot(111)
x = [0, 1, 2, 3]
y = x
ax.scatter(x, y)
plt.show()
Edit: In this version figure is not closed, just cleared with the clf function.
It is a feature that scatter does not automatically re-limit the graph as in many cases that would be undesirable. See Axes.autoscale and Axes.relim
ax.relim() # might not be needed
ax.autoscale()
should do what you want.
Here is a way of making another scatter plot without needing to clear the figure.
I basically update offests of PathCollection returned by axes.scatter() and add the collection back to the axes. Of course, the axes has to be cleared first. One thing I notice is that I have to manually set the margins.
import matplotlib.pyplot as plt
import numpy as np
fig = plt.figure()
ax = fig.add_subplot(111)
x = [0, 4000, 8000, 100000]
y = [0, 10, 100, 150]
scat = ax.scatter(x, y)
x = [0, 1, 2, 3]
y = x
ax.clear()
corners = (min(x), min(y)), (max(x), max(y))
ax.update_datalim(corners)
ax.margins(0.05, 0.05)
ax.autoscale_view()
scat.set_offsets(np.vstack((x,y)).transpose())
ax.add_collection(scat)
plt.show()
If you comment out the first scatter, everything will be fine