Skip to content
Snippets Groups Projects
Commit c22532dd authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Fix autotune parameters and remove extra arguments

parent af806b15
No related branches found
No related tags found
1 merge request!125Add support for distributed/multi-gpu training
...@@ -67,8 +67,6 @@ def prepare_dataset( ...@@ -67,8 +67,6 @@ def prepare_dataset(
output_shape, output_shape,
shuffle=False, shuffle=False,
augment=False, augment=False,
autotune=tf.data.experimental.AUTOTUNE,
n_cpus=cpu_count(),
shuffle_buffer=int(2e4), shuffle_buffer=int(2e4),
ctx=None, ctx=None,
): ):
...@@ -89,15 +87,10 @@ def prepare_dataset( ...@@ -89,15 +87,10 @@ def prepare_dataset(
augment: bool augment: bool
autotune: int
n_cpus: int
shuffle_buffer: int shuffle_buffer: int
ctx: ``tf.distribute.InputContext`` ctx: ``tf.distribute.InputContext``
""" """
logger.debug(f"Using {n_cpus} cpus to prepare the tensorflow dataset")
ds = tf.data.Dataset.list_files( ds = tf.data.Dataset.list_files(
tf_record_paths, shuffle=shuffle if ctx is None else False tf_record_paths, shuffle=shuffle if ctx is None else False
...@@ -117,7 +110,10 @@ def prepare_dataset( ...@@ -117,7 +110,10 @@ def prepare_dataset(
ignore_order.experimental_deterministic = False ignore_order.experimental_deterministic = False
ds = ds.with_options(ignore_order) ds = ds.with_options(ignore_order)
ds = ds.map(partial(decode_tfrecords, data_shape=data_shape, num_parallel_calls=tf.data.AUTOTUNE)) ds = ds.map(
partial(decode_tfrecords, data_shape=data_shape),
num_parallel_calls=tf.data.AUTOTUNE,
)
if shuffle: if shuffle:
ds = ds.shuffle(shuffle_buffer) ds = ds.shuffle(shuffle_buffer)
...@@ -126,8 +122,8 @@ def prepare_dataset( ...@@ -126,8 +122,8 @@ def prepare_dataset(
preprocessor = get_preprocessor(output_shape) preprocessor = get_preprocessor(output_shape)
ds = ds.batch(batch_size).map( ds = ds.batch(batch_size).map(
partial(preprocess, preprocessor, augment=augment), partial(preprocess, preprocessor, augment=augment),
num_parallel_calls=autotune, num_parallel_calls=tf.data.AUTOTUNE,
) )
# Use buffered prefecting on all datasets # Use buffered prefecting on all datasets
return ds.prefetch(buffer_size=autotune) return ds.prefetch(buffer_size=tf.data.AUTOTUNE)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment