User defined job type not available during calculation via SLURM - pyiron

I am trying to setup a pyiron calculation (version 0.3.6.). I want to execute a non-python script on a computer cluster via SLURM. I have written an own OwnProgramJob class, which inherits from the GenericJob class. Everything is running smoothly on my local computer. When running on the cluster, however, my own class is not available in pyiron:
...
File "/beegfs-home/users/fufl/.local/project/lib/python3.8/site-packages/pyiron_base/generic/hdfio.py", line 1251, in import_class
return getattr(
AttributeError: module '__main__' has no attribute 'OwnProgramJob'
How can I make my own class available for pyiron on the cluster?
I suppose one way would be to add my own class directly into the pyiron source code and to modify the JOB_CLASS_DICT, following the suggestion from https://github.com/pyiron/pyiron/issues/973#issuecomment-694347111. Is there another way without modifying the pyiron source code?
My source code can be found below for reference.
Thank you very much,
Florian
Jupyter notebook:
import pyiron
from pathlib import Path
pr = pyiron.Project(path=f"{str(Path.home())}/pyiron/projects/example")
from pyiron_base import GenericJob
import os
class OwnProgramJob(GenericJob):
def __init__(self, project, job_name):
super().__init__(project, job_name)
self.input = OwnProgramInput()
self.executable = "cat input.in > output.out"
def write_input(self):
with open(os.path.join(self.working_directory, "input.in"), 'w') as infile:
infile.write("asd 100")
def collect_output(self):
file = os.path.join(self.working_directory, "output.out")
with open(file) as f:
line = f.readlines()[0]
energy = float(line.split()[1])
with self.project_hdf5.open("output/generic") as h5out:
h5out["energy_tot"] = energy
class OwnProgramInput(GenericParameters):
def __init__(self, input_file_name=None):
super(OwnProgramInput, self).__init__(
input_file_name=input_file_name,
table_name="input")
def load_default(self):
self.load_string("input_energy 100")
job = pr.create_job(job_type=OwnProgramJob, job_name="test", delete_existing_job=True)
job.server.queue = 'cpu'
job.run()
pr.job_table()
SLURM jobfile:
#SBATCH --workdir={{working_directory}}
#SBATCH --get-user-env=L
#SBATCH --partition=cpu
{%- if run_time_max %}
#SBATCH --time={{run_time_max // 60}}
{%- endif %}
{%- if memory_max %}
#SBATCH --mem={{memory_max}}
{%- endif %}
#SBATCH --cpus-per-task={{cores}}
{{command}}

For the job class to be available when submitting to the queuing system it is necessary that it is included in the python path. So I suggest splitting the class definition in a separate python module named ownprogramjob.py:
import os
from pyiron_base import GenericJob, GenericParameters
class OwnProgramJob(GenericJob):
def __init__(self, project, job_name):
super().__init__(project, job_name)
self.input = OwnProgramInput()
self.executable = "cat input.in > output.out"
def write_input(self):
with open(os.path.join(self.working_directory, "input.in"), 'w') as infile:
infile.write("asd 100")
def collect_output(self):
file = os.path.join(self.working_directory, "output.out")
with open(file) as f:
line = f.readlines()[0]
energy = float(line.split()[1])
with self.project_hdf5.open("output/generic") as h5out:
h5out["energy_tot"] = energy
class OwnProgramInput(GenericParameters):
def __init__(self, input_file_name=None):
super(OwnProgramInput, self).__init__(
input_file_name=input_file_name,
table_name="input")
def load_default(self):
self.load_string("input_energy 100")
Then you can submit it using:
from pyiron import Project
from ownprogramjob import OwnProgramJob
pr = Project("test")
job = pr.create_job(job_type=OwnProgramJob, job_name="test", delete_existing_job=True)
job.server.queue = 'cpu'
job.run()
pr.job_table()
Best,
Jan

Related

How to write python unittest cases to mock redis connection (redis.StrictRedis) in Django

How can I mock the following function for connecting to Redis?
import redis
class RedisCache:
redis_instance = None
#classmethod
def set_connect(cls):
redis_instance = redis.StrictRedis(host='0.0.0.0', port=6379, password='xyz', charset='utf-8', decode_responses=True, socket_timeout=30)
return redis_instance
#classmethod
def get_conn(cls):
cls.redis_instance = cls.set_connect()
return cls.redis_instance
I looked for some solutions, but they were basically using fakeredis module. I wanted to have a simpler way to mock these functions.
Note-
data returned by the function: Redis<ConnectionPool<Connection<host=127.0.0.1,port=6379,db=0>>>
You can use patch() function to mock out redis.StrictRedis class. See where-to-patch
E.g.
redis_cache.py:
import redis
class RedisCache:
redis_instance = None
#classmethod
def set_connect(cls):
redis_instance = redis.StrictRedis(host='0.0.0.0', port=6379, password='xyz',
charset='utf-8', decode_responses=True, socket_timeout=30)
return redis_instance
#classmethod
def get_conn(cls):
cls.redis_instance = cls.set_connect()
return cls.redis_instance
test_redis_cache.py:
from unittest import TestCase
import unittest
from unittest.mock import patch, Mock
from redis_cache import RedisCache
class TestRedisCache(TestCase):
def test_set_connect(self):
with patch('redis.StrictRedis') as mock_StrictRedis:
mock_redis_instance = mock_StrictRedis.return_value
actual = RedisCache.set_connect()
self.assertEqual(actual, mock_redis_instance)
mock_StrictRedis.assert_called_once_with(host='0.0.0.0', port=6379, password='xyz',
charset='utf-8', decode_responses=True, socket_timeout=30)
#patch('redis.StrictRedis')
def test_get_conn(self, mock_StrictRedis):
mock_redis_instance = mock_StrictRedis.return_value
RedisCache.get_conn()
self.assertEqual(RedisCache.redis_instance, mock_redis_instance)
if __name__ == '__main__':
unittest.main()
test result:
..
----------------------------------------------------------------------
Ran 2 tests in 0.004s
OK
Name Stmts Miss Cover Missing
------------------------------------------------------------------------------
src/stackoverflow/70016401/redis_cache.py 11 0 100%
src/stackoverflow/70016401/test_redis_cache.py 18 0 100%
------------------------------------------------------------------------------
TOTAL 29 0 100%

how to add a right click menu on textBrowser placed on on a QDialog window using designer? [duplicate]

I am currently following this tutorial on threading in PyQt (code from here). As it was written in PyQt4 (and Python2), I adapted the code to work with PyQt5 and Python3.
Here is the gui file (newdesign.py):
# -*- coding: utf-8 -*-
# Form implementation generated from reading ui file 'threading_design.ui'
#
# Created by: PyQt5 UI code generator 5.6
#
# WARNING! All changes made in this file will be lost!
from PyQt5 import QtCore, QtGui, QtWidgets
class Ui_MainWindow(object):
def setupUi(self, MainWindow):
MainWindow.setObjectName("MainWindow")
MainWindow.resize(526, 373)
self.centralwidget = QtWidgets.QWidget(MainWindow)
self.centralwidget.setObjectName("centralwidget")
self.verticalLayout = QtWidgets.QVBoxLayout(self.centralwidget)
self.verticalLayout.setObjectName("verticalLayout")
self.subreddits_input_layout = QtWidgets.QHBoxLayout()
self.subreddits_input_layout.setObjectName("subreddits_input_layout")
self.label_subreddits = QtWidgets.QLabel(self.centralwidget)
self.label_subreddits.setObjectName("label_subreddits")
self.subreddits_input_layout.addWidget(self.label_subreddits)
self.edit_subreddits = QtWidgets.QLineEdit(self.centralwidget)
self.edit_subreddits.setObjectName("edit_subreddits")
self.subreddits_input_layout.addWidget(self.edit_subreddits)
self.verticalLayout.addLayout(self.subreddits_input_layout)
self.label_submissions_list = QtWidgets.QLabel(self.centralwidget)
self.label_submissions_list.setObjectName("label_submissions_list")
self.verticalLayout.addWidget(self.label_submissions_list)
self.list_submissions = QtWidgets.QListWidget(self.centralwidget)
self.list_submissions.setBatchSize(1)
self.list_submissions.setObjectName("list_submissions")
self.verticalLayout.addWidget(self.list_submissions)
self.progress_bar = QtWidgets.QProgressBar(self.centralwidget)
self.progress_bar.setProperty("value", 0)
self.progress_bar.setObjectName("progress_bar")
self.verticalLayout.addWidget(self.progress_bar)
self.buttons_layout = QtWidgets.QHBoxLayout()
self.buttons_layout.setObjectName("buttons_layout")
self.btn_stop = QtWidgets.QPushButton(self.centralwidget)
self.btn_stop.setEnabled(False)
self.btn_stop.setObjectName("btn_stop")
self.buttons_layout.addWidget(self.btn_stop)
self.btn_start = QtWidgets.QPushButton(self.centralwidget)
self.btn_start.setObjectName("btn_start")
self.buttons_layout.addWidget(self.btn_start)
self.verticalLayout.addLayout(self.buttons_layout)
MainWindow.setCentralWidget(self.centralwidget)
self.retranslateUi(MainWindow)
QtCore.QMetaObject.connectSlotsByName(MainWindow)
def retranslateUi(self, MainWindow):
_translate = QtCore.QCoreApplication.translate
MainWindow.setWindowTitle(_translate("MainWindow", "Threading Tutorial - nikolak.com "))
self.label_subreddits.setText(_translate("MainWindow", "Subreddits:"))
self.edit_subreddits.setPlaceholderText(_translate("MainWindow", "python,programming,linux,etc (comma separated)"))
self.label_submissions_list.setText(_translate("MainWindow", "Submissions:"))
self.btn_stop.setText(_translate("MainWindow", "Stop"))
self.btn_start.setText(_translate("MainWindow", "Start"))
and the main script (main.py):
from PyQt5 import QtWidgets
from PyQt5.QtCore import QThread, pyqtSignal, QObject
import sys
import newdesign
import urllib.request
import json
import time
class getPostsThread(QThread):
def __init__(self, subreddits):
"""
Make a new thread instance with the specified
subreddits as the first argument. The subreddits argument
will be stored in an instance variable called subreddits
which then can be accessed by all other class instance functions
:param subreddits: A list of subreddit names
:type subreddits: list
"""
QThread.__init__(self)
self.subreddits = subreddits
def __del__(self):
self.wait()
def _get_top_post(self, subreddit):
"""
Return a pre-formatted string with top post title, author,
and subreddit name from the subreddit passed as the only required
argument.
:param subreddit: A valid subreddit name
:type subreddit: str
:return: A string with top post title, author,
and subreddit name from that subreddit.
:rtype: str
"""
url = "https://www.reddit.com/r/{}.json?limit=1".format(subreddit)
headers = {'User-Agent': 'nikolak#outlook.com tutorial code'}
request = urllib.request.Request(url, header=headers)
response = urllib.request.urlopen(request)
data = json.load(response)
top_post = data['data']['children'][0]['data']
return "'{title}' by {author} in {subreddit}".format(**top_post)
def run(self):
"""
Go over every item in the self.subreddits list
(which was supplied during __init__)
and for every item assume it's a string with valid subreddit
name and fetch the top post using the _get_top_post method
from reddit. Store the result in a local variable named
top_post and then emit a pyqtSignal add_post(QString) where
QString is equal to the top_post variable that was set by the
_get_top_post function.
"""
for subreddit in self.subreddits:
top_post = self._get_top_post(subreddit)
self.emit(pyqtSignal('add_post(QString)'), top_post)
self.sleep(2)
class ThreadingTutorial(QtWidgets.QMainWindow, newdesign.Ui_MainWindow):
"""
How the basic structure of PyQt GUI code looks and behaves like is
explained in this tutorial
http://nikolak.com/pyqt-qt-designer-getting-started/
"""
def __init__(self):
super(self.__class__, self).__init__()
self.setupUi(self)
self.btn_start.clicked.connect(self.start_getting_top_posts)
def start_getting_top_posts(self):
# Get the subreddits user entered into an QLineEdit field
# this will be equal to '' if there is no text entered
subreddit_list = str(self.edit_subreddits.text()).split(',')
if subreddit_list == ['']: # since ''.split(',') == [''] we use that to check
# whether there is anything there to fetch from
# and if not show a message and abort
QtWidgets.QMessageBox.critical(self, "No subreddits",
"You didn't enter any subreddits.",
QtWidgets.QMessageBox.Ok)
return
# Set the maximum value of progress bar, can be any int and it will
# be automatically converted to x/100% values
# e.g. max_value = 3, current_value = 1, the progress bar will show 33%
self.progress_bar.setMaximum(len(subreddit_list))
# Setting the value on every run to 0
self.progress_bar.setValue(0)
# We have a list of subreddits which we use to create a new getPostsThread
# instance and we pass that list to the thread
self.get_thread = getPostsThread(subreddit_list)
# Next we need to connect the events from that thread to functions we want
# to be run when those pyqtSignals get fired
# Adding post will be handeled in the add_post method and the pyqtSignal that
# the thread will emit is pyqtSignal("add_post(QString)")
# the rest is same as we can use to connect any pyqtSignal
self.connect(self.get_thread, pyqtSignal("add_post(QString)"), self.add_post)
# This is pretty self explanatory
# regardless of whether the thread finishes or the user terminates it
# we want to show the notification to the user that adding is done
# and regardless of whether it was terminated or finished by itself
# the finished pyqtSignal will go off. So we don't need to catch the
# terminated one specifically, but we could if we wanted.
self.connect(self.get_thread, pyqtSignal("finished()"), self.done)
# We have all the events we need connected we can start the thread
self.get_thread.start()
# At this point we want to allow user to stop/terminate the thread
# so we enable that button
self.btn_stop.setEnabled(True)
# And we connect the click of that button to the built in
# terminate method that all QThread instances have
self.btn_stop.clicked.connect(self.get_thread.terminate)
# We don't want to enable user to start another thread while this one is
# running so we disable the start button.
self.btn_start.setEnabled(False)
def add_post(self, post_text):
"""
Add the text that's given to this function to the
list_submissions QListWidget we have in our GUI and
increase the current value of progress bar by 1
:param post_text: text of the item to add to the list
:type post_text: str
"""
self.list_submissions.addItem(post_text)
self.progress_bar.setValue(self.progress_bar.value()+1)
def done(self):
"""
Show the message that fetching posts is done.
Disable Stop button, enable the Start one and reset progress bar to 0
"""
self.btn_stop.setEnabled(False)
self.btn_start.setEnabled(True)
self.progress_bar.setValue(0)
QtWidgets.QMessageBox.information(self, "Done!", "Done fetching posts!")
def main():
app = QtWidgets.QApplication(sys.argv)
form = ThreadingTutorial()
form.show()
app.exec_()
if __name__ == '__main__':
main()
Now I'm getting the following error:
AttributeError: 'ThreadingTutorial' object has no attribute 'connect'
Can anyone please tell me how to fix this? Any help would be, as always, very much appreciated.
Using QObject.connect() and similar in PyQt4 is known as "Old style signals", and is not supported in PyQt5 anymore, it supports only "New style signals", which already in PyQt4 was the recommended way to connect signals.
In PyQt5 you need to use the connect() and emit() methods of the bound signal directly, e.g. instead of:
self.emit(pyqtSignal('add_post(QString)'), top_post)
...
self.connect(self.get_thread, pyqtSignal("add_post(QString)"), self.add_post)
self.connect(self.get_thread, pyqtSignal("finished()"), self.done)
use:
self.add_post.emit(top_post)
...
self.get_thread.add_post.connect(self.add_post)
self.get_thread.finished.connect(self.done)
However for this to work you need to explicitly define the add_post signal on your getPostsThread first, otherwise you'll get an attribute error.
class getPostsThread(QThread):
add_post = pyqtSignal(str)
...
In PyQt4 with old style signals when a signal was used it was automatically defined, this now needs to be done explicitly.

How scrapy crawl work:which class instanced and which method called?

Here is a simple python file--test.py.
import math
class myClass():
def myFun(self,x):
return(math.sqrt(x))
if __name__ == "__main__":
myInstance=myClass()
print(myInstance.myFun(9))
It print 3 with python test.py,let's analyse the running process.
1. to instance myClass and assign it to myInstance.
2.to call myFun function and print the result.
It is scrapy's turn.
In the scrapy1.4 manual,quotes_spider.py is as below.
import scrapy
class QuotesSpider(scrapy.Spider):
name = "quotes"
def start_requests(self):
urls = [
'http://quotes.toscrape.com/page/1/',
'http://quotes.toscrape.com/page/2/',
]
for url in urls:
yield scrapy.Request(url=url, callback=self.parse)
def parse(self, response):
page = response.url.split("/")[-2]
filename = 'quotes-%s.html' % page
with open(filename, 'wb') as f:
f.write(response.body)
self.log('Saved file %s' % filename)
To run the spider with scrapy crawl quotes,i am puzzled:
1.Where is the main function or main body for the spider?
2.Which class was instanced?
3.Which method was called?
mySpider = QuotesSpider(scrapy.Spider)
mySpider.parse(response)
How scrapy crawl work exactly?
So let's start. Assuming you use linux/mac. Let's check where us scrapy
$ which scrapy
/Users/tarun.lalwani/.virtualenvs/myproject/bin/scrapy
Let's look at the content of this file
$ cat /Users/tarun.lalwani/.virtualenvs/myproject/bin/scrapy
#!/Users/tarun.lalwani/.virtualenvs/myproject/bin/python3.6
# -*- coding: utf-8 -*-
import re
import sys
from scrapy.cmdline import execute
if __name__ == '__main__':
sys.argv[0] = re.sub(r'(-script\.pyw?|\.exe)?$', '', sys.argv[0])
sys.exit(execute())
So this executes execute method from cmdline.py and her is your main method.
cmdline.py
from __future__ import print_function
....
....
def execute(argv=None, settings=None):
if argv is None:
argv = sys.argv
# --- backwards compatibility for scrapy.conf.settings singleton ---
if settings is None and 'scrapy.conf' in sys.modules:
from scrapy import conf
if hasattr(conf, 'settings'):
settings = conf.settings
# ------------------------------------------------------------------
if settings is None:
settings = get_project_settings()
# set EDITOR from environment if available
try:
editor = os.environ['EDITOR']
except KeyError: pass
else:
settings['EDITOR'] = editor
check_deprecated_settings(settings)
# --- backwards compatibility for scrapy.conf.settings singleton ---
import warnings
from scrapy.exceptions import ScrapyDeprecationWarning
with warnings.catch_warnings():
warnings.simplefilter("ignore", ScrapyDeprecationWarning)
from scrapy import conf
conf.settings = settings
# ------------------------------------------------------------------
inproject = inside_project()
cmds = _get_commands_dict(settings, inproject)
cmdname = _pop_command_name(argv)
parser = optparse.OptionParser(formatter=optparse.TitledHelpFormatter(), \
conflict_handler='resolve')
if not cmdname:
_print_commands(settings, inproject)
sys.exit(0)
elif cmdname not in cmds:
_print_unknown_command(settings, cmdname, inproject)
sys.exit(2)
cmd = cmds[cmdname]
parser.usage = "scrapy %s %s" % (cmdname, cmd.syntax())
parser.description = cmd.long_desc()
settings.setdict(cmd.default_settings, priority='command')
cmd.settings = settings
cmd.add_options(parser)
opts, args = parser.parse_args(args=argv[1:])
_run_print_help(parser, cmd.process_options, args, opts)
cmd.crawler_process = CrawlerProcess(settings)
_run_print_help(parser, _run_command, cmd, args, opts)
sys.exit(cmd.exitcode)
if __name__ == '__main__':
execute()
Now if you notice execute method it processes the arguments passed by you. which is crawl quotes in your case. The execute methods scans the projects for classes and check which has name defined as quotes. It creates the CrawlerProcess class and that runs the whole show.
Scrapy is based on Twisted Python Framework. Which is a scheduler based framework.
Consider the below part of the code
for url in urls:
yield scrapy.Request(url=url, callback=self.parse)
When the engine executes this function and first yield is execute. The value is returned to the engined. The engine now looks at other task that are pending executes them, (when they yield, some other pending task queue function gets a chance). So yield is what allows to break a function execution into parts and help Scrapy/Twisted work.
You can get a detailed overview on the link below
https://doc.scrapy.org/en/latest/topics/architecture.html

MemoryError when querying database from Process

I am trying to create a program with 3 processes that read from the same database. The code was working before I started introducing processes.
I am getting MemoryError when performing a select() from PeeWee, I suspect there is something wrong with sharing of resources. Minimal example:
models.py
from playhouse.pool import PooledSqliteExtDatabase
file_scanner_database = PooledSqliteExtDatabase(
None,
max_connections=32,
)
class FileModel(Model):
class Meta:
database = file_scanner_database
main.py
from file_scanner import FileScanner
from models import file_scanner_database
from models import FileModel
from multiprocessing import Process
def create_scanner_agent(data):
scanner = FileScanner(data)
scanner.start_scanner()
shared_info = {'db_location': '/absolute/path/to/database'}
file_scanner_database.init(shared_info['db_location'])
file_scanner_database.connect()
file_scanner_database.create_tables([FileModel], safe=True)
new_process = Process(
target=create_scanner_agent,
args=(shared_info,)
)
new_process.daemon = True
new_process.start()
try:
new_process.join()
except KeyboardInterrupt:
pass
new_process.terminate()
file_scanner.py
from models import file_scanner_database
from models import FileModel
class FileScanner:
def __init__(self, data):
for k, v in data.items():
setattr(self, k, v)
file_scanner_database.init(self.db_location)
file_scanner_database.connect()
def start_scanner(self):
while True:
# THIS IS WHERE THE PROGRAM CRASHES
for row in FileModel.select():
...
It looks like you're trying to access memory across a fork? Or some such craziness? I think the answer is that you're doing it wrong homie. Try opening your DB connection after the fork.

Is there a way to get tensorflow tf.Print output to appear in Jupyter Notebook output

I'm using the tf.Print op in a Jupyter notebook. It works as required, but will only print the output to the console, without printing in the notebook. Is there any way to get around this?
An example would be the following (in a notebook):
import tensorflow as tf
a = tf.constant(1.0)
a = tf.Print(a, [a], 'hi')
sess = tf.Session()
a.eval(session=sess)
That code will print 'hi[1]' in the console, but nothing in the notebook.
Update Feb 3, 2017
I've wrapped this into memory_util package. Example usage
# install memory util
import urllib.request
response = urllib.request.urlopen("https://raw.githubusercontent.com/yaroslavvb/memory_util/master/memory_util.py")
open("memory_util.py", "wb").write(response.read())
import memory_util
sess = tf.Session()
a = tf.random_uniform((1000,))
b = tf.random_uniform((1000,))
c = a + b
with memory_util.capture_stderr() as stderr:
sess.run(c.op)
print(stderr.getvalue())
** Old stuff**
You could reuse FD redirector from IPython core. (idea from Mark Sandler)
import os
import sys
STDOUT = 1
STDERR = 2
class FDRedirector(object):
""" Class to redirect output (stdout or stderr) at the OS level using
file descriptors.
"""
def __init__(self, fd=STDOUT):
""" fd is the file descriptor of the outpout you want to capture.
It can be STDOUT or STERR.
"""
self.fd = fd
self.started = False
self.piper = None
self.pipew = None
def start(self):
""" Setup the redirection.
"""
if not self.started:
self.oldhandle = os.dup(self.fd)
self.piper, self.pipew = os.pipe()
os.dup2(self.pipew, self.fd)
os.close(self.pipew)
self.started = True
def flush(self):
""" Flush the captured output, similar to the flush method of any
stream.
"""
if self.fd == STDOUT:
sys.stdout.flush()
elif self.fd == STDERR:
sys.stderr.flush()
def stop(self):
""" Unset the redirection and return the captured output.
"""
if self.started:
self.flush()
os.dup2(self.oldhandle, self.fd)
os.close(self.oldhandle)
f = os.fdopen(self.piper, 'r')
output = f.read()
f.close()
self.started = False
return output
else:
return ''
def getvalue(self):
""" Return the output captured since the last getvalue, or the
start of the redirection.
"""
output = self.stop()
self.start()
return output
import tensorflow as tf
x = tf.constant([1,2,3])
a=tf.Print(x, [x])
redirect=FDRedirector(STDERR)
sess = tf.InteractiveSession()
redirect.start();
a.eval();
print "Result"
print redirect.stop()
I ran into the same problem and got around it by using a function like this in my notebooks:
def tf_print(tensor, transform=None):
# Insert a custom python operation into the graph that does nothing but print a tensors value
def print_tensor(x):
# x is typically a numpy array here so you could do anything you want with it,
# but adding a transformation of some kind usually makes the output more digestible
print(x if transform is None else transform(x))
return x
log_op = tf.py_func(print_tensor, [tensor], [tensor.dtype])[0]
with tf.control_dependencies([log_op]):
res = tf.identity(tensor)
# Return the given tensor
return res
# Now define a tensor and use the tf_print function much like the tf.identity function
tensor = tf_print(tf.random_normal([100, 100]), transform=lambda x: [np.min(x), np.max(x)])
# This will print the transformed version of the tensors actual value
# (which was summarized to just the min and max for brevity)
sess = tf.InteractiveSession()
sess.run([tensor])
sess.close()
FYI, using a logger instead of calling "print" in my custom function worked wonders for me as the stdout is often buffered by jupyter and not shown before "Loss is Nan" kind of errors -- which was the whole point in using that function in the first place in my case.
You can check the terminal where you launched the jupyter notebook to see the message.
import tensorflow as tf
tf.InteractiveSession()
a = tf.constant(1)
b = tf.constant(2)
opt = a + b
opt = tf.Print(opt, [opt], message="1 + 2 = ")
opt.eval()
In the terminal, I can see:
2018-01-02 23:38:07.691808: I tensorflow/core/kernels/logging_ops.cc:79] 1 + 2 = [3]
A simple way, tried it in regular python, but not jupyter yet.
os.dup2(sys.stdout.fileno(), 1)
os.dup2(sys.stdout.fileno(), 2)
Explanation is here: In python, how to capture the stdout from a c++ shared library to a variable
The issue that I faced was that one can't run a session inside a Tensorflow Graph, like in the training or in the evaluation.
That's why the options to use sess.run(opt) or opt.eval() were not a solution for me.
The best thing was to use tf.Print() and redirect the logging to an external file.
I did this using a temporal file, which I transferred to a regular file like this:
STDERR=2
import os
import sys
import tempfile
class captured:
def __init__(self, fd=STDERR):
self.fd = fd
self.prevfd = None
def __enter__(self):
t = tempfile.NamedTemporaryFile()
self.prevfd = os.dup(self.fd)
os.dup2(t.fileno(), self.fd)
return t
def __exit__(self, exc_type, exc_value, traceback):
os.dup2(self.prevfd, self.fd)
with captured(fd=STDERR) as tmp:
...
classifier.evaluate(input_fn=input_fn, steps=100)
with open('log.txt', 'w') as f:
print(open(tmp.name).read(), file=f)
And then in my evaluation I do:
a = tf.constant(1)
a = tf.Print(a, [a], message="a: ")