diff --git a/bob/learn/tensorflow/dataset/bio.py b/bob/learn/tensorflow/dataset/bio.py index f4ed404c59169333c6fea4f0740e442419c8574a..4a9b66bcc91f8656367739b59bfaf580a355c381 100644 --- a/bob/learn/tensorflow/dataset/bio.py +++ b/bob/learn/tensorflow/dataset/bio.py @@ -56,15 +56,20 @@ class BioGenerator(Generator): self.load_data = load_data self.biofile_to_label = biofile_to_label - def reader(f): + def _reader(f): label = int(self.biofile_to_label(f)) data = self.load_data(self.database, f) key = str(f.make_path("", "")).encode("utf-8") - if self.multiple_samples: + return data, label, key + + if multiple_samples: + def reader(f): + data, label, key = _reader(f) for d in data: yield (d, label, key) - else: - yield (data, label, key) + else: + def reader(f): + return _reader(f) super(BioGenerator, self).__init__( biofiles, reader, multiple_samples=multiple_samples, **kwargs