pulling data out of bins in density map created with matplotlib - matplotlib

I have created a lightning density map using lines of data representing lightning strikes. One line is shown below:
1996-01-17 03:54:35.853 44.9628 -78.9399 -37.9
Now that I have applied these lines of data to the density map and distributed them into their appropriate bins based on Lat/Long, I would like to pull the data back out specific to the bin that it fell into so that I can manipulate that data further.
I have tried to find answers to this online but have failed to find anything that is specific to what I am trying to do. Any and all help is greatly appreciated!
my code:
import matplotlib.pyplot as plt
import matplotlib as mpl
import matplotlib.axes as ax
import numpy as np
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from metpy.plots import USCOUNTIES
from matplotlib.axes import Axes
from cartopy.mpl.geoaxes import GeoAxes
GeoAxes._pcolormesh_patched = Axes.pcolormesh
import datetime
fig, ax = plt.subplots(figsize=(15,15),subplot_kw=dict(projection=ccrs.Stereographic(central_longitude=-76, central_latitude=43)))
ax.set_extent([-79, -73, 42, 45],ccrs.Geodetic())
ax.add_feature(USCOUNTIES.with_scale('500k'), edgecolor='gray', linewidth=0.25)
ax.add_feature(cfeature.STATES.with_scale('50m'))
winter = [12, 1, 2]
summer = [6, 7, 8]
seasondata = []
lons=[]
lats=[]
f = open("2007-2016.txt", "r")
for line in f.readlines():
parts = line.split()
dates = parts[0]
charges = float(parts[4])
date = datetime.datetime.strptime(dates, "%Y-%m-%d")
#if date.month in summer:
if date.month in winter:
seasondata.append(line)
if charges <= 0:
seasondata.append(line)
lon = float(parts[3])
lat = float(parts[2])
lons.append(lon)
lats.append(lat)
if charges >= 15:
seasondata.append(line)
lon = float(parts[3])
lat = float(parts[2])
lons.append(lon)
lats.append(lat)
lons=np.array(lons)
lats=np.array(lats)
ax.set_title('2007-2016 Jan, Feb, Dec: Lightning Density', loc ='Left')
xynps = (ax.projection.transform_points(ccrs.Geodetic(), lons, lats))
bins=[300,240]
h2d, xedges, yedges, im = ax.hist2d(xynps[:,0], xynps[:,1], bins=bins, cmap=plt.cm.YlOrRd, zorder=10, alpha=0.4)
lons=[]
lats=[]
f = open("turbine.txt", "r")
for line in f.readlines():
parts = line.split()
lat=float(parts[0])
lon=float(parts[1])
lats.append(lat)
lons.append(lon)
markerSymbol='o'
markerSize=10
ax.scatter(lons, lats, transform=ccrs.PlateCarree(), marker = markerSymbol, s=markerSize, c='b')
cbar = plt.colorbar(im, fraction=0.046, pad=0.04)
cbar.set_label("Flashes per 2km^2")
plt.show()

Related

Creating US map with 50 state density and color bar using basemap

I have a dictionary named density, I am trying to create a US state map as the color shows the density of the state. I am trying to replicate this use Basemap (Python) to plot US with 50 states
however I am getting error.
This is my data:
density = {'NY': 648.0,
'FL': 696.0,
'TX': 833.0,
'CA': 927.0,
'PA': 472.0,
'OH': 721.0,
'NJ': 645.0,
'IL': 607.0,
'MI': 570.0,
'AZ': 616.0,
'GA': 799.0,
'MD': 652.0,
'NC': 720.0,
'LA': 546.0,
'TN': 806.0,
'MO': 564.0,
'SC': 574.0,
'VA': 818.0,
'IN': 780.0,
'AL': 619.0,
'MA': 626.0,
'WA': 749.0,
'KY': 680.0,
'WI': 615.0,
'OK': 633.0,
'MN': 743.0,
'IA': 543.0,
'WV': 599.0,
'MS': 695.0,
'AR': 698.0,
'OR': 878.0,
'CO': 782.0,
'NV': 930.0,
'KS': 637.0,
'CT': 1078.0,
'UT': 580.0,
'NM': 667.0,
'NE': 552.0,
'PR': 698.0,
'ME': 702.0,
'ID': 679.0,
'DE': 845.0,
'NH': 668.0,
'RI': 616.0,
'HI': 1131.0,
'DC': 711.0,
'MT': 653.0,
'SD': 495.0,
'ND': 685.0,
'VT': 754.0,
'AK': 1080.0,
'WY': 1028.0,
'VI': 1261.0,
'GU': 889.0}
Here is my code which I get the error.
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.basemap import Basemap as Basemap
from matplotlib.colors import rgb2hex
from matplotlib.patches import Polygon
m = Basemap(llcrnrlon=-119,llcrnrlat=22,urcrnrlon=-64,urcrnrlat=49,
projection='lcc',lat_1=33,lat_2=45,lon_0=-95)
shp_info = m.readshapefile('st99_d00','states',drawbounds=True)
colors={}
statenames=[]
cmap = plt.cm.hot # use 'hot' colormap
vmin = 0; vmax = 450 # set range.
for shapedict in m.states_info:
statename = shapedict['NAME']
if statename not in ['District of Columbia','Puerto Rico']:
pop = popdensity[statename]
colors[statename] = cmap(1.-np.sqrt((pop-vmin)/(vmax-vmin)))[:3]
statenames.append(statename)
ax = plt.gca() # get current axes instance
for nshape,seg in enumerate(m.states):
if statenames[nshape] not in ['Puerto Rico', 'District of Columbia']:
if statenames[nshape] == 'Alaska':
seg = list(map(lambda (x,y): (0.35*x + 1100000, 0.35*y-1300000), seg))
if statenames[nshape] == 'Hawaii':
seg = list(map(lambda (x,y): (x + 5100000, y-900000), seg))
color = rgb2hex(colors[statenames[nshape]])
poly = Polygon(seg,facecolor=color,edgecolor=color)
ax.add_patch(poly)
plt.title('******')
plt.show()
I am confused what I need to do to this code work.
I am new to pyhton, any help and feedback is highly appreciated.
TIA!

How can the x-axis dates be formatted without hh:mm:ss using matplotlib DateFormatter?

I am pulling in data on Japanese GDP and graphing a stacked barchart overlayed with a line. I would like for the x-axis to have only yyyy-mm and no timestamp. I read about a compatability issue with pandas and matplotlib epochs. Is that the issue here? When I try to use matplotlib Dateformatter, the returned dates begin with 1970. How can I fix this?
import pandas as pd
import pandas_datareader.data as web
import datetime
import requests
import investpy
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
start1 = '01/01/2013' #dd/mm/yyyy
end1 = '22/04/2022'
# Real GDP growth
# Source: Cabinet Office http://www.esri.cao.go.jp/en/sna/sokuhou/sokuhou_top.html
# Get the data
url = 'https://www.esri.cao.go.jp/jp/sna/data/data_list/sokuhou/files/2021/qe214_2/tables/nritu-jk2142.csv'
url2 = url.replace('nritu','nkiyo') # URL used for GDP growth by component
url3 = url.replace('nritu-j', 'gaku-m')
url4 = url.replace('nritu', 'gaku')
url5 = url.replace('nritu', 'kgaku')
df = pd.read_csv(url2, header=5, encoding='iso-8859-1').loc[49:]
gdpkeep = {
'Unnamed: 0': 'date',
'GDP(Expenditure Approach)': 'GDP',
'PrivateConsumption': 'Consumption',
'PrivateResidentialInvestment': 'inv1',
'Private Non-Resi.Investment': 'inv2',
'Changein PrivateInventories': 'inv3',
'GovernmentConsumption': 'gov1',
'PublicInvestment': 'gov2',
'Changein PublicInventories': 'gov3',
'Goods & Services': 'Net Exports'
}
df = df[list(gdpkeep.keys())].dropna()
df.columns = df.columns.to_series().map(gdpkeep)
# Adjust the date column to make each value a consistent format
dts = df['date'].str.split('-').str[0].str.split('/ ')
for dt in dts:
if len(dt) == 1:
dt.append(dt[0])
dt[0] = None
df['year'] = dts.str[0].fillna(method='ffill')
df['month'] = dts.str[1].str.zfill(2)
df['date2'] = df['year'].str.cat(df['month'], sep='-')
df['date'] = pd.to_datetime(df['date2'], format='%Y-%m')
# Sum up various types of investment and government spending
df['Investment'] = df['inv1'] + df['inv2'] + df['inv3']
df['Government Spending'] = df['gov1'] + df['gov2'] + df['gov3']
df = df.set_index('date')[['GDP', 'Consumption', 'Investment', 'Government Spending', 'Net Exports']]
df.to_csv('G:\\AutomaticDailyBackup\\Python\\MacroEconomics\\Japan\\Data\\gdp.csv', header=True) # csv file created
print(df.tail(8))
# Plot
df['Net Exports'] = df['Net Exports'].astype(float)
ax = df[['Consumption', 'Investment', 'Government Spending', 'Net Exports']]['2013':].plot(label=df.columns, kind='bar', stacked=True, figsize=(10, 10))
ax.plot(range(len(df['2013':])), df['GDP']['2013':], label='Real GDP', marker='o', linestyle='None', color='black')
plt.title('Japan: Real GDP Growth')
plt.legend(frameon=False, loc='upper left')
ax.set_frame_on(False)
ax.set_ylabel('Annual Percent Change')
# dfmt = mdates.DateFormatter("%Y-%m") # proper formatting Year-month
# ax.xaxis.set_major_formatter(dfmt)
plt.savefig('G:\\AutomaticDailyBackup\\Python\\MacroEconomics\\Japan\\Data\\RealGDP.png')
plt.show()```
Don't use DateFormatter as it is causing trouble, rather change format of the dataframe index using df.index = pd.to_datetime(df.index, format = '%m/%d/%Y').strftime('%Y-%m')
Here is what I did with your gdp.csv file
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import matplotlib.dates as mdates
from matplotlib.dates import DateFormatter
import matplotlib.dates
df=pd.read_csv("D:\python\gdp.csv").set_index('date')
df.index = pd.to_datetime(df.index, format = '%m/%d/%Y').strftime('%Y-%m')
# Plot
fig, ax = plt.subplots()
df['Net Exports'] = df['Net Exports'].astype(float)
ax = df[['Consumption', 'Investment', 'Government Spending', 'Net Exports']]['2013':].plot(label=df.columns, kind='bar', stacked=True, figsize=(10, 10))
ax.plot(range(len(df['2013':])), df['GDP']['2013':], label='Real GDP', marker='o', linestyle='None', color='black')
plt.legend(frameon=False, loc='upper left')
ax.set_frame_on(False)
plt.savefig(r'D:\python\RealGDP.png')
plt.show()

Flight Path by shapely LineString is not correct

I want to connect airplanes in origin (lat_1 lon_1) to dest(lat_2 lon_2). I use these data.
callsign
latitude_1
longitude_1
latitude_2
longitude_2
0
HBAL102
-4.82114
-76.3194
-4.5249
-79.0103
1
AUA1028
-33.9635
151.181
48.1174
16.55
2
ABW120
41.9659
-87.8832
55.9835
37.4958
3
CSN461
33.9363
-118.414
50.0357
8.5723
4
ETH3730
25.3864
55.4221
50.6342
5.43903
But unfortunately, I would get an incorrect result when creating LineString with shapely. I used everything like rotate and affine but it didn't correct.
Code:
cols = pd.read_csv("/content/dirct_lines.csv",sep=";")
line = cols[["callsign","latitude_1","longitude_1","latitude_2","longitude_2"]].dropna()
line['geometry'] = line.apply(lambda x: [(x['latitude_1'],
x['longitude_1']),
(x['latitude_2'],
x['longitude_2'])], axis = 1)
geoline = gpd.GeoDataFrame(line,geometry="geometry",
crs="EPSG:4326")
import matplotlib.pyplot as plt
world = gpd.read_file(gpd.datasets.get_path('naturalearth_lowres'))
ax = world.plot(figsize=(14,9),
color='white', edgecolor='black')
geoline.plot(figsize=(14,9),ax=ax,facecolor = 'lightgrey', linewidth = 1.75,
edgecolor = 'red',
alpha = 2)
plt.show()
Shapely Output:
something that was interesting for me was that when I use Matplotlib to create lines everything is correct.
Code:
import cartopy.crs as ccrs
import matplotlib.pyplot as plt
fig = plt.figure(figsize=(12, 8))
ax = fig.add_subplot(projection=ccrs.PlateCarree())
ax.stock_img()
org_lon, org_lat = cols["longitude_1"], cols["latitude_1"]
dst_lon, dst_lat = cols["longitude_2"], cols["latitude_2"]
plt.plot([org_lon, dst_lon], [org_lat, dst_lat],
color='black', linewidth=0.5, marker='_',
transform=ccrs.PlateCarree()
)
plt.savefig(f"fight_path.png",dpi=60,facecolor = None, bbox_inches = 'tight', pad_inches = None)
plt.show()
Matplotlib Output:
What is the problem?
why isn't correct by shapely?
it's just the way you are creating the geometry. Below works correctly.
import io
import geopandas as gpd
import pandas as pd
import shapely.geometry
df = pd.read_csv(
io.StringIO(
"""callsign,latitude_1,longitude_1,latitude_2,longitude_2
HBAL102,-4.82114,-76.3194,-4.5249,-79.0103
AUA1028,-33.9635,151.181,48.1174,16.55
ABW120,41.9659,-87.8832,55.9835,37.4958
CSN461,33.9363,-118.414,50.0357,8.5723
ETH3730,25.3864,55.4221,50.6342,5.43903
"""
)
)
geoline = gpd.GeoDataFrame(
geometry=[
shapely.geometry.LineString(points)
for points in zip(
gpd.points_from_xy(df["longitude_1"], df["latitude_1"]),
gpd.points_from_xy(df["longitude_2"], df["latitude_2"]),
)
],
data=df,
)
import matplotlib.pyplot as plt
world = gpd.read_file(gpd.datasets.get_path("naturalearth_lowres"))
ax = world.plot(figsize=(14, 9), color="white", edgecolor="black")
geoline.plot(
figsize=(14, 9),
ax=ax,
facecolor="lightgrey",
linewidth=1.75,
edgecolor="red",
)
plt.show()

Percentile Distribution Graph

Does anyone have an idea how to change X axis scale and ticks to display a percentile distribution like the graph below? This image is from MATLAB, but I want to use Python (via Matplotlib or Seaborn) to generate.
From the pointer by #paulh, I'm a lot closer now. This code
import matplotlib
matplotlib.use('Agg')
import numpy as np
import matplotlib.pyplot as plt
import probscale
import seaborn as sns
clear_bkgd = {'axes.facecolor':'none', 'figure.facecolor':'none'}
sns.set(style='ticks', context='notebook', palette="muted", rc=clear_bkgd)
fig, ax = plt.subplots(figsize=(8, 4))
x = [30, 60, 80, 90, 95, 97, 98, 98.5, 98.9, 99.1, 99.2, 99.3, 99.4]
y = np.arange(0, 12.1, 1)
ax.set_xlim(40, 99.5)
ax.set_xscale('prob')
ax.plot(x, y)
sns.despine(fig=fig)
Generates the following plot (notice the re-distributed X-Axis)
Which I find much more useful than a the standard scale:
I contacted the author of the original graph and they gave me some pointers. It is actually a log scale graph, with x axis reversed and values of [100-val], with manual labeling of the x axis ticks. The code below recreates the original image with the same sample data as the other graphs here.
import matplotlib
matplotlib.use('Agg')
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
clear_bkgd = {'axes.facecolor':'none', 'figure.facecolor':'none'}
sns.set(style='ticks', context='notebook', palette="muted", rc=clear_bkgd)
x = [30, 60, 80, 90, 95, 97, 98, 98.5, 98.9, 99.1, 99.2, 99.3, 99.4]
y = np.arange(0, 12.1, 1)
# Number of intervals to display.
# Later calculations add 2 to this number to pad it to align with the reversed axis
num_intervals = 3
x_values = 1.0 - 1.0/10**np.arange(0,num_intervals+2)
# Start with hard-coded lengths for 0,90,99
# Rest of array generated to display correct number of decimal places as precision increases
lengths = [1,2,2] + [int(v)+1 for v in list(np.arange(3,num_intervals+2))]
# Build the label string by trimming on the calculated lengths and appending %
labels = [str(100*v)[0:l] + "%" for v,l in zip(x_values, lengths)]
fig, ax = plt.subplots(figsize=(8, 4))
ax.set_xscale('log')
plt.gca().invert_xaxis()
# Labels have to be reversed because axis is reversed
ax.xaxis.set_ticklabels( labels[::-1] )
ax.plot([100.0 - v for v in x], y)
ax.grid(True, linewidth=0.5, zorder=5)
ax.grid(True, which='minor', linewidth=0.5, linestyle=':')
sns.despine(fig=fig)
plt.savefig("test.png", dpi=300, format='png')
This is the resulting graph:
These type of graphs are popular in the low-latency community for plotting latency distributions. When dealing with latencies most of the interesting information tends to be in the higher percentiles, so a logarithmic view tends to work better. I've first seen these graphs used in https://github.com/giltene/jHiccup and https://github.com/HdrHistogram/.
The cited graph was generated by the following code
n = ceil(log10(length(values)));
p = 1 - 1./10.^(0:0.01:n);
percentiles = prctile(values, p * 100);
semilogx(1./(1-p), percentiles);
The x-axis was labelled with the code below
labels = cell(n+1, 1);
for i = 1:n+1
labels{i} = getPercentileLabel(i-1);
end
set(gca, 'XTick', 10.^(0:n));
set(gca, 'XTickLabel', labels);
% {'0%' '90%' '99%' '99.9%' '99.99%' '99.999%' '99.999%' '99.9999%'}
function label = getPercentileLabel(i)
switch(i)
case 0
label = '0%';
case 1
label = '90%';
case 2
label = '99%';
otherwise
label = '99.';
for k = 1:i-2
label = [label '9'];
end
label = [label '%'];
end
end
The following Python code uses Pandas to read a csv file that contains a list of recorded latency values (in milliseconds), then it records those latency values (as microseconds) in an HdrHistogram, and saves the HdrHistogram to an hgrm file, that will then be used by Seaborn to plot the latency distribution graph.
import pandas as pd
from hdrh.histogram import HdrHistogram
from hdrh.dump import dump
import numpy as np
from matplotlib import pyplot as plt
import seaborn as sns
import sys
import argparse
# Parse the command line arguments.
parser = argparse.ArgumentParser()
parser.add_argument('csv_file')
parser.add_argument('hgrm_file')
parser.add_argument('png_file')
args = parser.parse_args()
csv_file = args.csv_file
hgrm_file = args.hgrm_file
png_file = args.png_file
# Read the csv file into a Pandas data frame and generate an hgrm file.
csv_df = pd.read_csv(csv_file, index_col=False)
USECS_PER_SEC=1000000
MIN_LATENCY_USECS = 1
MAX_LATENCY_USECS = 24 * 60 * 60 * USECS_PER_SEC # 24 hours
# MAX_LATENCY_USECS = int(csv_df['response-time'].max()) * USECS_PER_SEC # 1 hour
LATENCY_SIGNIFICANT_DIGITS = 5
histogram = HdrHistogram(MIN_LATENCY_USECS, MAX_LATENCY_USECS, LATENCY_SIGNIFICANT_DIGITS)
for latency_sec in csv_df['response-time'].tolist():
histogram.record_value(latency_sec*USECS_PER_SEC)
# histogram.record_corrected_value(latency_sec*USECS_PER_SEC, 10)
TICKS_PER_HALF_DISTANCE=5
histogram.output_percentile_distribution(open(hgrm_file, 'wb'), USECS_PER_SEC, TICKS_PER_HALF_DISTANCE)
# Read the generated hgrm file into a Pandas data frame.
hgrm_df = pd.read_csv(hgrm_file, comment='#', skip_blank_lines=True, sep=r"\s+", engine='python', header=0, names=['Latency', 'Percentile'], usecols=[0, 3])
# Plot the latency distribution using Seaborn and save it as a png file.
sns.set_theme()
sns.set_style("dark")
sns.set_context("paper")
sns.set_color_codes("pastel")
fig, ax = plt.subplots(1,1,figsize=(20,15))
fig.suptitle('Latency Results')
sns.lineplot(x='Percentile', y='Latency', data=hgrm_df, ax=ax)
ax.set_title('Latency Distribution')
ax.set_xlabel('Percentile (%)')
ax.set_ylabel('Latency (seconds)')
ax.set_xscale('log')
ax.set_xticks([1, 10, 100, 1000, 10000, 100000, 1000000, 10000000])
ax.set_xticklabels(['0', '90', '99', '99.9', '99.99', '99.999', '99.9999', '99.99999'])
fig.tight_layout()
fig.savefig(png_file)

I want to add a "spheres" to my data cluster

I want to add a kind of "spheres" to my data cluster.
My data cluster is this, which does not have ""spheres".
And this is my code
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import style
style.use('ggplot')
import pandas as pd
from sklearn.cluster import KMeans
MY_FILE='total_watt.csv'
date = []
consumption = []
df = pd.read_csv(MY_FILE, parse_dates=[0], index_col=[0])
df = df.resample('1D', how='sum')
df = df.dropna()
date = df.index.tolist()
date = [x.strftime('%Y-%m-%d') for x in date]
from sklearn.preprocessing import LabelEncoder
encoder = LabelEncoder()
date_numeric = encoder.fit_transform(date)
consumption = df[df.columns[0]].values
X = np.array([date_numeric, consumption]).T
kmeans = KMeans(n_clusters=3)
kmeans.fit(X)
centroids = kmeans.cluster_centers_
labels = kmeans.labels_
print(centroids)
print(labels)
fig, ax = plt.subplots(figsize=(10,8))
rect = fig.patch
rect.set_facecolor('#2D2B2B')
colors = ["b.","r.","g."]
for i in range(len(X)):
print("coordinate:",encoder.inverse_transform(X[i,0].astype(int)), X[i,1], "label:", labels[i])
ax.plot(X[i][0], X[i][1], colors[labels[i]], markersize = 10)
ax.scatter(centroids[:, 0],centroids[:, 1], marker = "x", s=150, linewidths = 5, zorder = 10)
a = np.arange(0, len(X), 5)
ax.set_xticks(a)
ax.set_xticklabels(encoder.inverse_transform(a.astype(int)))
ax.tick_params(axis='x', colors='lightseagreen')
ax.tick_params(axis='y', colors='lightseagreen')
plt.scatter(centroids[:, 0],centroids[:, 1], marker = "x", s=100, c="black", linewidths = 5, zorder = 10)
ax.set_title('Energy consumptions Clusters (high/medium/low)', color='gold')
ax.set_xlabel('time', color='gold')
ax.set_ylabel('date(year 2011)', color='gold')
plt.show()
"Spheres" is area which surroundings plot(cluster), as this picture.
I tried to google it.
But when I type "matplotlib spheres", I could not get any result..
The sample graph in your post looks like resulting from Generalized Gaussian Mixture where each sphere is a Gaussian 2-d density.
I'll write up a sample code shortly to demonstrate how to use GMM on your dataset and do this kind of plotting.
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import style
style.use('ggplot')
import pandas as pd
# code changes here
# ===========================================
from sklearn.mixture import GMM
# ===========================================
from sklearn.preprocessing import LabelEncoder
# replace it with you file path
MY_FILE='/home/Jian/Downloads/total_watt.csv'
df = pd.read_csv(MY_FILE, parse_dates=[0], index_col=[0])
df = df.resample('1D', how='sum')
df = df.dropna()
date = df.index.tolist()
date = [x.strftime('%Y-%m-%d') for x in date]
encoder = LabelEncoder()
date_numeric = encoder.fit_transform(date)
consumption = df[df.columns[0]].values
X = np.array([date_numeric, consumption]).T
# code changes here
# ===========================================
gmm = GMM(n_components=3, random_state=0)
gmm.fit(X)
y_pred = gmm.predict(X)
# the center is given by mean
gmm.means_
# ===========================================
import matplotlib as mpl
fig, ax = plt.subplots(figsize=(10,8))
for i, color in enumerate('rgb'):
# sphere background
width, height = 2 * 1.96 * np.sqrt(np.diagonal(gmm._get_covars()[i]))
ell = mpl.patches.Ellipse(gmm.means_[i], width, height, color=color)
ell.set_alpha(0.1)
ax.add_artist(ell)
# data points
X_data = X[y_pred == i]
ax.scatter(X_data[:,0], X_data[:,1], color=color)
# center
ax.scatter(gmm.means_[i][0], gmm.means_[i][1], marker='x', s=100, c=color)
ax.set_title('Energy consumptions Clusters (high/medium/low)', color='gold')
ax.set_xlabel('time', color='gold')
ax.set_ylabel('date(year 2011)', color='gold')
a = np.arange(0, len(X), 5)
ax.set_xticks(a)
ax.set_xticklabels(encoder.inverse_transform(a.astype(int)))
ax.tick_params(axis='x', colors='lightseagreen')
ax.tick_params(axis='y', colors='lightseagreen')