How to prevent stack plots repeating plots? - matplotlib

I have a code that should produce two different graphs and place them into one image and cannot figure out why it returns the last mentioned graph twice. The code is as follows:
import spacepy as sp
from spacepy import pycdf
from pylab import *
from spacepy.toolbox import windowMean, normalize
from spacepy.plot.utils import annotate_xaxis
import pylab
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.mlab as mlab
import matplotlib.cbook as cbook
import matplotlib.ticker as ticker
from matplotlib.colors import LogNorm
from matplotlib.ticker import LogLocator
from matplotlib.dates import DateFormatter
from matplotlib.dates import DayLocator, HourLocator, MinuteLocator
from matplotlib import rc, rcParams
import matplotlib.dates as mdates
import datetime as dt
import bisect as bi
import seaborn as sea
import sys
import os
import multilabel as ml
import pandas as pd
sea.set_context('poster')
# sea.set_style('ticks',{'axes.facecolor':'yellow'})
sea.set_style('whitegrid')
sea.set_palette('muted',color_codes=True)
rc('text', usetex=True)
rc('font', family='Mono')
rcParams['text.latex.preamble']=[r'\usepackage{amsmath}']
MMS_1_HPCA_SURVEY_ION = pycdf.CDF(r'/home/ary/Desktop/Arya/Project/Data/MMS/1/HPCA/Survey/Ion/mms1_hpca_srvy_l2_ion_20151025120000_v1.0.0.cdf')
EPOCH_SURVEY_ION_1 = MMS_1_HPCA_SURVEY_ION['Epoch'][...]
H_Flux_SURVEY_ION_1 = MMS_1_HPCA_SURVEY_ION['mms1_hpca_hplus_flux'][...]
O_Flux_SURVEY_ION_1 = MMS_1_HPCA_SURVEY_ION['mms1_hpca_oplus_flux'][...]
Ion_Energy_SURVEY_ION_1 = MMS_1_HPCA_SURVEY_ION['mms1_hpca_ion_energy'][...]
MMS_SURVEY_ION_1_Start_time = dt.datetime(2015, 10, 25, 12, 0, 0, 908117)
MMS_SURVEY_ION_1_Finish_time = dt.datetime(2015, 10, 25, 16, 22, 24, 403623)
dt_MMS = dt.timedelta(seconds = 15)
plt.close('all')
fig_MMS, axs_MMS = plt.subplots(2,sharex=True)
cmap = plt.get_cmap(cm.jet)
cmap.set_bad('black')
sidx_MMS_1_SURVEY_ION = bi.bisect_left(EPOCH_SURVEY_ION_1,MMS_SURVEY_ION_1_Start_time)
sidx_MMS_1_SURVEY_ION = int(sidx_MMS_1_SURVEY_ION-(sidx_MMS_1_SURVEY_ION/100))
lidx_MMS_1_SURVEY_ION = bi.bisect_left(EPOCH_SURVEY_ION_1, MMS_SURVEY_ION_1_Finish_time)
lidx_MMS_1_SURVEY_ION = int(lidx_MMS_1_SURVEY_ION+((len(EPOCH_SURVEY_ION_1)-lidx_MMS_1_SURVEY_ION)/100))
if MMS_SURVEY_ION_1_Start_time.date() == MMS_SURVEY_ION_1_Finish_time.date():
stopfmt = '%H:%M'
else:
stopfmt = '%-m/%-d/%y %H:%M'
title_1 = MMS_SURVEY_ION_1_Start_time.strftime('%m/%d/%y %H:%M')+' -'+MMS_SURVEY_ION_1_Finish_time.strftime(stopfmt)
if dt_MMS.seconds !=0:
title_1 = title_1 + ' with '+str(dt_MMS.seconds)+' second time averaging'
for j, ax in enumerate(axs_MMS.T.flatten()):
flix_1 = np.array(H_Flux_SURVEY_ION_1[sidx_MMS_1_SURVEY_ION:lidx_MMS_1_SURVEY_ION,
j, :].T)
if dt_MMS==dt.timedelta(0):
fluxwin_1 = flix_1
timewin_1 = EPOCH_SURVEY_ION_1[sidx_MMS_1_SURVEY_ION:lidx_MMS_1_SURVEY_ION]
else:
fluxwin_1=[[0 for y in range(len(flix_1))] for x_1 in range(len(flix_1))]
for i, flox in enumerate(flix_1):
fluxwin_1[i], timewin_1 = windowMean(flox, EPOCH_SURVEY_ION_1[sidx_MMS_1_SURVEY_ION:lidx_MMS_1_SURVEY_ION],
winsize=dt_MMS, overlap=dt.timedelta(0))
fluxwin_1[i] = np.array(fluxwin_1[i])
for x_1 in np.where(np.diff(EPOCH_SURVEY_ION_1[sidx_MMS_1_SURVEY_ION:lidx_MMS_1_SURVEY_ION])
>dt.timedelta(hours=1))[0]+sidx_MMS_1_SURVEY_ION:
fluxwin_1[i][bi.bisect_right(timewin_1, EPOCH_SURVEY_ION_1[x_1]):bi.bisect_right(timewin_1,
EPOCH_SURVEY_ION_1[x_1+1])]=0
fluxwin_1 = np.array(fluxwin_1)
fluxwin_1[np.where(fluxwin_1<=0)] = 0
x_1 = mdates.date2num(timewin_1)
pax_1 = ax.pcolormesh(x_1, Ion_Energy_SURVEY_ION_1, fluxwin_1, shading='turkey',cmap=cmap, vmin=1,
vmax=np.nanmax(H_Flux_SURVEY_ION_1[sidx_MMS_1_SURVEY_ION:lidx_MMS_1_SURVEY_ION,:,:]),
norm=LogNorm())
sax_1 = ax.twinx()
plt.setp(sax_1.get_yticklabels(), visible=False)
sax_1.tick_params(axis='y', right='off')
ax.set_xlim(MMS_SURVEY_ION_1_Start_time, MMS_SURVEY_ION_1_Finish_time)
ax.set_yscale('log')
ax.set_yticks([10, 100, 1000,10000])
#Allows non-log formatted values to be used for ticks
ax.yaxis.set_major_formatter(plt.ScalarFormatter())
axs_MMS[0].set_ylabel('Energy (eV)')
axs_MMS[0].set_title(title_1)
for j, ax in enumerate(axs_MMS.T.flatten()):
flix_2 = np.array(O_Flux_SURVEY_ION_1[sidx_MMS_1_SURVEY_ION:lidx_MMS_1_SURVEY_ION,
j, :].T)
if dt_MMS==dt.timedelta(0):
fluxwin_2 = flix_2
timewin_2 = EPOCH_SURVEY_ION_2[sidx_MMS_1_SURVEY_ION:lidx_MMS_1_SURVEY_ION]
else:
fluxwin_2=[[0 for y in range(len(flix_2))] for x_2 in range(len(flix_2))]
for i, flox in enumerate(flix_2):
fluxwin_2[i], timewin_2 = windowMean(flox, EPOCH_SURVEY_ION_1[sidx_MMS_1_SURVEY_ION:lidx_MMS_1_SURVEY_ION],
winsize=dt_MMS, overlap=dt.timedelta(0))
fluxwin_2[i] = np.array(fluxwin_2[i])
for x_2 in np.where(np.diff(EPOCH_SURVEY_ION_1[sidx_MMS_1_SURVEY_ION:lidx_MMS_1_SURVEY_ION])
>dt.timedelta(hours=1))[0]+sidx_MMS_1_SURVEY_ION:
fluxwin_2[i][bi.bisect_right(timewin_2, EPOCH_SURVEY_ION_1[x_2]):bi.bisect_right(timewin_2,
EPOCH_SURVEY_ION_1[x_1+1])]=0
fluxwin_2 = np.array(fluxwin_2)
fluxwin_2[np.where(fluxwin_2<=0)] = 0
x_2 = mdates.date2num(timewin_2)
pax_2 = ax.pcolormesh(x_2, Ion_Energy_SURVEY_ION_1, fluxwin_2, shading='turkey',cmap=cmap, vmin=1,
vmax=np.nanmax(O_Flux_SURVEY_ION_1[sidx_MMS_1_SURVEY_ION:lidx_MMS_1_SURVEY_ION,:,:]),
norm=LogNorm())
sax_2 = ax.twinx()
plt.setp(sax_2.get_yticklabels(), visible=False)
sax_2.tick_params(axis='y', right='off')
ax.set_xlim(MMS_SURVEY_ION_1_Start_time, MMS_SURVEY_ION_1_Finish_time)
ax.set_yscale('log')
ax.set_yticks([10, 100, 1000,10000])
#Allows non-log formatted values to be used for ticks
ax.yaxis.set_major_formatter(plt.ScalarFormatter())
axs_MMS[1].set_ylabel('Energy (eV)')
cbar_ax_1 = fig_MMS.add_axes([0.93, 0.15, 0.02, 0.7])
cb_MMS_1 = fig_MMS.colorbar(pax_1, cax=cbar_ax_1)
cb_MMS_1.set_label(r'Counts sec$^{-1}$ ster$^{-1}$ cm$^{-2}$ keV$^{-1}$')
#Sets the colorbar value range
cb_MMS_1.set_clim(1, np.nanmax(H_Flux_SURVEY_ION_1[sidx_MMS_1_SURVEY_ION:lidx_MMS_1_SURVEY_ION,:,:]))
#Redraws the colorbar
cb_MMS_1.draw_all()
and the image returned looks as such:
enter image description here

Consider the following example which corresponds to your code:
import matplotlib.pyplot as plt
fig, axs = plt.subplots(2,sharex=True, figsize=(4,2.4))
for j, ax in enumerate(axs.T.flatten()):
x_1 = [[0,1,2],[0,1,2],[0,1,2]]
y_1 = [[0,0,0],[1,1,1],[2,2,2]]
z_1 = [[6,5,4],[2,3,4],[6,5,4]]
pax_1 = ax.pcolormesh(x_1, y_1, z_1)
axs[0].set_ylabel('Energy (eV)')
axs[0].set_title("Title1")
for j, ax in enumerate(axs.T.flatten()):
x_2 = [[3,4,5],[3,4,5],[3,4,5]]
y_2 = [[0,0,0],[1,1,1],[2,2,2]]
z_2 = [[2,1,2],[2,1,2],[2,1,2]]
pax_2 = ax.pcolormesh(x_2, y_2, z_2)
axs[1].set_ylabel('Energy (eV)')
plt.show()
Here you plot each plot both axes.
Instead you need to plot to the different axes:
import matplotlib.pyplot as plt
fig, axs = plt.subplots(2,sharex=True, figsize=(4,2.4))
x_1 = [[0,1,2],[0,1,2],[0,1,2]]
y_1 = [[0,0,0],[1,1,1],[2,2,2]]
z_1 = [[6,5,4],[2,3,4],[6,5,4]]
pax_1 = axs[0].pcolormesh(x_1, y_1, z_1)
axs[0].set_ylabel('Energy (eV)')
axs[0].set_title("Title1")
x_2 = [[3,4,5],[3,4,5],[3,4,5]]
y_2 = [[0,0,0],[1,1,1],[2,2,2]]
z_2 = [[2,1,2],[2,1,2],[2,1,2]]
pax_2 = axs[1].pcolormesh(x_2, y_2, z_2)
axs[1].set_ylabel('Energy (eV)')
plt.show()

Related

Flight Path by shapely LineString is not correct

I want to connect airplanes in origin (lat_1 lon_1) to dest(lat_2 lon_2). I use these data.
callsign
latitude_1
longitude_1
latitude_2
longitude_2
0
HBAL102
-4.82114
-76.3194
-4.5249
-79.0103
1
AUA1028
-33.9635
151.181
48.1174
16.55
2
ABW120
41.9659
-87.8832
55.9835
37.4958
3
CSN461
33.9363
-118.414
50.0357
8.5723
4
ETH3730
25.3864
55.4221
50.6342
5.43903
But unfortunately, I would get an incorrect result when creating LineString with shapely. I used everything like rotate and affine but it didn't correct.
Code:
cols = pd.read_csv("/content/dirct_lines.csv",sep=";")
line = cols[["callsign","latitude_1","longitude_1","latitude_2","longitude_2"]].dropna()
line['geometry'] = line.apply(lambda x: [(x['latitude_1'],
x['longitude_1']),
(x['latitude_2'],
x['longitude_2'])], axis = 1)
geoline = gpd.GeoDataFrame(line,geometry="geometry",
crs="EPSG:4326")
import matplotlib.pyplot as plt
world = gpd.read_file(gpd.datasets.get_path('naturalearth_lowres'))
ax = world.plot(figsize=(14,9),
color='white', edgecolor='black')
geoline.plot(figsize=(14,9),ax=ax,facecolor = 'lightgrey', linewidth = 1.75,
edgecolor = 'red',
alpha = 2)
plt.show()
Shapely Output:
something that was interesting for me was that when I use Matplotlib to create lines everything is correct.
Code:
import cartopy.crs as ccrs
import matplotlib.pyplot as plt
fig = plt.figure(figsize=(12, 8))
ax = fig.add_subplot(projection=ccrs.PlateCarree())
ax.stock_img()
org_lon, org_lat = cols["longitude_1"], cols["latitude_1"]
dst_lon, dst_lat = cols["longitude_2"], cols["latitude_2"]
plt.plot([org_lon, dst_lon], [org_lat, dst_lat],
color='black', linewidth=0.5, marker='_',
transform=ccrs.PlateCarree()
)
plt.savefig(f"fight_path.png",dpi=60,facecolor = None, bbox_inches = 'tight', pad_inches = None)
plt.show()
Matplotlib Output:
What is the problem?
why isn't correct by shapely?
it's just the way you are creating the geometry. Below works correctly.
import io
import geopandas as gpd
import pandas as pd
import shapely.geometry
df = pd.read_csv(
io.StringIO(
"""callsign,latitude_1,longitude_1,latitude_2,longitude_2
HBAL102,-4.82114,-76.3194,-4.5249,-79.0103
AUA1028,-33.9635,151.181,48.1174,16.55
ABW120,41.9659,-87.8832,55.9835,37.4958
CSN461,33.9363,-118.414,50.0357,8.5723
ETH3730,25.3864,55.4221,50.6342,5.43903
"""
)
)
geoline = gpd.GeoDataFrame(
geometry=[
shapely.geometry.LineString(points)
for points in zip(
gpd.points_from_xy(df["longitude_1"], df["latitude_1"]),
gpd.points_from_xy(df["longitude_2"], df["latitude_2"]),
)
],
data=df,
)
import matplotlib.pyplot as plt
world = gpd.read_file(gpd.datasets.get_path("naturalearth_lowres"))
ax = world.plot(figsize=(14, 9), color="white", edgecolor="black")
geoline.plot(
figsize=(14, 9),
ax=ax,
facecolor="lightgrey",
linewidth=1.75,
edgecolor="red",
)
plt.show()

Matplotlib clearing the figure/axis for new plot

am trying to figure out how to clear the axis in readiness for new plotting, I have tried ax.clf(), fig.clf() but nothing is happening. where am I not doing well? at the moment am not getting any errors and am using Matplotlib vers. 3.4.3.
from tkinter import *
import matplotlib.pyplot as plt
import numpy as np
import time
import datetime
import mysql.connector
import matplotlib.dates as mdates
my_connect = mysql.connector.connect(host="localhost", user="Kennedy", passwd="Kennerdol05071994", database="ecg_db", auth_plugin="mysql_native_password")
mycursor = my_connect.cursor()
voltage_container = []
time_container = []
def analyze_voltage_time():
global ax, fig
pat_id = 1
query = "SELECT voltage, time FROM ecg_data_tbl where patient_id = " +str(pat_id)
mycursor.execute(query)
result = mycursor .fetchall()
voltage, time = list(zip(*result))
for volts in voltage:
voltage_container.append(volts)
for tim in time:
time_container.append(str(tim))
fig = plt.figure(1, figsize = (15, 6), dpi = 80, constrained_layout = True)
ax = fig.add_subplot()
ax.plot(time_container, voltage_container)
for label in ax.get_xticklabels():
label.set_rotation(40)
label.set_horizontalalignment('right')
ax.set_title("Electrocadiogram")
ax.set_xlabel("Time(hh:mm:ss)")
ax.set_ylabel("Voltage(mV)")
ax.grid(b=True, which='major', color='#666666', linestyle='-')
ax.minorticks_on()
ax.grid(b=True, which='minor', color='#666666', linestyle='-', alpha=0.2)
plt.show()
def clear_():
ax.cla()
fig.clf()
# =================================MAIN GUI WINDOW======================================
analysis_window = Tk()
analysis_window.configure(background='light blue')
analysis_window.iconbitmap('lardmon_icon.ico')
analysis_window.title("ECG-LArdmon - ANALYZER")
analysis_window.geometry('400x200')
analysis_window.resizable(width=False, height=False)
# ===========================BUTTONS===================================
analyse_btn = Button(analysis_window, text='analyze', width = 20, command=analyze_voltage_time)
analyse_btn.pack()
clear_btn = Button(analysis_window, text= 'clear', width = 20, command=clear_)
clear_btn.pack()
analysis_window.mainloop()

Colorbar scientific notation, change e^ to 10^

I am using scientific notation in a colorbar within a 2D plot. I want to write 10^{-3} instead of e-3. I tried to change that (see code below) but it does not work...
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker
x = np.random.rand(100)
y = np.random.rand(100)
z = np.random.rand(100)*0.001
x=x.reshape((10,10))
y=y.reshape((10,10))
z=z.reshape((10,10))
fig, ax = plt.subplots(figsize=(8,6))
cs = ax.contourf(x,y,z, 10)
plt.xticks(fontsize=16,rotation=0)
plt.yticks(fontsize=16,rotation=0)
cbar = plt.colorbar(cs,)
cbar.set_label("test",fontsize = 22)
cbar.formatter.set_scientific(True)
cbar.formatter.set_powerlimits((0, 0))
cbar.ax.tick_params(labelsize=16)
cbar.ax.yaxis.get_offset_text().set_fontsize(22)
cbar.ax.xaxis.major.formatter._useMathText = True
cbar.update_ticks()
plt.savefig("test.png")
It seems you want a ScalarFormatter with mathtext in use.
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker
x = np.tile(np.arange(10), 10).reshape((10,10))
y = np.repeat(np.arange(10),10).reshape((10,10))
z = np.sort(np.random.rand(100)*0.001).reshape((10,10))
fig, ax = plt.subplots(figsize=(8,6))
cs = ax.contourf(x,y,z, 10)
fmt = matplotlib.ticker.ScalarFormatter(useMathText=True)
fmt.set_powerlimits((0, 0))
cbar = plt.colorbar(cs,format=fmt)
plt.show()

how can i make my figure made by matplotlib response to a mouse click event?

I read the document of matplotlib and write the following code, it supposed to capture my mouse event and move the grey line position when i clicked. I read this code in jupiter notebook online, it stop to show the coordinate of my cursor as it usually do, What's happend? Can anyone help me?
import pandas as pd
import numpy as np
import matplotlib.colors as mcol
import matplotlib.cm as cm
from scipy import stats
from matplotlib.lines import Line2D
import matplotlib.pyplot as plt
import scipy.spatial as spatial
np.random.seed(12345)
df = pd.DataFrame([np.random.normal(33500,150000,3650),
np.random.normal(41000,90000,3650),
np.random.normal(41000,120000,3650),
np.random.normal(48000,55000,3650)],
index=[1992,1993,1994,1995])
fig, ax = plt.subplots()
year_avg = df.mean(axis = 1)
year_std = df.std(axis = 1)
yerr = year_std / np.sqrt(df.shape[1]) * stats.t.ppf(1-0.05/2, df.shape[1]-1)
bars=ax.bar(range(df.shape[0]), year_avg, yerr = yerr, color = 'lightslategrey')
threshold=42000
line=plt.axhline(y = threshold, color = 'grey', alpha = 0.5)
cm1 = mcol.LinearSegmentedColormap.from_list("CmapName",["yellow", "orange", "red"])
cpick = cm.ScalarMappable(cmap=cm1)
percentages = []
cpick.set_array([])
def setColor(bars, yerr,threshold):
for bar, yerr_ in zip(bars, yerr):
low = bar.get_height() - yerr_
high = bar.get_height() + yerr_
percentage = (high-threshold)/(high-low)
if percentage>1: percentage = 1
if percentage<0: percentage=0
percentages.append(percentage)
cpick.to_rgba(percentages)
bars = ax.bar(range(df.shape[0]), year_avg, yerr = yerr, color = cpick.to_rgba(percentages))
return bars
line=plt.axhline(threshold, color = 'grey', alpha = 0.5)
setColor(bars, yerr,threshold)
plt.colorbar(cpick, orientation='horizontal')
plt.xticks(range(df.shape[0]), df.index)
fig = plt.figure()
plt.show()
def onclick(event):
print('%s click: button=%d, x=%d, y=%d, xdata=%f, ydata=%f' %
('double' if event.dblclick else 'single', event.button,
event.x, event.y, event.xdata, event.ydata))
line.set_ydata(event.ydata)
#plt.draw()
cid = fig.canvas.mpl_connect('button_press_event', onclick)

I want to add a "spheres" to my data cluster

I want to add a kind of "spheres" to my data cluster.
My data cluster is this, which does not have ""spheres".
And this is my code
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import style
style.use('ggplot')
import pandas as pd
from sklearn.cluster import KMeans
MY_FILE='total_watt.csv'
date = []
consumption = []
df = pd.read_csv(MY_FILE, parse_dates=[0], index_col=[0])
df = df.resample('1D', how='sum')
df = df.dropna()
date = df.index.tolist()
date = [x.strftime('%Y-%m-%d') for x in date]
from sklearn.preprocessing import LabelEncoder
encoder = LabelEncoder()
date_numeric = encoder.fit_transform(date)
consumption = df[df.columns[0]].values
X = np.array([date_numeric, consumption]).T
kmeans = KMeans(n_clusters=3)
kmeans.fit(X)
centroids = kmeans.cluster_centers_
labels = kmeans.labels_
print(centroids)
print(labels)
fig, ax = plt.subplots(figsize=(10,8))
rect = fig.patch
rect.set_facecolor('#2D2B2B')
colors = ["b.","r.","g."]
for i in range(len(X)):
print("coordinate:",encoder.inverse_transform(X[i,0].astype(int)), X[i,1], "label:", labels[i])
ax.plot(X[i][0], X[i][1], colors[labels[i]], markersize = 10)
ax.scatter(centroids[:, 0],centroids[:, 1], marker = "x", s=150, linewidths = 5, zorder = 10)
a = np.arange(0, len(X), 5)
ax.set_xticks(a)
ax.set_xticklabels(encoder.inverse_transform(a.astype(int)))
ax.tick_params(axis='x', colors='lightseagreen')
ax.tick_params(axis='y', colors='lightseagreen')
plt.scatter(centroids[:, 0],centroids[:, 1], marker = "x", s=100, c="black", linewidths = 5, zorder = 10)
ax.set_title('Energy consumptions Clusters (high/medium/low)', color='gold')
ax.set_xlabel('time', color='gold')
ax.set_ylabel('date(year 2011)', color='gold')
plt.show()
"Spheres" is area which surroundings plot(cluster), as this picture.
I tried to google it.
But when I type "matplotlib spheres", I could not get any result..
The sample graph in your post looks like resulting from Generalized Gaussian Mixture where each sphere is a Gaussian 2-d density.
I'll write up a sample code shortly to demonstrate how to use GMM on your dataset and do this kind of plotting.
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import style
style.use('ggplot')
import pandas as pd
# code changes here
# ===========================================
from sklearn.mixture import GMM
# ===========================================
from sklearn.preprocessing import LabelEncoder
# replace it with you file path
MY_FILE='/home/Jian/Downloads/total_watt.csv'
df = pd.read_csv(MY_FILE, parse_dates=[0], index_col=[0])
df = df.resample('1D', how='sum')
df = df.dropna()
date = df.index.tolist()
date = [x.strftime('%Y-%m-%d') for x in date]
encoder = LabelEncoder()
date_numeric = encoder.fit_transform(date)
consumption = df[df.columns[0]].values
X = np.array([date_numeric, consumption]).T
# code changes here
# ===========================================
gmm = GMM(n_components=3, random_state=0)
gmm.fit(X)
y_pred = gmm.predict(X)
# the center is given by mean
gmm.means_
# ===========================================
import matplotlib as mpl
fig, ax = plt.subplots(figsize=(10,8))
for i, color in enumerate('rgb'):
# sphere background
width, height = 2 * 1.96 * np.sqrt(np.diagonal(gmm._get_covars()[i]))
ell = mpl.patches.Ellipse(gmm.means_[i], width, height, color=color)
ell.set_alpha(0.1)
ax.add_artist(ell)
# data points
X_data = X[y_pred == i]
ax.scatter(X_data[:,0], X_data[:,1], color=color)
# center
ax.scatter(gmm.means_[i][0], gmm.means_[i][1], marker='x', s=100, c=color)
ax.set_title('Energy consumptions Clusters (high/medium/low)', color='gold')
ax.set_xlabel('time', color='gold')
ax.set_ylabel('date(year 2011)', color='gold')
a = np.arange(0, len(X), 5)
ax.set_xticks(a)
ax.set_xticklabels(encoder.inverse_transform(a.astype(int)))
ax.tick_params(axis='x', colors='lightseagreen')
ax.tick_params(axis='y', colors='lightseagreen')