Tensorflow Estimators MirrorStrategy Assertion Error - tensorflow

I'm using TF 1.10.0 on Ubuntu 16.04.
I'm using the Estimator API to build a language model and want to preserve the last hidden states to initialize the states for next batch. It looks something like this(Ref: https://stackoverflow.com/a/41240243/6948766):
init_state = get_state_variables(params.batch_size, lstm_cell)
# list of [batch_size, max_steps, lstm_dim]
with tf.variable_scope(direction):
_lstm_output_unpacked, final_state = tf.nn.static_rnn(
lstm_cell,
tf.unstack(lstm_input, axis=1),
initial_state=init_state,
dtype=DTYPE)
self.state_update_op.append(
get_state_update_op(init_state, final_state))
And training code looks like this:
dist_trategy = tf.contrib.distribute.MirroredStrategy()
run_config = tf.estimator.RunConfig(
train_distribute=dist_trategy)
estimator = tf.estimator.Estimator(model_fn=model.model_fn,
model_dir='model/',
params=params,
config=run_config)
But I get the following errors:
Traceback (most recent call last):
File "src/train_eval.py", line 64, in <module>
main()
File "src/train_eval.py", line 60, in main
run(p, arg.mode)
File "src/train_eval.py", line 47, in run
params, 'train'), max_steps=1000000)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/estimator/estimator.py", line 356, in train
loss = self._train_model(input_fn, hooks, saving_listeners)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/estimator/estimator.py", line 1179, in _train_model
return self._train_model_distributed(input_fn, hooks, saving_listeners)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/estimator/estimator.py", line 1290, in _train_model_distributed
self.config)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/training/distribute.py", line 718, in call_for_each_tower
return self._call_for_each_tower(fn, *args, **kwargs)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/contrib/distribute/python/mirrored_strategy.py", line 552, in _call_for_each_tower
return _call_for_each_tower(self, fn, *args, **kwargs)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/contrib/distribute/python/mirrored_strategy.py", line 183, in _call_for_each_tower
coord.join(threads)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/training/coordinator.py", line 389, in join
six.reraise(*self._exc_info_to_raise)
File "/usr/local/lib/python3.5/dist-packages/six.py", line 693, in reraise
raise value
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/training/coordinator.py", line 297, in stop_on_exception
yield
File "/usr/local/lib/python3.5/dist-packages/tensorflow/contrib/distribute/python/mirrored_strategy.py", line 166, in _call_for_each_tower
merge_args = values.regroup({t.device: t.merge_args for t in threads})
File "/usr/local/lib/python3.5/dist-packages/tensorflow/contrib/distribute/python/values.py", line 585, in regroup
for i in range(len(v0)))
File "/usr/local/lib/python3.5/dist-packages/tensorflow/contrib/distribute/python/values.py", line 585, in <genexpr>
for i in range(len(v0)))
File "/usr/local/lib/python3.5/dist-packages/tensorflow/contrib/distribute/python/values.py", line 576, in regroup
(len(v), len(v0), v, v0))
AssertionError: len(v) == 33, len(v0) == 2, v: [(<tensorflow.python.framework.ops.IndexedSlices object at 0x7fe818e9c518>, MirroredVariable({'/replica:0/task:0/device:GPU:3': <tf.Variable 'char_embed/replica_3:0' shape=(11204, 64) dtype=float32>, '/replica:0/task:0/device:GPU:2': <tf.Variable 'char_embed/replica_2:0' shape=(11204, 64) dtype=float32>, '/replica:0/task:0/device:GPU:0': <tf.Variable 'char_embed:0' shape=(11204, 64) dtype=float32>, '/replica:0/task:0/device:GPU:1': <tf.Variable 'char_embed/replica_1:0' shape=(11204, 64) dtype=float32>})), (<tf.Tensor 'clip_by_global_norm/clip_by_global_norm/_1:0' shape=(1, 1, 64, 32) dtype=float32>, MirroredVariable({'/replica:0/task:0/device:GPU:3': <tf.Variable 'CNN/W_cnn_0/replica_3:0' shape=(1, 1, 64, 32) dtype=float32>, '/replica:0/task:0/device:GPU:2': <tf.Variable 'CNN/W_cnn_0/replica_2:0' shape=(1, 1, 64, 32) dtype=float32>, '/replica:0/task:0/device:GPU:0': <tf.Variable 'CNN/W_cnn_0:0' shape=(1, 1, 64, 32) dtype=float32>, '/replica:0/task:0/device:GPU:1': <tf.Variable 'CNN/W_cnn_0/replica_1:0' shape=(1, 1, 64, 32) dtype=float32>})), (<tf.Tensor 'clip_by_global_norm/clip_by_global_norm/_2:0' shape=(32,) dtype=float32>, MirroredVariable({'/replica:0/task:0/device:GPU:3': <tf.Variable 'CNN/b_cnn_0/replica_3:0' shape=(32,) dtype=float32>, '/replica:0/task:0/device:GPU:2': <tf.Variable 'CNN/b_cnn_0/replica_2:0' shape=(32,) dtype=float32>, '/replica:0/task:0/device:GPU:0': <tf.Variable 'CNN/b_cnn_0:0' shape=(32,) dtype=float32>, '/replica:0/task:0/device:GPU:1': <tf.Variable 'CNN/b_cnn_0/replica_1:0' shape=(32,) dtype=float32>})), (<tf.Tensor 'clip_by_global_norm/clip_by_global_norm/_3:0' shape=(1, 2, 64, 32) dtype=float32>, MirroredVariable({'/replica:0/task:0/device:GPU:3': <tf.Variable 'CNN/W_cnn_1/replica_3:0' shape=(1, 2, 64, 32) dtype=float32>, '/replica:0/task:0/device:GPU:2': <tf.Variable 'CNN/W_cnn_1/replica_2:0' shape=(1, 2, 64, 32) dtype=float32>, '/replica:0/task:0/device:GPU:0': <tf.Variable 'CNN/W_cnn_1:0' shape=(1, 2, 64, 32) dtype=float32>, '/replica:0/task:0/device:GPU:1': <tf.Variable 'CNN/W_cnn_1/replica_1:0' shape=(1, 2, 64, 32) dtype=float32>})), (<tf.Tensor 'clip_by_global_norm/clip_by_global_norm/_4:0' shape=(32,) dtype=float32>, MirroredVariable({'/replica:0/task:0/device:GPU:3': <tf.Variable 'CNN/b_cnn_1/replica_3:0' shape=(32,) dtype=float32>, '/replica:0/task:0/device:GPU:2': <tf.Variable 'CNN/b_cnn_1/replica_2:0' shape=(32,) dtype=float32>, '/replica:0/task:0/device:GPU:0': <tf.Variable 'CNN/b_cnn_1:0' shape=(32,) dtype=float32>, '/replica:0/task:0/device:GPU:1': <tf.Variable 'CNN/b_cnn_1/replica_1:0' shape=(32,) dtype=float32>})), (<tf.Tensor 'clip_by_global_norm/clip_by_global_norm/_5:0' shape=(1, 3, 64, 64) dtype=float32>, MirroredVariable({'/replica:0/task:0/device:GPU:3': <tf.Variable 'CNN/W_cnn_2/replica_3:0' shape=(1, 3, 64, 64) dtype=float32>, '/replica:0/task:0/device:GPU:2': <tf.Variable 'CNN/W_cnn_2/replica_2:0' shape=(1, 3, 64, 64) dtype=float32>, '/replica:0/task:0/device:GPU:0': <tf.Variable 'CNN/W_cnn_2:0' shape=(1, 3, 64, 64) dtype=float32>, '/replica:0/task:0/device:GPU:1': <tf.Variable 'CNN/W_cnn_2/replica_1:0' shape=(1, 3, 64, 64) dtype=float32>})), (<tf.Tensor 'clip_by_global_norm/clip_by_global_norm/_6:0' shape=(64,) dtype=float32>, MirroredVariable({'/replica:0/task:0/device:GPU:3': <tf.Variable 'CNN/b_cnn_2/replica_3:0' shape=(64,) dtype=float32>, '/replica:0/task:0/device:GPU:2': <tf.Variable 'CNN/b_cnn_2/replica_2:0' shape=(64,) dtype=float32>, '/replica:0/task:0/device:GPU:0': <tf.Variable 'CNN/b_cnn_2:0' shape=(64,) dtype=float32>, '/replica:0/task:0/device:GPU:1': <tf.Variable 'CNN/b_cnn_2/replica_1:0' shape=(64,) dtype=float32>})), (<tf.Tensor 'clip_by_global_norm/clip_by_global_norm/_7:0' shape=(1, 4, 64, 128) dtype=float32>, MirroredVariable({'/replica:0/task:0/device:GPU:3': <tf.Variable 'CNN/W_cnn_3/replica_3:0' shape=(1, 4, 64, 128) dtype=float32>, '/replica:0/task:0/device:GPU:2': <tf.Variable 'CNN/W_cnn_3/replica_2:0' shape=(1, 4, 64, 128) dtype=float32>, '/replica:0/task:0/device:GPU:0': <tf.Variable 'CNN/W_cnn_3:0' shape=(1, 4, 64, 128) dtype=float32>, '/replica:0/task:0/device:GPU:1': <tf.Variable 'CNN/W_cnn_3/replica_1:0' shape=(1, 4, 64, 128) dtype=float32>})), (<tf.Tensor 'clip_by_global_norm/clip_by_global_norm/_8:0' shape=(128,) dtype=float32>, MirroredVariable({'/replica:0/task:0/device:GPU:3': <tf.Variable 'CNN/b_cnn_3/replica_3:0' shape=(128,) dtype=float32>, '/replica:0/task:0/device:GPU:2': <tf.Variable 'CNN/b_cnn_3/replica_2:0' shape=(128,) dtype=float32>, '/replica:0/task:0/device:GPU:0': <tf.Variable 'CNN/b_cnn_3:0' shape=(128,) dtype=float32>, '/replica:0/task:0/device:GPU:1': <tf.Variable 'CNN/b_cnn_3/replica_1:0' shape=(128,) dtype=float32>})), (<tf.Tensor 'clip_by_global_norm/clip_by_global_norm/_9:0' shape=(1, 5, 64, 256) dtype=float32>, MirroredVariable({'/replica:0/task:0/device:GPU:3': <tf.Variable 'CNN/W_cnn_4/replica_3:0' shape=(1, 5, 64, 256) dtype=float32>, '/replica:0/task:0/device:GPU:2': <tf.Variable 'CNN/W_cnn_4/replica_2:0' shape=(1, 5, 64, 256) dtype=float32>, '/replica:0/task:0/device:GPU:0': <tf.Variable 'CNN/W_cnn_4:0' shape=(1, 5, 64, 256) dtype=float32>, '/replica:0/task:0/device:GPU:1': <tf.Variable 'CNN/W_cnn_4/replica_1:0' shape=(1, 5, 64, 256) dtype=float32>})), (<tf.Tensor 'clip_by_global_norm/clip_by_global_norm/_10:0' shape=(256,) dtype=float32>, MirroredVariable({'/replica:0/task:0/device:GPU:3': <tf.Variable 'CNN/b_cnn_4/replica_3:0' shape=(256,) dtype=float32>, '/replica:0/task:0/device:GPU:2': <tf.Variable 'CNN/b_cnn_4/replica_2:0' shape=(256,) dtype=float32>, '/replica:0/task:0/device:GPU:0': <tf.Variable 'CNN/b_cnn_4:0' shape=(256,) dtype=float32>, '/replica:0/task:0/device:GPU:1': <tf.Variable 'CNN/b_cnn_4/replica_1:0' shape=(256,) dtype=float32>})), (<tf.Tensor 'clip_by_global_norm/clip_by_global_norm/_11:0' shape=(1, 6, 64, 512) dtype=float32>, MirroredVariable({'/replica:0/task:0/device:GPU:3': <tf.Variable 'CNN/W_cnn_5/replica_3:0' shape=(1, 6, 64, 512) dtype=float32>, '/replica:0/task:0/device:GPU:2': <tf.Variable 'CNN/W_cnn_5/replica_2:0' shape=(1, 6, 64, 512) dtype=float32>, '/replica:0/task:0/device:GPU:0': <tf.Variable 'CNN/W_cnn_5:0' shape=(1, 6, 64, 512) dtype=float32>, '/replica:0/task:0/device:GPU:1': <tf.Variable 'CNN/W_cnn_5/replica_1:0' shape=(1, 6, 64, 512) dtype=float32>})), (<tf.Tensor 'clip_by_global_norm/clip_by_global_norm/_12:0' shape=(512,) dtype=float32>, MirroredVariable({'/replica:0/task:0/device:GPU:3': <tf.Variable 'CNN/b_cnn_5/replica_3:0' shape=(512,) dtype=float32>, '/replica:0/task:0/device:GPU:2': <tf.Variable 'CNN/b_cnn_5/replica_2:0' shape=(512,) dtype=float32>, '/replica:0/task:0/device:GPU:0': <tf.Variable 'CNN/b_cnn_5:0' shape=(512,) dtype=float32>, '/replica:0/task:0/device:GPU:1': <tf.Variable 'CNN/b_cnn_5/replica_1:0' shape=(512,) dtype=float32>})), (<tf.Tensor 'clip_by_global_norm/clip_by_global_norm/_13:0' shape=(1, 7, 64, 1024) dtype=float32>, MirroredVariable({'/replica:0/task:0/device:GPU:3': <tf.Variable 'CNN/W_cnn_6/replica_3:0' shape=(1, 7, 64, 1024) dtype=float32>, '/replica:0/task:0/device:GPU:2': <tf.Variable 'CNN/W_cnn_6/replica_2:0' shape=(1, 7, 64, 1024) dtype=float32>, '/replica:0/task:0/device:GPU:0': <tf.Variable 'CNN/W_cnn_6:0' shape=(1, 7, 64, 1024) dtype=float32>, '/replica:0/task:0/device:GPU:1': <tf.Variable 'CNN/W_cnn_6/replica_1:0' shape=(1, 7, 64, 1024) dtype=float32>})), (<tf.Tensor 'clip_by_global_norm/clip_by_global_norm/_14:0' shape=(1024,) dtype=float32>, MirroredVariable({'/replica:0/task:0/device:GPU:3': <tf.Variable 'CNN/b_cnn_6/replica_3:0' shape=(1024,) dtype=float32>, '/replica:0/task:0/device:GPU:2': <tf.Variable 'CNN/b_cnn_6/replica_2:0' shape=(1024,) dtype=float32>, '/replica:0/task:0/device:GPU:0': <tf.Variable 'CNN/b_cnn_6:0' shape=(1024,) dtype=float32>, '/replica:0/task:0/device:GPU:1': <tf.Variable 'CNN/b_cnn_6/replica_1:0' shape=(1024,) dtype=float32>})), (<tf.Tensor 'clip_by_global_norm/clip_by_global_norm/_15:0' shape=(2048, 512) dtype=float32>, MirroredVariable({'/replica:0/task:0/device:GPU:3': <tf.Variable 'CNN_proj/W_proj/replica_3:0' shape=(2048, 512) dtype=float32>, '/replica:0/task:0/device:GPU:2': <tf.Variable 'CNN_proj/W_proj/replica_2:0' shape=(2048, 512) dtype=float32>, '/replica:0/task:0/device:GPU:0': <tf.Variable 'CNN_proj/W_proj:0' shape=(2048, 512) dtype=float32>, '/replica:0/task:0/device:GPU:1': <tf.Variable 'CNN_proj/W_proj/replica_1:0' shape=(2048, 512) dtype=float32>})), (<tf.Tensor 'clip_by_global_norm/clip_by_global_norm/_16:0' shape=(512,) dtype=float32>, MirroredVariable({'/replica:0/task:0/device:GPU:3': <tf.Variable 'CNN_proj/b_proj/replica_3:0' shape=(512,) dtype=float32>, '/replica:0/task:0/device:GPU:2': <tf.Variable 'CNN_proj/b_proj/replica_2:0' shape=(512,) dtype=float32>, '/replica:0/task:0/device:GPU:0': <tf.Variable 'CNN_proj/b_proj:0' shape=(512,) dtype=float32>, '/replica:0/task:0/device:GPU:1': <tf.Variable 'CNN_proj/b_proj/replica_1:0' shape=(512,) dtype=float32>})), (<tf.Tensor 'clip_by_global_norm/clip_by_global_norm/_17:0' shape=(2048, 2048) dtype=float32>, MirroredVariable({'/replica:0/task:0/device:GPU:3': <tf.Variable 'CNN_high_0/W_carry/replica_3:0' shape=(2048, 2048) dtype=float32>, '/replica:0/task:0/device:GPU:2': <tf.Variable 'CNN_high_0/W_carry/replica_2:0' shape=(2048, 2048) dtype=float32>, '/replica:0/task:0/device:GPU:0': <tf.Variable 'CNN_high_0/W_carry:0' shape=(2048, 2048) dtype=float32>, '/replica:0/task:0/device:GPU:1': <tf.Variable 'CNN_high_0/W_carry/replica_1:0' shape=(2048, 2048) dtype=float32>})), (<tf.Tensor 'clip_by_global_norm/clip_by_global_norm/_18:0' shape=(2048,) dtype=float32>, MirroredVariable({'/replica:0/task:0/device:GPU:3': <tf.Variable 'CNN_high_0/b_carry/replica_3:0' shape=(2048,) dtype=float32>, '/replica:0/task:0/device:GPU:2': <tf.Variable 'CNN_high_0/b_carry/replica_2:0' shape=(2048,) dtype=float32>, '/replica:0/task:0/device:GPU:0': <tf.Variable 'CNN_high_0/b_carry:0' shape=(2048,) dtype=float32>, '/replica:0/task:0/device:GPU:1': <tf.Variable 'CNN_high_0/b_carry/replica_1:0' shape=(2048,) dtype=float32>})), (<tf.Tensor 'clip_by_global_norm/clip_by_global_norm/_19:0' shape=(2048, 2048) dtype=float32>, MirroredVariable({'/replica:0/task:0/device:GPU:3': <tf.Variable 'CNN_high_0/W_transform/replica_3:0' shape=(2048, 2048) dtype=float32>, '/replica:0/task:0/device:GPU:2': <tf.Variable 'CNN_high_0/W_transform/replica_2:0' shape=(2048, 2048) dtype=float32>, '/replica:0/task:0/device:GPU:0': <tf.Variable 'CNN_high_0/W_transform:0' shape=(2048, 2048) dtype=float32>, '/replica:0/task:0/device:GPU:1': <tf.Variable 'CNN_high_0/W_transform/replica_1:0' shape=(2048, 2048) dtype=float32>})), (<tf.Tensor 'clip_by_global_norm/clip_by_global_norm/_20:0' shape=(2048,) dtype=float32>, MirroredVariable({'/replica:0/task:0/device:GPU:3': <tf.Variable 'CNN_high_0/b_transform/replica_3:0' shape=(2048,) dtype=float32>, '/replica:0/task:0/device:GPU:2': <tf.Variable 'CNN_high_0/b_transform/replica_2:0' shape=(2048,) dtype=float32>, '/replica:0/task:0/device:GPU:0': <tf.Variable 'CNN_high_0/b_transform:0' shape=(2048,) dtype=float32>, '/replica:0/task:0/device:GPU:1': <tf.Variable 'CNN_high_0/b_transform/replica_1:0' shape=(2048,) dtype=float32>})), (<tf.Tensor 'clip_by_global_norm/clip_by_global_norm/_21:0' shape=(2048, 2048) dtype=float32>, MirroredVariable({'/replica:0/task:0/device:GPU:3': <tf.Variable 'CNN_high_1/W_carry/replica_3:0' shape=(2048, 2048) dtype=float32>, '/replica:0/task:0/device:GPU:2': <tf.Variable 'CNN_high_1/W_carry/replica_2:0' shape=(2048, 2048) dtype=float32>, '/replica:0/task:0/device:GPU:0': <tf.Variable 'CNN_high_1/W_carry:0' shape=(2048, 2048) dtype=float32>, '/replica:0/task:0/device:GPU:1': <tf.Variable 'CNN_high_1/W_carry/replica_1:0' shape=(2048, 2048) dtype=float32>})), (<tf.Tensor 'clip_by_global_norm/clip_by_global_norm/_22:0' shape=(2048,) dtype=float32>, MirroredVariable({'/replica:0/task:0/device:GPU:3': <tf.Variable 'CNN_high_1/b_carry/replica_3:0' shape=(2048,) dtype=float32>, '/replica:0/task:0/device:GPU:2': <tf.Variable 'CNN_high_1/b_carry/replica_2:0' shape=(2048,) dtype=float32>, '/replica:0/task:0/device:GPU:0': <tf.Variable 'CNN_high_1/b_carry:0' shape=(2048,) dtype=float32>, '/replica:0/task:0/device:GPU:1': <tf.Variable 'CNN_high_1/b_carry/replica_1:0' shape=(2048,) dtype=float32>})), (<tf.Tensor 'clip_by_global_norm/clip_by_global_norm/_23:0' shape=(2048, 2048) dtype=float32>, MirroredVariable({'/replica:0/task:0/device:GPU:3': <tf.Variable 'CNN_high_1/W_transform/replica_3:0' shape=(2048, 2048) dtype=float32>, '/replica:0/task:0/device:GPU:2': <tf.Variable 'CNN_high_1/W_transform/replica_2:0' shape=(2048, 2048) dtype=float32>, '/replica:0/task:0/device:GPU:0': <tf.Variable 'CNN_high_1/W_transform:0' shape=(2048, 2048) dtype=float32>, '/replica:0/task:0/device:GPU:1': <tf.Variable 'CNN_high_1/W_transform/replica_1:0' shape=(2048, 2048) dtype=float32>})), (<tf.Tensor 'clip_by_global_norm/clip_by_global_norm/_24:0' shape=(2048,) dtype=float32>, MirroredVariable({'/replica:0/task:0/device:GPU:3': <tf.Variable 'CNN_high_1/b_transform/replica_3:0' shape=(2048,) dtype=float32>, '/replica:0/task:0/device:GPU:2': <tf.Variable 'CNN_high_1/b_transform/replica_2:0' shape=(2048,) dtype=float32>, '/replica:0/task:0/device:GPU:0': <tf.Variable 'CNN_high_1/b_transform:0' shape=(2048,) dtype=float32>, '/replica:0/task:0/device:GPU:1': <tf.Variable 'CNN_high_1/b_transform/replica_1:0' shape=(2048,) dtype=float32>})), (<tf.Tensor 'clip_by_global_norm/clip_by_global_norm/_25:0' shape=(1024, 16384) dtype=float32>, MirroredVariable({'/replica:0/task:0/device:GPU:3': <tf.Variable 'fw/rnn/multi_rnn_cell/cell_0/lstm_cell/kernel/replica_3:0' shape=(1024, 16384) dtype=float32>, '/replica:0/task:0/device:GPU:2': <tf.Variable 'fw/rnn/multi_rnn_cell/cell_0/lstm_cell/kernel/replica_2:0' shape=(1024, 16384) dtype=float32>, '/replica:0/task:0/device:GPU:0': <tf.Variable 'fw/rnn/multi_rnn_cell/cell_0/lstm_cell/kernel:0' shape=(1024, 16384) dtype=float32>, '/replica:0/task:0/device:GPU:1': <tf.Variable 'fw/rnn/multi_rnn_cell/cell_0/lstm_cell/kernel/replica_1:0' shape=(1024, 16384) dtype=float32>})), (<tf.Tensor 'clip_by_global_norm/clip_by_global_norm/_26:0' shape=(16384,) dtype=float32>, MirroredVariable({'/replica:0/task:0/device:GPU:3': <tf.Variable 'fw/rnn/multi_rnn_cell/cell_0/lstm_cell/bias/replica_3:0' shape=(16384,) dtype=float32>, '/replica:0/task:0/device:GPU:2': <tf.Variable 'fw/rnn/multi_rnn_cell/cell_0/lstm_cell/bias/replica_2:0' shape=(16384,) dtype=float32>, '/replica:0/task:0/device:GPU:0': <tf.Variable 'fw/rnn/multi_rnn_cell/cell_0/lstm_cell/bias:0' shape=(16384,) dtype=float32>, '/replica:0/task:0/device:GPU:1': <tf.Variable 'fw/rnn/multi_rnn_cell/cell_0/lstm_cell/bias/replica_1:0' shape=(16384,) dtype=float32>})), (<tf.Tensor 'clip_by_global_norm/clip_by_global_norm/_27:0' shape=(4096, 512) dtype=float32>, MirroredVariable({'/replica:0/task:0/device:GPU:3': <tf.Variable 'fw/rnn/multi_rnn_cell/cell_0/lstm_cell/projection/kernel/replica_3:0' shape=(4096, 512) dtype=float32>, '/replica:0/task:0/device:GPU:2': <tf.Variable 'fw/rnn/multi_rnn_cell/cell_0/lstm_cell/projection/kernel/replica_2:0' shape=(4096, 512) dtype=float32>, '/replica:0/task:0/device:GPU:0': <tf.Variable 'fw/rnn/multi_rnn_cell/cell_0/lstm_cell/projection/kernel:0' shape=(4096, 512) dtype=float32>, '/replica:0/task:0/device:GPU:1': <tf.Variable 'fw/rnn/multi_rnn_cell/cell_0/lstm_cell/projection/kernel/replica_1:0' shape=(4096, 512) dtype=float32>})), (<tf.Tensor 'clip_by_global_norm/clip_by_global_norm/_28:0' shape=(1024, 16384) dtype=float32>, MirroredVariable({'/replica:0/task:0/device:GPU:3': <tf.Variable 'fw/rnn/multi_rnn_cell/cell_1/lstm_cell/kernel/replica_3:0' shape=(1024, 16384) dtype=float32>, '/replica:0/task:0/device:GPU:2': <tf.Variable 'fw/rnn/multi_rnn_cell/cell_1/lstm_cell/kernel/replica_2:0' shape=(1024, 16384) dtype=float32>, '/replica:0/task:0/device:GPU:0': <tf.Variable 'fw/rnn/multi_rnn_cell/cell_1/lstm_cell/kernel:0' shape=(1024, 16384) dtype=float32>, '/replica:0/task:0/device:GPU:1': <tf.Variable 'fw/rnn/multi_rnn_cell/cell_1/lstm_cell/kernel/replica_1:0' shape=(1024, 16384) dtype=float32>})), (<tf.Tensor 'clip_by_global_norm/clip_by_global_norm/_29:0' shape=(16384,) dtype=float32>, MirroredVariable({'/replica:0/task:0/device:GPU:3': <tf.Variable 'fw/rnn/multi_rnn_cell/cell_1/lstm_cell/bias/replica_3:0' shape=(16384,) dtype=float32>, '/replica:0/task:0/device:GPU:2': <tf.Variable 'fw/rnn/multi_rnn_cell/cell_1/lstm_cell/bias/replica_2:0' shape=(16384,) dtype=float32>, '/replica:0/task:0/device:GPU:0': <tf.Variable 'fw/rnn/multi_rnn_cell/cell_1/lstm_cell/bias:0' shape=(16384,) dtype=float32>, '/replica:0/task:0/device:GPU:1': <tf.Variable 'fw/rnn/multi_rnn_cell/cell_1/lstm_cell/bias/replica_1:0' shape=(16384,) dtype=float32>})), (<tf.Tensor 'clip_by_global_norm/clip_by_global_norm/_30:0' shape=(4096, 512) dtype=float32>, MirroredVariable({'/replica:0/task:0/device:GPU:3': <tf.Variable 'fw/rnn/multi_rnn_cell/cell_1/lstm_cell/projection/kernel/replica_3:0' shape=(4096, 512) dtype=float32>, '/replica:0/task:0/device:GPU:2': <tf.Variable 'fw/rnn/multi_rnn_cell/cell_1/lstm_cell/projection/kernel/replica_2:0' shape=(4096, 512) dtype=float32>, '/replica:0/task:0/device:GPU:0': <tf.Variable 'fw/rnn/multi_rnn_cell/cell_1/lstm_cell/projection/kernel:0' shape=(4096, 512) dtype=float32>, '/replica:0/task:0/device:GPU:1': <tf.Variable 'fw/rnn/multi_rnn_cell/cell_1/lstm_cell/projection/kernel/replica_1:0' shape=(4096, 512) dtype=float32>})), (<tensorflow.python.framework.ops.IndexedSlices object at 0x7fe818e9c390>, MirroredVariable({'/replica:0/task:0/device:GPU:3': <tf.Variable 'softmax/W/replica_3:0' shape=(556527, 512) dtype=float32>, '/replica:0/task:0/device:GPU:2': <tf.Variable 'softmax/W/replica_2:0' shape=(556527, 512) dtype=float32>, '/replica:0/task:0/device:GPU:0': <tf.Variable 'softmax/W:0' shape=(556527, 512) dtype=float32>, '/replica:0/task:0/device:GPU:1': <tf.Variable 'softmax/W/replica_1:0' shape=(556527, 512) dtype=float32>})), (<tensorflow.python.framework.ops.IndexedSlices object at 0x7fe818e9c0b8>, MirroredVariable({'/replica:0/task:0/device:GPU:3': <tf.Variable 'softmax/b/replica_3:0' shape=(556527,) dtype=float32>, '/replica:0/task:0/device:GPU:2': <tf.Variable 'softmax/b/replica_2:0' shape=(556527,) dtype=float32>, '/replica:0/task:0/device:GPU:0': <tf.Variable 'softmax/b:0' shape=(556527,) dtype=float32>, '/replica:0/task:0/device:GPU:1': <tf.Variable 'softmax/b/replica_1:0' shape=(556527,) dtype=float32>}))], v0: [(<tensorflow.python.framework.ops.IndexedSlices object at 0x7fe8127cfe10>, MirroredVariable({'/replica:0/task:0/device:GPU:3': <tf.Variable 'softmax/W/replica_3:0' shape=(556527, 512) dtype=float32>, '/replica:0/task:0/device:GPU:2': <tf.Variable 'softmax/W/replica_2:0' shape=(556527, 512) dtype=float32>, '/replica:0/task:0/device:GPU:0': <tf.Variable 'softmax/W:0' shape=(556527, 512) dtype=float32>, '/replica:0/task:0/device:GPU:1': <tf.Variable 'softmax/W/replica_1:0' shape=(556527, 512) dtype=float32>})), (<tensorflow.python.framework.ops.IndexedSlices object at 0x7fe8127cfe80>, MirroredVariable({'/replica:0/task:0/device:GPU:3': <tf.Variable 'softmax/b/replica_3:0' shape=(556527,) dtype=float32>, '/replica:0/task:0/device:GPU:2': <tf.Variable 'softmax/b/replica_2:0' shape=(556527,) dtype=float32>, '/replica:0/task:0/device:GPU:0': <tf.Variable 'softmax/b:0' shape=(556527,) dtype=float32>, '/replica:0/task:0/device:GPU:1': <tf.Variable 'softmax/b/replica_1:0' shape=(556527,) dtype=float32>}))]
Thanks in advance.

Related

How can I write DSSIM+MAE loss function for model training

How can I write code DSSIM+MAE loss function from formula:
Loss = αMAE + (1-α)DSSIM
with
Mean Absolute Error(MAE) = (1/M) * ∑|yi – xi|
DSSIM = 1-SSIM
SSIM = (numerator1 * numerator2) / (denominator1 * denominator2)
numerator1 = 2 * μ12 + C1 #μ12 = μ1 * μ2
numerator2 = 2 * σ12 + C2
denominator1 = μ1_sq + μ2_sq + C1
denominator2 = σ1_sq + σ2_sq + C2
where α is a trade-off parameter between MAE and DSSIM,
M is the total number of pixels in the image,
μ is the mean value of the image,
σ is the standard variation of the image, and σx,y is the covariance of x and y two images. c1 and c2 are two variables that stabilize
the division with a weak denominator. In our implementation,
I set α = 0.75, c1 = (0.01L)^2 and c2 = (0.03L)^2
where L is the dynamic range of the pixel values in the image.
So, this is my code
def custom_loss (y_true,y_pred):
M = 512 #M = total number of pixels in the sCT image
sum = 0
y_pred = tf.cast(y_pred, tf.int32)
y_true = tf.cast(y_true, tf.int32)
print(y_pred.shape)
print(y_true.shape)
y_pred = y_pred[0]
y_true = y_true[0]
for i in range(n):
sum = sum+abs(y_true[i] - y_pred[i])
my_mae = sum / n
dssim = tf.reduce_mean((1 - tf.image.ssim(y_true,y_pred, max_val=512,
filter_size=11,filter_sigma=1.5, k1=0.01, k2=0.03)) / 2)
my_mae = tf.cast(my_mae, tf.float32)
return (0.75*my_mae) + (1 - 0.75*dssim)
it have error when I run
model.compile(optimizer='adam',loss= custom_loss,metrics=['accuracy'])
error is
Traceback (most recent call last):
File "C:/Users/CRA01/Desktop/Unet/custom loss.py", line 85, in
history = model.fit(Training_CBCT_dataset,Training_pCT_dataset,validation_split=0.2,batch_size=1, epochs=5, callbacks=[model_save_callback])
File "C:\Users\CRA01\miniconda3\envs\tf_2.9\lib\site-packages\keras\utils\traceback_utils.py", line 67, in error_handler
raise e.with_traceback(filtered_tb) from None
File "C:\Users\CRA01\AppData\Local\Temp_autograph_generated_fileqb1dimxg.py", line 15, in tf__train_function
retval = ag__.converted_call(ag__.ld(step_function), (ag__.ld(self), ag__.ld(iterator)), None, fscope)
ValueError: in user code:
File "C:\Users\CRA01\miniconda3\envs\tf_2.9\lib\site-packages\keras\engine\training.py", line 1051, in train_function *
return step_function(self, iterator)
File "C:\Users\CRA01\miniconda3\envs\tf_2.9\lib\site-packages\keras\engine\training.py", line 1040, in step_function **
outputs = model.distribute_strategy.run(run_step, args=(data,))
File "C:\Users\CRA01\miniconda3\envs\tf_2.9\lib\site-packages\keras\engine\training.py", line 1030, in run_step **
outputs = model.train_step(data)
File "C:\Users\CRA01\miniconda3\envs\tf_2.9\lib\site-packages\keras\engine\training.py", line 893, in train_step
self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
File "C:\Users\CRA01\miniconda3\envs\tf_2.9\lib\site-packages\keras\optimizers\optimizer_v2\optimizer_v2.py", line 539, in minimize
return self.apply_gradients(grads_and_vars, name=name)
File "C:\Users\CRA01\miniconda3\envs\tf_2.9\lib\site-packages\keras\optimizers\optimizer_v2\optimizer_v2.py", line 640, in apply_gradients
grads_and_vars = optimizer_utils.filter_empty_gradients(grads_and_vars)
File "C:\Users\CRA01\miniconda3\envs\tf_2.9\lib\site-packages\keras\optimizers\optimizer_v2\utils.py", line 73, in filter_empty_gradients
raise ValueError(f"No gradients provided for any variable: {variable}. "
ValueError: No gradients provided for any variable: (['conv2d/kernel:0', 'conv2d/bias:0', 'conv2d_1/kernel:0', 'conv2d_1/bias:0', 'conv2d_2/kernel:0', 'conv2d_2/bias:0', 'conv2d_3/kernel:0', 'conv2d_3/bias:0', 'conv2d_4/kernel:0', 'conv2d_4/bias:0', 'conv2d_5/kernel:0', 'conv2d_5/bias:0', 'conv2d_6/kernel:0', 'conv2d_6/bias:0', 'conv2d_7/kernel:0', 'conv2d_7/bias:0', 'conv2d_8/kernel:0', 'conv2d_8/bias:0', 'conv2d_9/kernel:0', 'conv2d_9/bias:0', 'conv2d_10/kernel:0', 'conv2d_10/bias:0', 'conv2d_11/kernel:0', 'conv2d_11/bias:0', 'conv2d_12/kernel:0', 'conv2d_12/bias:0', 'conv2d_13/kernel:0', 'conv2d_13/bias:0', 'conv2d_14/kernel:0', 'conv2d_14/bias:0', 'conv2d_15/kernel:0', 'conv2d_15/bias:0', 'conv2d_16/kernel:0', 'conv2d_16/bias:0', 'conv2d_17/kernel:0', 'conv2d_17/bias:0', 'conv2d_18/kernel:0', 'conv2d_18/bias:0', 'conv2d_19/kernel:0', 'conv2d_19/bias:0', 'conv2d_20/kernel:0', 'conv2d_20/bias:0', 'conv2d_21/kernel:0', 'conv2d_21/bias:0', 'conv2d_22/kernel:0', 'conv2d_22/bias:0', 'conv2d_23/kernel:0', 'conv2d_23/bias:0', 'conv2d_24/kernel:0', 'conv2d_24/bias:0', 'conv2d_25/kernel:0', 'conv2d_25/bias:0', 'conv2d_26/kernel:0', 'conv2d_26/bias:0'],). Provided `grads_and_vars` is ((None, <tf.Variable 'conv2d/kernel:0' shape=(3, 3, 1, 32) dtype=float32>), (None, <tf.Variable 'conv2d/bias:0' shape=(32,) dtype=float32>), (None, <tf.Variable 'conv2d_1/kernel:0' shape=(3, 3, 32, 32) dtype=float32>), (None, <tf.Variable 'conv2d_1/bias:0' shape=(32,) dtype=float32>), (None, <tf.Variable 'conv2d_2/kernel:0' shape=(3, 3, 32, 64) dtype=float32>), (None, <tf.Variable 'conv2d_2/bias:0' shape=(64,) dtype=float32>), (None, <tf.Variable 'conv2d_3/kernel:0' shape=(3, 3, 64, 64) dtype=float32>), (None, <tf.Variable 'conv2d_3/bias:0' shape=(64,) dtype=float32>), (None, <tf.Variable 'conv2d_4/kernel:0' shape=(3, 3, 64, 128) dtype=float32>), (None, <tf.Variable 'conv2d_4/bias:0' shape=(128,) dtype=float32>), (None, <tf.Variable 'conv2d_5/kernel:0' shape=(3, 3, 128, 128) dtype=float32>), (None, <tf.Variable 'conv2d_5/bias:0' shape=(128,) dtype=float32>), (None, <tf.Variable 'conv2d_6/kernel:0' shape=(3, 3, 128, 256) dtype=float32>), (None, <tf.Variable 'conv2d_6/bias:0' shape=(256,) dtype=float32>), (None, <tf.Variable 'conv2d_7/kernel:0' shape=(3, 3, 256, 256) dtype=float32>), (None, <tf.Variable 'conv2d_7/bias:0' shape=(256,) dtype=float32>), (None, <tf.Variable 'conv2d_8/kernel:0' shape=(3, 3, 256, 512) dtype=float32>), (None, <tf.Variable 'conv2d_8/bias:0' shape=(512,) dtype=float32>), (None, <tf.Variable 'conv2d_9/kernel:0' shape=(3, 3, 512, 512) dtype=float32>), (None, <tf.Variable 'conv2d_9/bias:0' shape=(512,) dtype=float32>), (None, <tf.Variable 'conv2d_10/kernel:0' shape=(3, 3, 512, 1024) dtype=float32>), (None, <tf.Variable 'conv2d_10/bias:0' shape=(1024,) dtype=float32>), (None, <tf.Variable 'conv2d_11/kernel:0' shape=(3, 3, 1024, 1024) dtype=float32>), (None, <tf.Variable 'conv2d_11/bias:0' shape=(1024,) dtype=float32>), (None, <tf.Variable 'conv2d_12/kernel:0' shape=(3, 3, 1024, 2048) dtype=float32>), (None, <tf.Variable 'conv2d_12/bias:0' shape=(2048,) dtype=float32>), (None, <tf.Variable 'conv2d_13/kernel:0' shape=(3, 3, 2048, 2048) dtype=float32>), (None, <tf.Variable 'conv2d_13/bias:0' shape=(2048,) dtype=float32>), (None, <tf.Variable 'conv2d_14/kernel:0' shape=(3, 3, 3072, 1024) dtype=float32>), (None, <tf.Variable 'conv2d_14/bias:0' shape=(1024,) dtype=float32>), (None, <tf.Variable 'conv2d_15/kernel:0' shape=(3, 3, 1024, 1024) dtype=float32>), (None, <tf.Variable 'conv2d_15/bias:0' shape=(1024,) dtype=float32>), (None, <tf.Variable 'conv2d_16/kernel:0' shape=(3, 3, 1536, 512) dtype=float32>), (None, <tf.Variable 'conv2d_16/bias:0' shape=(512,) dtype=float32>), (None, <tf.Variable 'conv2d_17/kernel:0' shape=(3, 3, 512, 512) dtype=float32>), (None, <tf.Variable 'conv2d_17/bias:0' shape=(512,) dtype=float32>), (None, <tf.Variable 'conv2d_18/kernel:0' shape=(3, 3, 768, 256) dtype=float32>), (None, <tf.Variable 'conv2d_18/bias:0' shape=(256,) dtype=float32>), (None, <tf.Variable 'conv2d_19/kernel:0' shape=(3, 3, 256, 256) dtype=float32>), (None, <tf.Variable 'conv2d_19/bias:0' shape=(256,) dtype=float32>), (None, <tf.Variable 'conv2d_20/kernel:0' shape=(3, 3, 384, 128) dtype=float32>), (None, <tf.Variable 'conv2d_20/bias:0' shape=(128,) dtype=float32>), (None, <tf.Variable 'conv2d_21/kernel:0' shape=(3, 3, 128, 128) dtype=float32>), (None, <tf.Variable 'conv2d_21/bias:0' shape=(128,) dtype=float32>), (None, <tf.Variable 'conv2d_22/kernel:0' shape=(3, 3, 192, 64) dtype=float32>), (None, <tf.Variable 'conv2d_22/bias:0' shape=(64,) dtype=float32>), (None, <tf.Variable 'conv2d_23/kernel:0' shape=(3, 3, 64, 64) dtype=float32>), (None, <tf.Variable 'conv2d_23/bias:0' shape=(64,) dtype=float32>), (None, <tf.Variable 'conv2d_24/kernel:0' shape=(3, 3, 96, 32) dtype=float32>), (None, <tf.Variable 'conv2d_24/bias:0' shape=(32,) dtype=float32>), (None, <tf.Variable 'conv2d_25/kernel:0' shape=(3, 3, 32, 32) dtype=float32>), (None, <tf.Variable 'conv2d_25/bias:0' shape=(32,) dtype=float32>), (None, <tf.Variable 'conv2d_26/kernel:0' shape=(1, 1, 32, 1) dtype=float32>), (None, <tf.Variable 'conv2d_26/bias:0' shape=(1,) dtype=float32>)).

Implementation of a WGAN-GP in tensorflow

Using tensorflow, I'm trying to reimplement the following architecture (for now I'm focusing on the Generator part):
What I've done for now has been defining the generator in the following way:
N_Z = 128
generator = [
tf.keras.layers.Dense(units=6144, activation="relu"),
tf.keras.layers.Reshape(target_shape=(6, 4, 256)),
tf.keras.layers.Conv2DTranspose(
filters=128, kernel_size=(5,5), strides=(2, 2), padding="SAME", activation="relu"
),
tf.keras.layers.Conv2DTranspose(
filters=128, kernel_size=(3,3), strides=(2, 1), padding="SAME", activation="relu"
),
tf.keras.layers.Conv2DTranspose(
filters=64, kernel_size=(3,3), strides=(1, 1), padding="SAME", activation="relu"
),
tf.keras.layers.Conv2DTranspose(
filters=64, kernel_size=(3,3), strides=(2, 1), padding="SAME", activation="relu"
),
tf.keras.layers.Conv2DTranspose(
filters=32, kernel_size=(3,3), strides=(1, 1), padding="SAME", activation="relu"
)
tf.keras.layers.Conv2DTranspose(
filters=32, kernel_size=(3,3), strides=(2, 1), padding="SAME", activation="relu"
)
tf.keras.layers.Conv2DTranspose(
filters=1, kernel_size=(3,3), strides=(1, 1), padding="SAME", activation="relu"
)
]
Generator = tf.keras.models.Sequential(generator)
But if I take some random noise and let the model process it, this is the final shape I get back:
noise = tf.random.normal((64,128))
result = Generator(noise)
result.shape
TensorShape([64, 28, 28, 1])
What am I doing wrong here? I was also checking the original implementation to see additional details but I can't find anything that makes me understand.
It is easy you need to see input-output, it required some help at the top levels.
[ Sample ]:
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
: Model Initialize
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
model = tf.keras.models.Sequential([
tf.keras.layers.InputLayer(input_shape=( 6144 )),
tf.keras.layers.Dense( 48 * 128, activation="linear" ),
tf.keras.layers.BatchNormalization( momentum=0.99, epsilon=0.00001 ),
tf.keras.layers.Reshape(target_shape=( 6, 4, 256 )),
tf.keras.layers.Conv2DTranspose(
filters=128, kernel_size=(5,5), strides=(2, 2), padding="same", activation="relu"
),
tf.keras.layers.Resizing( 11, 8, interpolation='bilinear', crop_to_aspect_ratio=False ),
tf.keras.layers.Reshape(target_shape=(11, 8, 128)),
tf.keras.layers.Conv2DTranspose(
filters=128, kernel_size=(3,3), strides=(2, 1), padding="SAME", activation="relu"
),
tf.keras.layers.Resizing( 22, 8, interpolation='bilinear', crop_to_aspect_ratio=False ),
tf.keras.layers.Reshape(target_shape=(22, 8, 128)),
tf.keras.layers.Conv2DTranspose(
filters=64, kernel_size=(3,3), strides=(1, 1), padding="SAME", activation="relu"
),
tf.keras.layers.Resizing( 22, 8, interpolation='bilinear', crop_to_aspect_ratio=False ),
tf.keras.layers.Reshape(target_shape=(22, 8, 64)),
tf.keras.layers.Conv2DTranspose(
filters=64, kernel_size=(3,3), strides=(2, 1), padding="SAME", activation="relu"
),
tf.keras.layers.Resizing( 43, 8, interpolation='bilinear', crop_to_aspect_ratio=False ),
tf.keras.layers.Reshape(target_shape=(43, 8, 64)),
tf.keras.layers.Conv2DTranspose(
filters=32, kernel_size=(3,3), strides=(1, 1), padding="SAME", activation="relu"
),
tf.keras.layers.Resizing( 43, 8, interpolation='bilinear', crop_to_aspect_ratio=False ),
tf.keras.layers.Reshape(target_shape=(43, 8, 32)),
tf.keras.layers.Conv2DTranspose(
filters=32, kernel_size=(3,3), strides=(2, 1), padding="SAME", activation="relu"
),
tf.keras.layers.Resizing( 85, 8, interpolation='bilinear', crop_to_aspect_ratio=False ),
tf.keras.layers.Reshape(target_shape=(85, 8, 32)),
])
model.summary()
[ Output ]:
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense (Dense) (None, 6144) 37754880
batch_normalization (BatchN (None, 6144) 24576
ormalization)
reshape (Reshape) (None, 6, 4, 256) 0
conv2d_transpose (Conv2DTra (None, 12, 8, 128) 819328
nspose)
resizing (Resizing) (None, 11, 8, 128) 0
reshape_1 (Reshape) (None, 11, 8, 128) 0
conv2d_transpose_1 (Conv2DT (None, 22, 8, 128) 147584
ranspose)
resizing_1 (Resizing) (None, 22, 8, 128) 0
reshape_2 (Reshape) (None, 22, 8, 128) 0
conv2d_transpose_2 (Conv2DT (None, 22, 8, 64) 73792
ranspose)
resizing_2 (Resizing) (None, 22, 8, 64) 0
reshape_3 (Reshape) (None, 22, 8, 64) 0
conv2d_transpose_3 (Conv2DT (None, 44, 8, 64) 36928
ranspose)
resizing_3 (Resizing) (None, 43, 8, 64) 0
reshape_4 (Reshape) (None, 43, 8, 64) 0
conv2d_transpose_4 (Conv2DT (None, 43, 8, 32) 18464
ranspose)
resizing_4 (Resizing) (None, 43, 8, 32) 0
reshape_5 (Reshape) (None, 43, 8, 32) 0
conv2d_transpose_5 (Conv2DT (None, 86, 8, 32) 9248
ranspose)
resizing_5 (Resizing) (None, 85, 8, 32) 0
reshape_6 (Reshape) (None, 85, 8, 32) 0
=================================================================
Total params: 38,884,800
Trainable params: 38,872,512
Non-trainable params: 12,288
_________________________________________________________________
2022-04-03 03:37:10.354570: I tensorflow/stream_executor/cuda/cuda_dnn.cc:368] Loaded cuDNN version 8100
(1, 85, 8, 32)
1/1 [==============================] - 2s 2s/step - loss: 0.0000e+00 - accuracy: 1.0000 - val_loss: 0.0000e+00 - val_accuracy: 1.0000

How to change a saved model input shape in Tensorflow?

I want to make this repo https://github.com/ildoonet/tf-pose-estimation run with Intel Movidius, so I tried convert the pb model using mvNCCompile.
The problem is mvNCCompile require a fixed input shape but the model I have is a dynamic one.
I tried this
graph_path = 'models/graph/mobilenet_thin/graph_opt.pb'
with tf.gfile.GFile(graph_path, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
graph = tf.get_default_graph()
tf.import_graph_def(graph_def, name='TfPoseEstimator')
x = graph.get_tensor_by_name('TfPoseEstimator/image:0')
x.set_shape([1, 368, 368, 3])
x = graph.get_tensor_by_name('TfPoseEstimator/MobilenetV1/Conv2d_0/Conv2D:0')
x.set_shape([1, 368, 368, 24])
and got this
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_0/weights:0' shape=(3, 3, 3, 24) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/image:0' shape=(1, 368, 368, 3) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_0/Conv2D:0' shape=(1, 368, 368, 24) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_0/Conv2D_bn_offset:0' shape=(24,) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_0/BatchNorm/FusedBatchNorm:0' shape=(?, ?, ?, 24) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_0/Relu:0' shape=(?, ?, ?, 24) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_1_depthwise/depthwise_weights:0' shape=(3, 3, 24, 1) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_1_pointwise/weights:0' shape=(1, 1, 24, 48) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_1_depthwise/depthwise:0' shape=(?, ?, ?, 24) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_1_pointwise/Conv2D:0' shape=(?, ?, ?, 48) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_1_pointwise/Conv2D_bn_offset:0' shape=(48,) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_1_pointwise/BatchNorm/FusedBatchNorm:0' shape=(?, ?, ?, 48) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_1_pointwise/Relu:0' shape=(?, ?, ?, 48) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_2_depthwise/depthwise_weights:0' shape=(3, 3, 48, 1) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_2_pointwise/weights:0' shape=(1, 1, 48, 96) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_2_depthwise/depthwise:0' shape=(?, ?, ?, 48) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_2_pointwise/Conv2D:0' shape=(?, ?, ?, 96) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_2_pointwise/Conv2D_bn_offset:0' shape=(96,) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_2_pointwise/BatchNorm/FusedBatchNorm:0' shape=(?, ?, ?, 96) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_2_pointwise/Relu:0' shape=(?, ?, ?, 96) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_3_depthwise/depthwise_weights:0' shape=(3, 3, 96, 1) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_3_pointwise/weights:0' shape=(1, 1, 96, 96) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_3_depthwise/depthwise:0' shape=(?, ?, ?, 96) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_3_pointwise/Conv2D:0' shape=(?, ?, ?, 96) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_3_pointwise/Conv2D_bn_offset:0' shape=(96,) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_3_pointwise/BatchNorm/FusedBatchNorm:0' shape=(?, ?, ?, 96) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_3_pointwise/Relu:0' shape=(?, ?, ?, 96) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_4_depthwise/depthwise_weights:0' shape=(3, 3, 96, 1) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_4_pointwise/weights:0' shape=(1, 1, 96, 192) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_4_depthwise/depthwise:0' shape=(?, ?, ?, 96) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_4_pointwise/Conv2D:0' shape=(?, ?, ?, 192) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_4_pointwise/Conv2D_bn_offset:0' shape=(192,) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_4_pointwise/BatchNorm/FusedBatchNorm:0' shape=(?, ?, ?, 192) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_4_pointwise/Relu:0' shape=(?, ?, ?, 192) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_5_depthwise/depthwise_weights:0' shape=(3, 3, 192, 1) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_5_pointwise/weights:0' shape=(1, 1, 192, 192) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_5_depthwise/depthwise:0' shape=(?, ?, ?, 192) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_5_pointwise/Conv2D:0' shape=(?, ?, ?, 192) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_5_pointwise/Conv2D_bn_offset:0' shape=(192,) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_5_pointwise/BatchNorm/FusedBatchNorm:0' shape=(?, ?, ?, 192) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_5_pointwise/Relu:0' shape=(?, ?, ?, 192) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_6_depthwise/depthwise_weights:0' shape=(3, 3, 192, 1) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_6_pointwise/weights:0' shape=(1, 1, 192, 384) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_6_depthwise/depthwise:0' shape=(?, ?, ?, 192) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_6_pointwise/Conv2D:0' shape=(?, ?, ?, 384) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_6_pointwise/Conv2D_bn_offset:0' shape=(384,) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_6_pointwise/BatchNorm/FusedBatchNorm:0' shape=(?, ?, ?, 384) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_6_pointwise/Relu:0' shape=(?, ?, ?, 384) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_7_depthwise/depthwise_weights:0' shape=(3, 3, 384, 1) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_7_pointwise/weights:0' shape=(1, 1, 384, 384) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_7_depthwise/depthwise:0' shape=(?, ?, ?, 384) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_7_pointwise/Conv2D:0' shape=(?, ?, ?, 384) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_7_pointwise/Conv2D_bn_offset:0' shape=(384,) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_7_pointwise/BatchNorm/FusedBatchNorm:0' shape=(?, ?, ?, 384) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_7_pointwise/Relu:0' shape=(?, ?, ?, 384) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_8_depthwise/depthwise_weights:0' shape=(3, 3, 384, 1) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_8_pointwise/weights:0' shape=(1, 1, 384, 384) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_8_depthwise/depthwise:0' shape=(?, ?, ?, 384) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_8_pointwise/Conv2D:0' shape=(?, ?, ?, 384) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_8_pointwise/Conv2D_bn_offset:0' shape=(384,) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_8_pointwise/BatchNorm/FusedBatchNorm:0' shape=(?, ?, ?, 384) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_8_pointwise/Relu:0' shape=(?, ?, ?, 384) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_9_depthwise/depthwise_weights:0' shape=(3, 3, 384, 1) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_9_pointwise/weights:0' shape=(1, 1, 384, 384) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_9_depthwise/depthwise:0' shape=(?, ?, ?, 384) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_9_pointwise/Conv2D:0' shape=(?, ?, ?, 384) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_9_pointwise/Conv2D_bn_offset:0' shape=(384,) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_9_pointwise/BatchNorm/FusedBatchNorm:0' shape=(?, ?, ?, 384) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_9_pointwise/Relu:0' shape=(?, ?, ?, 384) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_10_depthwise/depthwise_weights:0' shape=(3, 3, 384, 1) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_10_pointwise/weights:0' shape=(1, 1, 384, 384) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_10_depthwise/depthwise:0' shape=(?, ?, ?, 384) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_10_pointwise/Conv2D:0' shape=(?, ?, ?, 384) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_10_pointwise/Conv2D_bn_offset:0' shape=(384,) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_10_pointwise/BatchNorm/FusedBatchNorm:0' shape=(?, ?, ?, 384) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_10_pointwise/Relu:0' shape=(?, ?, ?, 384) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_11_depthwise/depthwise_weights:0' shape=(3, 3, 384, 1) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_11_pointwise/weights:0' shape=(1, 1, 384, 384) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_11_depthwise/depthwise:0' shape=(?, ?, ?, 384) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_11_pointwise/Conv2D:0' shape=(?, ?, ?, 384) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_11_pointwise/Conv2D_bn_offset:0' shape=(384,) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_11_pointwise/BatchNorm/FusedBatchNorm:0' shape=(?, ?, ?, 384) dtype=float32>,)
(<tf.Tensor 'TfPoseEstimator/MobilenetV1/Conv2d_11_pointwise/Relu:0' shape=(?, ?, ?, 384) dtype=float32>,)
Another layers beside TfPoseEstimator/image:0 and TfPoseEstimator/MobilenetV1/Conv2d_0/Conv2D:0 still have ? shape.
I'm very new in Tensorflow so this might be a stupid question, but how to change the input shape of a saved model?
I manage to solve this problem using this.
import tensorflow as tf
if __name__ == '__main__':
graph_path = 't/tf_model.pb'
with tf.gfile.GFile(graph_path, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
graph = tf.get_default_graph()
tf_new_image = tf.placeholder(shape=(1, 368, 368, 3), dtype='float32', name='new_image')
tf.import_graph_def(graph_def, name='TfPoseEstimator', input_map={"image:0": tf_new_image})
tf.train.write_graph(graph, "t", "mobilenet_thin_model.pb", as_text=False)
With tf2.x I belive you can change it to concrete func:
imported = tf.saved_model.load('/path/to/saved_model')
concrete_func = imported.signatures["serving_default"]
concrete_func.inputs[0].set_shape([1, 368, 368, 3])

Why does fixing the network input size reduce the model's file size

I've a model, whose inputs during training are variable in size, for generalisation.
In order to quantise, I have to fix the input size, so I just recreate the model with fixed input sizes, and copy across all the weights and biases, then save the model.
For some reason though, the model size gets roughly quartered.
Note this is before quantisation or anything else, and parameters remain the same.
Two model summaries are below:
Model 1 = 4.6MB
old_model.summary(line_length=110)
______________________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==============================================================================================================
input_1 (InputLayer) (None, None, None, 4) 0
______________________________________________________________________________________________________________
gaussian_noise_1 (GaussianNoise) (None, None, None, 4) 0 input_1[0][0]
______________________________________________________________________________________________________________
conv2d_1 (Conv2D) (None, None, None, 32) 1184 gaussian_noise_1[0][0]
______________________________________________________________________________________________________________
batch_normalization_1 (BatchNormali (None, None, None, 32) 128 conv2d_1[0][0]
______________________________________________________________________________________________________________
gaussian_noise_2 (GaussianNoise) (None, None, None, 32) 0 batch_normalization_1[0][0]
______________________________________________________________________________________________________________
conv2d_2 (Conv2D) (None, None, None, 32) 9248 gaussian_noise_2[0][0]
______________________________________________________________________________________________________________
batch_normalization_2 (BatchNormali (None, None, None, 32) 128 conv2d_2[0][0]
______________________________________________________________________________________________________________
gaussian_noise_3 (GaussianNoise) (None, None, None, 32) 0 batch_normalization_2[0][0]
______________________________________________________________________________________________________________
conv2d_3 (Conv2D) (None, None, None, 64) 18496 gaussian_noise_3[0][0]
______________________________________________________________________________________________________________
batch_normalization_3 (BatchNormali (None, None, None, 64) 256 conv2d_3[0][0]
______________________________________________________________________________________________________________
gaussian_noise_4 (GaussianNoise) (None, None, None, 64) 0 batch_normalization_3[0][0]
______________________________________________________________________________________________________________
conv2d_4 (Conv2D) (None, None, None, 64) 36928 gaussian_noise_4[0][0]
______________________________________________________________________________________________________________
batch_normalization_4 (BatchNormali (None, None, None, 64) 256 conv2d_4[0][0]
______________________________________________________________________________________________________________
gaussian_noise_5 (GaussianNoise) (None, None, None, 64) 0 batch_normalization_4[0][0]
______________________________________________________________________________________________________________
up_sampling2d_1 (UpSampling2D) (None, None, None, 64) 0 gaussian_noise_5[0][0]
______________________________________________________________________________________________________________
concatenate_1 (Concatenate) (None, None, None, 96) 0 up_sampling2d_1[0][0]
batch_normalization_1[0][0]
______________________________________________________________________________________________________________
gaussian_noise_6 (GaussianNoise) (None, None, None, 96) 0 concatenate_1[0][0]
______________________________________________________________________________________________________________
conv2d_5 (Conv2D) (None, None, None, 64) 55360 gaussian_noise_6[0][0]
______________________________________________________________________________________________________________
batch_normalization_5 (BatchNormali (None, None, None, 64) 256 conv2d_5[0][0]
______________________________________________________________________________________________________________
gaussian_noise_7 (GaussianNoise) (None, None, None, 64) 0 batch_normalization_5[0][0]
______________________________________________________________________________________________________________
conv2d_6 (Conv2D) (None, None, None, 64) 36928 gaussian_noise_7[0][0]
______________________________________________________________________________________________________________
batch_normalization_6 (BatchNormali (None, None, None, 64) 256 conv2d_6[0][0]
______________________________________________________________________________________________________________
gaussian_noise_8 (GaussianNoise) (None, None, None, 64) 0 batch_normalization_6[0][0]
______________________________________________________________________________________________________________
conv2d_7 (Conv2D) (None, None, None, 64) 36928 gaussian_noise_8[0][0]
______________________________________________________________________________________________________________
batch_normalization_7 (BatchNormali (None, None, None, 64) 256 conv2d_7[0][0]
______________________________________________________________________________________________________________
gaussian_noise_9 (GaussianNoise) (None, None, None, 64) 0 batch_normalization_7[0][0]
______________________________________________________________________________________________________________
conv2d_8 (Conv2D) (None, None, None, 64) 36928 gaussian_noise_9[0][0]
______________________________________________________________________________________________________________
batch_normalization_8 (BatchNormali (None, None, None, 64) 256 conv2d_8[0][0]
______________________________________________________________________________________________________________
gaussian_noise_10 (GaussianNoise) (None, None, None, 64) 0 batch_normalization_8[0][0]
______________________________________________________________________________________________________________
conv2d_9 (Conv2D) (None, None, None, 64) 36928 gaussian_noise_10[0][0]
______________________________________________________________________________________________________________
batch_normalization_9 (BatchNormali (None, None, None, 64) 256 conv2d_9[0][0]
______________________________________________________________________________________________________________
gaussian_noise_11 (GaussianNoise) (None, None, None, 64) 0 batch_normalization_9[0][0]
______________________________________________________________________________________________________________
up_sampling2d_2 (UpSampling2D) (None, None, None, 64) 0 gaussian_noise_11[0][0]
______________________________________________________________________________________________________________
input_2 (InputLayer) (None, None, None, 3) 0
______________________________________________________________________________________________________________
concatenate_2 (Concatenate) (None, None, None, 67) 0 up_sampling2d_2[0][0]
input_2[0][0]
______________________________________________________________________________________________________________
gaussian_noise_12 (GaussianNoise) (None, None, None, 67) 0 concatenate_2[0][0]
______________________________________________________________________________________________________________
conv2d_10 (Conv2D) (None, None, None, 67) 40468 gaussian_noise_12[0][0]
______________________________________________________________________________________________________________
batch_normalization_10 (BatchNormal (None, None, None, 67) 268 conv2d_10[0][0]
______________________________________________________________________________________________________________
gaussian_noise_13 (GaussianNoise) (None, None, None, 67) 0 batch_normalization_10[0][0]
______________________________________________________________________________________________________________
conv2d_11 (Conv2D) (None, None, None, 67) 40468 gaussian_noise_13[0][0]
______________________________________________________________________________________________________________
batch_normalization_11 (BatchNormal (None, None, None, 67) 268 conv2d_11[0][0]
______________________________________________________________________________________________________________
gaussian_noise_14 (GaussianNoise) (None, None, None, 67) 0 batch_normalization_11[0][0]
______________________________________________________________________________________________________________
conv2d_12 (Conv2D) (None, None, None, 32) 19328 gaussian_noise_14[0][0]
______________________________________________________________________________________________________________
batch_normalization_12 (BatchNormal (None, None, None, 32) 128 conv2d_12[0][0]
______________________________________________________________________________________________________________
gaussian_noise_15 (GaussianNoise) (None, None, None, 32) 0 batch_normalization_12[0][0]
______________________________________________________________________________________________________________
conv2d_13 (Conv2D) (None, None, None, 3) 867 gaussian_noise_15[0][0]
==============================================================================================================
Total params: 372,771
Trainable params: 371,415
Non-trainable params: 1,356
Model 2 = 1.6MB
model.summary(line_length=110)
______________________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==============================================================================================================
input_1 (InputLayer) (None, 368, 256, 4) 0
______________________________________________________________________________________________________________
gaussian_noise_1 (GaussianNoise) (None, 368, 256, 4) 0 input_1[0][0]
______________________________________________________________________________________________________________
conv2d_1 (Conv2D) (None, 368, 256, 32) 1184 gaussian_noise_1[0][0]
______________________________________________________________________________________________________________
batch_normalization_1 (BatchNormali (None, 368, 256, 32) 128 conv2d_1[0][0]
______________________________________________________________________________________________________________
gaussian_noise_2 (GaussianNoise) (None, 368, 256, 32) 0 batch_normalization_1[0][0]
______________________________________________________________________________________________________________
conv2d_2 (Conv2D) (None, 184, 128, 32) 9248 gaussian_noise_2[0][0]
______________________________________________________________________________________________________________
batch_normalization_2 (BatchNormali (None, 184, 128, 32) 128 conv2d_2[0][0]
______________________________________________________________________________________________________________
gaussian_noise_3 (GaussianNoise) (None, 184, 128, 32) 0 batch_normalization_2[0][0]
______________________________________________________________________________________________________________
conv2d_3 (Conv2D) (None, 184, 128, 64) 18496 gaussian_noise_3[0][0]
______________________________________________________________________________________________________________
batch_normalization_3 (BatchNormali (None, 184, 128, 64) 256 conv2d_3[0][0]
______________________________________________________________________________________________________________
gaussian_noise_4 (GaussianNoise) (None, 184, 128, 64) 0 batch_normalization_3[0][0]
______________________________________________________________________________________________________________
conv2d_4 (Conv2D) (None, 184, 128, 64) 36928 gaussian_noise_4[0][0]
______________________________________________________________________________________________________________
batch_normalization_4 (BatchNormali (None, 184, 128, 64) 256 conv2d_4[0][0]
______________________________________________________________________________________________________________
gaussian_noise_5 (GaussianNoise) (None, 184, 128, 64) 0 batch_normalization_4[0][0]
______________________________________________________________________________________________________________
up_sampling2d_1 (UpSampling2D) (None, 368, 256, 64) 0 gaussian_noise_5[0][0]
______________________________________________________________________________________________________________
concatenate_1 (Concatenate) (None, 368, 256, 96) 0 up_sampling2d_1[0][0]
batch_normalization_1[0][0]
______________________________________________________________________________________________________________
gaussian_noise_6 (GaussianNoise) (None, 368, 256, 96) 0 concatenate_1[0][0]
______________________________________________________________________________________________________________
conv2d_5 (Conv2D) (None, 368, 256, 64) 55360 gaussian_noise_6[0][0]
______________________________________________________________________________________________________________
batch_normalization_5 (BatchNormali (None, 368, 256, 64) 256 conv2d_5[0][0]
______________________________________________________________________________________________________________
gaussian_noise_7 (GaussianNoise) (None, 368, 256, 64) 0 batch_normalization_5[0][0]
______________________________________________________________________________________________________________
conv2d_6 (Conv2D) (None, 368, 256, 64) 36928 gaussian_noise_7[0][0]
______________________________________________________________________________________________________________
batch_normalization_6 (BatchNormali (None, 368, 256, 64) 256 conv2d_6[0][0]
______________________________________________________________________________________________________________
gaussian_noise_8 (GaussianNoise) (None, 368, 256, 64) 0 batch_normalization_6[0][0]
______________________________________________________________________________________________________________
conv2d_7 (Conv2D) (None, 368, 256, 64) 36928 gaussian_noise_8[0][0]
______________________________________________________________________________________________________________
batch_normalization_7 (BatchNormali (None, 368, 256, 64) 256 conv2d_7[0][0]
______________________________________________________________________________________________________________
gaussian_noise_9 (GaussianNoise) (None, 368, 256, 64) 0 batch_normalization_7[0][0]
______________________________________________________________________________________________________________
conv2d_8 (Conv2D) (None, 368, 256, 64) 36928 gaussian_noise_9[0][0]
______________________________________________________________________________________________________________
batch_normalization_8 (BatchNormali (None, 368, 256, 64) 256 conv2d_8[0][0]
______________________________________________________________________________________________________________
gaussian_noise_10 (GaussianNoise) (None, 368, 256, 64) 0 batch_normalization_8[0][0]
______________________________________________________________________________________________________________
conv2d_9 (Conv2D) (None, 368, 256, 64) 36928 gaussian_noise_10[0][0]
______________________________________________________________________________________________________________
batch_normalization_9 (BatchNormali (None, 368, 256, 64) 256 conv2d_9[0][0]
______________________________________________________________________________________________________________
gaussian_noise_11 (GaussianNoise) (None, 368, 256, 64) 0 batch_normalization_9[0][0]
______________________________________________________________________________________________________________
up_sampling2d_2 (UpSampling2D) (None, 736, 512, 64) 0 gaussian_noise_11[0][0]
______________________________________________________________________________________________________________
input_2 (InputLayer) (None, 736, 512, 3) 0
______________________________________________________________________________________________________________
concatenate_2 (Concatenate) (None, 736, 512, 67) 0 up_sampling2d_2[0][0]
input_2[0][0]
______________________________________________________________________________________________________________
gaussian_noise_12 (GaussianNoise) (None, 736, 512, 67) 0 concatenate_2[0][0]
______________________________________________________________________________________________________________
conv2d_10 (Conv2D) (None, 736, 512, 67) 40468 gaussian_noise_12[0][0]
______________________________________________________________________________________________________________
batch_normalization_10 (BatchNormal (None, 736, 512, 67) 268 conv2d_10[0][0]
______________________________________________________________________________________________________________
gaussian_noise_13 (GaussianNoise) (None, 736, 512, 67) 0 batch_normalization_10[0][0]
______________________________________________________________________________________________________________
conv2d_11 (Conv2D) (None, 736, 512, 67) 40468 gaussian_noise_13[0][0]
______________________________________________________________________________________________________________
batch_normalization_11 (BatchNormal (None, 736, 512, 67) 268 conv2d_11[0][0]
______________________________________________________________________________________________________________
gaussian_noise_14 (GaussianNoise) (None, 736, 512, 67) 0 batch_normalization_11[0][0]
______________________________________________________________________________________________________________
conv2d_12 (Conv2D) (None, 736, 512, 32) 19328 gaussian_noise_14[0][0]
______________________________________________________________________________________________________________
batch_normalization_12 (BatchNormal (None, 736, 512, 32) 128 conv2d_12[0][0]
______________________________________________________________________________________________________________
gaussian_noise_15 (GaussianNoise) (None, 736, 512, 32) 0 batch_normalization_12[0][0]
______________________________________________________________________________________________________________
conv2d_13 (Conv2D) (None, 736, 512, 3) 867 gaussian_noise_15[0][0]
==============================================================================================================
Total params: 372,771
Trainable params: 371,415
Non-trainable params: 1,356
___________________________
#FCOS I think the difference is due to the fact that you haven't trained one model whereas the other model was trained.
When you save a trained model, it saves
model architecture,
weights & biases, and
optimizer's configuration
However, when you save a model that was not trained, it will not have optimizer's configuration.
To test the size difference, I created simple model with and without input size and found out that both the models have exactly same size as the number of parameters are same for both the models. Please check the model1 and model2 below.
Here is model1
import tensorflow as tf
model1 = tf.keras.models.Sequential([
tf.keras.layers.Dense(128, activation='relu',input_shape=(None, None, 784,)),
tf.keras.layers.Dense(256, activation='relu'),
tf.keras.layers.Dense(512, activation='relu'),
tf.keras.layers.Dense(256, activation='relu'),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')
])
model1.save('mymodel1.h5',overwrite=True,include_optimizer=True)
model1.summary()
Model: "sequential_10"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense_20 (Dense) (None, None, None, 128) 100480
_________________________________________________________________
dense_21 (Dense) (None, None, None, 256) 33024
_________________________________________________________________
dense_22 (Dense) (None, None, None, 512) 131584
_________________________________________________________________
dense_23 (Dense) (None, None, None, 256) 131328
_________________________________________________________________
dense_24 (Dense) (None, None, None, 128) 32896
_________________________________________________________________
dense_25 (Dense) (None, None, None, 64) 8256
_________________________________________________________________
dense_26 (Dense) (None, None, None, 10) 650
=================================================================
Total params: 438,218
Trainable params: 438,218
Non-trainable params: 0
Here is model2
model2 = tf.keras.models.Sequential([
tf.keras.layers.Dense(128, activation='relu',input_shape=(300, 300, 784,)),
tf.keras.layers.Dense(256, activation='relu'),
tf.keras.layers.Dense(512, activation='relu'),
tf.keras.layers.Dense(256, activation='relu'),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')
])
model2.save('mymodel2.h5',overwrite=True,include_optimizer=True)
model2.summary()
Model: "sequential_11"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense_27 (Dense) (None, 300, 300, 128) 100480
_________________________________________________________________
dense_28 (Dense) (None, 300, 300, 256) 33024
_________________________________________________________________
dense_29 (Dense) (None, 300, 300, 512) 131584
_________________________________________________________________
dense_30 (Dense) (None, 300, 300, 256) 131328
_________________________________________________________________
dense_31 (Dense) (None, 300, 300, 128) 32896
_________________________________________________________________
dense_32 (Dense) (None, 300, 300, 64) 8256
_________________________________________________________________
dense_33 (Dense) (None, 300, 300, 10) 650
=================================================================
Total params: 438,218
Trainable params: 438,218
Non-trainable params: 0
_________________________________________________________________
Size of mode#1 and model#2 is same (1.7 MB).
Please let us know if you have any comments. Thanks!

Tensorflow MultiRNNCell save and restore

It seems like either saving or restoring the model with MultiRNNCell is not working properly.
I was working on classification problem by using below code;
stacked_rnn_cell = list()
for i in range(config.num_layers):
rnn_cell = tf.contrib.rnn.LSTMCell(num_units=config.dim_hidden)
stacked_rnn_cell.append(rnn_cell)
rnn_cell = tf.contrib.rnn.MultiRNNCell(cells=stacked_rnn_cell, state_is_tuple=True)
And I used rnn_cell(stacked) to train and saved the model.
In my case, I trained for two models: one with num_layers = 2, the other with num_layer = 3.
And then, I run the above code first then do the restore procedure to replace the weight values to the variables above.
It seems like it only loads the first layer of the rnn_cell since loading with num_layers = 1 gives exactly same result as the one with the model num_layers = 2, or num_layers = 3.
Model itself is loaded well so I can only think that it is not saved or loaded properly.
=====
Edited: I loaded without any matching model and used the code below to see
tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
Then the output is
[<tf.Variable 'global_step:0' shape=() dtype=int32_ref>,
<tf.Variable 'embedding_layer/w:0' shape=(11441, 200) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/fw/multi_rnn_cell/cell_0/lstm_cell/kernel:0' shape=(400, 800) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/fw/multi_rnn_cell/cell_0/lstm_cell/bias:0' shape=(800,) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/fw/multi_rnn_cell/cell_1/lstm_cell/kernel:0' shape=(400, 800) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/fw/multi_rnn_cell/cell_1/lstm_cell/bias:0' shape=(800,) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/fw/multi_rnn_cell/cell_2/lstm_cell/kernel:0' shape=(400, 800) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/fw/multi_rnn_cell/cell_2/lstm_cell/bias:0' shape=(800,) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/bw/multi_rnn_cell/cell_0/lstm_cell/kernel:0' shape=(400, 800) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/bw/multi_rnn_cell/cell_0/lstm_cell/bias:0' shape=(800,) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/bw/multi_rnn_cell/cell_1/lstm_cell/kernel:0' shape=(400, 800) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/bw/multi_rnn_cell/cell_1/lstm_cell/bias:0' shape=(800,) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/bw/multi_rnn_cell/cell_2/lstm_cell/kernel:0' shape=(400, 800) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/bw/multi_rnn_cell/cell_2/lstm_cell/bias:0' shape=(800,) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/attn/w1:0' shape=(400, 400) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/attn/w2:0' shape=(400, 400) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/attn/w3:0' shape=(400, 1) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/attn/b2:0' shape=(400,) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/attn/concat_w:0' shape=(800, 400) dtype=float32_ref>,
<tf.Variable 'logits/w:0' shape=(800, 7) dtype=float32_ref>,
<tf.Variable 'logits/b:0' shape=(7,) dtype=float32_ref>,
<tf.Variable 'train_optimizer/beta1_power:0' shape=() dtype=float32_ref>,
<tf.Variable 'train_optimizer/beta2_power:0' shape=() dtype=float32_ref>,
<tf.Variable 'embedding_layer/w/Adam:0' shape=(11441, 200) dtype=float32_ref>,
<tf.Variable 'embedding_layer/w/Adam_1:0' shape=(11441, 200) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/fw/multi_rnn_cell/cell_0/lstm_cell/kernel/Adam:0' shape=(400, 800) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/fw/multi_rnn_cell/cell_0/lstm_cell/kernel/Adam_1:0' shape=(400, 800) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/fw/multi_rnn_cell/cell_0/lstm_cell/bias/Adam:0' shape=(800,) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/fw/multi_rnn_cell/cell_0/lstm_cell/bias/Adam_1:0' shape=(800,) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/fw/multi_rnn_cell/cell_1/lstm_cell/kernel/Adam:0' shape=(400, 800) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/fw/multi_rnn_cell/cell_1/lstm_cell/kernel/Adam_1:0' shape=(400, 800) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/fw/multi_rnn_cell/cell_1/lstm_cell/bias/Adam:0' shape=(800,) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/fw/multi_rnn_cell/cell_1/lstm_cell/bias/Adam_1:0' shape=(800,) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/fw/multi_rnn_cell/cell_2/lstm_cell/kernel/Adam:0' shape=(400, 800) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/fw/multi_rnn_cell/cell_2/lstm_cell/kernel/Adam_1:0' shape=(400, 800) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/fw/multi_rnn_cell/cell_2/lstm_cell/bias/Adam:0' shape=(800,) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/fw/multi_rnn_cell/cell_2/lstm_cell/bias/Adam_1:0' shape=(800,) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/bw/multi_rnn_cell/cell_0/lstm_cell/kernel/Adam:0' shape=(400, 800) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/bw/multi_rnn_cell/cell_0/lstm_cell/kernel/Adam_1:0' shape=(400, 800) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/bw/multi_rnn_cell/cell_0/lstm_cell/bias/Adam:0' shape=(800,) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/bw/multi_rnn_cell/cell_0/lstm_cell/bias/Adam_1:0' shape=(800,) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/bw/multi_rnn_cell/cell_1/lstm_cell/kernel/Adam:0' shape=(400, 800) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/bw/multi_rnn_cell/cell_1/lstm_cell/kernel/Adam_1:0' shape=(400, 800) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/bw/multi_rnn_cell/cell_1/lstm_cell/bias/Adam:0' shape=(800,) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/bw/multi_rnn_cell/cell_1/lstm_cell/bias/Adam_1:0' shape=(800,) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/bw/multi_rnn_cell/cell_2/lstm_cell/kernel/Adam:0' shape=(400, 800) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/bw/multi_rnn_cell/cell_2/lstm_cell/kernel/Adam_1:0' shape=(400, 800) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/bw/multi_rnn_cell/cell_2/lstm_cell/bias/Adam:0' shape=(800,) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/bw/multi_rnn_cell/cell_2/lstm_cell/bias/Adam_1:0' shape=(800,) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/attn/w1/Adam:0' shape=(400, 400) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/attn/w1/Adam_1:0' shape=(400, 400) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/attn/w2/Adam:0' shape=(400, 400) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/attn/w2/Adam_1:0' shape=(400, 400) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/attn/w3/Adam:0' shape=(400, 1) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/attn/w3/Adam_1:0' shape=(400, 1) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/attn/b2/Adam:0' shape=(400,) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/attn/b2/Adam_1:0' shape=(400,) dtype=float32_ref>,
<tf.Variable 'logits/w/Adam:0' shape=(800, 7) dtype=float32_ref>,
<tf.Variable 'logits/w/Adam_1:0' shape=(800, 7) dtype=float32_ref>,
<tf.Variable 'logits/b/Adam:0' shape=(7,) dtype=float32_ref>,
<tf.Variable 'logits/b/Adam_1:0' shape=(7,) dtype=float32_ref>]
which means that it stored correctly as it was intended, using three hidden layers of rnn cells. But it seems like weights does not match to the models automatically with this multi hidden layers.
=====
I tried to find the save and restore of the model using deep rnn, but I couldn't find any so I am asking here to get some help.
Anyone had the same problem and solution for this?
You can change "tf.contrib.rnn.MultiRNNCell" to "tf.nn.rnn_cell.MultiRNNCell"