How to avoid memory leakage in an autoregressive model within tensorflow - tensorflow

Recently, I am training a LSTM with attention mechanism for regressionin tensorflow 2.9 and I met an problem during training with model.fit():
At the beginning, the training time is okay, like 7s/step. However, it was increasing during the process and after several steps, like 1000, the value might be 50s/step. Here below is a part of the code for my model:
class AttentionModel(tf.keras.Model):
def __init__(self, encoder_output_dim, dec_units, dense_dim, batch):
super().__init__()
self.dense_dim = dense_dim
self.batch = batch
encoder = Encoder(encoder_output_dim)
decoder = Decoder(dec_units,dense_dim)
self.encoder = encoder
self.decoder = decoder
def call(self, inputs):
# Creat a tensor to record the result
tempt = list()
encoder_output, encoder_state = self.encoder(inputs)
new_features = np.zeros((self.batch, 1, 1))
dec_initial_state = encoder_state
for i in range(6):
dec_inputs = DecoderInput(new_features=new_features, enc_output=encoder_output)
dec_result, dec_state = self.decoder(dec_inputs, dec_initial_state)
tempt.append(dec_result.logits)
new_features = dec_result.logits
dec_initial_state = dec_state
result=tf.concat(tempt,1)
return result
In the official documents for tf.function, I notice: "Don't rely on Python side effects like object mutation or list appends".
Since I use a dynamic python list with append() to record the intermediate variables, I guess each time during training, a new tf.graph was added. Is the reason my training is getting slower and slower?
Additionally, what should I use instead of python list to avoid this? I have tried with a numpy.zeros matrix but it will lead to another problem:
tempt = np.zeros(shape=(1,6))
...
for i in range(6):
dec_inputs = DecoderInput(new_features=new_features, enc_output=encoder_output)
dec_result, dec_state = self.decoder(dec_inputs, dec_initial_state)
tempt[i]=(dec_result.logits)
...
Cannot convert a symbolic tf.Tensor (decoder/dense_3/BiasAdd:0) to a numpy array. This error may indicate that you're trying to pass a Tensor to a NumPy call, which is not supported.

Related

Metrics using batches v/s metrics using full dataset

I am using training an image classification model using the pre-trained mobile network. During training, I am seeing very high values (more than 70%) for Accuracy, Precision, Recall, and F1-score on both the training dataset and validation dataset.
For me, this is an indication that my model is learning fine.
But when I checked these metrics on an Unbatched training and Unbatched Validation these metrics are very low. These are not even 1%.
Unbatched dataset means I am not taking calculating these metrics over batches and not taking the average of metrics to calculate the final metrics which is what Tensorflow/Keras does during model training. I am calculating these metrics on a full dataset in a single run
I am unable to find out what is causing this Behaviour. Please help me understand what is causing this difference and how to ensure that results are consistent on both, i.e. a minor difference is acceptable.
Code that I used for evaluating metrics
My old code
def test_model(model, data, CLASSES, label_one_hot=True, average="micro",
threshold_analysis=False, thres_analysis_start_point=0.0,
thres_analysis_end_point=0.95, thres_step=0.05, classwise_analysis=False,
produce_confusion_matrix=False):
images_ds = data.map(lambda image, label: image)
labels_ds = data.map(lambda image, label: label).unbatch()
NUM_VALIDATION_IMAGES = count_data_items(tf_records_filenames=data)
cm_correct_labels = next(iter(labels_ds.batch(NUM_VALIDATION_IMAGES))).numpy() # get everything as one batch
if label_one_hot is True:
cm_correct_labels = np.argmax(cm_correct_labels, axis=-1)
cm_probabilities = model.predict(images_ds)
cm_predictions = np.argmax(cm_probabilities, axis=-1)
warnings.filterwarnings('ignore')
overall_score = f1_score(cm_correct_labels, cm_predictions, labels=range(len(CLASSES)), average=average)
overall_precision = precision_score(cm_correct_labels, cm_predictions, labels=range(len(CLASSES)), average=average)
overall_recall = recall_score(cm_correct_labels, cm_predictions, labels=range(len(CLASSES)), average=average)
# cmat = (cmat.T / cmat.sum(axis=1)).T # normalized
# print('f1 score: {:.3f}, precision: {:.3f}, recall: {:.3f}'.format(score, precision, recall))
overall_test_results = {'overall_f1_score': overall_score, 'overall_precision':overall_precision, 'overall_recall':overall_recall}
if classwise_analysis is True:
label_index_dict = get_index_label_from_tf_record(dataset=data)
label_index_dict = {k:v for k, v in sorted(list(label_index_dict.items()))}
label_index_df = pd.DataFrame(label_index_dict, index=[0]).T.reset_index().rename(columns={'index':'class_ind', 0:'class_names'})
# Class wise precision, recall and f1_score
classwise_score = f1_score(cm_correct_labels, cm_predictions, labels=range(len(CLASSES)), average=None)
classwise_precision = precision_score(cm_correct_labels, cm_predictions, labels=range(len(CLASSES)), average=None)
classwise_recall = recall_score(cm_correct_labels, cm_predictions, labels=range(len(CLASSES)), average=None)
ind_class_count_df = class_ind_counter_from_tfrecord(data)
ind_class_count_df = ind_class_count_df.merge(label_index_df, how='left', left_on='class_names', right_on='class_names')
classwise_test_results = {'classwise_f1_score':classwise_score, 'classwise_precision':classwise_precision,
'classwise_recall':classwise_recall, 'class_names':CLASSES}
classwise_test_results_df = pd.DataFrame(classwise_test_results)
if produce_confusion_matrix is True:
cmat = confusion_matrix(cm_correct_labels, cm_predictions, labels=range(len(CLASSES)))
return overall_test_results, classwise_test_results, cmat
return overall_test_results, classwise_test_results
if produce_confusion_matrix is True:
cmat = confusion_matrix(cm_correct_labels, cm_predictions, labels=range(len(CLASSES)))
return overall_test_results, cmat
warnings.filterwarnings('always')
return overall_test_results
Just to ensure that my model testing function is correct I write a newer version of code in TensorFlow.
def eval_model(y_true, y_pred):
eval_results = {}
unbatch_accuracy = tf.keras.metrics.CategoricalAccuracy(name='unbatch_accuracy')
unbatch_recall = tf.keras.metrics.Recall(name='unbatch_recall')
unbatch_precision = tf.keras.metrics.Precision(name='unbatch_precision')
unbatch_f1_micro = tfa.metrics.F1Score(name='unbatch_f1_micro', num_classes=n_labels, average='micro')
unbatch_f1_macro = tfa.metrics.F1Score(name='unbatch_f1_macro', num_classes=n_labels, average='macro')
unbatch_accuracy.update_state(y_true, y_pred)
unbatch_recall.update_state(y_true, y_pred)
unbatch_precision.update_state(y_true, y_pred)
unbatch_f1_micro.update_state(y_true, y_pred)
unbatch_f1_macro.update_state(y_true, y_pred)
eval_results['unbatch_accuracy'] = unbatch_accuracy.result().numpy()
eval_results['unbatch_recall'] = unbatch_recall.result().numpy()
eval_results['unbatch_precision'] = unbatch_precision.result().numpy()
eval_results['unbatch_f1_micro'] = unbatch_f1_micro.result().numpy()
eval_results['unbatch_f1_macro'] = unbatch_f1_macro.result().numpy()
unbatch_accuracy.reset_states()
unbatch_recall.reset_states()
unbatch_precision.reset_states()
unbatch_f1_micro.reset_states()
unbatch_f1_macro.reset_states()
return eval_results
The results are nearly the same by using both of the functions.
Please suggest what is going on here.
I think this sugesstion MAY help you, I am not sure. in this, you added
unbatch_accuracy.reset_states()
unbatch_recall.reset_states()
unbatch_precision.reset_states()
unbatch_f1_micro.reset_states()
unbatch_f1_macro.reset_states()
resetting states at each epoch maybe not be a cumulative one
After spending many hours, I found the issue was due to the shuffle function. I was using the below function to shuffle, batch and prefetch the dataset.
def shuffle_batch_prefetch(dataset, prefetch_size=1, batch_size=16,
shuffle_buffer_size=None,
drop_remainder=False,
interleave_num_pcall=None):
if shuffle_buffer_size is None:
raise ValueError("shuffle_buffer_size can't be None")
def shuffle_fn(ds):
return ds.shuffle(buffer_size=shuffle_buffer_size, seed=108)
dataset = dataset.apply(shuffle_fn)
dataset = dataset.batch(batch_size, drop_remainder=drop_remainder)
dataset = dataset.prefetch(buffer_size=prefetch_size)
return dataset
Part of the function that causes the problem
def shuffle_fn(ds):
return ds.shuffle(buffer_size=shuffle_buffer_size, seed=108)
dataset = dataset.apply(shuffle_fn)
I removed the shuffle part and metrics are back as per the expectation.
Function after removing the shuffle part
def shuffle_batch_prefetch(dataset, prefetch_size=1, batch_size=16,
drop_remainder=False,
interleave_num_pcall=None):
dataset = dataset.batch(batch_size, drop_remainder=drop_remainder)
dataset = dataset.prefetch(buffer_size=prefetch_size)
return dataset
Results after removing the shuffle part
I am still not able to understand why shuffling causes this error. Shuffling was the best practice to follow before training your data. Although, I have already shuffled training data during data read time so removing this was not a problem for me

Tensorflow, Keras: In a multi-class classification, accuracy is high, but precision, recall, and f1-score is zero for most classes

General Explanation:
My codes work fine, but the results are wired. I don't know the problem is with
the network structure,
or the way I feed the data to the network,
or anything else.
I am struggling with this error several weeks and so far I have changed the loss function, optimizer, data generator, etc., but I could not solve it. I appreciate any help.
If the following information is not enough, let me know, please.
Field of study:
I am using tensorflow, keras for multiclass classification. The dataset has 36 binary human attributes. I have used resnet50, then for each part of the body (head, upper body, lower body, shoes, accessories), I have added a separated branch to the network. The network has 1 input image with 36 labels and 36 output nodes (36 denes layers with sigmoid activation).
Problem:
The problem is that the accuracy that keras is reporting is high, but f1-score is very low or zero for most of the outputs (even when I use f1-score as a metric when compiling the network, the f1-socre for validation is very bad).
aAfter train, when I use the network in prediction mode, it returns always one/zero for some classes. It means that the network is not able to learn (even when I use weighted loss function or focal loss function.)
Why it is weird? Because, state-of-the-art methods report heigh f1 score even after the first epoch (e.g. https://github.com/chufengt/iccv19_attribute, that I have run it in my PC and got good results after one epoch).
Parts of the Codes:
print("setup model ...")
input_image = KL.Input(args.img_input_shape, name= "input_1")
C1, C2, C3, C4, C5 = resnet_graph(input_image, architecture="resnet50", stage5=False, train_bn=True)
output_layers = merged_model (input_features=C4)
model = Model(inputs=input_image, outputs=output_layers, name='SoftBiometrics_Model')
...
print("model compiling ...")
OPTIM = optimizers.Adadelta(lr=args.learning_rate, rho=0.95)
model.compile(optimizer=OPTIM, loss=binary_focal_loss(alpha=.25, gamma=2), metrics=['acc',get_f1])
plot_model(model, to_file='model.png')
...
img_datagen = ImageDataGenerator(rotation_range=6, width_shift_range=0.03, height_shift_range=0.03, brightness_range=[0.85,1.15], shear_range=0.06, zoom_range=0.09, horizontal_flip=True, preprocessing_function=preprocess_input_resnet, rescale=1/255.)
img_datagen_test = ImageDataGenerator(preprocessing_function=preprocess_input_resnet, rescale=1/255.)
def multiple_outputs(generator, dataframe, batch_size, x_col):
Gen = generator.flow_from_dataframe(dataframe=dataframe,
directory=None,
x_col = x_col,
y_col = args.Categories,
target_size = (args.img_input_shape[0],args.img_input_shape[1]),
class_mode = "multi_output",
classes=None,
batch_size = batch_size,
shuffle = True)
while True:
gnext = Gen.next()
# return image batch and 36 sets of lables
labels = gnext[1]
output_dict = {"{}_output".format(Category): np.array(labels[index]) for index, Category in enumerate(args.Categories)}
yield {'input_1':gnext[0]}, output_dict
trainGen = multiple_outputs (generator = img_datagen, dataframe=Train_df_img, batch_size=args.BATCH_SIZE, x_col="Train_Filenames")
testGen = multiple_outputs (generator = img_datagen_test, dataframe=Test_df_img, batch_size=args.BATCH_SIZE, x_col="Test_Filenames")
STEP_SIZE_TRAIN = len(Train_df_img["Train_Filenames"]) // args.BATCH_SIZE
STEP_SIZE_VALID = len(Test_df_img["Test_Filenames"]) // args.BATCH_SIZE
...
print("Fitting the model to the data ...")
history = model.fit_generator(generator=trainGen,
epochs=args.Number_of_epochs,
steps_per_epoch=STEP_SIZE_TRAIN,
validation_data=testGen,
validation_steps=STEP_SIZE_VALID,
callbacks= [chekpont],
verbose=1)
There is a possibility that you are passing binary f1-score to compile function. This should fix the problem -
pip install tensorflow-addons
...
import tensorflow_addons as tfa
f1 = tfa.metrics.F1Score(36,'micro' or 'macro')
model.compile(...,metrics=[f1])
You can read more about how f1-micro and f1-macro is calculated and which can be useful here.
Somehow, the predict_generator() of Keras' model does not work as expected. I would rather loop through all test images one-by-one and get the prediction for each image in each iteration. I am using Plaid-ML Keras as my backend and to get prediction I am using the following code.
import os
from PIL import Image
import keras
import numpy
print("Prediction result:")
dir = "/path/to/test/images"
files = os.listdir(dir)
correct = 0
total = 0
#dictionary to label all traffic signs class.
classes = {
0:'This is Cat',
1:'This is Dog',
}
for file_name in files:
total += 1
image = Image.open(dir + "/" + file_name).convert('RGB')
image = image.resize((100,100))
image = numpy.expand_dims(image, axis=0)
image = numpy.array(image)
image = image/255
pred = model.predict_classes([image])[0]
sign = classes[pred]
if ("cat" in file_name) and ("cat" in sign):
print(correct,". ", file_name, sign)
correct+=1
elif ("dog" in file_name) and ("dog" in sign):
print(correct,". ", file_name, sign)
correct+=1
print("accuracy: ", (correct/total))

Why am I getting shape errors when trying to pass a batch from the Tensorflow Dataset API to my session operations?

I am dealing with an issue in my conversion over to the Dataset API and I guess I just don't have enough experience yet with the API to know how to handle the below situation. We currently have image augmentation that we perform currently using queueing and batching. I was tasked with checking out the new Dataset API and converting over our existing implementation using it rather than queues.
What we would like to do is get a reference to all the paths and handle all operations from just that reference. As you see in the dataset initialization, I have mapped the parse_fn to the dataset itself which then goes about reading the file and extracting the initial values from the filenames. However when I then go about calling the iterators next_batch method and then pass those values to get_summary, I'm now getting an error around shape. I have been trying a number of things which just keeps changing the error and so I felt I should see if anyone on SO saw possibly that I was going about this all wrong and should be taking a different route. Does anything jump out as absolutely wrong in my use of the Dataset API?
Should I not be calling the ops this way any longer? I noticed the majority of the examples I saw they would get the batch, pass the variables to the op and then capture that in a variable and pass that to sess.run, however I haven't found an easy way of doing that as of yet with our setup that wasn't erroring so this was the approach I took instead (but its still erroring). I'll be continuing to try to trace down the problem and post here should I find anything, but if anyone sees something please advise. Thanks!
Current Error:
... in get_summary summary, acc = sess.run([self._summary_op,
self._accuracy], feed_dict=feed_dict) ValueError: Cannot feed value of
shape (32,) for Tensor 'ph_input_labels:0', which has shape '(?, 1)
Below is the block where the get_summary method is called and error is fired:
def perform_train():
if __name__ == '__main__':
#Get all our image paths
filenames = data_layer_train.get_image_paths()
next_batch, iterator = preproc_image_fn(filenames=filenames)
with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
with sess.graph.as_default():
# Set the random seed for tensorflow
tf.set_random_seed(cfg.RNG_SEED)
classifier_network = c_common.create_model(len(products_to_class_dict), is_training=True)
optimizer, global_step_var = c_common.create_optimizer(classifier_network)
sess.run(tf.local_variables_initializer())
sess.run(tf.global_variables_initializer())
# Init tables and dataset iterator
sess.run(tf.tables_initializer())
sess.run(iterator.initializer)
cur_epoch = 0
blobs = None
try:
epoch_size = data_layer_train.get_steps_per_epoch()
num_steps = num_epochs * epoch_size
for step in range(num_steps):
timer_summary.tic()
if blobs is None:
#Now populate from our training dataset
blobs = sess.run(next_batch)
# *************** Below is where it is erroring *****************
summary_train, acc = classifier_network.get_summary(sess, blobs["images"], blobs["labels"], blobs["weights"])
...
Believe the error is in preproc_image_fn:
def preproc_image_fn(filenames, images=None, labels=None, image_paths=None, cells=None, weights=None):
def _parse_fn(filename, label, weight):
augment_instance = False
paths=[]
selected_cells=[]
if vals.FIRST_ITER:
#Perform our check of the path to see if _data_augmentation is within it
#If so set augment_instance to true and replace the substring with an empty string
new_filename = tf.regex_replace(filename, "_data_augmentation", "")
contains = tf.equal(tf.size(tf.string_split([filename], "")), tf.size(tf.string_split([new_filename])))
filename = new_filename
if contains is True:
augment_instance = True
core_file = tf.string_split([filename], '\\').values[-1]
product_id = tf.string_split([core_file], ".").values[0]
label = search_tf_table_for_entry(product_id)
weight = data_layer_train.get_weights(product_id)
image_string = tf.read_file(filename)
img = tf.image.decode_image(image_string, channels=data_layer_train._channels)
img.set_shape([None, None, None])
img = tf.image.resize_images(img, [data_layer_train._target_height, data_layer_train._target_width])
#Previously I was returning the below, but I was getting an error from the op when assigning feed_dict stating that it didnt like the dictionary
#retval = dict(zip([filename], [img])), label, weight
retval = img, label, weight
return retval
num_files = len(filenames)
filenames = tf.constant(filenames)
#*********** Setup dataset below ************
dataset = tf.data.Dataset.from_tensor_slices((filenames, labels, weights))
dataset=dataset.map(_parse_fn)
dataset = dataset.repeat()
dataset = dataset.batch(32)
iterator = dataset.make_initializable_iterator()
batch_features, batch_labels, batch_weights = iterator.get_next()
return {'images': batch_features, 'labels': batch_labels, 'weights': batch_weights}, iterator
def search_tf_table_for_entry(self, product_id):
'''Looks up keys in the table and outputs the values. Will return -1 if not found '''
if product_id is not None:
return self._products_to_class_table.lookup(product_id)
else:
if not self._real_eval:
logger().info("class not found in training {} ".format(product_id))
return -1
Where I create the model and have the placeholders used previously:
...
def create_model(self):
weights_regularizer = tf.contrib.layers.l2_regularizer(cfg.TRAIN.WEIGHT_DECAY)
biases_regularizer = weights_regularizer
# Input data.
self._input_images = tf.placeholder(
tf.float32, shape=(None, self._image_height, self._image_width, self._num_channels), name="ph_input_images")
self._input_labels = tf.placeholder(tf.int64, shape=(None, 1), name="ph_input_labels")
self._input_weights = tf.placeholder(tf.float32, shape=(None, 1), name="ph_input_weights")
self._is_training = tf.placeholder(tf.bool, name='ph_is_training')
self._keep_prob = tf.placeholder(tf.float32, name="ph_keep_prob")
self._accuracy = tf.reduce_mean(tf.cast(self._correct_prediction, tf.float32))
...
self.create_summaries()
def create_summaries(self):
val_summaries = []
with tf.device("/cpu:0"):
for var in self._act_summaries:
self._add_act_summary(var)
for var in self._train_summaries:
self._add_train_summary(var)
self._summary_op = tf.summary.merge_all()
self._summary_op_val = tf.summary.merge(val_summaries)
def get_summary(self, sess, images, labels, weights):
feed_dict = {self._input_images: images, self._input_labels: labels,
self._input_weights: weights, self._is_training: False}
summary, acc = sess.run([self._summary_op, self._accuracy], feed_dict=feed_dict)
return summary, acc
Since the error says:
Cannot feed value of shape (32,) for Tensor 'ph_input_labels:0', which has shape '(?, 1)
My guess is your labels in get_summary has the shape [32]. Can you just reshape it to (32, 1)? Or maybe reshape the label earlier in _parse_fn?

tensorflow error - you must feed a value for placeholder tensor 'in'

I'm trying to implement queues for my tensorflow prediction but get the following error -
you must feed a value for placeholder tensor 'in' with dtype float and shape [1024,1024,3]
The program works fine if I use the feed_dict, Trying to replace feed_dict with queues.
The program basically takes a list of positions and passes the image np array to the input tensor.
for each in positions:
y,x = each
images = img[y:y+1024,x:x+1024,:]
a = images.astype('float32')
q = tf.FIFOQueue(capacity=200,dtypes=dtypes)
enqueue_op = q.enqueue(a)
qr = tf.train.QueueRunner(q, [enqueue_op] * 1)
tf.train.add_queue_runner(qr)
data = q.dequeue()
graph=load_graph('/home/graph/frozen_graph.pb')
with tf.Session(graph=graph,config=tf.ConfigProto(log_device_placement=True)) as sess:
p_boxes = graph.get_tensor_by_name("cat:0")
p_confs = graph.get_tensor_by_name("sha:0")
y = [p_confs, p_boxes]
x = graph.get_tensor_by_name("in:0")
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord,sess=sess)
confs, boxes = sess.run(y)
coord.request_stop()
coord.join(threads)
How can I make sure the input data that I populated to the queue is recognized while running the graph in the session.
In my original run I call the
confs, boxes = sess.run([p_confs, p_boxes], feed_dict=feed_dict_testing)
I'd suggest not using queues for this problem, and switching to the new tf.data API. In particular tf.data.Dataset.from_generator() makes it easier to feed in data from a Python function. You can rewrite your code to be much simpler, as follows:
def generator():
for y, x in positions:
images = img[y:y+1024,x:x+1024,:]
yield images.astype('float32')
dataset = tf.data.Dataset.from_generator(
generator, tf.float32, [1024, 1024, img.shape[3]])
# Add any extra transformations in here, like `dataset.batch()` or
# `dataset.repeat()`.
# ...
iterator = dataset.make_one_shot_iterator()
data = iterator.get_next()
Note that in your program, there's no connection between the data tensor and the graph you loaded in load_graph() (at least, assuming that load_graph() doesn't grab data from the global state!). You will probably need to use tf.import_graph_def() and the input_map argument to associate data with one of the tensors in your frozen graph (possibly "in:0"?) to complete the task.

Store RNN states using graph collections

I frequently use tf.add_to_collection to have Tensorflow automatically serialize intermediary results into a checkpoint. I find this the most convenient way to later fetch pointers to interesting tensors when a model was restored from a checkpoint. However, I realized that RNN state tuples cannot easily be added to a graph collection. Consider the following dummy example in TF 1.3:
import tensorflow as tf
import numpy as np
in_ = tf.placeholder(tf.float32, shape=[None, 5, 1])
batch_size = tf.shape(in_)[0]
cell1 = tf.nn.rnn_cell.BasicLSTMCell(num_units=128)
cell2 = tf.nn.rnn_cell.BasicLSTMCell(num_units=256)
cell = tf.nn.rnn_cell.MultiRNNCell([cell1, cell2])
outputs, last_state = tf.nn.dynamic_rnn(cell=cell,
inputs=in_,
initial_state=cell.zero_state(batch_size, dtype=tf.float32))
tf.add_to_collection('states', last_state)
loss = tf.reduce_mean(in_ - outputs)
loss_s = tf.summary.scalar('loss', loss)
writer = tf.summary.FileWriter('.', tf.get_default_graph())
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
l, s = sess.run([loss, loss_s], feed_dict={in_: np.ones([1, 5, 1])})
writer.add_summary(s)
This will produce the following warning:
WARNING:tensorflow:Error encountered when serializing states.
Type is unsupported, or the types of the items don't match field type in CollectionDef.
'tuple' object has no attribute 'name'
It seems that the serialization cannot handle tuples, and of course the last_state variable is a tuple. May be one could loop through the tuple and add each element individually to the collection, but that seems too complicated. What's a better way of handling this? In the end, I would like to access last_state again when the model is restored, ideally without needing access to the original code that created the model.
Actually, looping through every element of the state is not too complicated, and straight-forward to implement:
def add_to_collection_rnn_state(name, rnn_state):
for layer in rnn_state:
tf.add_to_collection(name, layer.c)
tf.add_to_collection(name, layer.h)
And then to load it:
def get_collection_rnn_state(name):
layers = []
coll = tf.get_collection(name)
for i in range(0, len(coll), 2):
state = tf.nn.rnn_cell.LSTMStateTuple(coll[i], coll[i+1])
layers.append(state)
return tuple(layers)
Note that this assumes that one collection only stores on state, i.e. use a different collection for every state you want to store, e.g. like this:
add_to_collection_rnn_state('states', last_state)
add_to_collection_rnn_state('init_state', init_state)
Edit
As pointed out correctly in the comments, the proposed solution only works for LSTMCells (that are represented as tuples as well). A more general solution that can handle GRU cells or potentially custom cells and mixes thereof, could look like this:
import tensorflow as tf
import numpy as np
def add_to_collection_rnn_state(name, rnn_state):
# store the name of each cell type in a different collection
coll_of_names = name + '__names__'
for layer in rnn_state:
n = layer.__class__.__name__
tf.add_to_collection(coll_of_names, n)
try:
for l in layer:
tf.add_to_collection(name, l)
except TypeError:
# layer is not iterable so just add it directly
tf.add_to_collection(name, layer)
def get_collection_rnn_state(name):
layers = []
coll = tf.get_collection(name)
coll_of_names = tf.get_collection(name + '__names__')
idx = 0
for n in coll_of_names:
if 'LSTMStateTuple' in n:
state = tf.nn.rnn_cell.LSTMStateTuple(coll[idx], coll[idx+1])
idx += 2
else: # add more cell types here
state = coll[idx]
idx += 1
layers.append(state)
return tuple(layers)
in_ = tf.placeholder(tf.float32, shape=[None, 5, 1])
batch_size = tf.shape(in_)[0]
cell1 = tf.nn.rnn_cell.BasicLSTMCell(num_units=128)
cell2 = tf.nn.rnn_cell.GRUCell(num_units=256)
cell3 = tf.nn.rnn_cell.BasicRNNCell(num_units=256)
cell = tf.nn.rnn_cell.MultiRNNCell([cell1, cell2, cell3])
outputs, last_state = tf.nn.dynamic_rnn(cell=cell,
inputs=in_,
initial_state=cell.zero_state(batch_size, dtype=tf.float32))
add_to_collection_rnn_state('last_state', last_state)
last_state_r = get_collection_rnn_state('last_state')
Comparing last_state and last_state_r reveals that both are identical (which they should be). Note that I am using a different collection to store the names because tensorflow can only serialize a collection when all elements in the collection are of the same type. E.g. mixing strings with Tensors in the same collection does not work.