live plotting using pyserial and matplotlib - matplotlib

I can capture data from serial device via pyserial, at this time I can only export data to text file, the text file has format like below, it's have 3 columns
>21 21 0
>
>41 41 0.5
>
>73 73 1
>
....
>2053 2053 5
>
>2084 2084 5.5
>
>2125 2125 6
Now I want to use matplotlib to generate live graph has 2 figure (x,y) x,y are second and third column, first comlumn,'>', and lines don't have data can be remove
thank folks!
============================
Update : today, after follow these guides from
http://www.blendedtechnologies.com/realtime-plot-of-arduino-serial-data-using-python/231
http://eli.thegreenplace.net/2008/08/01/matplotlib-with-wxpython-guis
pyserial - How to read the last line sent from a serial device
now I can live plot with threading but eliben said that this Guis only plot single value each time, that lead to me the very big limitation, beacause my purpose is plotting 2 or 3 column, here is the code was modified from blendedtechnologies
Here is serial handler :
from threading import Thread
import time
import serial
last_received = ''
def receiving(ser):
global last_received
buffer = ''
while True:
buffer = buffer + ser.read(ser.inWaiting())
if '\n' in buffer:
lines = buffer.split('\n') # Guaranteed to have at least 2 entries
last_received = lines[-2]
#If the Arduino sends lots of empty lines, you'll lose the
#last filled line, so you could make the above statement conditional
#like so: if lines[-2]: last_received = lines[-2]
buffer = lines[-1]
class SerialData(object):
def __init__(self, init=50):
try:
self.ser = ser = serial.Serial(
port='/dev/ttyS0',
baudrate=9600,
bytesize=serial.EIGHTBITS,
parity=serial.PARITY_NONE,
stopbits=serial.STOPBITS_ONE,
timeout=0.1,
xonxoff=0,
rtscts=0,
interCharTimeout=None
)
except serial.serialutil.SerialException:
#no serial connection
self.ser = None
else:
Thread(target=receiving, args=(self.ser,)).start()
def next(self):
if not self.ser:
return 100 #return anything so we can test when Arduino isn't connected
#return a float value or try a few times until we get one
for i in range(40):
raw_line = last_received[1:].split(' ').pop(0)
try:
return float(raw_line.strip())
except ValueError:
print 'bogus data',raw_line
time.sleep(.005)
return 0.
def __del__(self):
if self.ser:
self.ser.close()
if __name__=='__main__':
s = SerialData()
for i in range(500):
time.sleep(.015)
print s.next()
For me I modified this segment so it can grab my 1st column data
for i in range(40):
raw_line = last_received[1:].split(' ').pop(0)
try:
return float(raw_line.strip())
except ValueError:
print 'bogus data',raw_line
time.sleep(.005)
return 0.
and generate graph base on these function on the GUI file
from Arduino_Monitor import SerialData as DataGen
def __init__(self):
wx.Frame.__init__(self, None, -1, self.title)
self.datagen = DataGen()
self.data = [self.datagen.next()]
................................................
def init_plot(self):
self.dpi = 100
self.fig = Figure((3.0, 3.0), dpi=self.dpi)
self.axes = self.fig.add_subplot(111)
self.axes.set_axis_bgcolor('black')
self.axes.set_title('Arduino Serial Data', size=12)
pylab.setp(self.axes.get_xticklabels(), fontsize=8)
pylab.setp(self.axes.get_yticklabels(), fontsize=8)
# plot the data as a line series, and save the reference
# to the plotted line series
#
self.plot_data = self.axes.plot(
self.data,
linewidth=1,
color=(1, 1, 0),
)[0]
So my next question is how to realtime grab at least 2 column and passing 2 columns'data to the GUIs that it can generate graph with 2 axis.
self.plot_data.set_xdata(np.arange(len(self.data))) #my 3rd column data
self.plot_data.set_ydata(np.array(self.data)) #my 2nd column data

Well, this reads your string and converts the numbers to floats. I assume you'll be able to adapt this as needed.
import numpy as np
import pylab as plt
str = '''>21 21 0
>
>41 41 0.5
>
>73 73 1
>
>2053 2053 5
>
>2084 2084 5.5
>
>2125 2125 6'''
nums = np.array([[float(n) for n in sub[1:].split(' ') if len(n)>0] for sub in str.splitlines() if len(sub)>1])
fig = plt.figure(0)
ax = plt.subplot(2,1,1)
ax.plot(nums[:,0], nums[:,1], 'k.')
ax = plt.subplot(2,1,2)
ax.plot(nums[:,0], nums[:,2], 'r+')
plt.show()

Here you have an Eli Bendersky's example of how plotting data arriving from a serial port

some time back I had the same problem. I wasted a lot of writing same ting over and over again. so I wrote a python package for it.
https://github.com/girish946/plot-cat
you just have to write the logic for obtaining the data from serial port.
the example is here: https://github.com/girish946/plot-cat/blob/master/examples/test-ser.py

Related

passing panda dataframe data to functions and its not outputting the results

In my code, I am trying to extract data from csv file to use in the function, but it doesnt output anything, and gives no error. My code works because I tried it with just numpy array as inputs. not sure why it doesnt work with panda.
import numpy as np
import pandas as pd
import os
# change the current directory to the directory where the running script file is
os.chdir(os.path.dirname(os.path.abspath(__file__)))
# finding best fit line for y=mx+b by iteration
def gradient_descent(x,y):
m_iter = b_iter = 1 #starting point
iteration = 10000
n = len(x)
learning_rate = 0.05
last_mse = 10000
#take baby steps to reach global minima
for i in range(iteration):
y_predicted = m_iter*x + b_iter
#mse = 1/n*sum([value**2 for value in (y-y_predicted)]) # cost function to minimize
mse = 1/n*sum((y-y_predicted)**2) # cost function to minimize
if (last_mse - mse)/mse < 0.001:
break
# recall MSE formula is 1/n*sum((yi-y_predicted)^2), where y_predicted = m*x+b
# using partial deriv of MSE formula, d/dm and d/db
dm = -(2/n)*sum(x*(y-y_predicted))
db = -(2/n)*sum((y-y_predicted))
# use current predicted value to get the next value for prediction
# by using learning rate
m_iter = m_iter - learning_rate*dm
b_iter = b_iter - learning_rate*db
print('m is {}, b is {}, cost is {}, iteration {}'.format(m_iter,b_iter,mse,i))
last_mse = mse
#x = np.array([1,2,3,4,5])
#y = np.array([5,7,8,10,13])
#gradient_descent(x,y)
df = pd.read_csv('Linear_Data.csv')
x = df['Area']
y = df['Price']
gradient_descent(x,y)
My code works because I tried it with just numpy array as inputs. not sure why it doesnt work with panda.
Well no, your code also works with pandas dataframes:
df = pd.DataFrame({'Area': [1,2,3,4,5], 'Price': [5,7,8,10,13]})
x = df['Area']
y = df['Price']
gradient_descent(x,y)
Above will give you the same output as with numpy arrays.
Try to check what's in Linear_Data.csv and/or add some print statements in the gradient_descent function just to check your assumptions. I would suggest to first of all add a print statement before the condition with the break statement:
print(last_mse, mse)
if (last_mse - mse)/mse < 0.001:
break

How to use hover events in mpl_connect in matplotlib

I'm working on line plotting a metric for a course module as well as each of its questions within a Jupyter Notebook using %matplotlib notebook. That part is no problem. A module has typically 20-35 questions, so it results in a lot of lines on a chart. Therefore, I am plotting the metric for each question in a low alpha and I want to change the alpha and display the question name when I hover over the line, then reverse those when no longer hovering over the line.
The thing is, I've tried every test version of interactivity from the matplotlib documentation on event handling, as well as those in this question. It seems like the mpl_connect event is never firing, whether I use click or hover.
Here's a test version with a reduced dataset using the solution to the question linked above. Am I missing something necessary to get events to fire?
def update_annot(ind):
x,y = line.get_data()
annot.xy = (x[ind["ind"][0]], y[ind["ind"][0]])
text = "{}, {}".format(" ".join(list(map(str,ind["ind"]))),
" ".join([names[n] for n in ind["ind"]]))
annot.set_text(text)
annot.get_bbox_patch().set_alpha(0.4)
def hover(event):
vis = annot.get_visible()
if event.inaxes == ax:
cont, ind = line.contains(event)
if cont:
update_annot(ind)
annot.set_visible(True)
fig.canvas.draw_idle()
else:
if vis:
annot.set_visible(False)
fig.canvas.draw_idle()
module = 'bd2bc472-ee0d-466f-8557-788cc6de3018'
module_metrics[module] = {
'q_count': 31,
'sequence_pks': [0.5274546300604932,0.5262044653349001,0.5360993905297703,0.5292329279700655,0.5268691588785047,0.5319099014547161,0.5305164319248826,0.5268235294117647,0.573648805381582,0.5647933116581514,0.5669839795681448,0.5646591970121382,0.5663157894736842,0.5646976090014064,0.5659005628517824,0.5693634879925391,0.5728268468888371,0.5668834184858337,0.5687237026647967,0.5795640965549567,0.5877684407096172,0.585690904839841,0.5766899766899767,0.5971341320178529,0.6059972105997211,0.6055516678329834,0.6209865053513262,0.6203121360354065,0.6153666510976179,0.6236909471724459,0.6387654898293196],
'q_pks': {
'0da04f02-4aad-4ac8-91a5-214862b5c0d0': [0.6686046511627907,0.6282051282051282,0.76,0.6746987951807228,0.7092198581560284,0.71875,0.6585365853658537,0.7070063694267515,0.7171052631578947,0.7346938775510204,0.7737226277372263,0.7380952380952381,0.6774193548387096,0.7142857142857143,0.7,0.6962962962962963,0.723404255319149,0.6737588652482269,0.7232704402515723,0.7142857142857143,0.7164179104477612,0.7317073170731707,0.6333333333333333,0.75,0.7217391304347827,0.7017543859649122,0.7333333333333333,0.7641509433962265,0.6869565217391305,0.75,0.794392523364486],
'10bd29aa-3a26-49e6-bc2c-50fd503d7ab5': [0.64375,0.6014492753623188,0.5968992248062015,0.5059523809523809,0.5637583892617449,0.5389221556886228,0.5576923076923077,0.51875,0.4931506849315068,0.5579710144927537,0.577922077922078,0.5467625899280576,0.5362318840579711,0.6095890410958904,0.5793103448275863,0.5159235668789809,0.6196319018404908,0.6143790849673203,0.5035971223021583,0.5897435897435898,0.5857142857142857,0.5851851851851851,0.6164383561643836,0.6054421768707483,0.5714285714285714,0.627906976744186,0.5826771653543307,0.6504065040650406,0.5864661654135338,0.6333333333333333,0.6851851851851852]
}}
suptitle_size = 24
title_size = 18
tick_size = 12
axis_label_size = 15
legend_size = 14
fig, ax = plt.subplots(figsize=(15,8))
fig.suptitle('PK by Sequence Order', fontsize=suptitle_size)
module_name = 'Test'
q_count = module_metrics[module]['q_count']
y_ticks = [0.0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0]
x_ticks = np.array([x for x in range(0,q_count)])
x_labels = x_ticks + 1
# Plot it
ax.set_title(module_name, fontsize=title_size)
ax.set_xticks(x_ticks)
ax.set_yticks(y_ticks)
ax.set_xticklabels(x_labels, fontsize=tick_size)
ax.set_yticklabels(y_ticks, fontsize=tick_size)
ax.set_xlabel('Sequence', fontsize=axis_label_size)
ax.set_xlim(-0.5,q_count-0.5)
ax.set_ylim(0,1)
ax.grid(which='major',axis='y')
# Output module PK by sequence
ax.plot(module_metrics[module]['sequence_pks'])
# Output PK by sequence for each question
for qid in module_metrics[module]['q_pks']:
ax.plot(module_metrics[module]['q_pks'][qid], alpha=0.15, label=qid)
annot = ax.annotate("", xy=(0,0), xytext=(-20,20),textcoords="offset points", bbox=dict(boxstyle="round", fc="w"), arrowprops=dict(arrowstyle="->"))
annot.set_visible(False)
mpl_id = fig.canvas.mpl_connect('motion_notify_event', hover)
Since there are dozens of modules, I created an ipywidgets dropdown to select the module, which then runs a function to output the chart. Nonetheless, whether running it hardcoded as here or from within the function, mpl_connect never seems to fire.
Here's what this one looks like when run

dask how to define a custom (time fold) function that operates in parallel and returns a dataframe with a different shape

I am trying to implement a time fold function to be 'map'ed to various partitions of a dask dataframe which in turn changes the shape of the dataframe in question (or alternatively produces a new dataframe with the altered shape). This is how far I have gotten. The result 'res' returned on compute is a list of 3 delayed objects. When I try to compute each of them in a loop (last tow lines of code) this results in a "TypeError: 'DataFrame' object is not callable" After going through the examples for map_partitions, I also tried altering the input DF (inplace) in the function with no return value which causes a similar TypeError with NoneType. What am I missing?
Also, looking at the visualization (attached) I feel like there is a need for reducing the individually computed (folded) partitions into a single DF. How do I do this?
#! /usr/bin/env python
# Start dask scheduler and workers
# dask-scheduler &
# dask-worker --nthreads 1 --nprocs 6 --memory-limit 3GB localhost:8786 --local-directory /dev/shm &
from dask.distributed import Client
from dask.delayed import delayed
import pandas as pd
import numpy as np
import dask.dataframe as dd
import math
foldbucketsecs=30
periodicitysecs=15
secsinday=24 * 60 * 60
chunksizesecs=60 # 1 minute
numts = 5
start = 1525132800 # 01/05
end = 1525132800 + (3 * 60) # 3 minute
c = Client('127.0.0.1:8786')
def fold(df, start, bucket):
return df
def reduce_folds(df):
return df
def load(epoch):
idx = []
for ts in range(0, chunksizesecs, periodicitysecs):
idx.append(epoch + ts)
d = np.random.rand(chunksizesecs/periodicitysecs, numts)
ts = []
for i in range(0, numts):
tsname = "ts_%s" % (i)
ts.append(tsname)
gts.append(tsname)
res = pd.DataFrame(index=idx, data=d, columns=ts, dtype=np.float64)
res.index = pd.to_datetime(arg=res.index, unit='s')
return res
gts = []
load(start)
cols = len(gts)
idx1 = pd.DatetimeIndex(start=start, freq=('%sS' % periodicitysecs), end=start+periodicitysecs, dtype='datetime64[s]')
meta = pd.DataFrame(index=idx1[:0], data=[], columns=gts, dtype=np.float64)
dfs = [delayed(load)(fn) for fn in range(start, end, chunksizesecs)]
from_delayed = dd.from_delayed(dfs, meta, 'sorted')
nfolds = int(math.ceil((end - start)/foldbucketsecs))
cprime = nfolds * cols
gtsnew = []
for i in range(0, cprime):
gtsnew.append("ts_%s,fold=%s" % (i%cols, i/cols))
idx2 = pd.DatetimeIndex(start=start, freq=('%sS' % periodicitysecs), end=start+foldbucketsecs, dtype='datetime64[s]')
meta = pd.DataFrame(index=idx2[:0], data=[], columns=gtsnew, dtype=np.float64)
folded_df = from_delayed.map_partitions(delayed(fold)(from_delayed, start, foldbucketsecs), meta=meta)
result = c.submit(reduce_folds, folded_df)
c.gather(result).visualize(filename='/usr/share/nginx/html/svg/df4.svg')
res = c.gather(result).compute()
for f in res:
f.compute()
Never mind! It was my fault, instead of wrapping my function in delayed I simply passed it to the map_partitions call like so and it worked.
folded_df = from_delayed.map_partitions(fold, start, foldbucketsecs, nfolds, meta=meta)

pandas histogram plot error: ValueError: num must be 1 <= num <= 0, not 1

I am drawing a histogram of a column from pandas data frame:
%matplotlib notebook
import matplotlib.pyplot as plt
import matplotlib
df.hist(column='column_A', bins = 100)
but got the following errors:
62 raise ValueError(
63 "num must be 1 <= num <= {maxn}, not {num}".format(
---> 64 maxn=rows*cols, num=num))
65 self._subplotspec = GridSpec(rows, cols)[int(num) - 1]
66 # num - 1 for converting from MATLAB to python indexing
ValueError: num must be 1 <= num <= 0, not 1
Does anyone know what this error mean? Thanks!
Problem
The problem you encounter arises when column_A does not contain numeric data. As you can see in the excerpt from pandas.plotting._core below, the numeric data is essential to make the function hist_frame (which you call by DataFrame.hist()) work correctly.
def hist_frame(data, column=None, by=None, grid=True, xlabelsize=None,
xrot=None, ylabelsize=None, yrot=None, ax=None, sharex=False,
sharey=False, figsize=None, layout=None, bins=10, **kwds):
# skipping part of the code
# ...
if column is not None:
if not isinstance(column, (list, np.ndarray, Index)):
column = [column]
data = data[column]
data = data._get_numeric_data() # there is no numeric data in the column
naxes = len(data.columns) # so the number of axes becomes 0
# naxes is passed to the subplot generating function as 0 and later determines the number of columns as 0
fig, axes = _subplots(naxes=naxes, ax=ax, squeeze=False,
sharex=sharex, sharey=sharey, figsize=figsize,
layout=layout)
# skipping the rest of the code
# ...
Solution
If your problem is to represent numeric data (but not of numeric dtype yet) with a histogram, you need to cast your data to numeric, either with pd.to_numeric or df.astype(a_selected_numeric_dtype), e.g. 'float64', and then proceed with your code.
If your problem is to represent non-numeric data in one column with a histogram, you can call the function hist_series with the following line: df['column_A'].hist(bins=100).
If your problem is to represent non-numeric data in many columns with a histogram, you may resort to a handful options:
Use matplotlib and create subplots and histograms directly
Update pandas at least to version 0.25
usually is 0
mta['penn'] = [mta_bystation[mta_bystation.STATION == "34 ST-PENN STA"], 'Penn Station']
mta['grdcntrl'] = [mta_bystation[mta_bystation.STATION == "GRD CNTRL-42 ST"], 'Grand Central']
mta['heraldsq'] = [mta_bystation[mta_bystation.STATION == "34 ST-HERALD SQ"], 'Herald Sq']
mta['23rd'] = [mta_bystation[mta_bystation.STATION == "23 ST"], '23rd St']
#mta['portauth'] = [mta_bystation[mta_bystation.STATION == "42 ST-PORT AUTH"], 'Port Auth']
#mta['unionsq'] = [mta_bystation[mta_bystation.STATION == "14 ST-UNION SQ"], 'Union Sq']
mta['timessq'] = [mta_bystation[mta_bystation.STATION == "TIMES SQ-42 ST"], 'Ti

Exporting a 3D numpy to a VTK file for viewing in Paraview/Mayavi

For those that want to export a simple 3D numpy array (along with axes) to a .vtk (or .vtr) file for post-processing and display in Paraview or Mayavi there's a little module called PyEVTK that does exactly that. The module supports structured and unstructured data etc..
Unfortunately, even though the code works fine in unix-based systems I couldn't make it work (keeps crashing) on any windows installation which simply makes things complicated. Ive contacted the developer but his suggestions did not work
Therefore my question is:
How can one use the from vtk.util import numpy_support function to export a 3D array (the function itself doesn't support 3D arrays) to a .vtk file? Is there a simple way to do it without creating vtkDatasets etc etc?
Thanks a lot!
It's been forever and I had entirely forgotten asking this question but I ended up figuring it out. I've written a post about it in my blog (PyScience) providing a tutorial on how to convert between NumPy and VTK. Do take a look if interested:
pyscience.wordpress.com/2014/09/06/numpy-to-vtk-converting-your-numpy-arrays-to-vtk-arrays-and-files/
It's not a direct answer to your question, but if you have tvtk (if you have mayavi, you should have it), you can use it to write your data to vtk format. (See: http://code.enthought.com/projects/files/ETS3_API/enthought.tvtk.misc.html )
It doesn't use PyEVTK, and it supports a broad range of data sources (more than just structured and unstructured grids), so it will probably work where other things aren't.
As a quick example (Mayavi's mlab interface can make this much less verbose, especially if you're already using it.):
import numpy as np
from enthought.tvtk.api import tvtk, write_data
data = np.random.random((10,10,10))
grid = tvtk.ImageData(spacing=(10, 5, -10), origin=(100, 350, 200),
dimensions=data.shape)
grid.point_data.scalars = np.ravel(order='F')
grid.point_data.scalars.name = 'Test Data'
# Writes legacy ".vtk" format if filename ends with "vtk", otherwise
# this will write data using the newer xml-based format.
write_data(grid, 'test.vtk')
And a portion of the output file:
# vtk DataFile Version 3.0
vtk output
ASCII
DATASET STRUCTURED_POINTS
DIMENSIONS 10 10 10
SPACING 10 5 -10
ORIGIN 100 350 200
POINT_DATA 1000
SCALARS Test%20Data double
LOOKUP_TABLE default
0.598189 0.228948 0.346975 0.948916 0.0109774 0.30281 0.643976 0.17398 0.374673
0.295613 0.664072 0.307974 0.802966 0.836823 0.827732 0.895217 0.104437 0.292796
0.604939 0.96141 0.0837524 0.498616 0.608173 0.446545 0.364019 0.222914 0.514992
...
...
TVTK of Mayavi has a beautiful way of writing vtk files. Here is a test example I have written for myself following #Joe and tvtk documentation. The advantage it has over evtk, is the support for both ascii and html.Hope it will help other people.
from tvtk.api import tvtk, write_data
import numpy as np
#data = np.random.random((3, 3, 3))
#
#i = tvtk.ImageData(spacing=(1, 1, 1), origin=(0, 0, 0))
#i.point_data.scalars = data.ravel()
#i.point_data.scalars.name = 'scalars'
#i.dimensions = data.shape
#
#w = tvtk.XMLImageDataWriter(input=i, file_name='spoints3d.vti')
#w.write()
points = np.array([[0,0,0], [1,0,0], [1,1,0], [0,1,0]], 'f')
(n1, n2) = points.shape
poly_edge = np.array([[0,1,2,3]])
print n1, n2
## Scalar Data
#temperature = np.array([10., 20., 30., 40.])
#pressure = np.random.rand(n1)
#
## Vector Data
#velocity = np.random.rand(n1,n2)
#force = np.random.rand(n1,n2)
#
##Tensor Data with
comp = 5
stress = np.random.rand(n1,comp)
#
#print stress.shape
## The TVTK dataset.
mesh = tvtk.PolyData(points=points, polys=poly_edge)
#
## Data 0 # scalar data
#mesh.point_data.scalars = temperature
#mesh.point_data.scalars.name = 'Temperature'
#
## Data 1 # additional scalar data
#mesh.point_data.add_array(pressure)
#mesh.point_data.get_array(1).name = 'Pressure'
#mesh.update()
#
## Data 2 # Vector data
#mesh.point_data.vectors = velocity
#mesh.point_data.vectors.name = 'Velocity'
#mesh.update()
#
## Data 3 additional vector data
#mesh.point_data.add_array( force)
#mesh.point_data.get_array(3).name = 'Force'
#mesh.update()
mesh.point_data.tensors = stress
mesh.point_data.tensors.name = 'Stress'
# Data 4 additional tensor Data
#mesh.point_data.add_array(stress)
#mesh.point_data.get_array(4).name = 'Stress'
#mesh.update()
write_data(mesh, 'polydata.vtk')
# XML format
# Method 1
#write_data(mesh, 'polydata')
# Method 2
#w = tvtk.XMLPolyDataWriter(input=mesh, file_name='polydata.vtk')
#w.write()
I know it is a bit late and I do love your tutorials #somada141. This should work too.
def numpy2VTK(img, spacing=[1.0, 1.0, 1.0]):
# evolved from code from Stou S.,
# on http://www.siafoo.net/snippet/314
# This function, as the name suggests, converts numpy array to VTK
importer = vtk.vtkImageImport()
img_data = img.astype('uint8')
img_string = img_data.tostring() # type short
dim = img.shape
importer.CopyImportVoidPointer(img_string, len(img_string))
importer.SetDataScalarType(VTK_UNSIGNED_CHAR)
importer.SetNumberOfScalarComponents(1)
extent = importer.GetDataExtent()
importer.SetDataExtent(extent[0], extent[0] + dim[2] - 1,
extent[2], extent[2] + dim[1] - 1,
extent[4], extent[4] + dim[0] - 1)
importer.SetWholeExtent(extent[0], extent[0] + dim[2] - 1,
extent[2], extent[2] + dim[1] - 1,
extent[4], extent[4] + dim[0] - 1)
importer.SetDataSpacing(spacing[0], spacing[1], spacing[2])
importer.SetDataOrigin(0, 0, 0)
return importer
Hope it helps!
Here's a SimpleITK version with the function load_itk taken from here:
import SimpleITK as sitk
import numpy as np
if len(sys.argv)<3:
print('Wrong number of arguments.', file=sys.stderr)
print('Usage: ' + __file__ + ' input_sitk_file' + ' output_sitk_file', file=sys.stderr)
sys.exit(1)
def quick_read(filename):
# Read image information without reading the bulk data.
file_reader = sitk.ImageFileReader()
file_reader.SetFileName(filename)
file_reader.ReadImageInformation()
print('image size: {0}\nimage spacing: {1}'.format(file_reader.GetSize(), file_reader.GetSpacing()))
# Some files have a rich meta-data dictionary (e.g. DICOM)
for key in file_reader.GetMetaDataKeys():
print(key + ': ' + file_reader.GetMetaData(key))
def load_itk(filename):
# Reads the image using SimpleITK
itkimage = sitk.ReadImage(filename)
# Convert the image to a numpy array first and then shuffle the dimensions to get axis in the order z,y,x
data = sitk.GetArrayFromImage(itkimage)
# Read the origin of the ct_scan, will be used to convert the coordinates from world to voxel and vice versa.
origin = np.array(list(reversed(itkimage.GetOrigin())))
# Read the spacing along each dimension
spacing = np.array(list(reversed(itkimage.GetSpacing())))
return data, origin, spacing
def convert(data, output_filename):
image = sitk.GetImageFromArray(data)
writer = sitk.ImageFileWriter()
writer.SetFileName(output_filename)
writer.Execute(image)
def wait():
print('Press Enter to load & convert or exit using Ctrl+C')
input()
quick_read(sys.argv[1])
print('-'*20)
wait()
data, origin, spacing = load_itk(sys.argv[1])
convert(sys.argv[2])