Skip to content
Snippets Groups Projects
Commit 97e81067 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Estimate the size of the tfrecords file that will be created

parent b598af9c
No related branches found
No related tags found
1 merge request!47Many changes
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
"""Converts Bio and PAD datasets to TFRecords file formats. """Converts Bio and PAD datasets to TFRecords file formats.
Usage: Usage:
%(prog)s [-v...] [options] <config_files>... %(prog)s [-v...] [--output=PATH|--size-estimate] [options] <config_files>...
%(prog)s --help %(prog)s --help
%(prog)s --version %(prog)s --version
...@@ -22,6 +22,8 @@ Options: ...@@ -22,6 +22,8 @@ Options:
--multiple-samples If provided, it means that the data provided by --multiple-samples If provided, it means that the data provided by
reader contains multiple samples with same reader contains multiple samples with same
label and path. label and path.
--size-estimate I provided, it will print the size estimate of
tfrecords instead of writing them.
-v, --verbose Increases the output verbosity level -v, --verbose Increases the output verbosity level
The best way to use this script is to send it to the io-big queue if you are at The best way to use this script is to send it to the io-big queue if you are at
...@@ -90,7 +92,11 @@ from __future__ import print_function ...@@ -90,7 +92,11 @@ from __future__ import print_function
import random import random
# import pkg_resources so that bob imports work properly: # import pkg_resources so that bob imports work properly:
import pkg_resources import pkg_resources
import tempfile
import os
import sys
import tensorflow as tf import tensorflow as tf
from docopt import docopt
from bob.io.base import create_directories_safe from bob.io.base import create_directories_safe
from bob.extension.config import load as read_config_file from bob.extension.config import load as read_config_file
from bob.learn.tensorflow.utils.commandline import \ from bob.learn.tensorflow.utils.commandline import \
...@@ -107,7 +113,8 @@ def int64_feature(value): ...@@ -107,7 +113,8 @@ def int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def write_a_sample(writer, data, label, key, feature=None): def write_a_sample(writer, data, label, key, feature=None,
size_estimate=False):
if feature is None: if feature is None:
feature = { feature = {
'data': bytes_feature(data.tostring()), 'data': bytes_feature(data.tostring()),
...@@ -116,13 +123,44 @@ def write_a_sample(writer, data, label, key, feature=None): ...@@ -116,13 +123,44 @@ def write_a_sample(writer, data, label, key, feature=None):
} }
example = tf.train.Example(features=tf.train.Features(feature=feature)) example = tf.train.Example(features=tf.train.Features(feature=feature))
writer.write(example.SerializeToString()) example = example.SerializeToString()
if not size_estimate:
writer.write(example)
return sys.getsizeof(example)
def _bytes2human(n, format='%(value).1f %(symbol)s', symbols='customary'):
"""Convert n bytes into a human readable string based on format.
From: https://code.activestate.com/recipes/578019-bytes-to-human-human-to-
bytes-converter/
Author: Giampaolo Rodola' <g.rodola [AT] gmail [DOT] com>
License: MIT
symbols can be either "customary", "customary_ext", "iec" or "iec_ext",
see: http://goo.gl/kTQMs
"""
SYMBOLS = {
'customary': ('B', 'K', 'M', 'G', 'T', 'P', 'E', 'Z', 'Y'),
'customary_ext': ('byte', 'kilo', 'mega', 'giga', 'tera', 'peta',
'exa', 'zetta', 'iotta'),
'iec': ('Bi', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi', 'Yi'),
'iec_ext': ('byte', 'kibi', 'mebi', 'gibi', 'tebi', 'pebi', 'exbi',
'zebi', 'yobi'),
}
n = int(n)
if n < 0:
raise ValueError("n < 0")
symbols = SYMBOLS[symbols]
prefix = {}
for i, s in enumerate(symbols[1:]):
prefix[s] = 1 << (i + 1) * 10
for symbol in reversed(symbols[1:]):
if n >= prefix[symbol]:
value = float(n) / prefix[symbol]
return format % locals()
return format % dict(symbol=symbols[0], value=n)
def main(argv=None): def main(argv=None):
from docopt import docopt
import os
import sys
docs = __doc__ % {'prog': os.path.basename(sys.argv[0])} docs = __doc__ % {'prog': os.path.basename(sys.argv[0])}
version = pkg_resources.require('bob.learn.tensorflow')[0].version version = pkg_resources.require('bob.learn.tensorflow')[0].version
defaults = docopt(docs, argv=[""]) defaults = docopt(docs, argv=[""])
...@@ -138,6 +176,8 @@ def main(argv=None): ...@@ -138,6 +176,8 @@ def main(argv=None):
multiple_samples = get_from_config_or_commandline( multiple_samples = get_from_config_or_commandline(
config, 'multiple_samples', args, defaults) config, 'multiple_samples', args, defaults)
shuffle = get_from_config_or_commandline(config, 'shuffle', args, defaults) shuffle = get_from_config_or_commandline(config, 'shuffle', args, defaults)
size_estimate = get_from_config_or_commandline(
config, 'size_estimate', args, defaults)
# Sets-up logging # Sets-up logging
set_verbosity_level(logger, verbosity) set_verbosity_level(logger, verbosity)
...@@ -145,16 +185,21 @@ def main(argv=None): ...@@ -145,16 +185,21 @@ def main(argv=None):
# required arguments # required arguments
samples = config.samples samples = config.samples
reader = config.reader reader = config.reader
output = get_from_config_or_commandline(config, 'output', args, defaults, if not size_estimate:
False) output = get_from_config_or_commandline(
config, 'output', args, defaults, False)
if not output.endswith(".tfrecords"):
output += ".tfrecords"
logger.info("Writing samples to `{}'".format(output))
else:
output = tempfile.NamedTemporaryFile(suffix='.tfrecords').name
if not output.endswith(".tfrecords"): total_size = 0
output += ".tfrecords"
create_directories_safe(os.path.dirname(output)) create_directories_safe(os.path.dirname(output))
n_samples = len(samples) n_samples = len(samples)
sample_counter = 0 sample_count = 0
with tf.python_io.TFRecordWriter(output) as writer: with tf.python_io.TFRecordWriter(output) as writer:
if shuffle: if shuffle:
logger.info("Shuffling the samples before writing ...") logger.info("Shuffling the samples before writing ...")
...@@ -174,13 +219,26 @@ def main(argv=None): ...@@ -174,13 +219,26 @@ def main(argv=None):
if multiple_samples: if multiple_samples:
for sample in data: for sample in data:
write_a_sample(writer, sample, label, key) total_size += write_a_sample(
sample_counter += 1 writer, sample, label, key,
size_estimate=size_estimate)
sample_count += 1
else: else:
write_a_sample(writer, data, label, key) total_size += write_a_sample(
sample_counter += 1 writer, data, label, key, size_estimate=size_estimate)
sample_count += 1
print("Wrote {} samples into the tfrecords file.".format(sample_counter))
if not size_estimate:
print("Wrote {} samples into the tfrecords file.".format(sample_count))
else:
# delete the empty tfrecords file
try:
os.remove(output)
except Exception:
pass
print("The total size of the tfrecords file will roughly be "
"{} bytes".format(_bytes2human(total_size)))
return total_size
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -29,3 +29,8 @@ def test_verify_and_tfrecords(): ...@@ -29,3 +29,8 @@ def test_verify_and_tfrecords():
finally: finally:
shutil.rmtree(test_dir) shutil.rmtree(test_dir)
def test_tfrecords_size_estimate():
total_size = tfrecords([dummy_config, '--size-estimate'])
assert total_size == 2079170, total_size
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment