In this Senkey are two Inputs: K and S, three Outputs: H,F and Sp and the Rest: x
The Inputs shall come from the left Side, the Outputs go to the right Side.
The Rest shall go to the Top.
from matplotlib.sankey import Sankey
import matplotlib.pyplot as plt
fig = plt.figure(figsize = [10,10])
ax = fig.add_subplot(1,1,1)
ax.set(yticklabels=[],xticklabels=[])
ax.text(-10,10, "xxx")
Sankey(ax=ax, flows = [ 20400,3000,-19900,-400,-2300,800],
labels = ['K', 'S', 'H', 'F', 'Sp', 'x'],
orientations = [ 1, -1, 1, 0, -1, -1 ],
scale=1, margin=100, trunklength=1.0).finish()
plt.tight_layout()
plt.show()
I played a lot with the orientations, but nothing works or looks nice.
And, it there a way to set different colors for every arrow?
The scale of the Sankey should be such that input-flow times scale is about 1.0 and output-flow times scale is about -1.0 (see docs). Therefore, about 1/25000 is a good starting point for experimentation. The margin should be a small number, maybe around 1, or leave it out. I think the only way to have individual colors, is to chain multiple Sankeys together (with add), but that's probably not what you want. Use plt.axis("off") to suppress the axes completely.
My test code:
from matplotlib.sankey import Sankey
import matplotlib.pyplot as plt
fig = plt.figure(figsize = [10,10])
ax = fig.add_subplot(1,1,1)
Sankey(ax=ax, flows = [ 20400,3000,-19900,-400,-2300,-800],
labels = ['K', 'S', 'H', 'F', 'Sp', 'x'],
orientations = [ 1, -1, 1, 0, -1, -1 ],
scale=1/25000, trunklength=1,
edgecolor = '#099368', facecolor = '#099368'
).finish()
plt.axis("off")
plt.show()
Generated Sankey:
With different Colors
from matplotlib.sankey import Sankey
import matplotlib.pyplot as plt
from matplotlib import rcParams
plt.rc('font', family = 'serif')
plt.rcParams['font.size'] = 10
plt.rcParams['font.serif'] = "Linux Libertine"
fig = plt.figure(figsize = [6,4], dpi = 330)
ax = fig.add_subplot(1, 1, 1,)
s = Sankey(ax = ax, scale = 1/40000, unit = 'kg', gap = .4, shoulder = 0.05,)
s.add(
flows = [3000, 20700, -23700,],
orientations = [ 1, 1, 0, ],
labels = ["S Value", "K Value", None, ],
trunklength = 1, pathlengths = 0.4, edgecolor = '#000000', facecolor = 'darkgreen',
lw = 0.5,
)
s.add(
flows = [23700, -800, -2400, -20500],
orientations = [0, 1, -1, 0],
labels = [None, "U Value", "Sp Value", None],
trunklength=1.5, pathlengths=0.5, edgecolor = '#000000', facecolor = 'grey',
prior = 0, connect = (2,0), lw = 0.5,
)
s.add(
flows = [20500, -20000, -500],
orientations = [0, -1, -1],
labels = [None, "H Value", "F Value"],
trunklength =1, pathlengths = 0.5, edgecolor = '#000000', facecolor = 'darkred',
prior = 1, connect = (3,0), lw = 0.5,
)
diagrams = s.finish()
for d in diagrams:
for t in d.texts:
t.set_horizontalalignment('left')
diagrams[0].texts[0].set_position(xy = [-0.58, 0.9,]) # S
diagrams[0].texts[1].set_position(xy = [-1.5, 0.9,]) # K
diagrams[2].texts[1].set_position(xy = [ 2.35, -1.2,]) # H
diagrams[2].texts[2].set_position(xy = [ 1.75, -1.2,]) # F
diagrams[1].texts[2].set_position(xy = [ 0.7, -1.2]) # Sp
diagrams[1].texts[1].set_position(xy = [ 0.7, 0.9,]) # U
# print(diagrams[0].texts[0])
# print(diagrams[0].texts[1])
# print(diagrams[1].texts[0])
# print(diagrams[1].texts[1])
# print(diagrams[1].texts[2])
# print(diagrams[2].texts[0])
# print(diagrams[2].texts[1])
# print(diagrams[2].texts[2])
plt.axis("off")
plt.show()
Related
Background: I compared the performance of 13 models by using each of them for prediction over four data sets. Now I have 4 * 13 R-Squared values which indicate the goodness of fit. The problem is that some large negative R-Squared values exist, making the visualization not so good.
The positive R-Squared values are hard to differentiate because of those negative values like -11 or -9.7. How can I extend the positive range and squeeze the negative range by customizing the color bar? The code and data is as follows.
import numpy as np
import seaborn as sns
from matplotlib import pyplot as plt
fig, ax = plt.subplots()
data = np.array([[ 0.9848, 0. , 0.9504, -0.8198, 0.9501, 0.9071,
0.8598, 0.9348, 0. , 0.713 , 0. , 0.669 ,
0.6184, 0. ],
[ 0.9733, 0. , 0.0566, -9.654 , 0.1291, -0.0926,
-0.0661, -2.3085, 0. , -10.63 , 0. , -3.797 ,
-7.592 , 0. ],
[ 0.9676, 0. , 0.9331, 0.9177, 0.9401, 0.9352,
0.9251, 0.7987, 0. , 0.5635, 0. , 0.5924,
0.2456, 0. ],
[ 0.9759, 0. , -0.114 , 0.1566, 0.0412, 0.3588,
0.2605, -0.5471, 0. , 0.2534, 0. , 0.5216,
0.3784, 0. ]])
def comp_heatmap(ax):
with sns.axes_style('white'):
ax = sns.heatmap(
data, ax=ax, vmax=.3,
annot=True,
xticklabels=np.arange(14),
yticklabels=np.arange(4),
)
ax.set_xlabel('Model', fontdict=font_text)
ax.set_ylabel(r'$R^2$', fontproperties=font_formula, labelpad=5)
ax.figure.colorbar(ax.collections[0])
# set tick labels
xticks = ax.get_xticks()
ax.set_xticks(xticks)
ax.set_xticklabels(xticks.astype(int))
yticks = ax.get_yticks()
ax.set_yticks(yticks)
ax.set_yticklabels(['lnr, fit', 'lg, fit', 'lnr, test', 'lg, test'])
comp_heatmap(ax)
I've used a FuncNorm method to resolve it.
from matplotlib import pyplot as plt, font_manager as fm, colors
def forward(x):
x = base ** x - 1
return x
def inverse(x):
x = np.log(x + 1) / np.log(base)
return x
def comp_heatmap(ax):
plt.rc('font', family='Times New Roman', size=15)
plt.subplots_adjust(left=0.05, right=1)
norm = colors.FuncNorm((forward, inverse), vmin=-11, vmax=1)
mask = np.zeros_like(data)
mask[:, [1, 8, 10, 13]] = 1
mask = mask.astype(np.bool)
with sns.axes_style('white'):
ax = sns.heatmap(
data, ax=ax, vmax=.3,
mask=mask,
annot=True, fmt='.4',
annot_kws=font_annot,
norm=norm,
xticklabels=np.arange(14),
yticklabels=np.arange(4),
cbar=False,
cmap='rainbow'
)
cbar = ax.figure.colorbar(ax.collections[0])
cbar.set_ticks([-11, -0.5, 0, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0])
# set tick labels
xticks = ax.get_xticks()
ax.set_xticks(xticks)
ax.set_xticklabels(xticks.astype(int), **font_tick)
yticks = ax.get_yticks()
ax.set_yticks(yticks)
ax.set_yticklabels(['', '', '', ''])
return ax
font_formula = fm.FontProperties(
math_fontfamily='cm', size=22
)
font_text = {'size': 22, 'fontfamily': 'Times New Roman'}
font_annot = {'size': 17, 'fontfamily': 'Times New Roman'}
font_tick = {'size': 18, 'fontfamily': 'Times New Roman'}
fig, axes = plt.subplots()
base = 5
ax = comp_heatmap(axes)
I am creating a figure like this:
fig = plt.figure(figsize = (7, 8))
outer_grid = gridspec.GridSpec(2, 1, height_ratios = [2, 1])
inner_grid1 = gridspec.GridSpecFromSubplotSpec(4, 3, subplot_spec=outer_grid[0])
inner_grid2 = gridspec.GridSpecFromSubplotSpec(2, 3, subplot_spec=outer_grid[1])
Now I would like to have one legend for all plots in inner_grid1 and a separate legend for all plots in inner_grid2. And I would like those legends to be placed nicely, even though they are higher than a single plot, and cannot have more than one column to not make the figure too wide.
Here is an example where I tried to align the legends with trial and error with method 2 below, however this took ages to make.
So I see three options to achieve this, none of which work:
Place the legend as part of an Axes object, but manually move it outside of the actual plot using axes.legend([...], bbox_to_anchor=(x, y)). This does not work when the legend is higher as a single plot, because it rescales the plots to fit the legend into its grid cell.
Place the legend globally on the Figure object. This works, but makes the correct placement really hard. I cannot use loc = "center right", since it centers it for the full figure instead of just the inner_grid1 or inner_grid2 plots.
Place the legend locally on the GridSpecFromSubplotSpec object. This would be perfect. However there is no method to create a legend on a GridSpecFromSubplotSpec or related classes, and the pyplot.legend method misses parameters to restrict the loc to parts of a grid.
Is there a way to place a legend as described?
As requested, a small code example generating something similar as desired.
This example uses method 2:
#!/usr/bin/env python3
import pandas as pd, seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
GENOMES = ["spneumoniae", "ecoliK12", "scerevisiae", "celegans", "bmori", "hg38"]
fig = plt.figure(figsize = (7, 8))
outer_grid = gridspec.GridSpec(2, 1, height_ratios = [2, 1])
inner_grid1 = gridspec.GridSpecFromSubplotSpec(4, 3, subplot_spec=outer_grid[0])
inner_grid2 = gridspec.GridSpecFromSubplotSpec(2, 3, subplot_spec=outer_grid[1])
# plots are in sets of six, 2 rows by 3 columns each
for index, genome in enumerate(GENOMES):
data = pd.DataFrame({"x": [0, 1, 2, 3, 0, 1, 2, 3], "y": [1, 0, 3, 2, 1, 0, 3, 2], "hue": ["a", "a", "a", "a", "b", "b", "b", "b"]})
# first set of six
ax1 = plt.Subplot(fig, inner_grid1[index])
ax1 = sns.lineplot(data = data, x = "x", y = "y", hue = "hue", ax = ax1)
ax1.set_xlabel("")
ax1.set_ylabel("")
if index == 2:
ax1.legend()
handles, labels = ax1.get_legend_handles_labels()
fig.legend(handles, labels, loc = "center left", title = "", bbox_to_anchor=(0.9, 2/3 - 0.03))
ax1.legend([], [], loc = "lower center", title = f"{genome}")
fig.add_subplot(ax1)
# second set of six
ax2 = plt.Subplot(fig, inner_grid1[index + 6])
ax2 = sns.lineplot(data = data, x = "x", y = "y", hue = "hue", ax = ax2)
ax2.set_xlabel("")
ax2.set_ylabel("")
ax2.legend([], [], loc = "upper center", title = f"{genome}")
fig.add_subplot(ax2)
#third set of six
ax3 = plt.Subplot(fig, inner_grid2[index])
ax3 = sns.lineplot(data = data, x = "x", y = "y", hue = "hue", ax = ax3)
ax3.set_xlabel("")
ax3.set_ylabel("")
if index == 2:
ax3.legend(["#unitigs", "avg. unitig len."])
handles, labels = ax3.get_legend_handles_labels()
fig.legend(handles, labels, loc = "center left", title = "", bbox_to_anchor=(0.9, 1/6 + 0.05))
ax3.legend([], [], loc = "upper center", title = f"{genome}")
fig.add_subplot(ax3)
plt.savefig("stackoverflow_test.pdf", bbox_inches="tight")
I want to have 5x4 subplots, one for each group. I wrote the following code:
axeng = []
for i in range(5):
for ii in range(4):
axeng.append([i,ii])`
yy = (0.5, 4.5, 9.5, 14.5, 19.5, 24.5)
xx=np.arange(0.5,10)
f,axes = plt.subplots(5,4,figsize=(50,50), sharex=True, sharey=True)
cbar_ax = f.add_axes([.92, .3, .03, .4])
for i in range(20):
paxesrow = tuple(axeng[i])[0]
paxescol = tuple(axeng[i])[1]
# gnuplot, jet, YlGnBu, GnBu_r
g=sns.heatmap(heat[i],cmap="viridis",vmin=0.1,vmax=1,
ax=axes[paxesrow,paxescol],linewidth=.1,
cbar=True if i==3 else False,
cbar_ax=cbar_ax if i==3 else None,
square=False)
g.set_yticks(yy)
g.set_xticks(xx)
g.set_yticklabels([' ','25',' ','15',' ','5'],fontsize=33)
g.set_xticklabels([' ','2',' ','4',' ','6',' ','8',' ','10'],fontsize=33,rotation=0)
f.tight_layout(rect=[1, 1, 1, 1])
f.suptitle('Behavior of all subgroups',fontsize=70,y=.93)
cbar=axes[tuple(axeng[3])[0],tuple(axeng[3])[1]].collections[0].colorbar
cbar.ax.tick_params(labelsize=35)
plt.show()
As you can see in the image, the last subplot is scaled, but I have no idea why that would be the case.
Thanks in advance.
I'm doing a small project which requires to resolve a bug in matplotlib in order to fix zorders of some ax.patches and ax.collections. More exactly, ax.patches are symbols rotatable in space and ax.collections are sides of ax.voxels (so text must be placed on them). I know so far, that a bug is hidden in draw method of mpl_toolkits.mplot3d.Axes3D: zorder are recalculated each time I move my diagram in an undesired way. So I decided to change definition of draw method in these lines:
for i, col in enumerate(
sorted(self.collections,
key=lambda col: col.do_3d_projection(renderer),
reverse=True)):
#col.zorder = zorder_offset + i #comment this line
col.zorder = col.stable_zorder + i #add this extra line
for i, patch in enumerate(
sorted(self.patches,
key=lambda patch: patch.do_3d_projection(renderer),
reverse=True)):
#patch.zorder = zorder_offset + i #comment this line
patch.zorder = patch.stable_zorder + i #add this extra line
It's assumed that every object of ax.collection and ax.patch has a stable_attribute which is assigned manually in my project. So every time I run my project, I must be sure that mpl_toolkits.mplot3d.Axes3D.draw method is changed manually (outside my project). How to avoid this change and override this method in any way inside my project?
This is MWE of my project:
import matplotlib.pyplot as plt
import numpy as np
#from mpl_toolkits.mplot3d import Axes3D
import mpl_toolkits.mplot3d.art3d as art3d
from matplotlib.text import TextPath
from matplotlib.transforms import Affine2D
from matplotlib.patches import PathPatch
class VisualArray:
def __init__(self, arr, fig=None, ax=None):
if len(arr.shape) == 1:
arr = arr[None,None,:]
elif len(arr.shape) == 2:
arr = arr[None,:,:]
elif len(arr.shape) > 3:
raise NotImplementedError('More than 3 dimensions is not supported')
self.arr = arr
if fig is None:
self.fig = plt.figure()
else:
self.fig = fig
if ax is None:
self.ax = self.fig.gca(projection='3d')
else:
self.ax = ax
self.ax.azim, self.ax.elev = -120, 30
self.colors = None
def text3d(self, xyz, s, zdir="z", zorder=1, size=None, angle=0, usetex=False, **kwargs):
d = {'-x': np.array([[-1.0, 0.0, 0], [0.0, 1.0, 0.0], [0, 0.0, -1]]),
'-y': np.array([[0.0, 1.0, 0], [-1.0, 0.0, 0.0], [0, 0.0, 1]]),
'-z': np.array([[1.0, 0.0, 0], [0.0, -1.0, 0.0], [0, 0.0, -1]])}
x, y, z = xyz
if "y" in zdir:
x, y, z = x, z, y
elif "x" in zdir:
x, y, z = y, z, x
elif "z" in zdir:
x, y, z = x, y, z
text_path = TextPath((-0.5, -0.5), s, size=size, usetex=usetex)
aff = Affine2D()
trans = aff.rotate(angle)
# apply additional rotation of text_paths if side is dark
if '-' in zdir:
trans._mtx = np.dot(d[zdir], trans._mtx)
trans = trans.translate(x, y)
p = PathPatch(trans.transform_path(text_path), **kwargs)
self.ax.add_patch(p)
art3d.pathpatch_2d_to_3d(p, z=z, zdir=zdir)
p.stable_zorder = zorder
return p
def on_rotation(self, event):
vrot_idx = [self.ax.elev > 0, True].index(True)
v_zorders = 10000 * np.array([(1, -1), (-1, 1)])[vrot_idx]
for side, zorder in zip((self.side1, self.side4), v_zorders):
for patch in side:
patch.stable_zorder = zorder
hrot_idx = [self.ax.azim < -90, self.ax.azim < 0, self.ax.azim < 90, True].index(True)
h_zorders = 10000 * np.array([(1, 1, -1, -1), (-1, 1, 1, -1),
(-1, -1, 1, 1), (1, -1, -1, 1)])[hrot_idx]
sides = (self.side3, self.side2, self.side6, self.side5)
for side, zorder in zip(sides, h_zorders):
for patch in side:
patch.stable_zorder = zorder
def voxelize(self):
shape = self.arr.shape[::-1]
x, y, z = np.indices(shape)
arr = (x < shape[0]) & (y < shape[1]) & (z < shape[2])
self.ax.voxels(arr, facecolors=self.colors, edgecolor='k')
for col in self.ax.collections:
col.stable_zorder = col.zorder
def labelize(self):
self.fig.canvas.mpl_connect('motion_notify_event', self.on_rotation)
s = self.arr.shape
self.side1, self.side2, self.side3, self.side4, self.side5, self.side6 = [], [], [], [], [], []
# labelling surfaces of side1 and side4
surf = np.indices((s[2], s[1])).T[::-1].reshape(-1, 2) + 0.5
surf_pos1 = np.insert(surf, 2, self.arr.shape[0], axis=1)
surf_pos2 = np.insert(surf, 2, 0, axis=1)
labels1 = (self.arr[0]).flatten()
labels2 = (self.arr[-1]).flatten()
for xyz, label in zip(surf_pos1, [f'${n}$' for n in labels1]):
t = self.text3d(xyz, label, zdir="z", zorder=10000, size=1, usetex=True, ec="none", fc="k")
self.side1.append(t)
for xyz, label in zip(surf_pos2, [f'${n}$' for n in labels2]):
t = self.text3d(xyz, label, zdir="-z", zorder=-10000, size=1, usetex=True, ec="none", fc="k")
self.side4.append(t)
# labelling surfaces of side2 and side5
surf = np.indices((s[2], s[0])).T[::-1].reshape(-1, 2) + 0.5
surf_pos1 = np.insert(surf, 1, 0, axis=1)
surf = np.indices((s[0], s[2])).T[::-1].reshape(-1, 2) + 0.5
surf_pos2 = np.insert(surf, 1, self.arr.shape[1], axis=1)
labels1 = (self.arr[:, -1]).flatten()
labels2 = (self.arr[::-1, 0].T[::-1]).flatten()
for xyz, label in zip(surf_pos1, [f'${n}$' for n in labels1]):
t = self.text3d(xyz, label, zdir="y", zorder=10000, size=1, usetex=True, ec="none", fc="k")
self.side2.append(t)
for xyz, label in zip(surf_pos2, [f'${n}$' for n in labels2]):
t = self.text3d(xyz, label, zdir="-y", zorder=-10000, size=1, usetex=True, ec="none", fc="k")
self.side5.append(t)
# labelling surfaces of side3 and side6
surf = np.indices((s[1], s[0])).T[::-1].reshape(-1, 2) + 0.5
surf_pos1 = np.insert(surf, 0, self.arr.shape[2], axis=1)
surf_pos2 = np.insert(surf, 0, 0, axis=1)
labels1 = (self.arr[:, ::-1, -1]).flatten()
labels2 = (self.arr[:, ::-1, 0]).flatten()
for xyz, label in zip(surf_pos1, [f'${n}$' for n in labels1]):
t = self.text3d(xyz, label, zdir="x", zorder=-10000, size=1, usetex=True, ec="none", fc="k")
self.side6.append(t)
for xyz, label in zip(surf_pos2, [f'${n}$' for n in labels2]):
t = self.text3d(xyz, label, zdir="-x", zorder=10000, size=1, usetex=True, ec="none", fc="k")
self.side3.append(t)
def vizualize(self):
self.voxelize()
self.labelize()
plt.axis('off')
arr = np.arange(60).reshape((2,6,5))
va = VisualArray(arr)
va.vizualize()
plt.show()
This is an output I get after external change of ...\mpl_toolkits\mplot3d\axes3d.py file:
This is an output (an unwanted one) I get if no change is done:
What you want to achieve is called Monkey Patching.
It has its downsides and has to be used with some care (there is plenty of information available under this keyword). But one option could look something like this:
from matplotlib import artist
from mpl_toolkits.mplot3d import Axes3D
# Create a new draw function
#artist.allow_rasterization
def draw(self, renderer):
# Your version
# ...
# Add Axes3D explicitly to super() calls
super(Axes3D, self).draw(renderer)
# Overwrite the old draw function
Axes3D.draw = draw
# The rest of your code
# ...
Caveats here are to import artist for the decorator and the explicit call super(Axes3D, self).method() instead of just using super().method().
Depending on your use case and to stay compatible with the rest of your code you could also save the original draw function and use the custom only temporarily:
def draw_custom():
...
draw_org = Axes3D.draw
Axes3D.draw = draw_custom
# Do custom stuff
Axes3D.draw = draw_org
# Do normal stuff
Here is my script to plot data from a Geogtiff file using basemap. The data is categorical and there are 13 categories within this domain. The problem is that some categories get bunched up into one colour and thus some resolution is lost.
Unfortunately, I do not know how to fix this. I read that plt.cm.get_cmp is better for discrete datasets but I have not gotten it to work unfortunately.
gtif = 'some_dir'
ds = gdal.Open(gtif)
data = ds.ReadAsArray()
gt = ds.GetGeoTransform()
proj = ds.GetProjection()
xres = gt[1]
yres = gt[5]
xmin = gt[0] + xres
xmax = gt[0] + (xres * ds.RasterXSize) - xres
ymin = gt[3] + (yres * ds.RasterYSize) + yres
ymax = gt[3] - yres
xy_source = np.mgrid[xmin:xmax+xres:xres, ymax+yres:ymin:yres]
ds = None
fig2 = plt.figure(figsize=[12, 11])
ax2 = fig2.add_subplot(111)
ax2.set_title("Land use plot")
bm2 = Basemap(ax=ax2,projection='cyl',llcrnrlat=ymin,urcrnrlat=ymax,llcrnrlon=xmin,urcrnrlon=xmax,resolution='l')
bm2.drawcoastlines(linewidth=0.2)
bm2.drawcountries(linewidth=0.2)
data_new=np.copy(data)
data_new[data_new==255] = 0
nbins = np.unique(data_new).size
cb =plt.cm.get_cmap('jet', nbins+1)
img2 =bm2.imshow(np.flipud(data_new), cmap=cb)
ax2.set_xlim(3, 6)
ax2.set_ylim(50,53)
plt.show()
labels = [str(i) for i in np.unique(data_new)]
cb2=bm2.colorbar(img2, "right", size="5%", pad='3%', label='NOAH Land Use Category')
cb2.set_ticklabels(labels)
cb2.set_ticks(np.unique(data_new))
Here are the categories that are found within the domain (numbered classes):
np.unique(data_new)
array([ 0, 1, 4, 5, 7, 10, 11, 12, 13, 14, 15, 16, 17], dtype=uint8)
Thanks so much for any help here. I have also attached the output image that shows the mismatch. (not working)
First, this colormap problem is independent of the use of basemap. The following is therefore applicable to any matplotlib plot.
The problem here is that creating a colormap from n values distributes those values equally over the colormap range. Some values from the image therefore fall into the same colorrange within the colormap.
To prevent this, one can generate a colormap with the initial number of categories as shown below.
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.colors
# generate some data
data = np.array( [ 0, 1, 4, 5, 7, 10]*8 )
np.random.shuffle(data)
data = data.reshape((8,6))
# generate colormap and norm
unique = np.unique(data)
vals = np.arange(int(unique.max()+1))/float(unique.max())
cols = plt.cm.jet(vals)
cmap = matplotlib.colors.ListedColormap(cols, int(unique.max())+1)
norm=matplotlib.colors.Normalize(vmin=-0.5, vmax=unique.max()+0.5)
fig, ax = plt.subplots(figsize=(5,5))
im = ax.imshow(data, cmap=cmap, norm=norm)
for i in range(data.shape[0]):
for j in range(data.shape[1]):
ax.text(j,i,data[i,j], color="w", ha="center", va="center")
cb = fig.colorbar(im, ax=ax, norm=norm)
cb.set_ticks(unique)
plt.show()
This can be extended to exclude the colors not present in the image as follows:
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.colors
# generate some data
data = np.array( [ 0, 1, 4, 5, 7, 10]*8 )
np.random.shuffle(data)
data = data.reshape((8,6))
unique, newdata = np.unique(data, return_inverse=1)
newdata = newdata.reshape(data.shape)
# generate colormap and norm
new_unique = np.unique(newdata)
vals = np.arange(int(new_unique.max()+1))/float(new_unique.max())
cols = plt.cm.jet(vals)
cmap = matplotlib.colors.ListedColormap(cols, int(new_unique.max())+1)
norm=matplotlib.colors.Normalize(vmin=-0.5, vmax=new_unique.max()+0.5)
fig, ax = plt.subplots(figsize=(5,5))
im = ax.imshow(newdata, cmap=cmap, norm=norm)
for i in range(newdata.shape[0]):
for j in range(newdata.shape[1]):
ax.text(j,i,data[i,j], color="w", ha="center", va="center")
cb = fig.colorbar(im, ax=ax, norm=norm)
cb.ax.set_yticklabels(unique)
plt.show()