From aece6e1b3461d11e63a243e6c64460c4b0dcea00 Mon Sep 17 00:00:00 2001 From: Amir MOHAMMADI <amir.mohammadi@idiap.ch> Date: Fri, 7 Feb 2020 14:40:18 +0100 Subject: [PATCH] improve the generator and biogenerator classes --- bob/learn/tensorflow/dataset/bio.py | 114 ++++++++-------------- bob/learn/tensorflow/dataset/generator.py | 22 +++-- 2 files changed, 52 insertions(+), 84 deletions(-) diff --git a/bob/learn/tensorflow/dataset/bio.py b/bob/learn/tensorflow/dataset/bio.py index 54729066..f4ed404c 100644 --- a/bob/learn/tensorflow/dataset/bio.py +++ b/bob/learn/tensorflow/dataset/bio.py @@ -1,12 +1,11 @@ -import six -import tensorflow as tf from bob.bio.base import read_original_data +from .generator import Generator import logging logger = logging.getLogger(__name__) -class BioGenerator(object): +class BioGenerator(Generator): """A generator class which wraps bob.bio.base databases so that they can be used with tf.data.Dataset.from_generator @@ -15,44 +14,37 @@ class BioGenerator(object): biofile_to_label : :obj:`object`, optional A callable with the signature of ``label = biofile_to_label(biofile)``. By default -1 is returned as label. - biofiles : [:any:`bob.bio.base.database.BioFile`] - The list of the bio files . database : :any:`bob.bio.base.database.BioDatabase` The database that you want to use. - epoch : int - The number of epochs that have been passed so far. - keys : [str] - The keys of samples obtained by calling ``biofile.make_path("", "")`` - labels : [int] - The labels obtained by calling ``label = biofile_to_label(biofile)`` load_data : :obj:`object`, optional A callable with the signature of ``data = load_data(database, biofile)``. :any:`bob.bio.base.read_original_data` is wrapped to be used by default. - multiple_samples : :obj:`bool`, optional - If true, it assumes that the bio database's samples actually contain - multiple samples. This is useful for when you want to for example treat - video databases as image databases. - output_types : (object, object, object) - The types of the returned samples. - output_shapes : ``(tf.TensorShape, tf.TensorShape, tf.TensorShape)`` - The shapes of the returned samples. + biofiles : [:any:`bob.bio.base.database.BioFile`] + The list of the bio files . + keys : [str] + The keys of samples obtained by calling ``biofile.make_path("", "")`` + labels : [int] + The labels obtained by calling ``label = biofile_to_label(biofile)`` """ - def __init__(self, - database, - biofiles, - load_data=None, - biofile_to_label=None, - multiple_samples=False, - **kwargs): - super(BioGenerator, self).__init__(**kwargs) + def __init__( + self, + database, + biofiles, + load_data=None, + biofile_to_label=None, + multiple_samples=False, + **kwargs + ): + if load_data is None: def load_data(database, biofile): - data = read_original_data(biofile, database.original_directory, - database.original_extension) + data = read_original_data( + biofile, database.original_directory, database.original_extension + ) return data if biofile_to_label is None: @@ -61,29 +53,22 @@ class BioGenerator(object): return -1 self.database = database - self.biofiles = list(biofiles) self.load_data = load_data self.biofile_to_label = biofile_to_label - self.multiple_samples = multiple_samples - self.epoch = 0 - - # load one data to get its type and shape - data = load_data(database, biofiles[0]) - if multiple_samples: - try: - data = data[0] - except TypeError: - # if the data is a generator - data = six.next(data) - data = tf.convert_to_tensor(data) - self._output_types = (data.dtype, tf.int64, tf.string) - self._output_shapes = (data.shape, tf.TensorShape([]), - tf.TensorShape([])) - - logger.info( - "Initializing a dataset with %d files and %s types " - "and %s shapes", len(self.biofiles), self.output_types, - self.output_shapes) + + def reader(f): + label = int(self.biofile_to_label(f)) + data = self.load_data(self.database, f) + key = str(f.make_path("", "")).encode("utf-8") + if self.multiple_samples: + for d in data: + yield (d, label, key) + else: + yield (data, label, key) + + super(BioGenerator, self).__init__( + biofiles, reader, multiple_samples=multiple_samples, **kwargs + ) @property def labels(self): @@ -93,34 +78,11 @@ class BioGenerator(object): @property def keys(self): for f in self.biofiles: - yield str(f.make_path("", "")).encode('utf-8') - - @property - def output_types(self): - return self._output_types + yield str(f.make_path("", "")).encode("utf-8") @property - def output_shapes(self): - return self._output_shapes + def biofiles(self): + return self.samples def __len__(self): return len(self.biofiles) - - def __call__(self): - """A generator function that when called will return the samples. - - Yields - ------ - (data, label, key) : tuple - A tuple containing the data, label, and the key. - """ - for f, label, key in six.moves.zip(self.biofiles, self.labels, - self.keys): - data = self.load_data(self.database, f) - if self.multiple_samples: - for d in data: - yield (d, label, key) - else: - yield (data, label, key) - self.epoch += 1 - logger.info("Elapsed %d epoch(s)", self.epoch) diff --git a/bob/learn/tensorflow/dataset/generator.py b/bob/learn/tensorflow/dataset/generator.py index 32d797b7..cf2798ae 100644 --- a/bob/learn/tensorflow/dataset/generator.py +++ b/bob/learn/tensorflow/dataset/generator.py @@ -1,4 +1,3 @@ -import six import tensorflow as tf import random import logging @@ -39,14 +38,21 @@ class Generator: self.epoch = 0 self.shuffle_on_epoch_end = shuffle_on_epoch_end - # load one data to get its type and shape - dlk = self.reader(self.samples[0]) - if self.multiple_samples: + # load samples until one of them is not empty + # this data is used to get the type and shape + for sample in self.samples: try: - dlk = dlk[0] - except TypeError: - # if the data is a generator - dlk = six.next(dlk) + dlk = self.reader(sample) + if self.multiple_samples: + try: + dlk = dlk[0] + except TypeError: + # if the data is a generator + dlk = next(dlk) + except StopIteration: + continue + else: + break # Creating a "fake" dataset just to get the types and shapes dataset = tf.data.Dataset.from_tensors(dlk) self._output_types = dataset.output_types -- GitLab