I have managed this implementation on retraining frozen graph in tensorflow 1 according to this wonderful detail topic. Basically, the methodology is described:
Load frozen model
Replace the constant frozen node with variable node.
The newly replaced variable node then will be redirected to the corresponding output of the frozen node.
This works in tensorflow 1.x by checking the tf.compat.v1.trainable_variables. However, in tensorflow 2.x, it can't work anymore.
Below is the code snippet:
1/ Load frozen model
frozen_path = '...'
detection_graph = tf.Graph()
with detection_graph.as_default():
od_graph_def = tf.compat.v1.GraphDef()
with tf.compat.v1.io.gfile.GFile(frozen_path, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.graph_util.import_graph_def(od_graph_def, name='')
2/ Create a clone
with detection_graph.as_default():
const_var_name_pairs = {}
probable_variables = [op for op in detection_graph.get_operations() if op.type == "Const"]
available_names = [op.name for op in detection_graph.get_operations()]
for op in probable_variables:
name = op.name
if name+'/read' not in available_names:
continue
tensor = detection_graph.get_tensor_by_name('{}:0'.format(name))
with tf.compat.v1.Session() as s:
tensor_as_numpy_array = s.run(tensor)
var_shape = tensor.get_shape()
# Give each variable a name that doesn't already exist in the graph
var_name = '{}_turned_var'.format(name)
var = tf.Variable(name=var_name, dtype=op.outputs[0].dtype, initial_value=tensor_as_numpy_array,trainable=True, shape=var_shape)
const_var_name_pairs[name] = var_name
3/ Relace frozen node by Graph Editor
import graph_def_editor as ge
ge_graph = ge.Graph(detection_graph.as_graph_def())
name_to_op = dict([(n.name, n) for n in ge_graph.nodes])
for const_name, var_name in const_var_name_pairs.items():
const_op = name_to_op[const_name+'/read']
var_reader_op = name_to_op[var_name + '/Read/ReadVariableOp']
ge.swap_outputs(ge.sgv(const_op), ge.sgv(var_reader_op))
detection_training_graph = ge_graph.to_tf_graph()
with detection_training_graph.as_default():
writer = tf.compat.v1.summary.FileWriter('remap', detection_training_graph )
writer.close
The problem was my Graph Editor when I import the tf.graph_def instead of the original tf.graph that has Variables.
Quickly solve by fixing step 3
Sol1: Using Graph Editor
ge_graph = ge.Graph(detection_graph)
for const_name, var_name in const_var_name_pairs.items():
const_op = ge_graph._node_name_to_node[const_name+'/read']
var_reader_op = ge_graph._node_name_to_node[var_name+'/Read/ReadVariableOp']
ge.swap_outputs(ge.sgv(const_op), ge.sgv(var_reader_op))
However, this requires disable eager execution. To work around with eager execution, you should attach the MetaGraphDef to Graph Editor as below
with detection_graph.as_default():
meta_saver = tf.compat.v1.train.Saver()
meta = meta_saver.export_meta_graph()
ge_graph = ge.Graph(detection_graph,collections=ge.graph._extract_collection_defs(meta))
However, this is the trickest to make the model trainable in tf2.x
Instead of using Graph Editor to export directly the graph, we should export ourselves. The reason is that the Graph Editor make the Variables data type to be resources. Therefore, we should export the graph as graphdef and import the variable def to the graph:
test_graph = tf.Graph()
with test_graph.as_default():
tf.import_graph_def(ge_graph.to_graph_def(), name="")
for var_name in ge_graph.variable_names:
var = ge_graph.get_variable_by_name(var_name)
ret = variable_pb2.VariableDef()
ret.variable_name = var._variable_name
ret.initial_value_name = var._initial_value_name
ret.initializer_name = var._initializer_name
ret.snapshot_name = var._snapshot_name
ret.trainable = var._trainable
ret.is_resource = True
tf_var = tf.Variable(variable_def=ret,dtype=tf.float32)
test_graph.add_to_collections(var.collection_names, tf_var)
Sol2: Manually map by Graphdef
with detection_graph.as_default() as graph:
training_graph_def = remap_input_node(detection_graph.as_graph_def(),const_var_name_pairs)
current_var = (tf.compat.v1.trainable_variables())
assert len(current_var)>0, "no training variables"
detection_training_graph = tf.Graph()
with detection_training_graph.as_default():
tf.graph_util.import_graph_def(training_graph_def, name='')
for var in current_var:
ret = variable_pb2.VariableDef()
ret.variable_name = var.name
ret.initial_value_name = var.name[:-2] + '/Initializer/initial_value:0'
ret.initializer_name = var.name[:-2] + '/Assign'
ret.snapshot_name = var.name[:-2] + '/Read/ReadVariableOp:0'
ret.trainable = True
ret.is_resource = True
tf_var = tf.Variable(variable_def=ret,dtype=tf.float32)
detection_training_graph.add_to_collections({'trainable_variables', 'variables'}, tf_var)
current_var = (tf.compat.v1.trainable_variables())
assert len(current_var)>0, "no training variables"
I am trying to do a grid search by calling model.fit recursively for different parameters of my model.
I get a resource exhausted error in tensorflow. In spite of doing del model and tf.keras.backend.clear_session() at the end of the loop. This is my code
def kfoldsplit(FRAME_PATH, MASK_PATH,k):
kfold = []
all_frames = os.listdir(FRAME_PATH)
all_masks = os.listdir(MASK_PATH)
all_frames.sort(key=lambda var: [int(x) if x.isdigit() else x
for x in re.findall(r'[^0-9]|[0-9]+', var)])
all_masks.sort(key=lambda var: [int(x) if x.isdigit() else x
for x in re.findall(r'[^0-9]|[0-9]+', var)])
random.seed(230)
random.shuffle(all_frames)
# Generate train, val, and test sets for frames
train_split = int(0.8 * len(all_frames))
#val_split = int(0.9 * len(all_frames))
#test_split = int(0.9 * len(all_frames))
train_frames = all_frames[:train_split]
#val_frames = all_frames[train_split:val_split]
test_frames = all_frames[train_split:]
# Generate corresponding mask lists for masks
train_masks = [f for f in all_masks if 'image_' + f[6:16] + 'dcm' in train_frames]
#val_masks = [f for f in all_masks if 'image_' + f[6:16] + 'dcm' in val_frames]
test_masks = [f for f in all_masks if 'image_' + f[6:16] + 'dcm' in test_frames]
size_of_subset =int(len(train_masks)/k)
for i in range (0,k):
subset = (train_frames[i*size_of_subset:(i+1)*size_of_subset],train_masks[i*size_of_subset:(i+1)*size_of_subset])
kfold.append(subset)
return kfold, (test_frames,test_masks)
def get_model_name(k):
return 'model_'+str(k)+'.hdf5'
def float_range(start, stop, step):
while start < stop:
yield float(start)
start += decimal.Decimal(step)
frames_path = 'C:/Datasets/elderlymen1/2d/images'
masks_path = 'C:/Datasets/elderlymen1/2d/FASCIA_FILLED'
kf = kfoldsplit(frames_path, masks_path, 10)
def crossvalidation(epoch,kf, loops):
VALIDATION_ACCURACY = []
VALIDATION_LOSS = []
Params=[]
save_dir = 'C:/saved_models/'
fold_var = 1
i=0
for i in float_range(0,1,0.1):
for j in float_range(1e-6,1e-3,1e-6):
#while i <= loops:
#_alpha = random.uniform(0, 1)
#lrate = random.uniform(1e-3, 1e-6)
_alpha = i
lrate = j
Params.append([_alpha,lrate])
for subset in kf[0]:
list_IDs = subset[0]
train_data_generator = DataGenerator2(list_IDs, frames_path, masks_path, to_fit=True, batch_size=2,
dim=(512, 512), dimy=(512, 512), n_channels=1, n_classes=2, shuffle=True,
data_gen_args=data_gen_args_dict)
list_IDs = kf[1][0]
valid_data_generator = DataGenerator(list_IDs, frames_path, masks_path, to_fit=True, batch_size=2,
dim=(512, 512), dimy=(512, 512), n_channels=1, n_classes=2, shuffle=True)
# CREATE NEW MODEL
model = unet(pretrained_weights='csa/unet_ThighOuterSurface.hdf5')
# COMPILE NEW MODEL
model.compile(optimizer=Adam(lr=lrate), loss=combo_loss(alpha=_alpha, beta=0.4), metrics=[dice_accuracy])
# CREATE CALLBACKS
checkpoint = tf.keras.callbacks.ModelCheckpoint(save_dir + get_model_name(fold_var),
monitor='val_loss', verbose=1,
save_best_only=True, mode='max')
callbacks_list = [checkpoint]
# There can be other callbacks, but just showing one because it involves the model name
# This saves the best model
# FIT THE MODEL
history = model.fit(train_data_generator, validation_steps=len(valid_data_generator), steps_per_epoch=len(train_data_generator),
epochs=epoch,
callbacks=callbacks_list,
validation_data=valid_data_generator)
# PLOT HISTORY
# :
# :
# LOAD BEST MODEL to evaluate the performance of the model
model.load_weights("C:/saved_models/model_" + str(fold_var) + ".hdf5")
results = model.evaluate(valid_data_generator)
results = dict(zip(model.metrics_names, results))
VALIDATION_ACCURACY.append(results['dice_accuracy'])
VALIDATION_LOSS.append(results['loss'])
tf.keras.backend.clear_session()
fold_var += 1
del model
#i+=1
print(VALIDATION_ACCURACY)
print(Params)
sample = open('metrics.txt', '+r')
print(VALIDATION_ACCURACY, file=sample)
print(Params, file=sample)
print('...',file=sample)
sample.close()
crossvalidation(15,kf, 2)
Why is the memory still exhausted and how can I release it. Or if it is not possible, is there another option for a grid search and cross validation for an image segmentation model?
Thank you
After trying everything I found in order to release memory, then only thing that solved the problem was adding
del model
gc.collect()
at the end of the for loop
I am working with TensorFlow object detection API, I have trained two different(SSD-mobilenet and FRCNN-inception-v2) models for my use case. Currently, my workflow is like this:
Take an input image, detect one particular object using SSD
mobilenet.
Crop the input image with the bounding box generated from
step 1 and then resize it to a fixed size(e.g. 200 X 300).
Feed this cropped and resized image to FRCNN-inception-V2 for detecting
smaller objects inside the ROI.
Currently at the time of inferencing, when I load two separate frozen graphs and follow the steps, I am getting my desired results. But I need only a single frozen graph because of my deployment requirement. I am new to TensorFlow and wanted to combine both graphs with crop and resizing process in between them.
Thanks, #matt and #Vedanshu for responding, Here is the updated code that works fine for my requirement, Please give suggestions, if it needs any improvement as I am still learning it.
# Dependencies
import tensorflow as tf
import numpy as np
# load graphs using pb file path
def load_graph(pb_file):
graph = tf.Graph()
with graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(pb_file, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
return graph
# returns tensor dictionaries from graph
def get_inference(graph, count=0):
with graph.as_default():
ops = tf.get_default_graph().get_operations()
all_tensor_names = {output.name for op in ops for output in op.outputs}
tensor_dict = {}
for key in ['num_detections', 'detection_boxes', 'detection_scores',
'detection_classes', 'detection_masks', 'image_tensor']:
tensor_name = key + ':0' if count == 0 else '_{}:0'.format(count)
if tensor_name in all_tensor_names:
tensor_dict[key] = tf.get_default_graph().\
get_tensor_by_name(tensor_name)
return tensor_dict
# renames while_context because there is one while function for every graph
# open issue at https://github.com/tensorflow/tensorflow/issues/22162
def rename_frame_name(graphdef, suffix):
for n in graphdef.node:
if "while" in n.name:
if "frame_name" in n.attr:
n.attr["frame_name"].s = str(n.attr["frame_name"]).replace("while_context",
"while_context" + suffix).encode('utf-8')
if __name__ == '__main__':
# your pb file paths
frozenGraphPath1 = '...replace_with_your_path/some_frozen_graph.pb'
frozenGraphPath2 = '...replace_with_your_path/some_frozen_graph.pb'
# new file name to save combined model
combinedFrozenGraph = 'combined_frozen_inference_graph.pb'
# loads both graphs
graph1 = load_graph(frozenGraphPath1)
graph2 = load_graph(frozenGraphPath2)
# get tensor names from first graph
tensor_dict1 = get_inference(graph1)
with graph1.as_default():
# getting tensors to add crop and resize step
image_tensor = tensor_dict1['image_tensor']
scores = tensor_dict1['detection_scores'][0]
num_detections = tf.cast(tensor_dict1['num_detections'][0], tf.int32)
detection_boxes = tensor_dict1['detection_boxes'][0]
# I had to add NMS becuase my ssd model outputs 100 detections and hence it runs out of memory becuase of huge tensor shape
selected_indices = tf.image.non_max_suppression(detection_boxes, scores, 5, iou_threshold=0.5)
selected_boxes = tf.gather(detection_boxes, selected_indices)
# intermediate crop and resize step, which will be input for second model(FRCNN)
cropped_img = tf.image.crop_and_resize(image_tensor,
selected_boxes,
tf.zeros(tf.shape(selected_indices), dtype=tf.int32),
[300, 60] # resize to 300 X 60
)
cropped_img = tf.cast(cropped_img, tf.uint8, name='cropped_img')
gdef1 = graph1.as_graph_def()
gdef2 = graph2.as_graph_def()
g1name = "graph1"
g2name = "graph2"
# renaming while_context in both graphs
rename_frame_name(gdef1, g1name)
rename_frame_name(gdef2, g2name)
# This combines both models and save it as one
with tf.Graph().as_default() as g_combined:
x, y = tf.import_graph_def(gdef1, return_elements=['image_tensor:0', 'cropped_img:0'])
z, = tf.import_graph_def(gdef2, input_map={"image_tensor:0": y}, return_elements=['detection_boxes:0'])
tf.train.write_graph(g_combined, "./", combinedFrozenGraph, as_text=False)
You can load output of one graph into another using input_map in import_graph_def. Also you have to rename the while_context because there is one while function for every graph. Something like this:
def get_frozen_graph(graph_file):
"""Read Frozen Graph file from disk."""
with tf.gfile.GFile(graph_file, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
return graph_def
def rename_frame_name(graphdef, suffix):
# Bug reported at https://github.com/tensorflow/tensorflow/issues/22162#issuecomment-428091121
for n in graphdef.node:
if "while" in n.name:
if "frame_name" in n.attr:
n.attr["frame_name"].s = str(n.attr["frame_name"]).replace("while_context",
"while_context" + suffix).encode('utf-8')
...
l1_graph = tf.Graph()
with l1_graph.as_default():
trt_graph1 = get_frozen_graph(pb_fname1)
[tf_input1, tf_scores1, tf_boxes1, tf_classes1, tf_num_detections1] = tf.import_graph_def(trt_graph1,
return_elements=['image_tensor:0', 'detection_scores:0', 'detection_boxes:0', 'detection_classes:0','num_detections:0'])
input1 = tf.identity(tf_input1, name="l1_input")
boxes1 = tf.identity(tf_boxes1[0], name="l1_boxes") # index by 0 to remove batch dimension
scores1 = tf.identity(tf_scores1[0], name="l1_scores")
classes1 = tf.identity(tf_classes1[0], name="l1_classes")
num_detections1 = tf.identity(tf.dtypes.cast(tf_num_detections1[0], tf.int32), name="l1_num_detections")
...
# Make your output tensor
tf_out = # your output tensor (here, crop the input image with the bounding box generated from step 1 and then resize it to a fixed size(e.g. 200 X 300).)
...
connected_graph = tf.Graph()
with connected_graph.as_default():
l1_graph_def = l1_graph.as_graph_def()
g1name = 'ved'
rename_frame_name(l1_graph_def, g1name)
tf.import_graph_def(l1_graph_def, name=g1name)
...
trt_graph2 = get_frozen_graph(pb_fname2)
g2name = 'level2'
rename_frame_name(trt_graph2, g2name)
[tf_scores, tf_boxes, tf_classes, tf_num_detections] = tf.import_graph_def(trt_graph2,
input_map={'image_tensor': tf_out},
return_elements=['detection_scores:0', 'detection_boxes:0', 'detection_classes:0','num_detections:0'])
#######
# Export the graph
with connected_graph.as_default():
print('\nSaving...')
cwd = os.getcwd()
path = os.path.join(cwd, 'saved_model')
shutil.rmtree(path, ignore_errors=True)
inputs_dict = {
"image_tensor": tf_input
}
outputs_dict = {
"detection_boxes_l1": tf_boxes_l1,
"detection_scores_l1": tf_scores_l1,
"detection_classes_l1": tf_classes_l1,
"max_num_detection": tf_max_num_detection,
"detection_boxes_l2": tf_boxes_l2,
"detection_scores_l2": tf_scores_l2,
"detection_classes_l2": tf_classes_l2
}
tf.saved_model.simple_save(
tf_sess_main, path, inputs_dict, outputs_dict
)
print('Ok')
In Tensorflow, my model is based on a pre-trained model, and I added a few more variables and remove some in the pre-trained model. When I restore the variables from the checkpoint file, I have to explicitly specify all variables I added to the graph that need to be excluded. For example, I did
exclude = # explicitly list all variables to exclude
variables_to_restore = slim.get_variables_to_restore(exclude=exclude)
saver = tf.train.Saver(variables_to_restore)
Is there a simpler way to do this? Namely, as long as a variable is not in checkpoint, then don't try to restore.
You should first find out all those variable that are useful(meaning also in your graph) and then add the joint set of the intersection of the two from the checkpoint rather than all from it.
variables_can_be_restored = list(set(tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES)).intersection(tf.train.list_variables(checkpoint_dir)))
then restore it after defining a saver like this:
temp_saver = tf.train.Saver(variables_can_be_restored)
ckpt_state = tf.train.get_checkpoint_state(checkpoint_dir, lastest_filename)
print('Loading checkpoint %s' % ckpt_state.model_checkpoint_path)
temp_saver.restore(sess, ckpt_state.model_checkpoint_path)
The only thing that you can do is firstly having the same model as in the checkpoint, secondly restoring the checkpoint values to the same model. After restoring the variables for the same model, you can add new layers, delete existing layers or change the weights of the layers.
But there is an important point that you need to be careful. After added new layers you need to initialize them. If you use tf.global_variables_initializer(), you will lose the values of reloaded layers. So you should only initialize the uninitialized weights, you can use following function for this.
def initialize_uninitialized(sess):
global_vars = tf.global_variables()
is_not_initialized = sess.run([tf.is_variable_initialized(var) for var in global_vars])
not_initialized_vars = [v for (v, f) in zip(global_vars, is_not_initialized) if not f]
# for i in not_initialized_vars: # only for testing
# print(i.name)
if len(not_initialized_vars):
sess.run(tf.variables_initializer(not_initialized_vars))
This is more full answer, that works for not-distributed setting:
from tensorflow.contrib.framework.python.framework import checkpoint_utils
slim = tf.contrib.slim
def scan_checkpoint_for_vars(checkpoint_path, vars_to_check):
check_var_list = checkpoint_utils.list_variables(checkpoint_path)
check_var_list = [x[0] for x in check_var_list]
check_var_set = set(check_var_list)
vars_in_checkpoint = [x for x in vars_to_check if x.name[:x.name.index(":")] in check_var_set]
vars_not_in_checkpoint = [x for x in vars_to_check if x.name[:x.name.index(":")] not in check_var_set]
return vars_in_checkpoint, vars_not_in_checkpoint
def create_easy_going_scaffold(vars_in_checkpoint, vars_not_in_checkpoint):
model_ready_for_local_init_op = tf.report_uninitialized_variables(var_list = vars_in_checkpoint)
model_init_vars_not_in_checkpoint = tf.variables_initializer(vars_not_in_checkpoint)
restoration_saver = tf.train.Saver(vars_in_checkpoint)
eg_scaffold = tf.train.Scaffold(saver=restoration_saver,
ready_for_local_init_op = model_ready_for_local_init_op,
local_init_op = model_init_vars_not_in_checkpoint)
return eg_scaffold
all_vars = slim.get_variables()
ckpoint_file = tf.train.latest_checkpoint(output_chkpt_dir)
vars_in_checkpoint, vars_not_in_checkpoint = scan_checkpoint_for_vars(ckpoint_file, all_vars)
is_checkpoint_complete = len(vars_not_in_checkpoint) == 0
# Create session that can handle current checkpoint
if (is_checkpoint_complete):
# Checkpoint is full - all variables can be found there
print('Using normal session')
sess = tf.train.MonitoredTrainingSession(checkpoint_dir = output_chkpt_dir,
save_checkpoint_secs = save_checkpoint_secs,
save_summaries_secs = save_summaries_secs)
else:
# Checkpoint is partial - some variables need to be initialized
print('Using easy going session')
eg_scaffold = create_easy_going_scaffold(vars_in_checkpoint, vars_not_in_checkpoint)
# Save all variables to next checkpoint
saver = tf.train.Saver()
hooks = [tf.train.CheckpointSaverHook(checkpoint_dir = output_chkpt_dir,
save_secs = save_checkpoint_secs,
saver = saver)]
# Such session is a little slower during the first iteration
sess = tf.train.MonitoredTrainingSession(checkpoint_dir = output_chkpt_dir,
scaffold = eg_scaffold,
hooks = hooks,
save_summaries_secs = save_summaries_secs,
save_checkpoint_secs = None)
with sess:
.....
I'm trying to build a convolutional lstm autoencoder (that also predicts future and past) with Tensorflow, and it works to a certain degree, but the error sometimes jumps back up, so essentially, it never converges.
The model is as follows:
The encoder starts with a 64x64 frame from a 20 frame bouncing mnist video for each time step of the lstm. Every stacking layer of LSTM halfs it and increases the depth via 2x2 convolutions with a stride of 2. (so -->32x32x3 -->...--> 1x1x96)
On the other hand, the lstm performs 3x3 convolutions with a stride of 1 on its state. Both results are concatenated to form the new state. In the same way, the decoder uses transposed convolutions to go back to the original format. Then the squared error is calculated.
The error starts at around 2700 and it takes around 20 hours (geforce1060) to get down to ~1700. At which point the jumping back up (and it sometimes jumps back up to 2300 or even ridiculous values like 440300) happens often enough that I can't really get any lower. Also at that point, it can usually pinpoint where the number should be, but its too fuzzy to actually make out the digit...
I tried different learning rates and optimizers, so if anybody knows why that jumping happens, that'd make me happy :)
Here is a graph of the loss with epochs:
import tensorflow as tf
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
#based on code by loliverhennigh (Github)
class ConvCell(tf.contrib.rnn.RNNCell):
count = 0 #exists only to remove issues with variable scope
def __init__(self, shape, num_features, transpose = False):
self.shape = shape
self.num_features = num_features
self._state_is_tuple = True
self._transpose = transpose
ConvCell.count+=1
self.count = ConvCell.count
#property
def state_size(self):
return (tf.contrib.rnn.LSTMStateTuple(self.shape[0:4],self.shape[0:4]))
#property
def output_size(self):
return tf.TensorShape(self.shape[1:4])
#here comes to the actual conv lstm implementation, if transpose = true, it performs a deconvolution on the input
def __call__(self, inputs, state, scope=None):
with tf.variable_scope(scope or type(self).__name__+str(self.count)):
c, h = state
state_shape = h.shape
input_shape = inputs.shape
#filter variables and convolutions on data coming from the same cell, a time step previous
h_filters = tf.get_variable("h_filters",[3,3,state_shape[3],self.num_features])
h_filters_gates = tf.get_variable("h_filters_gates",[3,3,state_shape[3],3])
h_partial = tf.nn.conv2d(h,h_filters,[1,1,1,1],'SAME')
h_partial_gates = tf.nn.conv2d(h,h_filters_gates,[1,1,1,1],'SAME')
c_filters = tf.get_variable("c_filters",[3,3,state_shape[3],3])
c_partial = tf.nn.conv2d(c,c_filters,[1,1,1,1],'SAME')
#filters and convolutions/deconvolutions on data coming fromthe cell input
if self._transpose:
x_filters = tf.get_variable("x_filters",[2,2,self.num_features,input_shape[3]])
x_filters_gates = tf.get_variable("x_filters_gates",[2,2,3,input_shape[3]])
x_partial = tf.nn.conv2d_transpose(inputs,x_filters,[int(state_shape[0]),int(state_shape[1]),int(state_shape[2]),self.num_features],[1,2,2,1],'VALID')
x_partial_gates = tf.nn.conv2d_transpose(inputs,x_filters_gates,[int(state_shape[0]),int(state_shape[1]),int(state_shape[2]),3],[1,2,2,1],'VALID')
else:
x_filters = tf.get_variable("x_filters",[2,2,input_shape[3],self.num_features])
x_filters_gates = tf.get_variable("x_filters_gates",[2,2,input_shape[3],3])
x_partial = tf.nn.conv2d(inputs,x_filters,[1,2,2,1],'VALID')
x_partial_gates = tf.nn.conv2d(inputs,x_filters_gates,[1,2,2,1],'VALID')
#some more lstm gate business
gate_bias = tf.get_variable("gate_bias",[1,1,1,3])
h_bias = tf.get_variable("h_bias",[1,1,1,self.num_features*2])
gates = h_partial_gates + x_partial_gates + c_partial + gate_bias
i,f,o = tf.split(gates,3,axis=3)
#concatenate the units coming from the spacial and the temporal dimension to build a unified state
concat = tf.concat([h_partial,x_partial],3) + h_bias
new_c = tf.nn.relu(concat)*tf.sigmoid(i)+c*tf.sigmoid(f)
new_h = new_c * tf.sigmoid(o)
new_state = tf.contrib.rnn.LSTMStateTuple(new_c,new_h)
return new_h, new_state #its redundant, but thats how tensorflow likes it, apparently
#global variables
LEARNING_RATE = 0.005
ITERATIONS_PER_EPOCH = 80
BATCH_SIZE = 75
TEST = False #manual switch to go from training to testing
if TEST:
BATCH_SIZE = 1
inputs = tf.placeholder(tf.float32, (20, BATCH_SIZE, 64, 64,1))
shape0 = [BATCH_SIZE,64,64,2]
shape1 = [BATCH_SIZE,32,32,6]
shape2 = [BATCH_SIZE,16,16,12]
shape3 = [BATCH_SIZE,8,8,24]
shape4 = [BATCH_SIZE,4,4,48]
shape5 = [BATCH_SIZE,2,2,96]
shape6 = [BATCH_SIZE,1,1,192]
#apparently tf.multirnncell has very specific requirements for the initial states oO
initial_state1 = (tf.contrib.rnn.LSTMStateTuple(tf.zeros(shape1),tf.zeros(shape1)),tf.contrib.rnn.LSTMStateTuple(tf.zeros(shape2),tf.zeros(shape2)),tf.contrib.rnn.LSTMStateTuple(tf.zeros(shape3),tf.zeros(shape3)),tf.contrib.rnn.LSTMStateTuple(tf.zeros(shape4),tf.zeros(shape4)),tf.contrib.rnn.LSTMStateTuple(tf.zeros(shape5),tf.zeros(shape5)),tf.contrib.rnn.LSTMStateTuple(tf.zeros(shape6),tf.zeros(shape6)))
initial_state2 = (tf.contrib.rnn.LSTMStateTuple(tf.zeros(shape5),tf.zeros(shape5)),tf.contrib.rnn.LSTMStateTuple(tf.zeros(shape4),tf.zeros(shape4)),tf.contrib.rnn.LSTMStateTuple(tf.zeros(shape3),tf.zeros(shape3)),tf.contrib.rnn.LSTMStateTuple(tf.zeros(shape2),tf.zeros(shape2)),tf.contrib.rnn.LSTMStateTuple(tf.zeros(shape1),tf.zeros(shape1)),tf.contrib.rnn.LSTMStateTuple(tf.zeros(shape0),tf.zeros(shape0)))
#encoding part of the autoencoder graph
cell1 = ConvCell(shape1,3)
cell2 = ConvCell(shape2,6)
cell3 = ConvCell(shape3,12)
cell4 = ConvCell(shape4,24)
cell5 = ConvCell(shape5,48)
cell6 = ConvCell(shape6,96)
mcell = tf.contrib.rnn.MultiRNNCell([cell1,cell2,cell3,cell4,cell5,cell6])
rnn_outputs, rnn_states = tf.nn.dynamic_rnn(mcell, inputs[0:20,:,:,:],initial_state=initial_state1,dtype=tf.float32, time_major=True)
#decoding part of the autoencoder graph, forward block and backwards block
cell9a = ConvCell(shape5,48,transpose = True)
cell10a = ConvCell(shape4,24,transpose = True)
cell11a = ConvCell(shape3,12,transpose = True)
cell12a = ConvCell(shape2,6,transpose = True)
cell13a = ConvCell(shape1,3,transpose = True)
cell14a = ConvCell(shape0,1,transpose = True)
mcella = tf.contrib.rnn.MultiRNNCell([cell9a,cell10a,cell11a,cell12a,cell13a,cell14a])
cell9b = ConvCell(shape5,48,transpose = True)
cell10b = ConvCell(shape4,24,transpose = True)
cell11b= ConvCell(shape3,12,transpose = True)
cell12b = ConvCell(shape2,6,transpose = True)
cell13b = ConvCell(shape1,3,transpose = True)
cell14b = ConvCell(shape0,1,transpose = True)
mcellb = tf.contrib.rnn.MultiRNNCell([cell9b,cell10b,cell11b,cell12b,cell13b,cell14b])
def PredictionLayer(rnn_outputs,viewPoint = 11, reverse = False):
predLength = viewPoint-2 if reverse else 20-viewPoint #vision is the input for the decoder
vision = tf.concat([rnn_outputs[viewPoint-1:viewPoint,:,:,:],tf.zeros([predLength,BATCH_SIZE,1,1,192])],0)
if reverse:
rnn_outputs2, rnn_states = tf.nn.dynamic_rnn(mcellb, vision, initial_state = initial_state2, time_major=True)
else:
rnn_outputs2, rnn_states = tf.nn.dynamic_rnn(mcella, vision, initial_state = initial_state2, time_major=True)
mean = tf.reduce_mean(rnn_outputs2,4)
if TEST:
return mean
if reverse:
return tf.reduce_sum(tf.square(mean-inputs[viewPoint-2::-1,:,:,:,0]))
else:
return tf.reduce_sum(tf.square(mean-inputs[viewPoint-1:20,:,:,:,0]))
if TEST:
mean = tf.concat([PredictionLayer(rnn_outputs,11,True)[::-1,:,:,:],createPredictionLayer(rnn_outputs,11)],0)
else: #training part of the graph
error = tf.zeros([1])
for i in range(8,15): #range size of 7 or less works, 9 or more does not, no idea why
error += PredictionLayer(rnn_outputs, i)
error += PredictionLayer(rnn_outputs, i, True)
train_fn = tf.train.RMSPropOptimizer(learning_rate=LEARNING_RATE).minimize(error)
################################################################################
## TRAINING LOOP ##
################################################################################
#code based on siemanko/tf_lstm.py (Github)
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.8)
saver = tf.train.Saver(restore_sequentially=True, allow_empty=True,)
session = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
session.run(tf.global_variables_initializer())
vids = np.load("mnist_test_seq.npy") #20/10000/64/64 , moving mnist dataset from http://www.cs.toronto.edu/~nitish/unsupervised_video/
vids = vids[:,0:6000,:,:] #training set
saver.restore(session,tf.train.latest_checkpoint('./conv_lstm_multiples_v2/'))
#saver.restore(session,'.\conv_lstm_multiples\iteration-74')
for epoch in range(1000):
if TEST:
break
epoch_error = 0
#randomize batches each epoch
vids = np.swapaxes(vids,0,1)
np.random.shuffle(vids)
vids = np.swapaxes(vids,0,1)
for i in range(ITERATIONS_PER_EPOCH):
#running the graph and feeding data
err,_ = session.run([error, train_fn], {inputs: np.expand_dims(vids[:,i*BATCH_SIZE:(i+1)*BATCH_SIZE,:,:],axis=4)})
print(err)
epoch_error += err
#training error each epoch and regular saving
epoch_error /= (ITERATIONS_PER_EPOCH*BATCH_SIZE*4096*20*7)
if (epoch+1) % 5 == 0:
saver.save(session,'.\conv_lstm_multiples_v2\iteration',global_step=epoch)
print("saved")
print("Epoch %d, train error: %f" % (epoch, epoch_error))
#testing
plt.ion()
f, axarr = plt.subplots(2)
vids = np.load("mnist_test_seq.npy")
for i in range(6000,10000):
img = session.run([mean], {inputs: np.expand_dims(vids[:,i:i+1,:,:],axis=4)})
for j in range(20):
axarr[0].imshow(img[0][j,0,:,:])
axarr[1].imshow(vids[j,i,:,:])
plt.show()
plt.pause(0.1)
Usually this happens when gradients' magnitude is very high at some point and causes your network parameters to change a lot. To verify that it is indeed the case, you can produce the same plot of gradient magnitudes and see if they jump right before the loss jump. Assuming this is the case, the classic approach is to use gradient clipping (or go all the way to natural gradient).