AttributeError: 'int' object has no attribute 'value' - tensorflow

I can't wrap my head around this problem I am getting here. I am running on Tensorflow 2 and I am really not seeing why this error appears. Is there something I am missing?
This is the relevant part of the code where the error appears:
from tensorflow.lite.experimental.examples.lstm.rnn import bidirectional_dynamic_rnn
from tensorflow.lite.experimental.examples.lstm.rnn_cell import TFLiteLSTMCell
...
lstm_cells = []
lstm_0 = TFLiteLSTMCell(num_units=256, forget_bias=0, name='rnn_0')
lstm_1 = TFLiteLSTMCell(num_units=256, forget_bias=0, name='rnn_1')
lstm_2 = TFLiteLSTMCell(num_units=128, forget_bias=0, name='rnn_2')
lstm_3 = TFLiteLSTMCell(num_units=128, forget_bias=0, name='rnn_3')
lstm_cells.append(lstm_0)
lstm_cells.append(lstm_1)
lstm_cells.append(lstm_2)
lstm_cells.append(lstm_3)
bi_LSTM_2 = layers.Lambda(buildLstmLayer, arguments={'layers' : lstm_cells})(fc_1)
...
This is the corresponding Lambda Layer. I am creating the bidirectional RNNs, but I think the error is more about the TFLiteLSTMCell itself, but I think I am using it correctly.
def buildLstmLayer(inputs, layers):
inputs = tf.transpose(inputs, [1,0,2])
# inputs = tf.unstack(inputs, axis=1)
inter_output, _ = bidirectional_dynamic_rnn (
layers[0],
layers[1],
inputs,
dtype='float32',
time_major=True)
inter_output = tf.concat(inter_output, 2)
output, _ = bidirectional_dynamic_rnn (
layers[2],
layers[3],
inter_output,
dtype='float32',
time_major=True)
output = tf.concat(output, 2)
# output = tf.stack(output, axis=1)
output = tf.transpose(output, [1,0,2])
return output
This is the traceback I am getting:
Traceback (most recent call last):
File "crnn_architecture.py", line 279, in <module>
model, base_model = CRNN_model(is_training=True)
File "crnn_architecture.py", line 108, in CRNN_model
bi_LSTM_2 = layers.Lambda(buildLstmLayer, arguments={'layers' : lstm_cells})(fc_1)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/base_layer.py", line 847, in __call__
outputs = call_fn(cast_inputs, *args, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/layers/core.py", line 795, in call
return self.function(inputs, **arguments)
File "crnn_architecture.py", line 146, in buildLstmLayer
time_major=True)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/lite/experimental/examples/lstm/rnn.py", line 379, in bidirectional_dynamic_rnn
scope=fw_scope)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/lite/experimental/examples/lstm/rnn.py", line 266, in dynamic_rnn
dtype=dtype)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/ops/rnn.py", line 916, in _dynamic_rnn_loop
swap_memory=swap_memory)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/ops/control_flow_ops.py", line 2675, in while_loop
back_prop=back_prop)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/ops/while_v2.py", line 198, in while_loop
add_control_dependencies=add_control_dependencies)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/func_graph.py", line 915, in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/ops/while_v2.py", line 176, in wrapped_body
outputs = body(*_pack_sequence_as(orig_loop_vars, args))
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/ops/rnn.py", line 884, in _time_step
(output, new_state) = call_cell()
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/ops/rnn.py", line 870, in <lambda>
call_cell = lambda: cell(input_t, state)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/base_layer.py", line 847, in __call__
outputs = call_fn(cast_inputs, *args, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/layers/recurrent.py", line 137, in call
inputs, states = cell.call(inputs, states, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/lite/experimental/examples/lstm/rnn_cell.py", line 440, in call
if input_size.value is None:
AttributeError: 'int' object has no attribute 'value'

Related

Reimplementing bert-style pooler throws shape error as if length-dimension were still needed

I have trained an off-the-shelf Transformer().
Now I want to use the encoder in order to build a classifier. For that I want to only use the first token's output (bert-style cls-token-result) and run that through a dense layer.
What I do:
tl.Serial(encoder, tl.Fn('pooler', lambda x: (x[:, 0, :])), tl.Dense(7))
Shapes:
The encoder gives me shape (64, 50, 512)
with
64 = batch_size,
50 = seq_len,
512 = model_dim
The pooler gives me shape (64, 512) which is as expected and desired.
The dense layer is supposed to take the 512 dimensions for each batchmember and classify over 7 classes. But I guess trax/jax still expects this to have length seq_len (50).
TypeError: dot_general requires contracting dimensions to have the same shape, got [512] and [50].
What do I miss?
Full traceback:
Traceback (most recent call last):
File "mikado_classes.py", line 2054, in <module>
app.run(main)
File "/root/.local/lib/python3.7/site-packages/absl/app.py", line 300, in run
_run_main(main, args)
File "/root/.local/lib/python3.7/site-packages/absl/app.py", line 251, in _run_main
sys.exit(main(argv))
File "mikado_classes.py", line 1153, in main
loop_neu.run(2)
File "/root/.local/lib/python3.7/site-packages/trax/supervised/training.py", line 361, in run
loss, optimizer_metrics = self._run_one_step(task_index, task_changed)
File "/root/.local/lib/python3.7/site-packages/trax/supervised/training.py", line 483, in _run_one_step
batch, rng, step=step, learning_rate=learning_rate
File "/root/.local/lib/python3.7/site-packages/trax/optimizers/trainer.py", line 134, in one_step
(weights, self._slots), step, self._opt_params, batch, state, rng)
File "/root/.local/lib/python3.7/site-packages/trax/optimizers/trainer.py", line 173, in single_device_update_fn
batch, weights, state, rng)
File "/root/.local/lib/python3.7/site-packages/trax/layers/base.py", line 549, in pure_fn
self._caller, signature(x), trace) from None
jax._src.traceback_util.FilteredStackTrace: trax.layers.base.LayerError: Exception passing through layer Serial (in pure_fn):
layer created in file [...]/trax/supervised/training.py, line 865
layer input shapes: (ShapeDtype{shape:(64, 50), dtype:int32}, ShapeDtype{shape:(64, 1), dtype:int32}, ShapeDtype{shape:(64, 1), dtype:int32})
File [...]/trax/layers/combinators.py, line 88, in forward
outputs, s = layer.pure_fn(inputs, w, s, rng, use_cache=True)
LayerError: Exception passing through layer Serial (in pure_fn):
layer created in file [...]/mikado_classes.py, line 1134
layer input shapes: (ShapeDtype{shape:(64, 50), dtype:int32}, ShapeDtype{shape:(64, 1), dtype:int32}, ShapeDtype{shape:(64, 1), dtype:int32})
File [...]/trax/layers/combinators.py, line 88, in forward
outputs, s = layer.pure_fn(inputs, w, s, rng, use_cache=True)
LayerError: Exception passing through layer Dense_7 (in pure_fn):
layer created in file [...]/mikado_classes.py, line 1133
layer input shapes: ShapeDtype{shape:(64, 512), dtype:float32}
File [...]/trax/layers/assert_shape.py, line 122, in forward_wrapper
y = forward(self, x, *args, **kwargs)
File [...]/trax/layers/core.py, line 95, in forward
return jnp.dot(x, w) + b # Affine map.
File [...]/_src/numpy/lax_numpy.py, line 3498, in dot
return lax.dot_general(a, b, (contract_dims, batch_dims), precision)
File [...]/_src/lax/lax.py, line 674, in dot_general
preferred_element_type=preferred_element_type)
File [...]/site-packages/jax/core.py, line 282, in bind
out = top_trace.process_primitive(self, tracers, params)
File [...]/jax/interpreters/ad.py, line 285, in process_primitive
primal_out, tangent_out = jvp(primals_in, tangents_in, **params)
File [...]/jax/interpreters/ad.py, line 458, in standard_jvp
val_out = primitive.bind(*primals, **params)
File [...]/site-packages/jax/core.py, line 282, in bind
out = top_trace.process_primitive(self, tracers, params)
File [...]/jax/interpreters/partial_eval.py, line 140, in process_primitive
return self.default_process_primitive(primitive, tracers, params)
File [...]/jax/interpreters/partial_eval.py, line 147, in default_process_primitive
return primitive.bind(*consts, **params)
File [...]/site-packages/jax/core.py, line 282, in bind
out = top_trace.process_primitive(self, tracers, params)
File [...]/jax/interpreters/partial_eval.py, line 1058, in process_primitive
out_avals = primitive.abstract_eval(*avals, **params)
File [...]/_src/lax/lax.py, line 1992, in standard_abstract_eval
shapes, dtypes = shape_rule(*args, **kwargs), dtype_rule(*args, **kwargs)
File [...]/_src/lax/lax.py, line 3090, in _dot_general_shape_rule
raise TypeError(msg.format(lhs_contracting_shape, rhs_contracting_shape))
TypeError: dot_general requires contracting dimensions to have the same shape, got [512] and [50].
The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "mikado_classes.py", line 2054, in <module>
app.run(main)
File "/root/.local/lib/python3.7/site-packages/absl/app.py", line 300, in run
_run_main(main, args)
File "/root/.local/lib/python3.7/site-packages/absl/app.py", line 251, in _run_main
sys.exit(main(argv))
File "mikado_classes.py", line 1153, in main
loop_neu.run(2)
File "/root/.local/lib/python3.7/site-packages/trax/supervised/training.py", line 361, in run
loss, optimizer_metrics = self._run_one_step(task_index, task_changed)
File "/root/.local/lib/python3.7/site-packages/trax/supervised/training.py", line 483, in _run_one_step
batch, rng, step=step, learning_rate=learning_rate
File "/root/.local/lib/python3.7/site-packages/trax/optimizers/trainer.py", line 134, in one_step
(weights, self._slots), step, self._opt_params, batch, state, rng)
File "/root/.local/lib/python3.7/site-packages/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/root/.local/lib/python3.7/site-packages/jax/api.py", line 398, in f_jitted
return cpp_jitted_f(context, *args, **kwargs)
File "/root/.local/lib/python3.7/site-packages/jax/api.py", line 295, in cache_miss
donated_invars=donated_invars)
File "/root/.local/lib/python3.7/site-packages/jax/core.py", line 1275, in bind
return call_bind(self, fun, *args, **params)
File "/root/.local/lib/python3.7/site-packages/jax/core.py", line 1266, in call_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/root/.local/lib/python3.7/site-packages/jax/core.py", line 1278, in process
return trace.process_call(self, fun, tracers, params)
File "/root/.local/lib/python3.7/site-packages/jax/core.py", line 631, in process_call
return primitive.impl(f, *tracers, **params)
File "/root/.local/lib/python3.7/site-packages/jax/interpreters/xla.py", line 581, in _xla_call_impl
*unsafe_map(arg_spec, args))
File "/root/.local/lib/python3.7/site-packages/jax/linear_util.py", line 260, in memoized_fun
ans = call(fun, *args)
File "/root/.local/lib/python3.7/site-packages/jax/interpreters/xla.py", line 656, in _xla_callable
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args)
File "/root/.local/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 1216, in trace_to_jaxpr_final
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
File "/root/.local/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 1196, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers)
File "/root/.local/lib/python3.7/site-packages/jax/linear_util.py", line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/root/.local/lib/python3.7/site-packages/trax/optimizers/trainer.py", line 173, in single_device_update_fn
batch, weights, state, rng)
File "/root/.local/lib/python3.7/site-packages/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/root/.local/lib/python3.7/site-packages/jax/api.py", line 810, in value_and_grad_f
ans, vjp_py, aux = _vjp(f_partial, *dyn_args, has_aux=True)
File "/root/.local/lib/python3.7/site-packages/jax/api.py", line 1918, in _vjp
out_primal, out_vjp, aux = ad.vjp(flat_fun, primals_flat, has_aux=True)
File "/root/.local/lib/python3.7/site-packages/jax/interpreters/ad.py", line 116, in vjp
out_primals, pvals, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True)
File "/root/.local/lib/python3.7/site-packages/jax/interpreters/ad.py", line 101, in linearize
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
File "/root/.local/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 506, in trace_to_jaxpr
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
File "/root/.local/lib/python3.7/site-packages/jax/linear_util.py", line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/root/.local/lib/python3.7/site-packages/trax/layers/base.py", line 549, in pure_fn
self._caller, signature(x), trace) from None
trax.layers.base.LayerError: Exception passing through layer Serial (in pure_fn):
layer created in file [...]/trax/supervised/training.py, line 865
layer input shapes: (ShapeDtype{shape:(64, 50), dtype:int32}, ShapeDtype{shape:(64, 1), dtype:int32}, ShapeDtype{shape:(64, 1), dtype:int32})
File [...]/trax/layers/combinators.py, line 88, in forward
outputs, s = layer.pure_fn(inputs, w, s, rng, use_cache=True)
LayerError: Exception passing through layer Serial (in pure_fn):
layer created in file [...]/mikado_classes.py, line 1134
layer input shapes: (ShapeDtype{shape:(64, 50), dtype:int32}, ShapeDtype{shape:(64, 1), dtype:int32}, ShapeDtype{shape:(64, 1), dtype:int32})
File [...]/trax/layers/combinators.py, line 88, in forward
outputs, s = layer.pure_fn(inputs, w, s, rng, use_cache=True)
LayerError: Exception passing through layer Dense_7 (in pure_fn):
layer created in file [...]/mikado_classes.py, line 1133
layer input shapes: ShapeDtype{shape:(64, 512), dtype:float32}
File [...]/trax/layers/assert_shape.py, line 122, in forward_wrapper
y = forward(self, x, *args, **kwargs)
File [...]/trax/layers/core.py, line 95, in forward
return jnp.dot(x, w) + b # Affine map.
File [...]/_src/numpy/lax_numpy.py, line 3498, in dot
return lax.dot_general(a, b, (contract_dims, batch_dims), precision)
File [...]/_src/lax/lax.py, line 674, in dot_general
preferred_element_type=preferred_element_type)
File [...]/site-packages/jax/core.py, line 282, in bind
out = top_trace.process_primitive(self, tracers, params)
File [...]/jax/interpreters/ad.py, line 285, in process_primitive
primal_out, tangent_out = jvp(primals_in, tangents_in, **params)
File [...]/jax/interpreters/ad.py, line 458, in standard_jvp
val_out = primitive.bind(*primals, **params)
File [...]/site-packages/jax/core.py, line 282, in bind
out = top_trace.process_primitive(self, tracers, params)
File [...]/jax/interpreters/partial_eval.py, line 140, in process_primitive
return self.default_process_primitive(primitive, tracers, params)
File [...]/jax/interpreters/partial_eval.py, line 147, in default_process_primitive
return primitive.bind(*consts, **params)
File [...]/site-packages/jax/core.py, line 282, in bind
out = top_trace.process_primitive(self, tracers, params)
File [...]/jax/interpreters/partial_eval.py, line 1058, in process_primitive
out_avals = primitive.abstract_eval(*avals, **params)
File [...]/_src/lax/lax.py, line 1992, in standard_abstract_eval
shapes, dtypes = shape_rule(*args, **kwargs), dtype_rule(*args, **kwargs)
File [...]/_src/lax/lax.py, line 3090, in _dot_general_shape_rule
raise TypeError(msg.format(lhs_contracting_shape, rhs_contracting_shape))
TypeError: dot_general requires contracting dimensions to have the same shape, got [512] and [50].
The mistake was not in the architecture. Problem was: My inputs were not shaped correctly.
The target should have been of shape (batch_size, ) but I sent (batch_size, 1). So a target array should have been, e.g.:
[1, 5, 99, 2, 1, 3, 2, 8]
but I produced
[[1], [5], [99], [2], [1], [3], [2], [8]].

Tensorflow: sample of dataset gets shape None in map

I have a dataset of spectrograms (images) of shape (128x128x1), I want to do dataaugmentation on it. But when I try to,
def augment(stft, label):
print(stft.shape)
stft = tf.image.random_brightness(stft, 0.2)
print(stft.shape)
stft = tf.keras.preprocessing.image.random_shift(x=stft, wrg=0.1, hrg=0.1, row_axis=0, col_axis=1, channel_axis=2, fill_mode='wrap')
return stft, label
val_ds= (
val_ds.map(augment, num_parallel_calls=config.AUTOTUNE).prefetch(config.AUTOTUNE))
I get the following output and error:
(None, 128, 1)
(None, 128, 1)
Traceback (most recent call last):
File "train.py", line 132, in <module>
val_ds.map(augment, num_parallel_calls=config.AUTOTUNE).prefetch(config.AUTOTUNE))
File "/usr/local/lib64/python3.8/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 1623, in map
return ParallelMapDataset(
File "/usr/local/lib64/python3.8/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 4016, in __init__
self._map_func = StructuredFunctionWrapper(
File "/usr/local/lib64/python3.8/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 3221, in __init__
self._function = wrapper_fn.get_concrete_function()
File "/usr/local/lib64/python3.8/site-packages/tensorflow/python/eager/function.py", line 2531, in get_concrete_function
graph_function = self._get_concrete_function_garbage_collected(
File "/usr/local/lib64/python3.8/site-packages/tensorflow/python/eager/function.py", line 2496, in _get_concrete_function_garbage_collected
graph_function, args, kwargs = self._maybe_define_function(args, kwargs)
File "/usr/local/lib64/python3.8/site-packages/tensorflow/python/eager/function.py", line 2777, in _maybe_define_function
graph_function = self._create_graph_function(args, kwargs)
File "/usr/local/lib64/python3.8/site-packages/tensorflow/python/eager/function.py", line 2657, in _create_graph_function
func_graph_module.func_graph_from_py_func(
File "/usr/local/lib64/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 981, in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
File "/usr/local/lib64/python3.8/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 3214, in wrapper_fn
ret = _wrapper_helper(*args)
File "/usr/local/lib64/python3.8/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 3156, in _wrapper_helper
ret = autograph.tf_convert(func, ag_ctx)(*nested_args)
File "/usr/local/lib64/python3.8/site-packages/tensorflow/python/autograph/impl/api.py", line 265, in wrapper
raise e.ag_error_metadata.to_exception(e)
TypeError: in user code:
train.py:128 augment *
stft = tf.keras.preprocessing.image.random_shift(x=stft, wrg=0.1, hrg=0.1, row_axis=0, col_axis=1, channel_axis=2, fill_mode='wrap')
/usr/local/lib64/python3.8/site-packages/keras_preprocessing/image/affine_transformations.py:85 random_shift *
tx = np.random.uniform(-hrg, hrg) * h
TypeError: unsupported operand type(s) for *: 'float' and 'NoneType'
Why is the shape of the tensor passed to the mapped function (None, 128, 1) and how to fix this problem?
Thanks,
You need to specify a batch size.
add
val_ds.batch(32)
before your map.

getting error while training yolov3 :- ValueError: tf.function-decorated function tried to create variables on non-first call

I am training a custom yolov3 model and getting error "ValueError: tf.function-decorated function tried to create variables on non-first call." while fitting the model for training.
getting eror on fit_generator statement. Could somebody please help?
train_generator = BatchGenerator(
instances = train_ints,
anchors = config['model']['anchors'],
labels = labels,
downsample = 32, # ratio between network input's size and network output's size, 32 for YOLOv3
max_box_per_image = max_box_per_image,
batch_size = config['train']['batch_size'],
min_net_size = config['model']['min_input_size'],
max_net_size = config['model']['max_input_size'],
shuffle = True,
jitter = 0.3,
norm = normalize
)
train_model, infer_model = create_model(
nb_class = len(labels),
anchors = config['model']['anchors'],
max_box_per_image = max_box_per_image,
max_grid = [config['model']['max_input_size'], config['model']['max_input_size']],
batch_size = config['train']['batch_size'],
warmup_batches = warmup_batches,
ignore_thresh = config['train']['ignore_thresh'],
multi_gpu = multi_gpu,
saved_weights_name = config['train']['saved_weights_name'],
lr = config['train']['learning_rate'],
grid_scales = config['train']['grid_scales'],
obj_scale = config['train']['obj_scale'],
noobj_scale = config['train']['noobj_scale'],
xywh_scale = config['train']['xywh_scale'],
class_scale = config['train']['class_scale'],
)
###############################
# Kick off the training
###############################
callbacks = create_callbacks(config['train']['saved_weights_name'], config['train']['tensorboard_dir'], infer_model)
print ("before kickoff", len(train_generator))
print ("before kickoff", train_generator)
**train_model.fit_generator(
generator = train_generator,**
steps_per_epoch = len(train_generator) * config['train']['train_times'],
epochs = config['train']['nb_epochs'] + config['train']['warmup_epochs'],
#epochs = 1,
verbose = 2 if config['train']['debug'] else 1,
callbacks = callbacks,
workers = 2,
max_queue_size = 8
)
print ("after kickoff")
Error am getting is :
WARNING:tensorflow:Model failed to serialize as JSON. Ignoring... Layer YoloLayer has arguments in __init__ and therefore must override get_config.
Epoch 1/21
Traceback (most recent call last):
File "train.py", line 300, in
main(args)
File "train.py", line 269, in main
train_model.fit_generator(
File "/Users/karthikeyan/opt/anaconda3/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py", line 324, in new_func
return func(*args, **kwargs)
File "/Users/karthikeyan/opt/anaconda3/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py", line 1815, in fit_generator
return self.fit(
File "/Users/karthikeyan/opt/anaconda3/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py", line 108, in _method_wrapper
return method(self, *args, **kwargs)
File "/Users/karthikeyan/opt/anaconda3/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py", line 1098, in fit
tmp_logs = train_function(iterator)
File "/Users/karthikeyan/opt/anaconda3/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 780, in call
result = self._call(*args, **kwds)
File "/Users/karthikeyan/opt/anaconda3/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 840, in _call
return self._stateless_fn(*args, **kwds)
File "/Users/karthikeyan/opt/anaconda3/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 2828, in call
graph_function, args, kwargs = self._maybe_define_function(args, kwargs)
File "/Users/karthikeyan/opt/anaconda3/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3213, in _maybe_define_function
graph_function = self._create_graph_function(args, kwargs)
File "/Users/karthikeyan/opt/anaconda3/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3065, in _create_graph_function
func_graph_module.func_graph_from_py_func(
File "/Users/karthikeyan/opt/anaconda3/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 986, in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
File "/Users/karthikeyan/opt/anaconda3/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 600, in wrapped_fn
return weak_wrapped_fn().wrapped(*args, **kwds)
File "/Users/karthikeyan/opt/anaconda3/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 973, in wrapper
raise e.ag_error_metadata.to_exception(e)
ValueError: in user code:
/Users/karthikeyan/opt/anaconda3/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py:806 train_function *
return step_function(self, iterator)
/Users/karthikeyan/Desktop/table/yolo.py:46 call *
batch_seen = tf.Variable(0.)
/Users/karthikeyan/opt/anaconda3/lib/python3.8/site-packages/tensorflow/python/ops/variables.py:262 __call__ **
return cls._variable_v2_call(*args, **kwargs)
/Users/karthikeyan/opt/anaconda3/lib/python3.8/site-packages/tensorflow/python/ops/variables.py:244 _variable_v2_call
return previous_getter(
/Users/karthikeyan/opt/anaconda3/lib/python3.8/site-packages/tensorflow/python/ops/variables.py:67 getter
return captured_getter(captured_previous, **kwargs)
/Users/karthikeyan/opt/anaconda3/lib/python3.8/site-packages/tensorflow/python/distribute/distribute_lib.py:2857 creator
return next_creator(**kwargs)
/Users/karthikeyan/opt/anaconda3/lib/python3.8/site-packages/tensorflow/python/ops/variables.py:67 getter
return captured_getter(captured_previous, **kwargs)
/Users/karthikeyan/opt/anaconda3/lib/python3.8/site-packages/tensorflow/python/distribute/distribute_lib.py:2857 creator
return next_creator(**kwargs)
/Users/karthikeyan/opt/anaconda3/lib/python3.8/site-packages/tensorflow/python/ops/variables.py:67 getter
return captured_getter(captured_previous, **kwargs)
/Users/karthikeyan/opt/anaconda3/lib/python3.8/site-packages/tensorflow/python/distribute/distribute_lib.py:2857 creator
return next_creator(**kwargs)
/Users/karthikeyan/opt/anaconda3/lib/python3.8/site-packages/tensorflow/python/ops/variables.py:67 getter
return captured_getter(captured_previous, **kwargs)
/Users/karthikeyan/opt/anaconda3/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py:701 invalid_creator_scope
raise ValueError(
ValueError: tf.function-decorated function tried to create variables on non-first call.
I am able to find an answer. Including "tf.config.experimental_run_functions_eagerly(True)" this statement after import tensorflow resolved the issue.

InvalidArgumentError : ConcatOp : Dimensions of inputs should match

Tensorflow 1.7 when using dynamic_rnn.It runs fine at first , but at the 32th(it changes when i run the code) step , the error appears. When i used smaller batch , it seems the code can run longer , however the error still poped up .Just cannt figure out what's wrong.
from mapping import *
def my_input_fn(features, targets, batch_size=20, shuffle=True, num_epochs=None, sequece_lenth=None):
ds = tf.data.Dataset.from_tensor_slices(
(features, targets, sequece_lenth)) # warning: 2GB limit
ds = ds.batch(batch_size).repeat(num_epochs)
if shuffle:
ds = ds.shuffle(10000)
features, labels, sequence = ds.make_one_shot_iterator().get_next()
return features, labels, sequence
def lstm_cell(lstm_size=50):
return tf.contrib.rnn.BasicLSTMCell(lstm_size)
class RnnModel:
def __init__(self,
batch_size,
hidden_units,
time_steps,
num_features
):
self.batch_size = batch_size
self.hidden_units = hidden_units
stacked_lstm = tf.contrib.rnn.MultiRNNCell(
[lstm_cell(i) for i in self.hidden_units])
self.initial_state = stacked_lstm.zero_state(batch_size, tf.float32)
self.model = stacked_lstm
self.state = self.initial_state
self.time_steps = time_steps
self.num_features = num_features
def loss_mean_squre(self, outputs, targets):
pos = tf.add(outputs, tf.ones(self.batch_size))
eve = tf.div(pos, 2)
error = tf.subtract(eve,
targets)
return tf.reduce_mean(tf.square(error))
def train(self,
num_steps,
learningRate,
input_fn,
inputs,
targets,
sequenceLenth):
periods = 10
step_per_periods = int(num_steps / periods)
input, target, sequence = input_fn(inputs, targets, self.batch_size, shuffle=True, sequece_lenth=sequenceLenth)
initial_state = self.model.zero_state(self.batch_size, tf.float32)
outputs, state = tf.nn.dynamic_rnn(self.model, input, initial_state=initial_state)
loss = self.loss_mean_squre(tf.reshape(outputs, [self.time_steps, self.batch_size])[-1], target)
optimizer = tf.train.AdamOptimizer(learning_rate=learningRate)
grads_and_vars = optimizer.compute_gradients(loss, self.model.variables)
optimizer.apply_gradients(grads_and_vars)
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
for i in range(num_steps):
sess.run(init_op)
state2, current_loss= sess.run([state, loss])
if i % step_per_periods == 0:
print("period " + str(int(i / step_per_periods)) + ":" + str(current_loss))
return self.model, self.state
def processFeature(df):
df = df.drop('class', 1)
features = []
for i in range(len(df["vecs"])):
features.append(df["vecs"][i])
aa = pd.Series(features).tolist() # tramsform into list
featuresList = []
for i in features:
p1 = []
for k in i:
p1.append(list(k))
featuresList.append(p1)
return featuresList
def processTargets(df):
selected_features = df[
"class"]
processed_features = selected_features.copy()
return tf.convert_to_tensor(processed_features.astype(float).tolist())
if __name__ == '__main__':
dividNumber = 30
"""
some code here to modify my data to input
it looks like this:
inputs before use input function : [fullLenth, charactorLenth, embeddinglenth]
"""
model = RnnModel(15, [100, 80, 80, 1], time_steps=dividNumber, num_features=25)
model.train(5000, 0.0001, my_input_fn, training_examples, training_targets, sequenceLenth=trainSequenceL)
And error is under here
Traceback (most recent call last):
File "D:\Anaconda3\envs\tensorflow-cpu\lib\site-packages\tensorflow\python\client\session.py", line 1330, in _do_call
return fn(*args)
File "D:\Anaconda3\envs\tensorflow-cpu\lib\site-packages\tensorflow\python\client\session.py", line 1315, in _run_fn
options, feed_dict, fetch_list, target_list, run_metadata)
File "D:\Anaconda3\envs\tensorflow-cpu\lib\site-packages\tensorflow\python\client\session.py", line 1423, in _call_tf_sessionrun
status, run_metadata)
File "D:\Anaconda3\envs\tensorflow-cpu\lib\site-packages\tensorflow\python\framework\errors_impl.py", line 516, in __exit__
c_api.TF_GetCode(self.status.status))
tensorflow.python.framework.errors_impl.InvalidArgumentError: ConcatOp : Dimensions of inputs should match: shape[0] = [20,25] vs. shape[1] = [30,100]
[[Node: rnn/while/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/concat = ConcatV2[N=2, T=DT_FLOAT, Tidx=DT_INT32, _device="/job:localhost/replica:0/task:0/device:CPU:0"](rnn/while/TensorArrayReadV3, rnn/while/Switch_4:1, rnn/while/rnn/multi_rnn_cell/cell_3/basic_lstm_cell/Const)]]
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "D:/programming/mlwords/dnn_gragh.py", line 198, in <module>
model.train(5000, 0.0001, my_input_fn, training_examples, training_targets, sequenceLenth=trainSequenceL)
File "D:/programming/mlwords/dnn_gragh.py", line 124, in train
state2, current_loss, nowAccuracy = sess.run([state, loss, accuracy])
File "D:\Anaconda3\envs\tensorflow-cpu\lib\site-packages\tensorflow\python\client\session.py", line 908, in run
run_metadata_ptr)
File "D:\Anaconda3\envs\tensorflow-cpu\lib\site-packages\tensorflow\python\client\session.py", line 1143, in _run
feed_dict_tensor, options, run_metadata)
File "D:\Anaconda3\envs\tensorflow-cpu\lib\site-packages\tensorflow\python\client\session.py", line 1324, in _do_run
run_metadata)
File "D:\Anaconda3\envs\tensorflow-cpu\lib\site-packages\tensorflow\python\client\session.py", line 1343, in _do_call
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: ConcatOp : Dimensions of inputs should match: shape[0] = [20,25] vs. shape[1] = [30,100]
[[Node: rnn/while/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/concat = ConcatV2[N=2, T=DT_FLOAT, Tidx=DT_INT32, _device="/job:localhost/replica:0/task:0/device:CPU:0"](rnn/while/TensorArrayReadV3, rnn/while/Switch_4:1, rnn/while/rnn/multi_rnn_cell/cell_3/basic_lstm_cell/Const)]]
Caused by op 'rnn/while/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/concat', defined at:
File "D:/programming/mlwords/dnn_gragh.py", line 198, in <module>
model.train(5000, 0.0001, my_input_fn, training_examples, training_targets, sequenceLenth=trainSequenceL)
File "D:/programming/mlwords/dnn_gragh.py", line 95, in train
outputs, state = tf.nn.dynamic_rnn(self.model, input, initial_state=initial_state)#,sequence_length=sequence
File "D:\Anaconda3\envs\tensorflow-cpu\lib\site-packages\tensorflow\python\ops\rnn.py", line 627, in dynamic_rnn
dtype=dtype)
File "D:\Anaconda3\envs\tensorflow-cpu\lib\site-packages\tensorflow\python\ops\rnn.py", line 824, in _dynamic_rnn_loop
swap_memory=swap_memory)
File "D:\Anaconda3\envs\tensorflow-cpu\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 3205, in while_loop
result = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants)
File "D:\Anaconda3\envs\tensorflow-cpu\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 2943, in BuildLoop
pred, body, original_loop_vars, loop_vars, shape_invariants)
File "D:\Anaconda3\envs\tensorflow-cpu\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 2880, in _BuildLoop
body_result = body(*packed_vars_for_body)
File "D:\Anaconda3\envs\tensorflow-cpu\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 3181, in <lambda>
body = lambda i, lv: (i + 1, orig_body(*lv))
File "D:\Anaconda3\envs\tensorflow-cpu\lib\site-packages\tensorflow\python\ops\rnn.py", line 795, in _time_step
(output, new_state) = call_cell()
File "D:\Anaconda3\envs\tensorflow-cpu\lib\site-packages\tensorflow\python\ops\rnn.py", line 781, in <lambda>
call_cell = lambda: cell(input_t, state)
File "D:\Anaconda3\envs\tensorflow-cpu\lib\site-packages\tensorflow\python\ops\rnn_cell_impl.py", line 232, in __call__
return super(RNNCell, self).__call__(inputs, state)
File "D:\Anaconda3\envs\tensorflow-cpu\lib\site-packages\tensorflow\python\layers\base.py", line 714, in __call__
outputs = self.call(inputs, *args, **kwargs)
File "D:\Anaconda3\envs\tensorflow-cpu\lib\site-packages\tensorflow\python\ops\rnn_cell_impl.py", line 1283, in call
cur_inp, new_state = cell(cur_inp, cur_state)
File "D:\Anaconda3\envs\tensorflow-cpu\lib\site-packages\tensorflow\python\ops\rnn_cell_impl.py", line 339, in __call__
*args, **kwargs)
File "D:\Anaconda3\envs\tensorflow-cpu\lib\site-packages\tensorflow\python\layers\base.py", line 714, in __call__
outputs = self.call(inputs, *args, **kwargs)
File "D:\Anaconda3\envs\tensorflow-cpu\lib\site-packages\tensorflow\python\ops\rnn_cell_impl.py", line 620, in call
array_ops.concat([inputs, h], 1), self._kernel)
File "D:\Anaconda3\envs\tensorflow-cpu\lib\site-packages\tensorflow\python\ops\array_ops.py", line 1181, in concat
return gen_array_ops.concat_v2(values=values, axis=axis, name=name)
File "D:\Anaconda3\envs\tensorflow-cpu\lib\site-packages\tensorflow\python\ops\gen_array_ops.py", line 1101, in concat_v2
"ConcatV2", values=values, axis=axis, name=name)
File "D:\Anaconda3\envs\tensorflow-cpu\lib\site-packages\tensorflow\python\framework\op_def_library.py", line 787, in _apply_op_helper
op_def=op_def)
File "D:\Anaconda3\envs\tensorflow-cpu\lib\site-packages\tensorflow\python\framework\ops.py", line 3309, in create_op
op_def=op_def)
File "D:\Anaconda3\envs\tensorflow-cpu\lib\site-packages\tensorflow\python\framework\ops.py", line 1669, in __init__
self._traceback = self._graph._extract_stack() # pylint: disable=protected-access
InvalidArgumentError (see above for traceback): ConcatOp : Dimensions of inputs should match: shape[0] = [20,25] vs. shape[1] = [30,100]
[[Node: rnn/while/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/concat = ConcatV2[N=2, T=DT_FLOAT, Tidx=DT_INT32, _device="/job:localhost/replica:0/task:0/device:CPU:0"](rnn/while/TensorArrayReadV3, rnn/while/Switch_4:1, rnn/while/rnn/multi_rnn_cell/cell_3/basic_lstm_cell/Const)]]
this is my code used to check my input
def checkData(inputs, targets, sequencelence):
batch_size = 20
features, target, sequece = my_input_fn(inputs, targets, batch_size=batch_size, shuffle=True, num_epochs=None,
sequece_lenth=sequencelence)
with tf.Session() as sess:
for i in range(1000):
features1, target1, sequece1 = sess.run([features, target, sequece])
assert len(features1) == batch_size
for sentence in features1 :
assert len(sentence) == 30
for word in sentence:
assert len(word) == 25
assert len(target1) == batch_size
assert len(sequece1) == batch_size
print(target1)
print("OK")
The error is coming from LSTMCell.call call method. There we are trying to tf.concat([inputs, h], 1) meaning that we want to concatenate the next input with the current hidden state before matmul'ing with the kernel variables matrix. The error is saying that you can't do it because the batch (0th) dimensions don't match up - your input is shaped [20,25] and your hidden state is shaped [30,100].
For some reason on your 32nd iteration, or whenever you see the error, the input is not batched to 30, but only to 20. This usually happens at the end of your training data when the total number of training examples does not evenly divide your batch size. This hypothesis is also consistent with "When i used smaller batch , it seems the code can run longer" statement.
I had the same issue. When I corrected the image input size to match the input shape, it ran without errors.

Tensorflow Estimator API: Remember LSTM state from previous batch for next batch with dynamic batch_size

I know that a similar question has been already asked several times here on stackoverflow and across the Internet, but I am just not able to find a solution for the following problem: I am trying to build a stateful LSTM model in tensorflow and its Estimator API.
I tried the solution of Tensorflow, best way to save state in RNNs?, which works as long as i am using a static batch_size. Having a dynamic batch_size causes the following problem:
ValueError: initial_value must have a shape specified:
Tensor("DropoutWrapperZeroState/MultiRNNCellZeroState/DropoutWrapperZeroState/LSTMCellZeroState/zeros:0",
shape=(?, 200), dtype=float32)
Setting tf.Variable(...., validate_shape=False) just moves the problem further down the Graph:
Traceback (most recent call last):
File "model.py", line 576, in <module>
tf.app.run(main=run_experiment)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/platform/app.py", line 48, in run
_sys.exit(main(_sys.argv[:1] + flags_passthrough))
File "model.py", line 137, in run_experiment
hparams=params # HParams
File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/learn/python/learn/learn_runner.py", line 210, in run
return _execute_schedule(experiment, schedule)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/learn/python/learn/learn_runner.py", line 47, in _execute_schedule
return task()
File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/learn/python/learn/experiment.py", line 495, in train_and_evaluate
self.train(delay_secs=0)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/learn/python/learn/experiment.py", line 275, in train
hooks=self._train_monitors + extra_hooks)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/learn/python/learn/experiment.py", line 660, in _call_train
hooks=hooks)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/estimator/estimator.py", line 241, in train
loss = self._train_model(input_fn=input_fn, hooks=hooks)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/estimator/estimator.py", line 560, in _train_model
model_fn_lib.ModeKeys.TRAIN)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/estimator/estimator.py", line 545, in _call_model_fn
features=features, labels=labels, **kwargs)
File "model.py", line 218, in model_fn
output, state = get_model(features, params)
File "model.py", line 567, in get_model
model = lstm(inputs, params)
File "model.py", line 377, in lstm
output, new_states = tf.nn.dynamic_rnn(multicell, inputs=inputs, initial_state = states)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/rnn.py", line 574, in dynamic_rnn
dtype=dtype)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/rnn.py", line 737, in _dynamic_rnn_loop
swap_memory=swap_memory)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/control_flow_ops.py", line 2770, in while_loop
result = context.BuildLoop(cond, body, loop_vars, shape_invariants)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/control_flow_ops.py", line 2599, in BuildLoop
pred, body, original_loop_vars, loop_vars, shape_invariants)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/control_flow_ops.py", line 2549, in _BuildLoop
body_result = body(*packed_vars_for_body)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/rnn.py", line 722, in _time_step
(output, new_state) = call_cell()
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/rnn.py", line 708, in <lambda>
call_cell = lambda: cell(input_t, state)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/rnn_cell_impl.py", line 752, in __call__
output, new_state = self._cell(inputs, state, scope)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/rnn_cell_impl.py", line 180, in __call__
return super(RNNCell, self).__call__(inputs, state)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/layers/base.py", line 441, in __call__
outputs = self.call(inputs, *args, **kwargs)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/rnn_cell_impl.py", line 916, in call
cur_inp, new_state = cell(cur_inp, cur_state)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/rnn_cell_impl.py", line 752, in __call__
output, new_state = self._cell(inputs, state, scope)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/rnn_cell_impl.py", line 180, in __call__
return super(RNNCell, self).__call__(inputs, state)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/layers/base.py", line 441, in __call__
outputs = self.call(inputs, *args, **kwargs)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/rnn_cell_impl.py", line 542, in call
lstm_matrix = _linear([inputs, m_prev], 4 * self._num_units, bias=True)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/rnn_cell_impl.py", line 1002, in _linear
raise ValueError("linear is expecting 2D arguments: %s" % shapes)
ValueError: linear is expecting 2D arguments: [TensorShape([Dimension(None), Dimension(62)]), TensorShape(None)]
According to github issue 2838 it is NOT recommended to use non-trainable variables anyway(???), which is why I continued looking for other solutions.
Now I use placeholders and something like that (also suggested in the github thread) in my model_fn:
def rnn_placeholders(state):
"""Convert RNN state tensors to placeholders with the zero state as default."""
if isinstance(state, tf.contrib.rnn.LSTMStateTuple):
c, h = state
c = tf.placeholder_with_default(c, c.shape, c.op.name)
h = tf.placeholder_with_default(h, h.shape, h.op.name)
return tf.contrib.rnn.LSTMStateTuple(c, h)
elif isinstance(state, tf.Tensor):
h = state
h = tf.placeholder_with_default(h, h.shape, h.op.name)
return h
else:
structure = [rnn_placeholders(x) for x in state]
return tuple(structure)
state = rnn_placeholders(cell.zero_state(batch_size, tf.float32))
for tensor in flatten(state):
tf.add_to_collection('rnn_state_input', tensor)
x, new_state = tf.nn.dynamic_rnn(...)
for tensor in flatten(new_state):
tf.add_to_collection('rnn_state_output', tensor)
But unfortunately I do not know how to use the placeholder new_state to feed back its values to the placeholder state every iteration, when using tf.Estimator API etc. Since I am quite new to Tensorflow I think I have a lack of conceptual knowledge here. Might it be possible to use a custom SessionRunHook?:
class UpdateHook(tf.train.SessionRunHook):
def before_run(self, run_context):
run_args = super(UpdateHook, self).before_run(run_context)
run_args = tf.train.SessionRunArgs(new_state)
#print(run_args)
return run_args
def after_run(self, run_context, run_values):
#run_values gives the actual value of new_state.
# How to update now the state placeholder??
Is there anyone who has an idea how to solve that problem? Tips and tricks are highly appreciated!!!
Thanks a lot!
PS: If something is unclear let me know ;)
EDIT: Unfortunately I am using the new tf.data API and cannot use StateSavingRNNEstimator as Eugene suggested.
this answer might be late.
I had a similar problem some months ago.
I solved it using a customised SessionRunHook. It might not be perfect in terms of performance but you can give it a try.
class LSTMStateHook(tf.train.SessionRunHook):
def __init__(self, params):
self.init_states = None
self.current_state = np.zeros((params.rnn_layers, 2, params.batch_size, params.state_size))
def before_run(self, run_context):
run_args = tf.train.SessionRunArgs([tf.get_default_graph().get_tensor_by_name('LSTM/output_states:0')],{self.init_states:self.current_state,},)
return run_args
def after_run(self, run_context, run_values):
self.current_state = run_values[0][0] //depends on your session run arguments!!!!!!!
def begin(self):
self.init_states = tf.get_default_graph().get_tensor_by_name('LSTM/init_states:0')
In your code where you define your lstm graph you need something like this:
if self.stateful is True:
init_states = multicell.zero_state(self.batch_size, tf.float32)
init_states = tf.identity(init_states, "init_states")
l = tf.unstack(init_states, axis=0)
rnn_tuple_state = tuple([tf.nn.rnn_cell.LSTMStateTuple(l[idx][0], l[idx][1]) for idx in range(self.rnn_layers)])
else:
rnn_tuple_state = multicell.zero_state(self.batch_size, tf.float32)
# Unroll RNN
output, output_states = tf.nn.dynamic_rnn(multicell, inputs=inputs, initial_state = rnn_tuple_state)
if self.stateful is True:
output_states = tf.identity(output_states, "output_states")
return output
There is an estimator your can base your code on that uses batch_sequences_with_states. It is called StateSavingRNNEstimator. Unless you are using the new tf.contrib.data / tf.data API, it should be enough to get you started.