Simulating a system of batching jobs with interrupting set-up/switch-on times using Simpy - interrupt

I am new to Simpy and have a problem with combining batching jobs and interrupting set-up time. So could you please help me?
I would like to create a system with servers that need time to set up before being ready to serve.
The system starts to set up whenever enough M (2, 3,...) customers are in the queue. If the number of customers in the system reaches the maximum number of K(50), the coming customer will balk.
When a batch( group of M customers) leaves the system, we check if there are M customers(a batch) who are waiting to be served. If so, we keep the server remaining ON, otherwise, we turn off the server immediately.
I found some code for quite the same problem in a Simpy google group about Covid test simulation that uses Stores Resources and the answer for interrupting set-up time with Container Resources by Michael R. Gibbs
https://groups.google.com/g/python-simpy/c/iFYaDlL4fq0
Interrupt an earlier timeout event in Simpy
I tried to combine 2 codes but It didn't work.
Example, when M = 2, K = 50
Customer 1 arrives and waits
Customer 2 arrives, enough 2 customers then request a server.
Server 1 is SETUP in t1 secs.
Customer 3 arrives and waits
Customer 4 enough 2 customers then request a server.
Server 2 is SETUP in t1 secs.
Server 1 is ON.
Customers 1 and 2 occupied server 1
Customer 1 and 2 completes the service and leaves the system.
Customers 3 and 4 occupied server 1 (because when server 1 finishes
Server 2 is still in the setup process)
Server 2 (still in SETUP mode) is turned off...
... Customer 100 arrives and sees the system has 50 customers, then customer 100 balk

Broke customer arrivals into two parts. A first queue that where the customer waits until there is enough customers to make a batch. When I have enough customers to make a batch, I do so, popping the batched customers from the batching queue, and putting the batch in a processing queue. I count the customers in both queues to see if a arriving customer aborts entry.
When a batch is put in the processing queue, I also start up a server. This means that the number of batches in the processing queue will equal the number of servers starting up. This also means that when a server finishes starting up, there will be a batch to process. Since there will never be a wait for a batch, I use a simple list for my queue.
When a batch starts up, it grabs a batch, and removes itself from the list of starting servers. After the server finishes processing a batch, it checks if there is a batch in the processing queue. If so, grab the batch and keep processing, but also kill the server that is starting up to process the batch. If no batches in the processing queue, shut down.
Here is the code. You should see in the log the queues max out and customers abort, but also see servers start to shut down towards the end
"""
Simulation of servers processing batches
Customers enter a queue where they wait for
enough customers to make a batch
If the there are too many customers in the queues
the arriving customer will abort
When a batch is made, it is put into a second
processing queue where the batch waits to be processed.
When a batch is put into the processing queue, it
starts a server. The server has a start up delay
then loops by seizing a batch, process the batch, release
the batch, checking if another batch is in the
processing queue. If there is another batch, stop a server
that is starting up and process the batch, else end loop
and shutdown server
Programmer: Michael R. Gibbs
"""
import simpy
import random
max_q_size = 50
batch_size = 2
server_start_time = 55
processing_time = lambda : random.triangular(5,20,10)
arrival_gap = lambda : random.triangular(1,1,1)
# there is no wating so normal lists are good enough
batching_q = list()
processing_q = list()
server_q = list() # servers that are still starting up
class Server():
"""
Server that process batches
Has two states: starting up, and batch processing
"""
def __init__(self, id, env, processing_q, server_q):
self.id = id
self.env = env
self.processing_q = processing_q
self.server_q = server_q
self.start_process = self.env.process(self.start_up())
def start_up(self):
"""
starts up the server, then start processing batches
start up can be interrupted, stoping the server
"""
# start up
try:
print(f'{self.env.now} server {self.id} starting up')
yield self.env.timeout(server_start_time)
print(f'{self.env.now} server {self.id} started')
self.env.process(self.process())
except simpy.Interrupt:
print(f'{env.now} server {self.id} has been interupted')
def process(self):
"""
process batches
keeps going as long as there are batches in queue
If starts second batch, also interupts starting up server
"""
while True:
print(f'{self.env.now} server {self.id} starting batch process')
b = processing_q.pop(0)
yield self.env.timeout(processing_time())
print(f'{self.env.now} server {self.id} finish batch process')
if len(self.processing_q) > 0:
# more processes to do,
# steal batch from starting up server
s = self.server_q.pop() # lifo
s.stop()
else:
print(f'{env.now} server {self.id} no more batches, shutting down')
break
def stop(self):
"""
Interrupts server start up, stoping server
"""
try:
self.start_process.interrupt()
except:
pass
def gen_arrivals(env, batching_q, processing_q, server_q):
"""
Generate arring customers
If queues are too big customer will abort
If have enough customers, create a batch and start a server
"""
id = 1
while True:
yield env.timeout(arrival_gap())
q_size = len(batching_q) + (batch_size * len(processing_q))
if q_size >= max_q_size:
print(f'{env.now} customer arrived and aborted, q len: {q_size}')
else:
print(f'{env.now} customer has arrived, q len: {q_size}')
customer = object()
batching_q.append(customer)
# check if a batch can be creatd
while len(batching_q) >= batch_size:
batch = list()
while len(batch) < batch_size:
batch.append(batching_q.pop(0))
# put batch in processing q
processing_q.append(batch)
# start server
server = Server(id, env, processing_q, server_q)
id += 1
server_q.append(server)
# boot up sim
env = simpy.Environment()
env.process(gen_arrivals(env, batching_q, processing_q, server_q))
env.run(100)

When I add a condition to limit the number of servers, it works until a server was interrupted or shut down. Then, these servers seem to have disappeared and no longer active.
Sorry for asking you too much. Here is my code:
import simpy
import random
import numpy as np
class param:
def __init__(self, x):
#self.FILE = 'Setup_time.csv'
self.MEAN_INTERARRIVAL = x # arrival_gap
self.MEAN_SERVICE_TIME = 2 # processing_time
self.MEAN_SWITCH_TIME = 3 # server_start_time
self.NUM_OF_SERVER = 4 # maximum number of servers
self.MAX_SYS_SIZE = 10 # maximum number of customers in the system
self.BATCH_SIZE = 2
self.RANDOM_SEED = 0
# there is no wating so normal lists are good enough
class Server():
"""
Server that process batches
Has two states: starting up, and batch processing
"""
def __init__(self, id, env, processing_q, server_q, param):
self.id = id
self.env = env
self.processing_q = processing_q
self.server_q = server_q
self.start_process = self.env.process(self.start_up(param))
def start_up(self, param):
"""
starts up the server, then start processing batches
start up can be interrupted, stoping the server
"""
global num_servers
# start up
if self.id <= param.NUM_OF_SERVER: # I add the condition to limit the number of servers
try:
num_servers += 1
print(f'{self.env.now} server {self.id} starting up')
yield self.env.timeout(param.MEAN_SWITCH_TIME)
#yield env.timeout(np.random.exponential(1/param.MEAN_SWITCH_TIME))
print(f'{self.env.now} server {self.id} started')
self.env.process(self.process(param))
except simpy.Interrupt:
print(f'{env.now} server {self.id} has been interupted-------------------')
def process(self, param):
"""
process batches
keeps going as long as there are batches in queue
If starts second batch, also interupts starting up server
"""
global num_servers, num_active_server
while True:
num_active_server += 1
b = processing_q.pop(0)
print(f'{self.env.now} server {self.id} starting batch process')
yield self.env.timeout(param.MEAN_SERVICE_TIME)
#yield env.timeout(np.random.exponential(1/param.MEAN_SERVICE_TIME))
num_servers -= 1
num_active_server -= 1
print(f'{self.env.now} server {self.id} finish batch process')
if len(self.processing_q) > 0:
# more processes to do,
# steal batch from starting up server
#if self.server_q:
#s = self.server_q.pop(0) # Do these lines work for FIFO rule?
#s.stop()
s = self.server_q.pop() # lifo
s.stop()
else:
print(f'{env.now} server {self.id} no more batches, shutting down')
break
def stop(self):
"""
Interrupts server start up, stoping server
"""
try:
self.start_process.interrupt()
except:
pass
def gen_arrivals(env, batching_q, processing_q, server_q, param):
"""
Generate arring customers
If queues are too big customer will abort
If have enough customers, create a batch and start a server
"""
global num_servers, num_balk, num_cumulative_customer, num_active_server
id = 1
while True:
yield env.timeout(param.MEAN_INTERARRIVAL)
#yield env.timeout(np.random.exponential(1/param.MEAN_INTERARRIVAL))
num_cumulative_customer += 1
customer = object()
batching_q.append(customer)
q_size = len(batching_q) + (param.BATCH_SIZE * len(processing_q))
sys_size = q_size + (num_active_server * param.BATCH_SIZE)
#if q_size > max_q_size:
if sys_size > param.MAX_SYS_SIZE: # I check the limited condition for number customer in system instead of number customer in queue
num_balk += 1
batching_q.pop(-1) # I added the statement
print(f'{env.now} customer arrived and aborted, sys len: {sys_size }')
else:
#customer = object() # I moved these 2 lines above to update system size before using the if statement
#batching_q.append(customer)
print(f'{env.now} customer has arrived, q len: {q_size}, sys len: {sys_size}')
# check if a batch can be creatd
while len(batching_q) >= param.BATCH_SIZE:
batch = list()
while len(batch) < param.BATCH_SIZE:
batch.append(batching_q.pop(0))
# put batch in processing q
processing_q.append(batch)
# start server
server = Server(id, env, processing_q, server_q, param)
id += 1
server_q.append(server)
#Calculate balking probability
prob_balk = num_balk/num_cumulative_customer
#print(f'{env.now} prob_balk {prob_balk}')
list_prob_balk.append(prob_balk)
# boot up sim
trial = 0
Pb= [] #balking probability
global customer_balk_number
for x in range(1,3):
trial += 1
print('trial:', trial)
batching_q = list()
processing_q = list()
server_q = list() # servers that are still starting up
num_servers = 0 # number of server in system (both starting and serving server)
num_active_server = 0 # number of servers serving customers
num_balk = 0 # number of balking customers
num_cumulative_customer = 0 # total arriving customers
list_prob_balk = [] #list balk prob each trial
paramtest1 = param(x)
random.seed(paramtest1.RANDOM_SEED)
# create and start the model
env = simpy.Environment()
env.process(gen_arrivals(env, batching_q, processing_q, server_q, paramtest1))
env.run(30)
Pb.append(list_prob_balk[-1])
#print('List of balk prob', Pb )

Related

Getting Tensorflow To Run Faster

I have developed a machine learning python script (let's call it classify_obj written with python 3.6) that imports TensorFlow. It was developed initially for bulk analysis but now I find the need to run this script repeatedly on smaller datasets to cater for more real time usage. I am doing this on Linux RH7.
Process Flow:
Master tool (written in Java) call classify_obj with object input to categorize.
classify_obj generates the classification result as a csv (takes about 7-10s)
Master tool reads the result from #2
Master tool proceeds to do other logic
Repeat #1 with next object input
To breakdown the time taken, I switched off the main logic and just do the modules import without performing any other action. I found that the import takes about 4-5s out of the 7-10s run time on the small dataset. The classification takes about 2s. I am also looking at other ways to reduce the run time for other areas but the bulk seems to be from the import.
Import time: 4-6s
Classify time: 1s
Read, write and other logic time: 0.2s
I am thinking what options are there to reduce the import time?
One idea I had was to modify the classify_obj into a "stay alive" process. The master tool after completing all its activity will stop this process/service. The intent (not sure if this would be the case) is that all the required libraries are already loaded during the process start and when the master tool calls that process/service, it will only incur the classification time instead of needing to import the libraries repeated.
What do you think about this? Also how can I set this up on Linux RHEL 7.4? Some reference links would be greatly appreciated.
Other suggestion would be greatly appreciated.
Thanks and have a great day!
This is the solution I designed to achieve the above.
Reference: https://realpython.com/python-sockets/
I have to create 2 scripts.
1. client python script: Used to pass the raw data to be classified to the server python script using socket programming.
server python script: Loads the keras (tensorflow) lib and model at launch. Continues to stay alive until a 'stop' request from client (to exit the while loop). When the client script sends the data to the server script, server script will process the incoming data and return a ok/not ok output back to the client script.
In the end, the classification time is reduced to 0.1 - 0.3s.
Client Script
import socket
import argparse
from argparse import ArgumentParser
def main():
parser = ArgumentParser(description='XXXXX')
parser.add_argument('-i','--input', default='NA', help='Input txt file path')
parser.add_argument('-o','--output', default='NA', help='Output csv path with class')
parser.add_argument('-stop','--stop', default='no', help='Stop the server script')
args = parser.parse_args()
str = args.input + ',' + args.output + ',' + args.stop
HOST = '127.0.0.1' # The server's hostname or IP address
PORT = 65432 # The port used by the server
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect((HOST, PORT))
bytedata = str.encode()
sock.send(bytedata)
data = sock.recv(1024)
print('Received', data)
if __name__== "__main__":
main()
Server Script
def main():
HOST = '127.0.0.1' # Standard loopback interface address (localhost)
PORT = 65432 # Port to listen on (non-privileged ports are > 1023)
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.bind((HOST,PORT))
sock.listen(5)
stop_process = 'no'
while (stop_process == 'no'):
# print('Waiting for connection')
conn, addr = sock.accept()
data = ''
try:
# print('Connected by', addr)
while True:
data = conn.recv(1024)
if data:
stop_process = process_input(data) # process_input function processes incoming data. If client sends 'yes' for the stop argument, the stop_process variable will be set to 'yes' by the function.
byte_reply = stop_process.encode()
conn.sendall(byte_reply) # send reply back to client
else:
break
conn.close()
# print('Closing connection',addr)
finally:
conn.close()
if __name__== "__main__":
main()

Processing a huge file (>30GB) in Python

I need to process a huge file of around 30GB containing hundreds of millions of rows. More precisely, I want to perform the three following steps:
Reading the file by chunks: given the size of the file, I don't have the memory to read the file in one go;
Computing stuff on the chunks before aggregating each of them to a more manageable size;
Concatenating the aggregated chunks into a final dataset containing the results of my analyses.
So far, I have coded two threads :
One thread in charge of reading the file by chunks and storing the chunks in a Queue (step 1);
One thread in charge of performing the analyses (step 2) on the chunks;
Here is the spirit of my code so far with dummy data:
import queue
import threading
import concurrent.futures
import os
import random
import pandas as pd
import time
def process_chunk(df):
return df.groupby(["Category"])["Value"].sum().reset_index(drop=False)
def producer(queue, event):
print("Producer: Reading the file by chunks")
reader = pd.read_table(full_path, sep=";", chunksize=10000, names=["Row","Category","Value"])
for index, chunk in enumerate(reader):
print(f"Producer: Adding chunk #{index} to the queue")
queue.put((index, chunk))
time.sleep(0.2)
print("Producer: Finished putting chunks")
event.set()
print("Producer: Event set")
def consumer(queue, event, result_list):
# The consumer stops iff queue is empty AND event is set
# <=> The consumer keeps going iff queue is not empty OR event is not set
while not queue.empty() or not event.is_set():
try:
index, chunk = queue.get(timeout=1)
except queue.Empty:
continue
print(f"Consumer: Retrieved chunk #{index}")
print(f"Consumer: Queue size {queue.qsize()}")
result_list.append(process_chunk(chunk))
time.sleep(0.1)
print("Consumer: Finished retrieving chunks")
if __name__=="__main__":
# Record the execution time
start = time.perf_counter()
# Generate a fake file in the current directory if necessary
path = os.path.dirname(os.path.realpath(__file__))
filename = "fake_file.txt"
full_path = os.path.join(path, filename)
if not os.path.exists(full_path):
print("Main: Generate a dummy dataset")
with open(full_path, "w", encoding="utf-8") as f:
for i in range(100000):
value = random.randint(1,101)
category = i%2
f.write(f"{i+1};{value};{category}\n")
# Defining a queue that will store the chunks of the file read by the Producer
queue = queue.Queue(maxsize=5)
# Defining an event that will be set by the Producer when he is done
event = threading.Event()
# Defining a list storing the chunks processed by the Consumer
result_list = list()
# Launch the threads Producer and Consumer
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
executor.submit(producer, queue, event)
executor.submit(consumer, queue, event, result_list)
# Display that the program is finished
print("Main: Consumer & Producer have finished!")
print(f"Main: Number of processed chunks = {len(result_list)}")
print(f"Main: Execution time = {time.perf_counter()-start} seconds")
I know that each iteration of step 1 takes more time than each iteration of step 2 i.e. that the Consumer will always be waiting for the Producer.
How can I speed up the process of reading my file by chunks (step 1) ?

Why a single process can achieve multiple CPU usage of 100% on Windows Subsystem for Linux(WSL), but it can't on Ubuntu on server?

I want to achieve parallel computing by Python multiprocessing module, so I implement a simulated calculation to test whether I can use multiple CPU cores. I found a very strange thing that a single process can achieve 8 CPU usage of 100% on Windows Subsystem for Linux(WSL) on my desktop rather than only one CPU usage of 100% on Ubuntu on Lab's server.
Like this:
And this is the contrast:
Furthermore, I found that using multiple processes does not reduce the time cost on WSL on my desktop, but which indeed largely reduce the time cost on Ubuntu on Lab's server.
Like this:
(Here I run 6 processes and running a single process on Lab's server needs about 440s.)
And this is the contrast:
(Here I run 3 processes and running a single process on my desktop needs about 29s.)
Here is my Python source codes:
import numpy as np
import time
import os
import multiprocessing as mp
PROCESS_MAX = 1
LOOPS = 1
process_list = []
def simulated_calculation():
x = np.random.rand(100, 100)
y = np.random.rand(100, 100)
z = np.outer(x, y)
determinant = np.linalg.det(z)
def child_process(name):
for i in range(LOOPS):
print("The child process[%s] starts at %s and its PID is %s" % (str(name), time.ctime(), os.getpid()))
simulated_calculation()
print("The child process[%s] stops at %s and its PID is %s" %(str(name), time.ctime(), os.getpid()))
def main():
print("All start at %s" % time.ctime())
print("The parent process stars at %s and its PID is %s" % (time.ctime(), os.getpid()))
start_wall_time = time.time()
for i in range(PROCESS_MAX):
p = mp.Process(target = child_process, args = (i + 1, ))
process_list.append(p)
p.daemon = True
p.start()
for i in process_list:
i.join()
stop_wall_time = time.time()
print("All stop at %s" % time.ctime())
print("The whole runtime is %ss" % str(stop_wall_time - start_wall_time))
if __name__ == "__main__":
main()
I hope someone can help me. Thanks!
WSL1 has a virtual layer through which the Windows device drivers are being passed. WSL2 on the other hand, has more access due to a Linux kernel in place. However direct access to the hardware is inaccessible to WSL1 except USB. Hardware such as USB and GPU are currently not available to WSL2 but is being worked.

Multiprocessing code works using numpy but deadlocked using pytorch

I'm hitting what appears to be a deadlock when trying to make use of multiprocessing with pytorch. The equivalent numpy code works like I expect it to.
I've made a simplified version of my code: a pool of 4 workers executing an array-wide broadcast operation 1000 times (so ~250 each worker). The array in question is 100,000 x 3 and the broadcast operation is subtraction of all rows by a single 1 x 3 row array. The large array is a shared/global array, and the row array is different at each iteration.
The code works exactly as expected using numpy, with the pooled workers showing a 4x speedup over the equivalent for loop.
The code in pytorch, however, hits a deadlock (I assume): none of the workers complete the array broadcast operation even once.
The numpy code below prints the following:
Finished for loop over my_subtractor: took 8.1504 seconds.
Finished pool over my_subtractor: took 2.2247 seconds.
The pytorch code, on the other hand, prints this then stalls:
Finished for loop over my_subtractor: took 3.1082 seconds.
BLA
BLA
BLA
BLA
"BLA" print statements are just to show that each worker is stuck in -- apparently -- a deadlock state. There are exactly 4 of these: one per worker entering -- and getting stuck in -- an iteration.
If you feel ambitious enough to reproduce, note that it doesn't work on Windows because it's not wrapped around if __name__ == '__main__': (I read somewhere that you need this because of the way Windows handles launching processes). Also you will need to create an empty file called my_globals.py.
Here is the numpy code
from time import time
import numpy as np
import my_globals
from multiprocessing import Pool as ThreadPool
# shared memory by virtue of being global
my_globals.minuend = np.random.rand(100000,3)
# array to be iterated over in for loop / pool of workers
subtrahends = np.random.rand(10000,3)
# function called at each iteration (broadcast operation)
def my_subtractor(subtrahend):
my_globals.minuend - subtrahend
return 0
# launch for loop
ts = time()
for idx, subtrahend in enumerate(subtrahends):
my_subtractor(subtrahend)
te = time()
print('Finished for loop over my_subtractor: took %2.4f seconds.' % (te - ts))
# launch equivalent pool of workers
ts = time()
pool = ThreadPool(4)
pool.map(my_subtractor, subtrahends)
pool.close()
pool.join()
te = time()
print('Finished pool over my_subtractor: took %2.4f seconds.' % (te - ts))
Here is the equivalent pytorch code:
from time import time
import torch
import my_globals
from torch.multiprocessing import Pool as ThreadPool
# necessary on my system because it has low limits for number of file descriptors; not recommended for most systems,
# see: https://pytorch.org/docs/stable/multiprocessing.html#file-descriptor-file-descriptor
torch.multiprocessing.set_sharing_strategy('file_system')
# shared memory by virtue of being global
my_globals.minuend = torch.rand(100000,3)
# array to be iterated over in for loop / pool of workers
subtrahends = torch.rand(10000,3)
# function called at each iteration (broadcast operation)
def my_subtractor(subtrahend, verbose=True):
if verbose:
print("BLA") # -- prints for every worker in the pool (so 4 times total)
my_globals.minuend - subtrahend
if verbose:
print("ALB") # -- doesn't print for any worker
return 0
# launch for loop
ts = time()
for idx, subtrahend in enumerate(subtrahends):
my_subtractor(subtrahend, verbose=False)
te = time()
print('Finished for loop over my_subtractor: took %2.4f seconds.' % (te - ts))
# launch equivalent pool of workers
ts = time()
pool = ThreadPool(4)
pool.map(my_subtractor, subtrahends)
pool.close()
pool.join()
te = time()
print('Finished pool over my_subtractor: took %2.4f seconds.' % (te - ts))
You can try to set OMP_NUM_THREADS=1 environment variable as an attempt to crunch-fix this. It helped me with DataLoader+OpenCV deadlock.

how to use more than one ps in distributed tensorflow?

I am trying to run the distributed tensorflow. But I have some troubles.
Firstly, it can process 35 images/sec on a single GPU(GTX TITAN X),single host(intel E5-2630 v3), however running it with the distributed code can only process 26 images/sec each process on 4 GPUs ,single host. Moreover, it can process 8.5 images/sec on 2 hosts, each with 4 GPUs. So the performance of this distributed version seems very poor. Could anybody give me some suggestions that why I got such a poor result.
Secondly, I wonder whether more ps server can improve the performance. So I tried to use 2 ps server, the program was blocked with log info :
CreateSession still waiting for response from worker: /job:ps/replica:0/task:1
I ran the program on the slurm system, so I used the python multiprocessing model to start the ps server.
def get_slurm_env():
node_list = expand_hostlist(os.environ['SLURM_NODELIST'])
node_id = int(os.environ['SLURM_NODEID'])
tasks_per_node = int(os.environ['SLURM_NTASKS_PER_NODE'])
# It is difficult to assign the port and gpu id in slurm env.
# The assigned gpu in different host is not always the same, and you nerver know
# which gpu is assigned in another host.
# Different slurm job may run in the same machine, so the port num may be conflict as well
task_id = int(os.environ['SLURM_PROCID'])
task_num = int(os.environ['SLURM_NTASKS'])
visible_gpu_ids = os.environ['CUDA_VISIBLE_DEVICES'].split(',')
visible_gpu_ids = [int(gpu) for gpu in visible_gpu_ids]
worker_port_list=[FLAGS.worker_port_start + incr for incr in range(len(visible_gpu_ids))]
FLAGS.worker_hosts = ["%s:%d" % (name, port) for name in node_list for port in worker_port_list]
assert len(FLAGS.worker_hosts) == task_num, 'Job count is not equal %d : %d' % (len(FLAGS.worker_hosts), task_num)
FLAGS.worker_hosts = ','.join(FLAGS.worker_hosts)
FLAGS.ps_hosts = ["%s:%d" % (name, FLAGS.ps_port_start) for name in node_list]
FLAGS.ps_hosts = ','.join(FLAGS.ps_hosts)
FLAGS.job_name = "worker"
FLAGS.task_id = task_id
os.environ['CUDA_VISIBLE_DEVICES'] = str(visible_gpu_ids[task_id%tasks_per_node])
def ps_runner(cluster, task_id):
tf.logging.info('Setup ps process, id: %d' % FLAGS.task_id)
os.environ['CUDA_VISIBLE_DEVICES'] = ""
server = tf.train.Server(cluster, job_name="ps", task_index=task_id)
server.join()
tf.logging.info('Stop ps process, id: %d' % FLAGS.task_id)
def main(unused_args):
get_slurm_env()
# Extract all the hostnames for the ps and worker jobs to construct the
# cluster spec.
ps_hosts = FLAGS.ps_hosts.split(',')
worker_hosts = FLAGS.worker_hosts.split(',')
tf.logging.info('PS hosts are: %s' % ps_hosts)
tf.logging.info('Worker hosts are: %s' % worker_hosts)
cluster_spec = tf.train.ClusterSpec({'ps': ps_hosts,
'worker': worker_hosts})
if FLAGS.task_id == 0:
p = multiprocessing.Process(target = ps_runner, args = ({'ps': ps_hosts,'worker': worker_hosts}, 0))
p.start()
server = tf.train.Server(
{'ps': ps_hosts,
'worker': worker_hosts},
job_name=FLAGS.job_name,
task_index=FLAGS.task_id)
# `worker` jobs will actually do the work.
dataset = ImagenetData(subset=FLAGS.subset)
assert dataset.data_files()
# Only the chief checks for or creates train_dir.
if FLAGS.task_id == 0:
if not tf.gfile.Exists(FLAGS.train_dir):
tf.gfile.MakeDirs(FLAGS.train_dir)
tf.logging.info('Setup worker process, id: %d' % FLAGS.task_id)
inception_distributed_train.train(server.target, dataset, cluster_spec)
Are you willing to consider MPI based solutions which do not require distributed memory specific changes to your code for distributed tensorflow? We have recently developed a version of user-transparent distributed tensorflow using MaTEx. https://github.com/matex-org/matex
We will be able to help you, should you face any problems.