Commit aece6e1b authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

improve the generator and biogenerator classes

parent f6e1bb57
import six
import tensorflow as tf
from bob.bio.base import read_original_data
from .generator import Generator
import logging
logger = logging.getLogger(__name__)
class BioGenerator(object):
class BioGenerator(Generator):
"""A generator class which wraps bob.bio.base databases so that they can
be used with tf.data.Dataset.from_generator
......@@ -15,44 +14,37 @@ class BioGenerator(object):
biofile_to_label : :obj:`object`, optional
A callable with the signature of ``label = biofile_to_label(biofile)``.
By default -1 is returned as label.
biofiles : [:any:`bob.bio.base.database.BioFile`]
The list of the bio files .
database : :any:`bob.bio.base.database.BioDatabase`
The database that you want to use.
epoch : int
The number of epochs that have been passed so far.
keys : [str]
The keys of samples obtained by calling ``biofile.make_path("", "")``
labels : [int]
The labels obtained by calling ``label = biofile_to_label(biofile)``
load_data : :obj:`object`, optional
A callable with the signature of
``data = load_data(database, biofile)``.
:any:`bob.bio.base.read_original_data` is wrapped to be used by
default.
multiple_samples : :obj:`bool`, optional
If true, it assumes that the bio database's samples actually contain
multiple samples. This is useful for when you want to for example treat
video databases as image databases.
output_types : (object, object, object)
The types of the returned samples.
output_shapes : ``(tf.TensorShape, tf.TensorShape, tf.TensorShape)``
The shapes of the returned samples.
biofiles : [:any:`bob.bio.base.database.BioFile`]
The list of the bio files .
keys : [str]
The keys of samples obtained by calling ``biofile.make_path("", "")``
labels : [int]
The labels obtained by calling ``label = biofile_to_label(biofile)``
"""
def __init__(self,
database,
biofiles,
load_data=None,
biofile_to_label=None,
multiple_samples=False,
**kwargs):
super(BioGenerator, self).__init__(**kwargs)
def __init__(
self,
database,
biofiles,
load_data=None,
biofile_to_label=None,
multiple_samples=False,
**kwargs
):
if load_data is None:
def load_data(database, biofile):
data = read_original_data(biofile, database.original_directory,
database.original_extension)
data = read_original_data(
biofile, database.original_directory, database.original_extension
)
return data
if biofile_to_label is None:
......@@ -61,29 +53,22 @@ class BioGenerator(object):
return -1
self.database = database
self.biofiles = list(biofiles)
self.load_data = load_data
self.biofile_to_label = biofile_to_label
self.multiple_samples = multiple_samples
self.epoch = 0
# load one data to get its type and shape
data = load_data(database, biofiles[0])
if multiple_samples:
try:
data = data[0]
except TypeError:
# if the data is a generator
data = six.next(data)
data = tf.convert_to_tensor(data)
self._output_types = (data.dtype, tf.int64, tf.string)
self._output_shapes = (data.shape, tf.TensorShape([]),
tf.TensorShape([]))
logger.info(
"Initializing a dataset with %d files and %s types "
"and %s shapes", len(self.biofiles), self.output_types,
self.output_shapes)
def reader(f):
label = int(self.biofile_to_label(f))
data = self.load_data(self.database, f)
key = str(f.make_path("", "")).encode("utf-8")
if self.multiple_samples:
for d in data:
yield (d, label, key)
else:
yield (data, label, key)
super(BioGenerator, self).__init__(
biofiles, reader, multiple_samples=multiple_samples, **kwargs
)
@property
def labels(self):
......@@ -93,34 +78,11 @@ class BioGenerator(object):
@property
def keys(self):
for f in self.biofiles:
yield str(f.make_path("", "")).encode('utf-8')
@property
def output_types(self):
return self._output_types
yield str(f.make_path("", "")).encode("utf-8")
@property
def output_shapes(self):
return self._output_shapes
def biofiles(self):
return self.samples
def __len__(self):
return len(self.biofiles)
def __call__(self):
"""A generator function that when called will return the samples.
Yields
------
(data, label, key) : tuple
A tuple containing the data, label, and the key.
"""
for f, label, key in six.moves.zip(self.biofiles, self.labels,
self.keys):
data = self.load_data(self.database, f)
if self.multiple_samples:
for d in data:
yield (d, label, key)
else:
yield (data, label, key)
self.epoch += 1
logger.info("Elapsed %d epoch(s)", self.epoch)
import six
import tensorflow as tf
import random
import logging
......@@ -39,14 +38,21 @@ class Generator:
self.epoch = 0
self.shuffle_on_epoch_end = shuffle_on_epoch_end
# load one data to get its type and shape
dlk = self.reader(self.samples[0])
if self.multiple_samples:
# load samples until one of them is not empty
# this data is used to get the type and shape
for sample in self.samples:
try:
dlk = dlk[0]
except TypeError:
# if the data is a generator
dlk = six.next(dlk)
dlk = self.reader(sample)
if self.multiple_samples:
try:
dlk = dlk[0]
except TypeError:
# if the data is a generator
dlk = next(dlk)
except StopIteration:
continue
else:
break
# Creating a "fake" dataset just to get the types and shapes
dataset = tf.data.Dataset.from_tensors(dlk)
self._output_types = dataset.output_types
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment