I am learning TensorFlow (TF), and its been just one day, so I apologize in advance if my doubt is too basic to ask.
I was studying the linear classification example on the official TF website.
The authors defined a function called input_fun to read the data. The function is as follows:
def input_fn(data_file, num_epochs, shuffle, batch_size):
"""Generate an input function for the Estimator."""
assert tf.gfile.Exists(data_file), (
'%s not found. Please make sure you have either run data_download.py or '
'set both arguments --train_data and --test_data.' % data_file)
def parse_csv(value):
print('Parsing', data_file)
columns = tf.decode_csv(value, record_defaults=_CSV_COLUMN_DEFAULTS)
features = dict(zip(_CSV_COLUMNS, columns))
labels = features.pop('income_bracket')
return features, tf.equal(labels, '>50K')
# Extract lines from input files using the Dataset API.
dataset = tf.data.TextLineDataset(data_file)
if shuffle:
dataset = dataset.shuffle(buffer_size=_NUM_EXAMPLES['train'])
dataset = dataset.map(parse_csv, num_parallel_calls=5)
# We call repeat after shuffling, rather than before, to prevent separate
# epochs from blending together.
dataset = dataset.repeat(num_epochs)
dataset = dataset.batch(batch_size)
iterator = dataset.make_one_shot_iterator()
features, labels = iterator.get_next()
return features, labels
I am not able to understand the second last line. The one-shot-iterator calls get_next() only once but shouldn't it iterate on the data multiple times (i.e. number of rows times) to extract the rows, like this example here?
So here, get_next() is basically a dequeue op. The data is in a queue, when you consume (use/call) the element called by get_next(), it is removed from the queue, and the next image/labels is moved in its place, which is dequeued next time you call it.
So currently, this function only returns the tensorflow op for dequeing elements, you can consume it in your training loop.
Related
I am following the Google Machine Learning Intensive Course. But it uses version 1.x of TensorFlow, so I was planning to change the exercises to be able to run them in TensorFlow 2.0. But I am stuck in that exercise:
https://colab.research.google.com/notebooks/mlcc/first_steps_with_tensor_flow.ipynb?utm_source=mlcc&utm_campaign=colab-external&utm_medium=referral&utm_content=firststeps-colab&hl=es#scrollTo=7UwqGbbxP53O
Specifically the code:
def my_input_fn(features, targets, batch_size=1, shuffle=True, num_epochs=None):
"""Trains a linear regression model of one feature.
Args:
features: pandas DataFrame of features
targets: pandas DataFrame of targets
batch_size: Size of batches to be passed to the model
shuffle: True or False. Whether to shuffle the data.
num_epochs: Number of epochs for which data should be repeated. None = repeat indefinitely
Returns:
Tuple of (features, labels) for next data batch
"""
# Convert pandas data into a dict of np arrays.
features = {key:np.array(value) for key,value in dict(features).items()}
# Construct a dataset, and configure batching/repeating.
ds = Dataset.from_tensor_slices((features,targets)) # warning: 2GB limit
ds = ds.batch(batch_size).repeat(num_epochs)
# Shuffle the data, if specified.
if shuffle:
ds = ds.shuffle(buffer_size=10000)
# Return the next batch of data.
features, labels = ds.make_one_shot_iterator().get_next()
return features, labels
I have replaced features, labels = ds.make_one_shot_iterator().get_next() with features, labels = tf.compat.v1.data.make_one_shot_iterator(ds).get_next()
and it seems to work but make_one_shot_iterator() is depreceated, so, how can i replace it?
Also according to https://github.com/tensorflow/tensorflow/issues/29252 , I have tried
features, labels = ds.__iter__()
next(ds.__iter__())
return features, labels
but it returns the error __iter __ () is only supported inside of tf.function or when eager execution is enabled.
I am quite inexperienced in python and follow the course as a hobbyist. Any ideas on how to solve it? Thank you.
After several tests, the python hang was a local problem.
To replace features, labels = ds.make_one_shot_iterator (). Get_next () I have tried several things:
features, labels = ds.__iter__().get_next()
iterator = ds.__iter__()
features, labels = iterator.get_next()
it = iter(ds)
features, labels = next(it)
All three cases return __iter__() is only supported inside of tf.function or when eager execution is enabled. so I tried:
features, labels = ds
return ds
And also just:
return features, labels
And both returns the same error, finally I tried:
return ds
And mysteriously it worked, I have no idea why, but it did.
1). I doubt, that you've really got what you wanted. Because if your input really needed to be multi-input - then your ds unlikely suits, you just need the list... something like this:
features = tf.compat.v1.data.make_one_shot_iterator(train_dataset).get_next()
image, label = features['image'], features['label']
2). Concerning Iterator - now it is belonging to 'tf.data' - with 'tf.data.Iterator.get_next()' method as opposed to previous tf.data.Datasetds.make_one_shot_iterator() -- 'Dependency Invertion' (D from SOLID principles of dev.) perhaps was done, perhaps to refactor....
New Iterator-entity now could be used for tf.data.Dataset.from_generator() objects feeding from fn_generator in async-mode each chunk of data yielded -- here is example of Custom-tfds.core.GeneratorBasedBuilder overwritting...
I think, the overall architecture of tf-lib was refactored a little-bit, because the input started to eat batch-by-batch itself (due to dev.'s implementations) -- & make_one_shot_iterator applied for Dataset no more needed... Even for debugging there is .as_numpy_iterator(), & make_one_shot_iterator no more considered to be needed by developers
though sometimes people use:
iterator = iter(batched_dataset)
next_element = iterator.get_next()
cannot assume where this could be needed yet
P.S. BTW, as I remember smth from Debugger, if your container is hashable or not iterable (or correct me) - you can try:
iterator = iter(dataset)
# batch_features, batch_labels = iterator.get_next()
el = iterator.get_next()
batch_features= el[:]
print(batch_features)
batch_labels= el[:-1]
print(batch_labels)
works OK
I created a dataset in TFRecord format for testing. Every entry contains 200 columns, named C1 - C199, each being a strings list, and a label column to denote the labels. The code to create the data can be found here: https://github.com/codescv/tf-dist/blob/8bb3c44f55939fc66b3727a730c57887113e899c/src/gen_data.py#L25
Then I used a linear model to train the data. The first approach looks like this:
dataset = tf.data.TFRecordDataset(data_file)
dataset = dataset.prefetch(buffer_size=batch_size*10)
dataset = dataset.map(parse_tfrecord, num_parallel_calls=5)
dataset = dataset.repeat(num_epochs)
dataset = dataset.batch(batch_size)
features, labels = dataset.make_one_shot_iterator().get_next()
logits = tf.feature_column.linear_model(features=features, feature_columns=columns, cols_to_vars=cols_to_vars)
train_op = ...
with tf.Session() as sess:
sess.run(train_op)
The full code can be found here: https://github.com/codescv/tf-dist/blob/master/src/lr_single.py
When I run the code above, I get 0.85 steps/sec (batch size being 1024).
In the second approach, I manually get batches from Dataset into python, then feed them to a placeholder, like this:
example = tf.placeholder(dtype=tf.string, shape=[None])
features = tf.parse_example(example, features=tf.feature_column.make_parse_example_spec(columns+[tf.feature_column.numeric_column('label', dtype=tf.float32, default_value=0)]))
labels = features.pop('label')
train_op = ...
dataset = tf.data.TFRecordDataset(data_file).repeat().batch(batch_size)
next_batch = dataset.make_one_shot_iterator().get_next()
with tf.Session() as sess:
data_batch = sess.run(next_batch)
sess.run(train_op, feed_dict={example: data_batch})
The full code can be found here: https://github.com/codescv/tf-dist/blob/master/src/lr_single_feed.py
When I run the code above, I get 5 steps/sec. That is 5x faster than the first approach. This is what I do not understand, because theoretically the second should be slower due to the extra serialization/deserialization of data batches.
Thanks!
There is currently (as of TensorFlow 1.9) a performance issue when using tf.data to map and batch tensors that have a large number of features with a small amount of data in each. The issue has two causes:
The dataset.map(parse_tfrecord, ...) transformation will execute O(batch_size * num_columns) small operations to create a batch. By contrast, feeding a tf.placeholder() to tf.parse_example() will execute O(1) operations to create the same batch.
Batching many tf.SparseTensor objects using dataset.batch() is much slower than directly creating the same tf.SparseTensor as the output of tf.parse_example().
Improvements to both these issues are underway, and should be available in a future version of TensorFlow. In the meantime, you can improve the performance of the tf.data-based pipeline by switching the order of the dataset.map() and dataset.batch() and rewriting the dataset.map() to work on a vector of strings, like the feeding based version:
dataset = tf.data.TFRecordDataset(data_file)
dataset = dataset.prefetch(buffer_size=batch_size*10)
dataset = dataset.repeat(num_epochs)
# Batch first to create a vector of strings as input to the map().
dataset = dataset.batch(batch_size)
def parse_tfrecord_batch(record_batch):
features = tf.parse_example(
record_batch,
features=tf.feature_column.make_parse_example_spec(
columns + [
tf.feature_column.numeric_column(
'label', dtype=tf.float32, default_value=0)]))
labels = features.pop('label')
return features, labels
# NOTE: Parallelism might not be as useful, because the individual map function now does
# more work per invocation, but you might want to experiment with this.
dataset = dataset.map(parse_tfrecord_batch)
# Add a prefetch at the end to pipeline execution.
dataset = dataset.prefetch(1)
features, labels = dataset.make_one_shot_iterator().get_next()
# ...
EDIT (2018/6/18): To answer your questions from the comments:
Why is dataset.map(parse_tfrecord, ...) O(batch_size * num_columns), not O(batch_size)? If parsing requires enumeration of the columns, why doesn't parse_example take O(num_columns)?
When you wrap TensorFlow code in a Dataset.map() (or other functional transformation) a constant number of extra operations per output are added to "return" values from the function and (in the case of tf.SparseTensor values) "convert" them to a standard format. When you directly pass the outputs of tf.parse_example() to the input of your model, these operations aren't added. While they are very small operations, executing so many of them can become a bottleneck. (Technically the parsing does take O(batch_size * num_columns) time, but the constants involved in parsing are much smaller than executing an operation.)
Why do you add a prefetch at the end of the pipeline?
When you're interested in performance, this is almost always the best thing to do, and it should improve the overall performance of your pipeline. For more information about best practices, see the performance guide for tf.data.
I'm switching my old datalayer (using Queues) to the "new" and recommended Dataset API. I'm using it for the first time, so I'm providing code examples in case I got something fundamentally wrong.
I create my Dataset from a generator (that will read a file, and provide n samples). It's a small dataset and n_iterations >> n_samples, so I simply want to read this dataset over and over again, ideally shuffled.
sample_set = tf.data.Dataset.from_generator( data_generator(filename),
(tf.uint8, tf.uint8), (tf.TensorShape([256,256,4]), tf.TensorShape([256,256,1]))
)
with datagenerator:
class data_generator:
def __init__(self, filename):
self.filename= filename
def __call__(self):
with filename.open() as f:
for idx in f: yield img[idx], label[idx]
To actually use the data, I got that I need to define an Iterator
sample = sample_set.make_one_shot_iterator().get_next()
and then we are set to read data
while True:
try: my_sample = sess.run(sample)
except tf.errors.OutOfRangeError: break # this happens after dset is read once
But all available Iterators seem to be "finite", in the way that they read a dataset only once.
Is there a simple way to make reading from the Dataset endless?
Datasets have repeat and shuffle methods.
BUF_SIZE = 100 # choose it depending on your data
sample_set = tf.data.Dataset.from_generator( data_generator(filename),
(tf.uint8, tf.uint8), (tf.TensorShape([256,256,4]),
tf.TensorShape([256,256,1]))
).repeat().shuffle(BUF_SIZE)
The Dataset.repeat() transformation will repeat a dataset endlessly if you don't pass an explicit count to it:
sample_set = tf.data.Dataset.from_generator(
data_generator(filename), (tf.uint8, tf.uint8),
(tf.TensorShape([256,256,4]), tf.TensorShape([256,256,1])))
# Repeats `sample_set` endlessly.
sample_set = sample_set.repeat()
sample = sample_set.make_one_shot_iterator().get_next()
The reinitializable Iterator will work with reinitializing on the same dataset, so this code will read the same dataset over and over again:
sample = tf.data.Iterator.from_structure(sample_set.output_types,
sample_set.output_shapes).get_next()
sample_it.make_initializer(sample_set) # create initialize op
with tf.Session(config=config) as sess:
sess.run(sample_set_init_op) # initialize in the beginning
while True:
try:
my_sample = sess.run(sample)
except tf.errors.OutOfRangeError:
sess.run(sample_set_init_op) # re-initialize on same dataset
I have a GCMLE experiment and I am trying to upgrade my input_fn to use the new tf.data functionality. I have created the following input_fn based off of this sample
def input_fn(...):
dataset = tf.data.Dataset.list_files(filenames).shuffle(num_shards) # shuffle up the list of input files
dataset = dataset.interleave(lambda filename: # mix together records from cycle_length number of shards
tf.data.TextLineDataset(filename).skip(1).map(lambda row: parse_csv(row, hparams)), cycle_length=5)
if shuffle:
dataset = dataset.shuffle(buffer_size = 10000)
dataset = dataset.repeat(num_epochs)
dataset = dataset.batch(batch_size)
iterator = dataset.make_one_shot_iterator()
features = iterator.get_next()
labels = features.pop(LABEL_COLUMN)
return features, labels
my parse_csv is the same as what I used previously, but it is not currently working. I can fix some of the issues, but I don't fully understand why I am having these issues. Here is the start of my parse_csv() function
def parse_csv(..):
columns = tf.decode_csv(rows, record_defaults=CSV_COLUMN_DEFAULTS)
raw_features = dict(zip(FIELDNAMES, columns))
words = tf.string_split(raw_features['sentences']) # splitting words
vocab_table = tf.contrib.lookup.index_table_from_file(vocabulary_file = hparams.vocab_file,
default_value = 0)
....
Right away this tf.string_split() stops working and the error is ValueError: Shape must be rank 1 but is rank 0 for 'csv_preprocessing/input_sequence_generation/StringSplit' (op: 'StringSplit') with input shapes: [], []. -- this is easily solved by packing raw_features['sentences'] into a tensor via [raw_features['sentences']] but I do not understand why this is needed with the this dataset approach? How come in the old version this worked fine? For the shapes to match up with the rest of my model, I end up needing to remove this extra dimension at the end via words = tf.squeeze(words, 0) because I add this "unecessary" dimension to the tensor.
For whatever reason, I am also getting an error that the table is not initialized tensorflow.python.framework.errors_impl.FailedPreconditionError: Table not initialized. however, this code works completely fine with my old input_fn() (see below) so I don't know why I would now need to initialize the tables? I have not figured out a solution to this part. Is there anything that I am missing to be able to use tf.contrib.lookup.index_table_from_file within my parse_csv function?
For reference, this is my old input_fn() that still does work:
def input_fn(...):
filename_queue = tf.train.string_input_producer(tf.train.match_filenames_once(filenames),
num_epochs=num_epochs, shuffle=shuffle, capacity=32)
reader = tf.TextLineReader(skip_header_lines=skip_header_lines)
_, rows = reader.read_up_to(filename_queue, num_records=batch_size)
features = parse_csv(rows, hparams)
if shuffle:
features = tf.train.shuffle_batch(
features,
batch_size,
min_after_dequeue=2 * batch_size + 1,
capacity=batch_size * 10,
num_threads=multiprocessing.cpu_count(),
enqueue_many=True,
allow_smaller_final_batch=True
)
else:
features = tf.train.batch(
features,
batch_size,
capacity=batch_size * 10,
num_threads=multiprocessing.cpu_count(),
enqueue_many=True,
allow_smaller_final_batch=True
)
labels = features.pop(LABEL_COLUMN)
return features, labels
UPDATE TF 1.7
I am revisiting this with TF 1.7 (which should have all of the TF 1.6 features mentioned in #mrry answer) but I'm still unable to replicate the behavior. For my old input_fn() I am able to gete around 13 steps/sec. The new function that I am using is as follows:
def input_fn(...):
files = tf.data.Dataset.list_files(filenames).shuffle(num_shards)
dataset = files.apply(tf.contrib.data.parallel_interleave(lambda filename: tf.data.TextLineDataset(filename).skip(1), cycle_length=num_shards))
dataset = dataset.apply(tf.contrib.data.map_and_batch(lambda row:
parse_csv_dataset(row, hparams = hparams),
batch_size = batch_size,
num_parallel_batches = multiprocessing.cpu_count()))
dataset = dataset.prefetch(1)
if shuffle:
dataset = dataset.shuffle(buffer_size = 10000)
dataset = dataset.repeat(num_epochs)
iterator = dataset.make_initializable_iterator()
features = iterator.get_next()
tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS, iterator.initializer)
labels = {key: features.pop(key) for key in LABEL_COLUMNS}
return features, labels
I believe that I am following all of the performance guildines such as 1) use prefetch 2) use map_and_batch with num_parallel_batches = cores 3) use parallel_interleave 4) applying shuffle before the repeat. The only steps I am not using is the cache suggestion, but would expect that to really only help for epochs beyond the first one as well as "applying interleave, prefetch and shuffle first." -- however I found that having prefetch and shuffle after the map_and_batch was ~10% speedup.
BUFFER ISSUE
The first performance issue that I am noticing is with my old input_fn() it took me about 13 wall clock minutes to get through 20k steps, and yet even with the buffer_size of 10,000 (which I take to mean we are waiting until we have 10,000 batches processed) I am still waiting more than 40 minutes for the buffer to get full . Does it make sense to take this long? If I know that my sharded .csv's on GCS are already randomized, is it acceptable to have this shuffle/buffer size smaller? I am trying to replicate the behavior from tf.train.shuffle_batch() -- however, it seems that at worst it should take the same 13 mins that it took to reach 10k steps in order to fill up the buffer?
STEPS/SEC
Even once the buffer has filled up, the global steps/sec tops out around 3 steps/sec (often as low as 2 steps/sec) on the same model with the previous input_fn() that is getting ~13 steps/sec.
SLOPPY INTERLEAVE
I finall tried to replace parallel_interleave() with sloppy_interleave() as this is another suggestion from #mrry. When I switched to sloppy_interleave I got 14 steps/sec! I know this means that it is not deterministic, but that should really just mean it is not deterministic from one run (or epoch) to the next? Or are there larger implications for this? Should I be concerned about any real difference between the old shuffle_batch() method and sloppy_interleave? Does the fact that this results in a 4-5x improvement suggest what the previous blocking factor was?
In TF 1.4 (which is currently the latest version of TF that works with GCMLE) you will not be able to use make_one_shot_iterator() with the lookup tables (see relevant post) you will need to use Dataset.make_initializable_iterator() and then initialize iterator.initalizer with your default TABLES_INITIALIZER (from this post). Here is what the input_fn() should look like:
def input_fn(...):
dataset = tf.data.Dataset.list_files(filenames).shuffle(num_shards)
# Define `vocab_table` outside the map function and use it in `parse_csv()`.
vocab_table = tf.contrib.lookup.index_table_from_file(
vocabulary_file=hparams.vocab_file, default_value=0)
dataset = dataset.interleave(
lambda filename: (tf.data.TextLineDataset(filename)
.skip(1)
.map(lambda row: parse_csv(row, hparams),
num_parallel_calls=multiprocessing.cpu_count())),
cycle_length=5)
if shuffle:
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.repeat(num_epochs)
dataset = dataset.batch(batch_size)
iterator = dataset.make_initializable_iterator()
features = iterator.get_next()
# add iterator.intializer to be handled by default table initializers
tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS, iterator.initializer)
labels = features.pop(LABEL_COLUMN)
return features, labels
When you use tf.data.TextLineDataset, each element is a scalar string. In this respect, it is more similar to using tf.TextLineReader.read(), rather than the batch version tf.TextLineReader.read_up_to(), which returns a vector of strings. Unfortunately the tf.string_split() op demands a vector input (although this could potentially be changed in future), so the shape manipulation is currently necessary.
Lookup tables interact a little differently with the functions in tf.data. The intuition is that you should declare the lookup table once outside the Dataset.map() call (so that it will be initialized once) and then capture it inside the parse_csv() function to call vocab_table.lookup(). Something like the following should work:
def input_fn(...):
dataset = tf.data.Dataset.list_files(filenames).shuffle(num_shards)
# Define `vocab_table` outside the map function and use it in `parse_csv()`.
vocab_table = tf.contrib.lookup.index_table_from_file(
vocabulary_file=hparams.vocab_file, default_value=0)
def parse_csv(...):
columns = tf.decode_csv(rows, record_defaults=CSV_COLUMN_DEFAULTS)
raw_features = dict(zip(FIELDNAMES, columns))
words = tf.string_split([raw_features['sentences']]) # splitting words
# Use the captured `vocab_table` here.
word_indices = vocab_table.lookup(words)
# ...
features = ...
# NOTE: Structure the output here so that you can simply return
# the dataset from `input_fn()`.
labels = features.pop(LABEL_COLUMN)
return features, labels
# NOTE: Consider using `tf.contrib.data.parallel_interleave()` to perform
# the reads in parallel.
dataset = dataset.interleave(
lambda filename: (tf.data.TextLineDataset(filename)
.skip(1)
.map(lambda row: parse_csv(row, hparams),
num_parallel_calls=multiprocessing.cpu_count())),
cycle_length=5)
if shuffle:
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.repeat(num_epochs)
dataset = dataset.batch(batch_size)
# NOTE: Add prefetching here to run the input pipeline in the background.
dataset = dataset.prefetch(1)
# NOTE: This requires TensorFlow 1.5 or later, but this change simplifies the
# initialization of the lookup table.
return dataset
I'd like to speed up my training routine that uses the Estimator API with input_fn wrote using tf.data.Dataset.
My implementation takes 2 second to prepare a batch of data and then runs training on GPU for 1 sec, and then start over preparing a batch. Which is really inefficient.
I'm looking for a way to prepare the batches asynchronously and upload them to GPU to speed up the training. Or alternatively for a way to cache datasets between invocations of input_fn (the dataset.cache() doesn't seems to be a good choice as the dataset has to be recreated on each input_fn invocation).
Here is a simplified version of my code:
def input_fn(filenames, labels, epochs):
dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
dataset = dataset.map(_read_wav, num_parallel_calls=num_map_threads)
if shuffle:
dataset = dataset.shuffle(buffer_size=len(labels))
dataset = dataset.map(_post_process, num_parallel_calls=num_map_threads)
dataset = dataset.map(lambda wav, label: ({'wav': wav}, label))
dataset = dataset.batch(128)
dataset = dataset.repeat(epochs) # to iterate over the training set forever
iterator = dataset.dataset.make_one_shot_iterator()
features, labels = iterator.get_next()
return features, labels
train_input_fn = lambda : input_fn(train_files, train_labels, None)
eval_input_fn = lambda : input_fn(eval_files, eval_labels, 1)
train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=45000)
eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn)
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
I've noticed that the Estimator API is under active development and in the master branch of tensorflow the input_fn can return datasets already, so maybe I'm asking too early and this feature isn't ready yet. But if so, please provide a ticket where this implementation can be tracked.
Using tf.data.Dataset.cache() is indeed not a good choice since it will cache the whole dataset into memory, which takes time and might overflow your memory.
The way to go is to use tf.data.Dataset.prefetch() at the end of your pipeline, which will always make sure that the data pipeline holds buffer_size elements. It is usually enough to have buffer_size = 1 at the end:
dataset = ...
dataset = dataset.batch(128)
dataset = dataset.prefetch(1) # prefetch one batch
As explained by #mrry in this answer, you can also try to increase the number of prefetched batches a bit.
Typically it is most useful to add a small prefetch buffer (with perhaps just a single element) at the very end of the pipeline, but more complex pipelines can benefit from additional prefetching, especially when the time to produce a single element can vary.
If you still have a slow input pipeline compared to your GPU computations, you need to increase the number of threads working in parallel using the num_parallel_calls argument of tf.data.Dataset.map().
A few points to add to Olivier's answer, mostly from this post:
repeat before shuffle is slightly faster, at the downside of blurred epoch boundaries. This may be significant in rare cases, but I doubt it.
shuffle before mapping - this reduces the memory foot print of your shuffle buffer size, since it only needs to buffer the filenames rather than the file contents.
it makes more sense to me to apply the third map transform to the output of get_next() rather than the dataset - not sure if that affects speed much. You could also consider putting both other map calls in the same one to reduce scheduling issues.
experiment with repeat before batching. Probably won't make a difference, but might be minor. If you repeat before shuffle as mentioned above you'll have to.
as mentioned by Olivier, use prefetch.
Code with modifications:
def input_fn(filenames, labels, epochs):
dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
dataset = dataset.repeat(epochs)
if shuffle:
dataset = dataset.shuffle(buffer_size=len(labels))
def combined_map_fn(*args):
return _post_process(_read_wav(*args))
dataset = dataset.map(combined_map_fn, num_parallel_calls=num_map_threads)
dataset = dataset.batch(128)
dataset = dataset.prefetch(1)
iterator = dataset.dataset.make_one_shot_iterator()
wavs, labels = iterator.get_next()
features = {'wav': wavs}
return features, labels