How to limit scrolling outside bounds in pyqtgraph? - pyqt5

I've set up a chart for candlestick display using PyQtGraph.
I've made a simplified example below.
I'm trying to figure out how to limit the viewable/scrollable range on the chart for the y1 and y2 axis.
I want to limit them to equal the ymin and ymax settings in the boundingRect() function.
If I run the chart it starts off with the bounds set correctly but it allows you to manually scroll around the chart outside of the bounds that are set within the boundingRect()
I want to prevent the ability to scroll beyond what the boundingRect() allows.
I want to be able to scroll along the X axis without issue but I want the Y axis to dynamically limit the bounds to the Y axis of the candlesticks that are currently viewable.
For starters I can't figure out how to force limits on scrollable bounds in a way that is compatible with what I have written below.
QPainterPath or QRectF doesn't seem to have a function to limit the scrollable view that I can find. Or at the very least I can't figure out the proper syntax.
Then I need to figure out how to return the axis range of the currently viewable candles in order to dynamically set the scrollable/viewable limits. Haven't gotten that far yet.
Any help is appreciated.
import pyqtgraph as pg
from pyqtgraph import QtCore, QtGui, QtWidgets
import numpy as np
data = np.array([ ## fields are (time, open, close, min, max).
(1., 10, 13),
(2., 13, 17),
(3., 17, 14),
(4., 14, 15),
(5., 15, 9),
(6., 9, 15),
(7., 15, 5),
(8., 5, 7),
(9., 7, 3),
(10., 3, 10),
(11., 10, 15),
(12., 15, 25),
(13., 25, 20),
(14., 20, 17),
(15., 17, 30),
(16., 30, 32),
(17., 32, 35),
(18., 35, 28),
(19., 28, 27),
(20., 27, 25),
(21., 25, 29),
(22., 29, 35),
(23., 35, 40),
(24., 40, 45),
])
class CandlestickItem(pg.GraphicsObject):
global data
_boundingRect = QtCore.QRectF()
# ...
def __init__(self):
pg.GraphicsObject.__init__(self)
self.flagHasData = False
def set_data(self, data):
self.data = data
self.flagHasData = True
self.generatePicture()
self.informViewBoundsChanged()
def generatePicture(self):
self.picture = QtGui.QPicture()
path = QtGui.QPainterPath()
p = QtGui.QPainter(self.picture)
p.setPen(pg.mkPen('w'))
w = (self.data[1][0] - self.data[0][0]) / 3.
for (t, open, close) in self.data:
rect = QtCore.QRectF(t-w, open, w*2, close-open)
path.addRect(rect)
if open > close:
p.setBrush(pg.mkBrush('r'))
else:
p.setBrush(pg.mkBrush('g'))
p.drawRect(rect)
p.end()
self._boundingRect = path.boundingRect()
def paint(self, p, *args):
if self.flagHasData:
p.drawPicture(0, 0, self.picture)
def boundingRect(self):
# data =data
# xmin = np.nanmin(data[:,0])
xmax = np.nanmax(data[:,0])
xmin = xmax - 5
ymin = np.nanmin(data[:,2])
ymax = np.nanmax(data[:,2])
return QtCore.QRectF(xmin, ymin, xmax-xmin, ymax-ymin)
item = CandlestickItem()
plt = pg.plot()
plt.addItem(item)
plt.setWindowTitle('pyqtgraph example: customGraphicsItem')
item.set_data(data)
if __name__ == '__main__':
import sys
if (sys.flags.interactive != 1) or not hasattr(QtCore, 'PYQT_VERSION'):
QtWidgets.QApplication.instance().exec_()

If I understood you correctly, you want to prevent user to zoom or scroll outside boundictRect of your candlesitck plot. That can be achieved with setting limit to viewbox.
EDIT: y view is now limited to min and max value within current view specified by x_start and x_end.
Here is modified code, that does that:
import math
import numpy as np
import pyqtgraph as pg
from pyqtgraph import QtCore, QtGui, QtWidgets
data = np.array([ ## fields are (time, open, close, min, max).
(1., 10, 13),
(2., 13, 17),
(3., 17, 14),
(4., 14, 15),
(5., 15, 9),
(6., 9, 15),
(7., 15, 5),
(8., 5, 7),
(9., 7, 3),
(10., 3, 10),
(11., 10, 15),
(12., 15, 25),
(13., 25, 20),
(14., 20, 17),
(15., 17, 30),
(16., 30, 32),
(17., 32, 35),
(18., 35, 28),
(19., 28, 27),
(20., 27, 25),
(21., 25, 29),
(22., 29, 35),
(23., 35, 40),
(24., 40, 45),
])
class CandlestickItem(pg.GraphicsObject):
global data
_boundingRect = QtCore.QRectF()
# ...
def __init__(self):
pg.GraphicsObject.__init__(self)
self.picture = QtGui.QPicture()
self.flagHasData = False
def set_data(self, data):
self.data = data
self.flagHasData = True
self.generatePicture()
self.informViewBoundsChanged()
def generatePicture(self):
self.picture = QtGui.QPicture()
path = QtGui.QPainterPath()
p = QtGui.QPainter(self.picture)
p.setPen(pg.mkPen('w'))
w = (self.data[1][0] - self.data[0][0]) / 3.
for (t, open, close) in self.data:
rect = QtCore.QRectF(t - w, open, w * 2, close - open)
path.addRect(rect)
if open > close:
p.setBrush(pg.mkBrush('r'))
else:
p.setBrush(pg.mkBrush('g'))
p.drawRect(rect)
p.end()
self._boundingRect = path.boundingRect()
def paint(self, p, *args):
if self.flagHasData:
p.drawPicture(0, 0, self.picture)
def boundingRect(self):
return QtCore.QRectF(self.picture.boundingRect())
def viewTransformChanged(self):
super(CandlestickItem, self).viewTransformChanged()
br = self.boundingRect()
# Get coords of view mapped to data
mapped_view = self.mapRectToView(self.viewRect())
# Get start and end of x slice
x_slice_start = int(mapped_view.x()) - 1
x_slice_end = x_slice_start + (math.ceil(mapped_view.width()) + 1)
if x_slice_start < 0:
x_slice_start = 0
if x_slice_end > data.shape[0]:
x_slice_end = data.shape[0]
# Get data in x interval
y_slice = data[x_slice_start:x_slice_end]
try:
ymin = np.nanmin(y_slice[:, 2])
ymax = np.nanmax(y_slice[:, 2])
if ymin != ymax:
self.getViewBox().setLimits(xMin=br.x(), xMax=br.width(), yMin=ymin, yMax=ymax)
except ValueError:
pass
return br
item = CandlestickItem()
plt = pg.plot()
plt.addItem(item)
plt.setWindowTitle('pyqtgraph example: customGraphicsItem')
item.set_data(data)
if __name__ == '__main__':
import sys
if (sys.flags.interactive != 1) or not hasattr(QtCore, 'PYQT_VERSION'):
QtWidgets.QApplication.instance().exec_()

Related

How to plot ROC-AUC figures without using scikit-learn

I have the following list containing multiple tuples of (TP, FP, FN):
[(12, 0, 0), (5, 2, 2), (10, 0, 1), (7, 1, 1), (13, 0, 0), (7, 2, 2), (11, 0, 2)]
each tuple represents the scores for a single image. This means I have 7 images and I have calculated the scores for a object detection task. Now I calculate precision and recall for each image(tuple) using the following function:
def calculate_recall_precision(data):
precisions_bundle = []
recalls_bundle = []
for tp, fp, fn in data:
precision = tp / (tp + fp)
recall = tp / (tp + fn)
precisions_bundle.append(precision)
recalls_bundle.append(recall)
return (precisions_bundle, recalls_bundle)
This function returns a tuple which contains two lists. The first one is precision values for each image and the second one is recall values for each image.
Now my main goal is to plot ROC and AUC curves using only matplotlib. Please note that I do not want to use scikit-learn library.
You can simply use matplotlib.pyplot.plot method. For example:
import numpy as np
import matplotlib.pyplot as plt
def plot_PR(precision_bundle, recall_bundle, save_path:Path=None):
line = plt.plot(recall_bundle, precision_bundle, linewidth=2, markersize=6)
line = plt.title('Precision/Recall curve', size =18, weight='bold')
line = plt.ylabel('Precision', size=15)
line = plt.xlabel('Recall', size=15 )
random_classifier_line_x = np.linspace(0, 1, 10)
random_classifier_line_y = np.linspace(1, 0, 10)
_ = plt.plot(random_classifier_line_x, random_classifier_line_y, color='firebrick', linestyle='--')
if save_path:
outname = save_path / 'PR_curve_thresh_opt.png'
_ = plt.savefig(outname, dpi = 100, bbox_inches='tight' )
return line
and then just use it as plot_PR(precision_bundle, recall_bundle).
Note: here I also added a dashed line for a random classifier and the possibility to save the figure in case you want to

Setting alpha in matplotlib.axes.Axes.table?

I'm trying to understand how I can set the alpha level in a matplotlib table. I tried setting it with a global rcParams, but not quite sure how to do that? (I want to change the transparency in the header color). In general I'm not sure this can be done globally, if not, how do i pass the parameter to table? Thx in advance.
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from cycler import cycler
import six
# Set universal font and transparency
plt.rcParams["font.family"] = "DejaVu Sans"
plt.rcParams['axes.prop_cycle'] = cycler(alpha=[0.5])
raw_data = dict(Simulation=[42, 39, 86, 15, 23, 57],
SP500=[52, 41, 79, 80, 34, 47],
NASDAQ=[62, 37, 84, 51, 67, 32],
Benchmark=[72, 43, 36, 26, 53, 88])
df = pd.DataFrame(raw_data, index=pd.Index(
['Sharpe Ratio', 'Sortino Ratio', 'Calmars Ratio', 'VaR', 'CVaR', 'Max DD'], name='Metric'),
columns=pd.Index(['Simulation', 'SP500', 'NASDAQ', 'Benchmark'], name='Series'))
def create_table(data, col_width=None, row_height=None, font_size=None,
header_color='#000080', row_colors=None, edge_color='w',
header_columns=0, ax=None, bbox=None):
if row_colors is None:
row_colors = ['#D8D8D8', 'w']
if bbox is None:
bbox = [0, 0, 1, 1]
if ax is None:
size = (np.array(data.shape[::-1]) + np.array([0, 1])) * np.array([col_width, row_height])
fig, ax = plt.subplots(figsize=size)
ax.axis('off')
ax.axis([0, 1, data.shape[0], -1])
data_table = ax.table(cellText=data.values, colLabels=data.columns, rowLabels=data.index,
bbox=bbox, cellLoc='center', rowLoc='left', colLoc='center',
colWidths=([col_width] * len(data.columns)))
cell_map = data_table.get_celld()
for i in range(0, len(data.columns)):
cell_map[(0, i)].set_height(row_height * 0.20)
data_table.auto_set_font_size(False)
data_table.set_fontsize(font_size)
for k, cell in six.iteritems(data_table._cells):
cell.set_edgecolor(edge_color)
if k[0] == 0 or k[1] < header_columns:
cell.set_text_props(weight='bold', color='w')
cell.set_facecolor(header_color)
else:
cell.set_facecolor(row_colors[k[0] % len(row_colors)])
return ax
ax = create_table(df, col_width=1.5, row_height=0.5, font_size=8)
ax.set_title("Risk Measures", fontweight='bold')
ax.axis('off')
plt.tight_layout()
plt.savefig('risk_parameter_table[1].pdf')
plt.show()
You can set_alpha() manually on the table's _cells.
If you only want to change the headers, check if row == 0 or col == -1:
def create_table(...):
...
for (row, col), cell in data_table._cells.items():
if (row == 0) or (col == -1):
cell.set_alpha(0.5)
return ax

PyQt5 Combining QVBoxLayout and QStackedLayout

Can QVBoxLayout contain a QStackedLayout in a widget? I am trying to create a custom widget that looks like the following:
import sys
from PyQt5.QtWidgets import *
from pyqtgraph import PlotWidget, plot
import pyqtgraph as pg
class MyCustomWidget(QWidget):
def __init__(self):
super().__init__()
self.my_combo_box = QComboBox(self)
self.vlayout = QVBoxLayout(self)
self.slayout = QStackedLayout(self)
self.vlayout.addWidget(self.my_combo_box)
self.graphWidget = {}
self.widget = QWidget()
def add_new_graphics(self, name, x, y):
self.my_combo_box.addItem(name)
self.graphWidget[name] = pg.PlotWidget()
self.graphWidget[name].plot(x, y)
self.slayout.addWidget(self.graphWidget[name])
self.vlayout.addWidget(self.slayout)
self.widget.setLayout(self.vlayout)
app = QApplication([])
a = MyCustomWidget()
x1 = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
y1 = [30, 32, 34, 32, 33, 31, 29, 32, 35, 45]
a.add_new_graphics('data1', x1, y1)
x2 = [11, 12, 13, 14, 15, 16, 17, 18, 19, 20]
y2 = [0, 2, 4, 2, 3, 1, 9, 2, 5, 5]
a.add_new_graphics('data2', x2, y2)
a.show()
sys.exit(app.exec_())
I get the following errors:
TypeError: addWidget(self, QWidget, stretch: int = 0, alignment: Union[Qt.Alignment, Qt.AlignmentFlag] = Qt.Alignment()): argument 1 has unexpected type 'QStackedLayout'
sys:1: RuntimeWarning: Visible window deleted. To prevent this, store a reference to the window object.
You can add any layout to a QVBoxLayout including a QStackedLayout, but you need to use layout.addLayout for this, not layout.addWidget. There are a few other things off in the definition of MyCustomWidget. Here is a version that should work:
class MyCustomWidget(QWidget):
def __init__(self):
super().__init__()
self.my_combo_box = QComboBox(self)
self.vlayout = QVBoxLayout(self)
self.slayout = QStackedLayout()
self.vlayout.addWidget(self.my_combo_box)
# add stacked layout to self.vlayout (only need to be done once).
self.vlayout.addLayout(self.slayout)
self.graphWidget = {}
# self.widget isn't doing anything in your code so I commented out this line
# self.widget = QWidget()
# Add signal-slot connection to show correct graph when current index in combo box is changed.
self.my_combo_box.currentIndexChanged.connect(self.slayout.setCurrentIndex)
def add_new_graphics(self, name, x, y):
self.my_combo_box.addItem(name)
self.graphWidget[name] = pg.PlotWidget()
self.graphWidget[name].plot(x, y)
self.slayout.addWidget(self.graphWidget[name])

How to set more margins

I have a pyplot code.
Since I want to group multiple bars, I am trying to write text in the graph using plt.annotate.
However, as you can see in the picture, the word 'Something' in left bottom gets cropped. Does anyone know How I can fix this?
Here is my code
#!/usr/bin/python
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import operator as o
import numpy as np
n_groups = 5
means_men = (20, 35, 30, 35, 27)
std_men = (2, 3, 4, 1, 2)
means_women = (25, 32, 34, 20, 25)
std_women = (3, 5, 2, 3, 3)
fig, ax = plt.subplots()
index = np.arange(n_groups)
bar_width = 0.35
opacity = 0.4
error_config = {'ecolor': '0.3'}
rects1 = plt.bar(index, means_men, bar_width, alpha=opacity, color='b', yerr=std_men, error_kw=error_config, label='Men')
rects2 = plt.bar(index + bar_width, means_women, bar_width,
alpha=opacity,
color='r',
yerr=std_women,
error_kw=error_config,
label='Women')
#plt.xlabel('Group')
plt.ylabel('Scores')
plt.title('Scores by group and gender')
plt.annotate('Something', (0,0), (50,-40), xycoords = 'axes fraction', textcoords='offset points', va='top');
plt.annotate('Something', (0,0), (200,-20), xycoords = 'axes fraction', textcoords='offset points', va='top');
plt.xticks(index + bar_width, ('A', 'B', 'C', 'D', 'E'))
plt.legend()
plt.savefig('barchart_3.png')
For some reason, matplotlib sometimes clips too aggressively. If you add bbox_inches='tight' to save fig this should include the figure correctly,
plt.savefig('barchart_3.png', bbox_inches='tight')
More generally, you can adjust your main figure with something like,
plt.subplots_adjust(bottom=0.1)

Matplotlib and numpy histogram2d axis issue

I'm struggling to get the axis right:
I've got the x and y values, and want to plot them in a 2d histogram (to examine correlation). Why do I get a histogram with limits from 0-9 on each axis? How do I get it to show the actual value ranges?
This is a minimal example and I would expect to see the red "star" at (3, 3):
import numpy as np
import matplotlib.pyplot as plt
x = (1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 3)
y = (1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 3)
xedges = range(5)
yedges = range(5)
H, xedges, yedges = np.histogram2d(y, x)
im = plt.imshow(H, origin='low')
plt.show()
I think the problem is twofold:
Firstly you should have 5 bins in your histogram (it's set to 10 as default):
H, xedges, yedges = np.histogram2d(y, x,bins=5)
Secondly, to set the axis values, you can use the extent parameter, as per the histogram2d man pages:
im = plt.imshow(H, interpolation=None, origin='low',
extent=[xedges[0], xedges[-1], yedges[0], yedges[-1]])
If I understand correctly, you just need to set interpolation='none'
import numpy as np
import matplotlib.pyplot as plt
x = (1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 3)
y = (1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 3)
xedges = range(5)
yedges = range(5)
H, xedges, yedges = np.histogram2d(y, x)
im = plt.imshow(H, origin='low', interpolation='none')
Does that look right?