From aece6e1b3461d11e63a243e6c64460c4b0dcea00 Mon Sep 17 00:00:00 2001
From: Amir MOHAMMADI <amir.mohammadi@idiap.ch>
Date: Fri, 7 Feb 2020 14:40:18 +0100
Subject: [PATCH] improve the generator and biogenerator classes

---
 bob/learn/tensorflow/dataset/bio.py       | 114 ++++++++--------------
 bob/learn/tensorflow/dataset/generator.py |  22 +++--
 2 files changed, 52 insertions(+), 84 deletions(-)

diff --git a/bob/learn/tensorflow/dataset/bio.py b/bob/learn/tensorflow/dataset/bio.py
index 54729066..f4ed404c 100644
--- a/bob/learn/tensorflow/dataset/bio.py
+++ b/bob/learn/tensorflow/dataset/bio.py
@@ -1,12 +1,11 @@
-import six
-import tensorflow as tf
 from bob.bio.base import read_original_data
+from .generator import Generator
 import logging
 
 logger = logging.getLogger(__name__)
 
 
-class BioGenerator(object):
+class BioGenerator(Generator):
     """A generator class which wraps bob.bio.base databases so that they can
     be used with tf.data.Dataset.from_generator
 
@@ -15,44 +14,37 @@ class BioGenerator(object):
     biofile_to_label : :obj:`object`, optional
         A callable with the signature of ``label = biofile_to_label(biofile)``.
         By default -1 is returned as label.
-    biofiles : [:any:`bob.bio.base.database.BioFile`]
-        The list of the bio files .
     database : :any:`bob.bio.base.database.BioDatabase`
         The database that you want to use.
-    epoch : int
-        The number of epochs that have been passed so far.
-    keys : [str]
-        The keys of samples obtained by calling ``biofile.make_path("", "")``
-    labels : [int]
-        The labels obtained by calling ``label = biofile_to_label(biofile)``
     load_data : :obj:`object`, optional
         A callable with the signature of
         ``data = load_data(database, biofile)``.
         :any:`bob.bio.base.read_original_data` is wrapped to be used by
         default.
-    multiple_samples : :obj:`bool`, optional
-        If true, it assumes that the bio database's samples actually contain
-        multiple samples. This is useful for when you want to for example treat
-        video databases as image databases.
-    output_types : (object, object, object)
-        The types of the returned samples.
-    output_shapes : ``(tf.TensorShape, tf.TensorShape, tf.TensorShape)``
-        The shapes of the returned samples.
+    biofiles : [:any:`bob.bio.base.database.BioFile`]
+        The list of the bio files .
+    keys : [str]
+        The keys of samples obtained by calling ``biofile.make_path("", "")``
+    labels : [int]
+        The labels obtained by calling ``label = biofile_to_label(biofile)``
     """
 
-    def __init__(self,
-                 database,
-                 biofiles,
-                 load_data=None,
-                 biofile_to_label=None,
-                 multiple_samples=False,
-                 **kwargs):
-        super(BioGenerator, self).__init__(**kwargs)
+    def __init__(
+        self,
+        database,
+        biofiles,
+        load_data=None,
+        biofile_to_label=None,
+        multiple_samples=False,
+        **kwargs
+    ):
+
         if load_data is None:
 
             def load_data(database, biofile):
-                data = read_original_data(biofile, database.original_directory,
-                                          database.original_extension)
+                data = read_original_data(
+                    biofile, database.original_directory, database.original_extension
+                )
                 return data
 
         if biofile_to_label is None:
@@ -61,29 +53,22 @@ class BioGenerator(object):
                 return -1
 
         self.database = database
-        self.biofiles = list(biofiles)
         self.load_data = load_data
         self.biofile_to_label = biofile_to_label
-        self.multiple_samples = multiple_samples
-        self.epoch = 0
-
-        # load one data to get its type and shape
-        data = load_data(database, biofiles[0])
-        if multiple_samples:
-            try:
-                data = data[0]
-            except TypeError:
-                # if the data is a generator
-                data = six.next(data)
-        data = tf.convert_to_tensor(data)
-        self._output_types = (data.dtype, tf.int64, tf.string)
-        self._output_shapes = (data.shape, tf.TensorShape([]),
-                               tf.TensorShape([]))
-
-        logger.info(
-            "Initializing a dataset with %d files and %s types "
-            "and %s shapes", len(self.biofiles), self.output_types,
-            self.output_shapes)
+
+        def reader(f):
+            label = int(self.biofile_to_label(f))
+            data = self.load_data(self.database, f)
+            key = str(f.make_path("", "")).encode("utf-8")
+            if self.multiple_samples:
+                for d in data:
+                    yield (d, label, key)
+            else:
+                yield (data, label, key)
+
+        super(BioGenerator, self).__init__(
+            biofiles, reader, multiple_samples=multiple_samples, **kwargs
+        )
 
     @property
     def labels(self):
@@ -93,34 +78,11 @@ class BioGenerator(object):
     @property
     def keys(self):
         for f in self.biofiles:
-            yield str(f.make_path("", "")).encode('utf-8')
-
-    @property
-    def output_types(self):
-        return self._output_types
+            yield str(f.make_path("", "")).encode("utf-8")
 
     @property
-    def output_shapes(self):
-        return self._output_shapes
+    def biofiles(self):
+        return self.samples
 
     def __len__(self):
         return len(self.biofiles)
-
-    def __call__(self):
-        """A generator function that when called will return the samples.
-
-        Yields
-        ------
-        (data, label, key) : tuple
-            A tuple containing the data, label, and the key.
-        """
-        for f, label, key in six.moves.zip(self.biofiles, self.labels,
-                                           self.keys):
-            data = self.load_data(self.database, f)
-            if self.multiple_samples:
-                for d in data:
-                    yield (d, label, key)
-            else:
-                yield (data, label, key)
-        self.epoch += 1
-        logger.info("Elapsed %d epoch(s)", self.epoch)
diff --git a/bob/learn/tensorflow/dataset/generator.py b/bob/learn/tensorflow/dataset/generator.py
index 32d797b7..cf2798ae 100644
--- a/bob/learn/tensorflow/dataset/generator.py
+++ b/bob/learn/tensorflow/dataset/generator.py
@@ -1,4 +1,3 @@
-import six
 import tensorflow as tf
 import random
 import logging
@@ -39,14 +38,21 @@ class Generator:
         self.epoch = 0
         self.shuffle_on_epoch_end = shuffle_on_epoch_end
 
-        # load one data to get its type and shape
-        dlk = self.reader(self.samples[0])
-        if self.multiple_samples:
+        # load samples until one of them is not empty
+        # this data is used to get the type and shape
+        for sample in self.samples:
             try:
-                dlk = dlk[0]
-            except TypeError:
-                # if the data is a generator
-                dlk = six.next(dlk)
+                dlk = self.reader(sample)
+                if self.multiple_samples:
+                    try:
+                        dlk = dlk[0]
+                    except TypeError:
+                        # if the data is a generator
+                        dlk = next(dlk)
+            except StopIteration:
+                continue
+            else:
+                break
         # Creating a "fake" dataset just to get the types and shapes
         dataset = tf.data.Dataset.from_tensors(dlk)
         self._output_types = dataset.output_types
-- 
GitLab