Skip to content
Snippets Groups Projects
Commit aa07a44f authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

[generator] remove use of deprecated arguments

parent 7a498fd9
Branches
Tags
No related merge requests found
......@@ -37,7 +37,7 @@ class Generator:
reader,
multiple_samples=False,
shuffle_on_epoch_end=False,
**kwargs
**kwargs,
):
super().__init__(**kwargs)
self.reader = reader
......@@ -65,6 +65,7 @@ class Generator:
dataset = tf.data.Dataset.from_tensors(dlk)
self._output_types = tf.compat.v1.data.get_output_types(dataset)
self._output_shapes = tf.compat.v1.data.get_output_shapes(dataset)
self._element_spec = dataset.element_spec
logger.info(
"Initializing a dataset with %d %s and %s types and %s shapes",
......@@ -84,6 +85,11 @@ class Generator:
"The shapes of the returned samples"
return self._output_shapes
@property
def element_spec(self):
"The type specification of an element of the dataset"
return self._element_spec
def __call__(self):
"""A generator function that when called will yield the samples.
......@@ -106,7 +112,13 @@ class Generator:
random.shuffle(self.samples)
def dataset_using_generator(samples, reader, **kwargs):
def dataset_using_generator(
samples,
reader,
multiple_samples=False,
shuffle_on_epoch_end=False,
**kwargs,
):
"""
A generator class which wraps samples so that they can
be used with tf.data.Dataset.from_generator
......@@ -128,8 +140,14 @@ def dataset_using_generator(samples, reader, **kwargs):
A tf.data.Dataset
"""
generator = Generator(samples, reader, **kwargs)
generator = Generator(
samples,
reader,
multiple_samples=multiple_samples,
shuffle_on_epoch_end=shuffle_on_epoch_end,
**kwargs,
)
dataset = tf.data.Dataset.from_generator(
generator, generator.output_types, generator.output_shapes
generator, output_signature=generator.element_spec
)
return dataset
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment